ttnn.scale_mask_softmax_in_place

ttnn.scale_mask_softmax_in_place(input_tensor: ttnn.Tensor, scale: float | None, mask: ttnn.Tensor | None, *, program_config: SoftmaxProgramConfig = SoftmaxDefaultProgramConfig(, is_causal_mask: bool = False, compute_kernel_config: DeviceComputeKernelConfig | None, numeric_stable: bool = False) ttnn.Tensor

Computes a fused scale-mask-softmax operation along the last dimension in-place.

This operation modifies the input tensor directly and performs the following sequence:
  1. Optionally scales the input: input_tensor *= scale (if scale is provided)

  2. Optionally applies mask: input_tensor += mask (if mask is provided, with broadcasting)

  3. Computes softmax: input_tensor = softmax(input_tensor)

This in-place fused operation is commonly used in attention mechanisms and is memory-efficient as it reuses the input tensor for output, avoiding additional memory allocation.

Parameters:
  • input_tensor (ttnn.Tensor) – The input tensor to process. This tensor is modified in-place.

  • scale (float, optional) – Scaling factor to multiply with input tensor.

  • mask (ttnn.Tensor, optional) – Attention mask tensor to add to scaled input.

Keyword Arguments:
  • program_config (SoftmaxProgramConfig, optional) – Program configuration for the operation. Defaults to SoftmaxDefaultProgramConfig().

  • is_causal_mask (bool, optional) – Whether the mask is a causal mask. Defaults to False.

  • compute_kernel_config (DeviceComputeKernelConfig, optional) – Compute kernel configuration for the operation.

  • numeric_stable (bool, optional) – Whether to use numerically stable softmax computation. Defaults to False.

Returns:

ttnn.Tensor – The same tensor as input with the fused scale-mask-softmax operation applied in-place.

Note

The tensors support the following data types and layouts:

Input Tensor

Dtypes

Layouts

BFLOAT16, FLOAT32, BFLOAT8_B

TILE

Mask Tensor (optional)

Dtypes

Layouts

Ranks

BFLOAT16, BFLOAT8_B

TILE, ROW_MAJOR

2, 3, 4

The output tensor will be in TILE layout and have the same dtype as the input_tensor

Limitations:
  • All tensors must be on-device.

  • For unsharded ROW_MAJOR masks: intermediate dimensions (except last two) must be 1; last dimension must equal TILE_WIDTH; width must align to input tensor.

  • For sharded inputs: mask must be TILE layout with identical padded shape to input.

  • Internal block size constraints may restrict in-place operation for very large width tensors.

Example

# Setup input tensor and mask
input_shape = (1, 1, 32, 32)

attention_mask_t = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)

# Apply in-place scale mask softmax
tt_output = ttnn.scale_mask_softmax_in_place(
    input_tensor=input_tensor,
    scale=1.0,
    mask=attention_mask_t,
)
logger.info(f"Scale Mask Softmax In Place result: {tt_output}")

compute_grid_size = device.compute_with_storage_grid_size()
fuse_head = 2
batch = compute_grid_size.x
num_cores_r = compute_grid_size.y

input_shape = (batch, num_cores_r, fuse_head * 384, 768)

attention_mask_t = ttnn.rand((batch, 1, 384, 768), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)

input_tensor = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)

# Shard the input tensor
grid_coord = ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1)
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
shard_shape = [fuse_head * 384, 768]
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec)

input_sharded = ttnn.to_memory_config(input_tensor, sharded_mem_config)

# Create sharded program config
program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig(
    compute_with_storage_grid_size=compute_grid_size,
    subblock_w=8,
    block_h=12 * fuse_head,
    block_w=24,
)

tt_output = ttnn.scale_mask_softmax_in_place(
    input_tensor=input_sharded,
    scale=1.0,
    mask=attention_mask_t,
    program_config=program_config,
)
logger.info(f"Scale Mask Softmax In Place result: {tt_output}")