ttnn.matmul
- ttnn.matmul(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, transpose_a: bool | None = False, transpose_b: bool | None = False, memory_config: ttnn.MemoryConfig | None = None, dtype: ttnn.DataType = None, program_config: ttnn.MatmulProgramConfig = None, activation: str | None = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None, core_grid: ttnn.CoreGrid = None, output_tile: List of [int] | None = None, optional_output_tensor: ttnn.Tensor | None = None) ttnn.Tensor
-
Returns the matrix product of two tensors.
The input tensors need to be tiled. Therefore, the input tensors have to be at least 2-dimensional.
If the input tensors have more than two dimensions, the additional, front, dimensions may be used for batched matrix multiply. These front dimensions may also be referred to as batch dimensions. E.g. a tensor with dimensions (a x b x c x d) has batch dimensions a and b. The following are the allowed possibilities for batch dimensions. Examples below show concrete operations and tensor sizes.
If all batch dimensions are of size 1, then there is no batched operation.
If both inputs have batch dimensions that are not all of size 1, then the batch dimensions of both inputs should be the same. If the dimensions are not the same then, although there may be combinations that may work, in most cases various errors will be reported.
If the first input has batch dimensions that are not all of size 1, and the second input has no batch dimensions or has batch dimensions all of size 1, then the second input is broadcasted to align appropriately with the first input.
Matrix multiplication will not work if the first input has batch dimensions that are all of size 1 and the second input has batch dimensions that are not all of size 1.
Note: Dimensions of size 0 are not supported.
Note: In general, the number of dimensions between the two inputs should match. There may be cases where they don’t. In that case, if the inputs are not valid based on the above criteria, the error messages may be unexpected and refer to non-obvious issues.
-
Note: There are various combinations of dimensions possible. The behaviour is the same as PyTorch, except for two exceptions. These exceptions are for the following scenarios related to batch dimensions:
The two batch dimensions are swapped. E.g. the first input has (j x 1) and the second input has (1 x j) or the first input has (1 x j) and the second input has (j x 1)
When a batch dimension is implicitly extended, the two patch dimensions are swapped. E.g. (j x 1) and (j) which is treated as (j x 1) and (1 x j)
-
In order to leverage sharded matmul implementations we can shard both input_tensor_a and input_tensor_b. The sharding strategy used will be according to the sharding strategy on the respective tensor. A sharded 1D matmul can be either HEIGHT or WIDTH sharded, 2D matmuls can be BLOCK sharded.
Note: the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if
input_tensor_a
is a (j x 1 x n_size x m_size) tensor andinput_tensor_b
is a (k_size x m_size x p) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. The operation will return a (j x k_size x n_size x p) tensor. Note: there are various additional constraints related to specific program configs chosen. Please look at the error messages carefully and fix problems appropriately.
Note: If optional output tensor is specified, then dtype and memory config need to be checked as follows: - if they are default then they should be set based on optional output tensor - if the are not default then they should be compared and if there is a difference an error is reported
- Parameters:
-
input_tensor_a (ttnn.Tensor) – the first tensor to be multiplied. Needs to be on the device.
input_tensor_b (ttnn.Tensor) – the second tensor to be multiplied. Needs to be on the device.
- Keyword Arguments:
-
transpose_a (bool, optional) – Whether to transpose input_tensor_a. Defaults to False.
transpose_b (bool, optional) – Whether to transpose input_tensor_b. Defaults to False.
memory_config (ttnn.MemoryConfig, optional) – the memory configuration of the output tensor. Defaults to None, which will result in using ttnn.DRAM_MEMORY_CONFIG.
dtype (ttnn.DataType) – the data type of the output tensor. Defaults to None.
program_config (ttnn.MatmulProgramConfig) – the program configuration for the matmul operation. Defaults to None.
activation (str, optional) – the activation function to be applied. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig) – the compute kernel configuration for the matmul operation. Defaults to None.
core_grid (ttnn.CoreGrid) – the grid on which to distribute the sharded tensor on (writes to the cores L1s). Defaults to None.
output_tile (List of [int], optional) – Specifies the output tile configuration. Defaults to None.
optional_output_tensor (ttnn.Tensor, optional) – User provided on-device output tensor where the result of matmul is to be written. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.
Example
>>> # matrix x matrix - no batch dimensions >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 64), dtype=torch.bfloat16)), device) >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [64, 64] >>> # extended matrix x extended matrix - all batch dimensions of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 64, 32), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 32, 64), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [1, 1, 64, 64] >>> # extended matrix x extended matrix - all batch dimensions of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 64, 32), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 32, 64), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [1, 1, 64, 64] >>> # batched matrix x broadcasted matrix - first input has batch dimensions not of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 64), dtype=torch.bfloat16)), device) >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [10, 64, 64] >>> # batched matrix x batched matrix - both inputs have batch dimensions >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 32, 128), dtype=torch.bfloat16)), device) >>> output = tensor1 @ tensor2 # alternative to ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [10, 64, 128] >>> # batched matrix x broadcasted extended matrix - first input has batch dimensions not of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 32, 128), dtype=torch.bfloat16)), device) >>> output = tensor1 @ tensor2 >>> print(output.shape) [1, 10, 64, 128]