ttnn.argmax

ttnn.argmax(input_tensor: ttnn.Tensor, *, dim: int = None, keepdim: bool = False, memory_config: ttnn.MemoryConfig = None, output_tensor: ttnn.Tensor = None) ttnn.Tensor

Returns the indices of the maximum value of elements in the input_tensor. If no dim is provided, it will return the indices of maximum value of all elements in given input_tensor.

Parameters:

input_tensor (ttnn.Tensor) – the input tensor.

Keyword Arguments:
  • dim (int, optional) – dimension to reduce. Defaults to None.

  • keepdim (bool, optional) – whether to keep the reduced dimension. Defaults to False.

  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

  • output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.

Returns:

ttnn.Tensor – Output tensor containing the indices of the maximum value.

Note

The input tensor supports the following data types and layouts:

Input Tensor

dtype

layout

FLOAT32

ROW_MAJOR, TILE

BFLOAT16

ROW_MAJOR, TILE

UINT32

ROW_MAJOR

INT32

ROW_MAJOR

UINT16

ROW_MAJOR

The output tensor will be of the following data type and layout:

Output Tensor

dtype

layout

UINT32

ROW_MAJOR

Memory Support:
  • Interleaved: DRAM and L1

Limitations:
  • All input tensors must be on-device.

  • Currently this op only supports dimension-specific reduction on the last dimension (i.e. dim = -1).

  • Sharding is not supported for this operation

  • Reduction over all elements (when dim=None) is not supported with the TILE input tensor layout

  • The (optional) preallocated output tensor must have ROW_MAJOR layout

Example

# Create tensor
tensor_input = ttnn.rand([1, 1, 32, 64], device=device, layout=ttnn.ROW_MAJOR_LAYOUT)

# Last dim reduction yields shape of [1, 1, 32, 1]
output_onedim = ttnn.argmax(tensor_input, dim=-1, keepdim=True)
logger.info(f"Argmax onedim result: {output_onedim}")

# All dim reduction yields shape of []
output_alldim = ttnn.argmax(tensor_input)
logger.info(f"Argmax alldim result: {output_alldim}")