ttnn.argmax

ttnn.argmax(input_tensor: ttnn.Tensor, *, dim: int | None = None, memory_config: ttnn.MemoryConfig | None = None, output_tensor: ttnn.Tensor | None = None, queue_id: int | None = 0) List of ttnn.Tensor

Returns the indices of the maximum value of elements in the input tensor If no dim is provided, it will return the indices of maximum value of all elements in given input

Currenly this op only support dimension-specific reduction on last dimension.

Input tensor must have BFLOAT16 data type and ROW_MAJOR layout.

Output tensor will have UINT32 data type.

Equivalent pytorch code:

return torch.argmax(input_tensor, dim=dim)
Parameters:

input_tensor (ttnn.Tensor) – the input tensor.

Keyword Arguments:
  • dim (int, optional) – dimension to reduce. Defaults to None.

  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

  • output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.

  • queue_id (int, optional) – command queue id. Defaults to 0.

Returns:

List of ttnn.Tensor – the output tensor.