ttnn.repeat_interleave
- ttnn.repeat_interleave(input_tensor: ttnn.Tensor, repeats: number, dim: number, *, memory_config: ttnn.MemoryConfig | None = None) ttnn.Tensor
-
Repeats elements of a
tensor
in the givendim
.- Parameters:
-
input_tensor (ttnn.Tensor) – the input tensor.
repeats (number) – he number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
dim (number) – the dimension to expand with the repetitions.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.
Example:
- torch_input_tensor =
-
torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output = ttnn.repeat_interleave(input_tensor, repeats, dim=dim) >>> a = ttnn.from_torch(torch.rand(1, 1, 32, 32, dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device) >>> b = ttnn.repeat_interleave(a, 2, dim=0) >>> print(a.shape, b.shape) ttnn.Shape([1, 1, 32, 32]) ttnn.Shape([2, 1, 32, 32])