ttnn.sparse_matmul
- ttnn.sparse_matmul(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, sparsity: ttnn.Tensor, program_config: ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig, nnz: int | None, is_input_a_sparse: bool = False, is_input_b_sparse: bool = True, memory_config: ttnn.MemoryConfig = None, dtype: ttnn.DataType = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None, core_grid: ttnn.CoreGrid = None, output_tile: List of [int] = None, optional_output_tensor: ttnn.Tensor = None) ttnn.Tensor
-
Returns the matrix product of two tensors. Based on is_input_a_sparse, is_input_b_sparse and the sparsity tensor, some parts of the output computation is skipped.
The two input tensors must be be tiled and each have a rank of 4. The sparsity tensor must be a rank 4 tensor in row major layout.
Based on the input tensor shapes and is_input_a_sparse and is_input_b_sparse values, the output tensor shape is computed. See the supported modes table below.
- Parameters:
-
input_tensor_a (ttnn.Tensor) – the first tensor to be multiplied. Needs to be on the device.
input_tensor_b (ttnn.Tensor) – the second tensor to be multiplied. Needs to be on the device.
- Keyword Arguments:
-
sparsity (ttnn.Tensor) – the sparsity tensor containing the mask values. Needs to be on the device. The data type must be bfloat16.
program_config (ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig) – the program configuration for the matmul operation. Only this config type is supported.
mcast_in0must be set to True.nnz (int, optional) – the number of non-zero values in the sparsity tensor. If not provided, it will be inferred from the sparsity tensor at runtime.
is_input_a_sparse (bool, optional) – boolean indicating whether input_tensor_a is sparse. Defaults to False. Together with is_input_b_sparse, it determines how the sparsity tensor is interpreted. See the supported modes table below.
is_input_b_sparse (bool, optional) – boolean indicating whether input_tensor_b is sparse. Defaults to True. Together with is_input_a_sparse, it determines how the sparsity tensor is interpreted. See the supported modes table below.
memory_config (ttnn.MemoryConfig, optional) – the memory configuration of the output tensor. Defaults to None, which will result in using ttnn.DRAM_MEMORY_CONFIG.
dtype (ttnn.DataType, optional) – the data type of the output tensor. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – the compute kernel configuration for the matmul operation. Defaults to None.
core_grid (ttnn.CoreGrid, optional) – the grid on which to distribute the sharded tensor on (writes to the cores L1s). Defaults to None.
output_tile (List of [int], optional) – Specifies the output tile configuration. Defaults to None.
optional_output_tensor (ttnn.Tensor, optional) – User provided on-device output tensor where the result of matmul is to be written. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor with sparse results.
- Supported Modes
-
is_input_a_sparse
is_input_b_sparse
input_tensor_a shape
input_tensor_b shape
sparsity shape
nnz
output shape
True
True
[1, E, M, K]
[1, E, K, N]
[1, 1, 1, E]
None or 0 ≤ nnz ≤ E
[1, E, M, N]
False
True
[A, B, M, K]
[1, E, K, N]
[A, B, 1, E]
None or 0 ≤ nnz ≤ A * B * E
[A, B, 1, E, M, N]
True
False
[A, E, M, K]
[1, E, K, N]
[1, 1, A, E]
None or 0 ≤ nnz ≤ A * E
[A, E, M, N]
False
False
Invalid
–
–
–
–
Note
The input tensors support the following data types and layouts:
input_tensor_a dtype
layout
BFLOAT4_B, BFLOAT8_B, BFLOAT16, FLOAT32
TILE
input_tensor_b dtype
layout
BFLOAT4_B, BFLOAT8_B, BFLOAT16, FLOAT32
TILE
sparsity dtype
layout
BFLOAT16
ROW_MAJOR
- Memory Support:
-
The supported memory configurations for the two input tensors are program config dependent, as described below:
Supported Memory Configurations Config
Input A
Input B
ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig with (mcast_in0=True)
Interleaved (L1/DRAM)
Interleaved (L1/DRAM)
Example
# Define program configuration config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( compute_with_storage_grid_size=ttnn.CoreCoord(1, 2), in0_block_w=1, out_subblock_h=1, out_subblock_w=1, out_block_h=1, out_block_w=1, per_core_M=2, per_core_N=1, fuse_batch=False, fused_activation=None, mcast_in0=True, ) nnz = 4 # # Case 1: When `is_input_a_sparse` is True and `is_input_b_sparse` is True # tensor1 = ttnn.rand((1, 8, 64, 32), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) tensor2 = ttnn.rand((1, 8, 32, 64), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) # Create a sparsity tensor sparsity_bitmask = torch.zeros((1, 1, 1, 8), dtype=torch.bfloat16) sparsity_bitmask.view(-1)[torch.randperm(sparsity_bitmask.numel())[:nnz]] = 1.0 sparsity_bitmask = ttnn.to_device(ttnn.from_torch(sparsity_bitmask), device) output = ttnn.sparse_matmul( tensor1, tensor2, sparsity=sparsity_bitmask, nnz=nnz, is_input_a_sparse=True, is_input_b_sparse=True, program_config=config, ) logger.info(f"Output shape: {output.shape}") # Output shape: Shape([1, 8, 64, 64]) # When nnz is not provided, it will be inferred from the sparsity tensor at runtime output = ttnn.sparse_matmul( tensor1, tensor2, sparsity=sparsity_bitmask, is_input_a_sparse=True, is_input_b_sparse=True, program_config=config, ) logger.info(f"Output shape: {output.shape}") # Output shape: Shape([1, 8, 64, 64]) # # Case 2: When `is_input_a_sparse` is False and `is_input_b_sparse` is True # tensor1 = ttnn.rand((2, 16, 64, 32), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) tensor2 = ttnn.rand((1, 8, 32, 64), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) # Create a sparsity tensor sparsity_bitmask = torch.zeros((2, 16, 1, 8), dtype=torch.bfloat16) sparsity_bitmask.view(-1)[torch.randperm(sparsity_bitmask.numel())[:nnz]] = 1.0 sparsity_bitmask = ttnn.to_device(ttnn.from_torch(sparsity_bitmask), device) output = ttnn.sparse_matmul( tensor1, tensor2, sparsity=sparsity_bitmask, nnz=nnz, is_input_a_sparse=False, is_input_b_sparse=True, program_config=config, ) logger.info(f"Output shape: {output.shape}") # Output shape: Shape([2, 16, 1, 8, 64, 64]) # When nnz is not provided, it will be inferred from the sparsity tensor at runtime output = ttnn.sparse_matmul( tensor1, tensor2, sparsity=sparsity_bitmask, is_input_a_sparse=False, is_input_b_sparse=True, program_config=config, ) logger.info(f"Output shape: {output.shape}") # Output shape: Shape([2, 16, 1, 8, 64, 64]) # # Case 3: When `is_input_a_sparse` is True and `is_input_b_sparse` is False # tensor1 = ttnn.rand((4, 8, 64, 32), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) tensor2 = ttnn.rand((1, 8, 32, 64), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) # Create a sparsity tensor sparsity_bitmask = torch.zeros((1, 1, 4, 8), dtype=torch.bfloat16) sparsity_bitmask.view(-1)[torch.randperm(sparsity_bitmask.numel())[:nnz]] = 1.0 sparsity_bitmask = ttnn.to_device(ttnn.from_torch(sparsity_bitmask), device) output = ttnn.sparse_matmul( tensor1, tensor2, sparsity=sparsity_bitmask, nnz=nnz, is_input_a_sparse=True, is_input_b_sparse=False, program_config=config, ) logger.info(f"Output shape: {output.shape}") # Output shape: Shape([4, 8, 64, 64]) # When nnz is not provided, it will be inferred from the sparsity tensor at runtime output = ttnn.sparse_matmul( tensor1, tensor2, sparsity=sparsity_bitmask, is_input_a_sparse=True, is_input_b_sparse=False, program_config=config, ) logger.info(f"Output shape: {output.shape}") # Output shape: Shape([4, 8, 64, 64]) # # Case 4: When `is_input_a_sparse` is False and `is_input_b_sparse` is False # # This is invalid