ttnn.matmul
- ttnn.matmul = Operation(python_fully_qualified_name='ttnn.matmul', function=<ttnn._ttnn.operations.matmul.matmul_t object>, preprocess_golden_function_inputs=<function default_preprocess_golden_function_inputs>, golden_function=<function _golden_function>, postprocess_golden_function_outputs=<function default_postprocess_golden_function_outputs>, is_cpp_operation=True, is_experimental=False)
-
Returns the matrix product of two tensors.
The input tensors need to be tiled and at least 1-dimensional.
If both input tensors are 1-dimensional, then the operation is a dot product.
If first input tensor is 1-dimensional and the other input tensor is at least 2-dimensional, the batched vector-matrix multiplication is performed.
If the first input tensor is at least 2-dimensional and the second input tensor is 1-dimensional, the batched matrix-vector multiplication is performed.
If both input tensors are at least 2-dimensional, then a batched matrix multiply is performed.
The following are the allowed possibilities for batch dimensions. Examples below show concrete operations and tensor sizes.
If all batch dimensions are of size 1, then there is no batched operation.
If both inputs have batch dimensions that are not all of size 1, then the batch dimensions of both inputs should be the same. If the dimensions are not the same then, although there may be combinations that may work, in most cases various errors will be reported.
If the first input has batch dimensions that are not all of size 1, and the second input has no batch dimensions or has batch dimensions all of size 1, then the second input is broadcasted to align appropriately with the first input.
Matrix multiplication will not work if the first input has batch dimensions that are all of size 1 and the second input has batch dimensions that are not all of size 1.
Note: In general, the number of dimensions between the two inputs should match. There may be cases where they don’t. In that case, if the inputs are not valid based on the above criteria, the error messages may be unexpected and refer to non-obvious issues.
-
Note: There are various combinations of dimensions possible. The behaviour is the same as PyTorch, except for two exceptions. These exceptions are for the following scenarios related to batch dimensions:
The two batch dimensions are swapped. E.g. the first input has (j x 1) and the second input has (1 x j) or the first input has (1 x j) and the second input has (j x 1)
When a batch dimension is implicitly extended, the two patch dimensions are swapped. E.g. (j x 1) and (j) which is treated as (j x 1) and (1 x j)
-
In order to leverage sharded matmul implementations we can shard both input_tensor_a and input_tensor_b. The sharding strategy used will be according to the sharding strategy on the respective tensor. A sharded 1D matmul can be either HEIGHT or WIDTH sharded, 2D matmuls can be BLOCK sharded.
Note: the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if
input_tensor_a
is a (j x 1 x n_size x m_size) tensor andinput_tensor_b
is a (k_size x m_size x p) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. The operation will return a (j x k_size x n_size x p) tensor. Note: there are various additional constraints related to specific program configs chosen. Please look at the error messages carefully and fix problems appropriately.
Note: If optional output tensor is specified, then dtype and memory config need to be checked as follows: - if they are default then they should be set based on optional output tensor - if the are not default then they should be compared and if there is a difference an error is reported
- Parameters:
-
input_tensor_a (ttnn.Tensor) – the first tensor to be multiplied. Needs to be on the device.
input_tensor_b (ttnn.Tensor) – the second tensor to be multiplied. Needs to be on the device.
- Keyword Arguments:
-
transpose_a (bool, optional) – Whether to transpose input_tensor_a. Defaults to False.
transpose_b (bool, optional) – Whether to transpose input_tensor_b. Defaults to False.
memory_config (ttnn.MemoryConfig, optional) – the memory configuration of the output tensor. Defaults to None, which will result in using ttnn.DRAM_MEMORY_CONFIG.
dtype (ttnn.DataType) – the data type of the output tensor. Defaults to None.
program_config (ttnn.MatmulProgramConfig) – the program configuration for the matmul operation. Defaults to None.
activation (str, optional) – the activation function to be applied. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig) – the compute kernel configuration for the matmul operation. Defaults to None.
core_grid (ttnn.CoreGrid) – the grid on which to distribute the sharded tensor on (writes to the cores L1s). Defaults to None.
output_tile (List of [int], optional) – Specifies the output tile configuration. Defaults to None.
optional_output_tensor (ttnn.Tensor, optional) – User provided on-device output tensor where the result of matmul is to be written. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.
Example
>>> # matrix x matrix - no batch dimensions >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 64), dtype=torch.bfloat16)), device) >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [64, 64] >>> # extended matrix x extended matrix - all batch dimensions of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 64, 32), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 32, 64), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [1, 1, 64, 64] >>> # extended matrix x extended matrix - all batch dimensions of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 64, 32), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 32, 64), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [1, 1, 64, 64] >>> # batched matrix x broadcasted matrix - first input has batch dimensions not of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 64), dtype=torch.bfloat16)), device) >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [10, 64, 64] >>> # batched matrix x batched matrix - both inputs have batch dimensions >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 32, 128), dtype=torch.bfloat16)), device) >>> output = tensor1 @ tensor2 # alternative to ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [10, 64, 128] >>> # batched matrix x broadcasted extended matrix - first input has batch dimensions not of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 32, 128), dtype=torch.bfloat16)), device) >>> output = tensor1 @ tensor2 >>> print(output.shape) [1, 10, 64, 128]