ttnn.sort
- ttnn.sort(input_tensor: ttnn.Tensor) None
-
Sorts the elements of the input tensor along the specified dimension in ascending order by default. If no dimension is specified, the last dimension of the input tensor is used.
- Parameters:
-
input_tensor (ttnn.Tensor) – The input tensor to be sorted.
- Keyword Arguments:
-
dim (int, optional) – The dimension along which to sort. Defaults to -1 (last dimension).
descending (bool, optional) – If True, sorts in descending order. Defaults to False.
stable (bool, optional) – If True, ensures the original order of equal elements is preserved. Defaults to False.
memory_config (ttnn.MemoryConfig, optional) – Specifies the memory configuration for the output tensor. Defaults to None.
out (tuple of ttnn.Tensor, optional) – Preallocated output tensors for the sorted values and indices. Defaults to None. The index tensor must be of type uint16 or uint32.
- Additional info:
-
For now the stable argument is not supported.
Note
Supported dtypes and layout for input tensor values:
Dtypes
Layouts
BFLOAT16
TILE
UINT16
TILE
Supported dtypes and layout for index tensor values:
Dtypes
Layouts
UINT16, UINT32
TILE
- Memory Support:
-
Interleaved: DRAM and L1
Example
# Create a tensor input_tensor = torch.Tensor([[3, 1, 2], [3, 1, 2]]) # Convert tensor to ttnn format input_tensor_ttnn = ttnn.from_torch(input_tensor, dtype=ttnn.bfloat16, layout=ttnn.Layout.TILE, device=device) # Sort the tensor in ascending order sorted_tensor, indices = ttnn.sort(input_tensor_ttnn) # Sort the tensor in descending order sorted_tensor_desc, indices_desc = ttnn.sort(input_tensor_ttnn, descending=True) # Sort along a specific dimension input_tensor_2d = torch.Tensor([[3, 1, 2], [6, 5, 4]]) input_tensor_2d_ttnn = ttnn.from_torch(input_tensor_2d, dtype=ttnn.bfloat16, layout=ttnn.Layout.TILE, device=device) sorted_tensor_dim, indices_dim = ttnn.sort(input_tensor_2d_ttnn, dim=1)