ttnn.scatter
- ttnn.scatter(input: ttnn.Tensor, dim: int, index: ttnn.Tensor, src: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = None, reduce: ttnn.ScatterReductionType = None, sub_core_grids: ttnn.CoreRangeSet = None) ttnn.Tensor
-
Scatters the source tensor’s values along a given dimension according to the index tensor.
- Parameters:
-
input (ttnn.Tensor) – the input tensor to scatter values onto.
dim (int) – the dimension to scatter along.
index (ttnn.Tensor) – the tensor specifying indices where values from the source tensor must go to.
src (ttnn.Tensor) – the tensor containing the source values to be scattered onto input.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – memory configuration for the output tensor. Defaults to None.
reduce (ttnn.ScatterReductionType, optional) – reduction operation to apply when multiple values are scattered to the same location (e.g., amax, amin, sum). Currently not supported. Defaults to None.
sub_core_grids (ttnn.CoreRangeSet, optional) – specifies which cores scatter should run on. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor with scattered values.
Note
Input tensors must be interleaved and on device.
Example
# Create input, index, and source tensors input_torch = torch.randn([10, 20, 30, 20, 10], dtype=torch.float32) index_torch = torch.randint(0, 10, [10, 20, 30, 20, 5], dtype=torch.int64) source_torch = torch.randn([10, 20, 30, 20, 10], dtype=input_torch.dtype) input_ttnn = ttnn.from_torch(input_torch, dtype=ttnn.bfloat16, device=device, layout=ttnn.ROW_MAJOR_LAYOUT) index_ttnn = ttnn.from_torch(index_torch, dtype=ttnn.int32, device=device, layout=ttnn.ROW_MAJOR_LAYOUT) source_ttnn = ttnn.from_torch(source_torch, dtype=ttnn.bfloat16, device=device, layout=ttnn.ROW_MAJOR_LAYOUT) dim = -1 # Perform scatter operation output = ttnn.scatter(input_ttnn, dim, index_ttnn, source_ttnn) logger.info(f"Scatter operation result: {output}")