ttnn.scale_mask_softmax

ttnn.scale_mask_softmax(input_tensor: ttnn.Tensor, scale: float | None, mask: ttnn.Tensor | None, *, memory_config: ttnn.MemoryConfig | None, is_causal_mask: bool = False, compute_kernel_config: DeviceComputeKernelConfig | None, numeric_stable: bool = True) ttnn.Tensor

Computes a fused scale-mask-softmax operation along the last dimension of the input tensor.

This operation performs the following sequence:
  1. Optionally scales the input: scaled = input_tensor * scale (if scale is provided)

  2. Optionally applies mask: masked = scaled + mask (if mask is provided, with broadcasting)

  3. Computes softmax: output = softmax(masked)

This fused operation is commonly used in attention mechanisms where scaling and masking are applied before the softmax operation for efficiency.

Parameters:
  • input_tensor (ttnn.Tensor) – The input tensor to process.

  • scale (float, optional) – Scaling factor to multiply with input tensor.

  • mask (ttnn.Tensor, optional) – Attention mask tensor to add to scaled input.

Keyword Arguments:
  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the output tensor. If not provided, inherits from input tensor.

  • is_causal_mask (bool, optional) – Whether the mask is a causal mask. Defaults to False.

  • compute_kernel_config (DeviceComputeKernelConfig, optional) – Compute kernel configuration for the operation.

  • numeric_stable (bool, optional) – Whether to use numerically stable softmax computation. Defaults to True.

Returns:

ttnn.Tensor – Output tensor with the fused scale-mask-softmax operation applied.

Note

The tensors support the following data types and layouts:

Input Tensor

Dtypes

Layouts

BFLOAT16, FLOAT32, BFLOAT8_B

TILE

Mask Tensor (optional)

Dtypes

Layouts

BFLOAT16, BFLOAT8_B

TILE, ROW_MAJOR

The output tensor will be in TILE layout and have the same dtype as the input_tensor

Limitations:
  • All tensors must be on-device.

  • For ROW_MAJOR masks: intermediate dimensions (except last two) must be 1; last dimension must equal TILE_WIDTH; width must align to input tensor’s tile width.

Example

# Setup input tensor and mask
compute_grid_size = device.compute_with_storage_grid_size()
fuse_head = 2
batch = compute_grid_size.x
num_cores_r = compute_grid_size.y

input_shape = (batch, num_cores_r, fuse_head * 384, 768)
attention_mask_t = ttnn.rand((batch, 1, 1, 768), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)

# Apply scale mask softmax
tt_output = ttnn.scale_mask_softmax(
    input_tensor=input_tensor,
    scale=1.0,
    mask=attention_mask_t,
)
logger.info(f"Scale Mask Softmax result: {tt_output}")