ttnn.argmax
- ttnn.argmax(input_tensor: ttnn.Tensor, *, dim: int | None = None, memory_config: ttnn.MemoryConfig | None = None, output_tensor: ttnn.Tensor | None = None, queue_id: int | None = 0) List of ttnn.Tensor
-
Returns the indices of the maximum value of elements in the
input
tensor If nodim
is provided, it will return the indices of maximum value of all elements in giveninput
Currenly this op only support dimension-specific reduction on last dimension.
Input tensor must have BFLOAT16 data type and ROW_MAJOR layout.
Output tensor will have UINT32 data type.
Equivalent pytorch code:
return torch.argmax(input_tensor, dim=dim)
- Parameters:
-
input_tensor (ttnn.Tensor) – the input tensor.
- Keyword Arguments:
-
dim (int, optional) – dimension to reduce. Defaults to None.
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.
output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.
queue_id (int, optional) – command queue id. Defaults to 0.
- Returns:
-
List of ttnn.Tensor – the output tensor.