ttnn.sort

ttnn.sort(input_tensor: ttnn.Tensor) None

Sorts the elements of the input tensor along the specified dimension in ascending order by default. If no dimension is specified, the last dimension of the input tensor is used.

Parameters:

input_tensor (ttnn.Tensor) – The input tensor to be sorted.

Keyword Arguments:
  • dim (int, optional) – The dimension along which to sort. Defaults to -1 (last dimension).

  • descending (bool, optional) – If True, sorts in descending order. Defaults to False.

  • stable (bool, optional) – If True, ensures the original order of equal elements is preserved. Defaults to False.

  • memory_config (ttnn.MemoryConfig, optional) – Specifies the memory configuration for the output tensor. Defaults to None.

  • out (tuple of ttnn.Tensor, optional) – Preallocated output tensors for the sorted values and indices. Defaults to None. The index tensor must be of type uint16 or uint32.

Additional info:
  • For now the stable argument is not supported.

Note

Supported dtypes and layout for input tensor values:

Dtypes

Layouts

BFLOAT16

TILE

UINT16

TILE

Supported dtypes and layout for index tensor values:

Dtypes

Layouts

UINT16, UINT32

TILE

Memory Support:
  • Interleaved: DRAM and L1

Example

# Create a tensor
input_tensor = torch.Tensor([[3, 1, 2], [3, 1, 2]])

# Convert tensor to ttnn format
input_tensor_ttnn = ttnn.from_torch(input_tensor, dtype=ttnn.bfloat16, layout=ttnn.Layout.TILE, device=device)

# Sort the tensor in ascending order
sorted_tensor, indices = ttnn.sort(input_tensor_ttnn)

# Sort the tensor in descending order
sorted_tensor_desc, indices_desc = ttnn.sort(input_tensor_ttnn, descending=True)

# Sort along a specific dimension
input_tensor_2d = torch.Tensor([[3, 1, 2], [6, 5, 4]])
input_tensor_2d_ttnn = ttnn.from_torch(input_tensor_2d, dtype=ttnn.bfloat16, layout=ttnn.Layout.TILE, device=device)
sorted_tensor_dim, indices_dim = ttnn.sort(input_tensor_2d_ttnn, dim=1)