ttnn.sparse_matmul

ttnn.sparse_matmul(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, sparsity: ttnn.Tensor, program_config: ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig, nnz: int | None, is_input_a_sparse: bool = False, is_input_b_sparse: bool = True, memory_config: ttnn.MemoryConfig = None, dtype: ttnn.DataType = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None, core_grid: ttnn.CoreGrid = None, output_tile: List of [int] = None, optional_output_tensor: ttnn.Tensor = None) ttnn.Tensor

Returns the matrix product of two tensors. Based on is_input_a_sparse, is_input_b_sparse and the sparsity tensor, some parts of the output computation is skipped.

The two input tensors must be be tiled and each have a rank of 4. The sparsity tensor must be a rank 4 tensor in row major layout.

Based on the input tensor shapes and is_input_a_sparse and is_input_b_sparse values, the output tensor shape is computed. See the supported modes table below.

Parameters:
  • input_tensor_a (ttnn.Tensor) – the first tensor to be multiplied. Needs to be on the device.

  • input_tensor_b (ttnn.Tensor) – the second tensor to be multiplied. Needs to be on the device.

Keyword Arguments:
  • sparsity (ttnn.Tensor) – the sparsity tensor containing the mask values. Needs to be on the device. The data type must be bfloat16.

  • program_config (ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig) – the program configuration for the matmul operation. Only this config type is supported. mcast_in0 must be set to True.

  • nnz (int, optional) – the number of non-zero values in the sparsity tensor. If not provided, it will be inferred from the sparsity tensor at runtime.

  • is_input_a_sparse (bool, optional) – boolean indicating whether input_tensor_a is sparse. Defaults to False. Together with is_input_b_sparse, it determines how the sparsity tensor is interpreted. See the supported modes table below.

  • is_input_b_sparse (bool, optional) – boolean indicating whether input_tensor_b is sparse. Defaults to True. Together with is_input_a_sparse, it determines how the sparsity tensor is interpreted. See the supported modes table below.

  • memory_config (ttnn.MemoryConfig, optional) – the memory configuration of the output tensor. Defaults to None, which will result in using ttnn.DRAM_MEMORY_CONFIG.

  • dtype (ttnn.DataType, optional) – the data type of the output tensor. Defaults to None.

  • compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – the compute kernel configuration for the matmul operation. Defaults to None.

  • core_grid (ttnn.CoreGrid, optional) – the grid on which to distribute the sharded tensor on (writes to the cores L1s). Defaults to None.

  • output_tile (List of [int], optional) – Specifies the output tile configuration. Defaults to None.

  • optional_output_tensor (ttnn.Tensor, optional) – User provided on-device output tensor where the result of matmul is to be written. Defaults to None.

Returns:

ttnn.Tensor – the output tensor with sparse results.

Supported Modes

is_input_a_sparse

is_input_b_sparse

input_tensor_a shape

input_tensor_b shape

sparsity shape

nnz

output shape

True

True

[1, E, M, K]

[1, E, K, N]

[1, 1, 1, E]

None or 0 ≤ nnz ≤ E

[1, E, M, N]

False

True

[A, B, M, K]

[1, E, K, N]

[A, B, 1, E]

None or 0 ≤ nnz ≤ A * B * E

[A, B, 1, E, M, N]

True

False

[A, E, M, K]

[1, E, K, N]

[1, 1, A, E]

None or 0 ≤ nnz ≤ A * E

[A, E, M, N]

False

False

Invalid

Note

The input tensors support the following data types and layouts:

input_tensor_a

dtype

layout

BFLOAT4_B, BFLOAT8_B, BFLOAT16, FLOAT32

TILE

input_tensor_b

dtype

layout

BFLOAT4_B, BFLOAT8_B, BFLOAT16, FLOAT32

TILE

sparsity

dtype

layout

BFLOAT16

ROW_MAJOR

Memory Support:

The supported memory configurations for the two input tensors are program config dependent, as described below:

Supported Memory Configurations

Config

Input A

Input B

ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig with (mcast_in0=True)

Interleaved (L1/DRAM)

Interleaved (L1/DRAM)

Example

# Define program configuration
config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
    compute_with_storage_grid_size=ttnn.CoreCoord(1, 2),
    in0_block_w=1,
    out_subblock_h=1,
    out_subblock_w=1,
    out_block_h=1,
    out_block_w=1,
    per_core_M=2,
    per_core_N=1,
    fuse_batch=False,
    fused_activation=None,
    mcast_in0=True,
)
nnz = 4

#
# Case 1: When `is_input_a_sparse` is True and `is_input_b_sparse` is True
#
tensor1 = ttnn.rand((1, 8, 64, 32), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
tensor2 = ttnn.rand((1, 8, 32, 64), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# Create a sparsity tensor
sparsity_bitmask = torch.zeros((1, 1, 1, 8), dtype=torch.bfloat16)
sparsity_bitmask.view(-1)[torch.randperm(sparsity_bitmask.numel())[:nnz]] = 1.0
sparsity_bitmask = ttnn.to_device(ttnn.from_torch(sparsity_bitmask), device)
output = ttnn.sparse_matmul(
    tensor1,
    tensor2,
    sparsity=sparsity_bitmask,
    nnz=nnz,
    is_input_a_sparse=True,
    is_input_b_sparse=True,
    program_config=config,
)
logger.info(f"Output shape: {output.shape}")  # Output shape: Shape([1, 8, 64, 64])

# When nnz is not provided, it will be inferred from the sparsity tensor at runtime
output = ttnn.sparse_matmul(
    tensor1,
    tensor2,
    sparsity=sparsity_bitmask,
    is_input_a_sparse=True,
    is_input_b_sparse=True,
    program_config=config,
)
logger.info(f"Output shape: {output.shape}")  # Output shape: Shape([1, 8, 64, 64])

#
# Case 2: When `is_input_a_sparse` is False and `is_input_b_sparse` is True
#
tensor1 = ttnn.rand((2, 16, 64, 32), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
tensor2 = ttnn.rand((1, 8, 32, 64), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# Create a sparsity tensor
sparsity_bitmask = torch.zeros((2, 16, 1, 8), dtype=torch.bfloat16)
sparsity_bitmask.view(-1)[torch.randperm(sparsity_bitmask.numel())[:nnz]] = 1.0
sparsity_bitmask = ttnn.to_device(ttnn.from_torch(sparsity_bitmask), device)
output = ttnn.sparse_matmul(
    tensor1,
    tensor2,
    sparsity=sparsity_bitmask,
    nnz=nnz,
    is_input_a_sparse=False,
    is_input_b_sparse=True,
    program_config=config,
)
logger.info(f"Output shape: {output.shape}")  # Output shape: Shape([2, 16, 1, 8, 64, 64])

# When nnz is not provided, it will be inferred from the sparsity tensor at runtime
output = ttnn.sparse_matmul(
    tensor1,
    tensor2,
    sparsity=sparsity_bitmask,
    is_input_a_sparse=False,
    is_input_b_sparse=True,
    program_config=config,
)
logger.info(f"Output shape: {output.shape}")  # Output shape: Shape([2, 16, 1, 8, 64, 64])

#
# Case 3: When `is_input_a_sparse` is True and `is_input_b_sparse` is False
#
tensor1 = ttnn.rand((4, 8, 64, 32), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
tensor2 = ttnn.rand((1, 8, 32, 64), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# Create a sparsity tensor
sparsity_bitmask = torch.zeros((1, 1, 4, 8), dtype=torch.bfloat16)
sparsity_bitmask.view(-1)[torch.randperm(sparsity_bitmask.numel())[:nnz]] = 1.0
sparsity_bitmask = ttnn.to_device(ttnn.from_torch(sparsity_bitmask), device)
output = ttnn.sparse_matmul(
    tensor1,
    tensor2,
    sparsity=sparsity_bitmask,
    nnz=nnz,
    is_input_a_sparse=True,
    is_input_b_sparse=False,
    program_config=config,
)
logger.info(f"Output shape: {output.shape}")  # Output shape: Shape([4, 8, 64, 64])
# When nnz is not provided, it will be inferred from the sparsity tensor at runtime
output = ttnn.sparse_matmul(
    tensor1,
    tensor2,
    sparsity=sparsity_bitmask,
    is_input_a_sparse=True,
    is_input_b_sparse=False,
    program_config=config,
)
logger.info(f"Output shape: {output.shape}")  # Output shape: Shape([4, 8, 64, 64])

#
# Case 4: When `is_input_a_sparse` is False and `is_input_b_sparse` is False
#
# This is invalid