ttnn.experimental.gather
- ttnn.experimental.gather = Operation(python_fully_qualified_name='ttnn.experimental.gather', function=<ttnn._ttnn.operations.experimental.gather_t object>, preprocess_golden_function_inputs=<function default_preprocess_golden_function_inputs>, golden_function=None, postprocess_golden_function_outputs=<function default_postprocess_golden_function_outputs>, is_cpp_operation=True, is_experimental=False)
-
The gather operation extracts values from the input tensor based on indices provided in the index tensor along a specified dimension.
The input tensor and the index tensor must have the same number of dimensions. For all dimensions except the specified one (dim), the size of the index tensor must not exceed the size of the input tensor. The output tensor will have the same shape as the index tensor. Note that the input and index tensors do not broadcast against each other.
- Parameters:
-
input (ttnn.Tensor) – The source tensor from which values are gathered.
dim (int) – The dimension along which values are gathered.
index (ttnn.Tensor) – A tensor containing the indices of elements to gather, with the same number of dimensions as the input tensor. The index tensor must be of type uint16 or uint32.
- Keyword Arguments:
-
sparse_grad (bool, optional) – If True, the gradient computation will be sparse. Defaults to False.
memory_config (ttnn.MemoryConfig, optional) – Specifies the memory configuration for the output tensor. Defaults to None.
out (ttnn.Tensor, optional) – A preallocated tensor to store the gathered values. Defaults to None.
- Additional Information:
-
Currently, the sparse_grad argument is not supported.
Example:
import ttnn import torch # Create a 2D input tensor input_tensor = torch.tensor([[10, 20, 30, 40], [50, 60, 70, 80]]) # Create a 2D index tensor index_tensor = torch.tensor([[3, 0], [2, 1]]) # Convert tensors to ttnn format input_tensor_ttnn = ttnn.from_torch(input_tensor, ttnn.bfloat16, layout=ttnn.Layout.TILE, device=device) index_tensor_ttnn = ttnn.from_torch(index_tensor, ttnn.uint16, layout=ttnn.Layout.TILE, device=device) # Perform the gather operation along dimension 1 gathered_tensor = ttnn.experimental.gather(input_tensor_ttnn, dim=1, index=index_tensor_ttnn) # Result: gathered_tensor = [[40, 10], [70, 60]]