ttnn.scale_mask_softmax
- ttnn.scale_mask_softmax(input_tensor: ttnn.Tensor, scale: float | None, mask: ttnn.Tensor | None, *, memory_config: ttnn.MemoryConfig | None, is_causal_mask: bool = False, compute_kernel_config: DeviceComputeKernelConfig | None, numeric_stable: bool = True) ttnn.Tensor
-
Computes a fused scale-mask-softmax operation along the last dimension of the input tensor.
- This operation performs the following sequence:
-
Optionally scales the input:
scaled = input_tensor * scale(if scale is provided)Optionally applies mask:
masked = scaled + mask(if mask is provided, with broadcasting)Computes softmax:
output = softmax(masked)
This fused operation is commonly used in attention mechanisms where scaling and masking are applied before the softmax operation for efficiency.
- Parameters:
-
input_tensor (ttnn.Tensor) – The input tensor to process.
scale (float, optional) – Scaling factor to multiply with input tensor.
mask (ttnn.Tensor, optional) – Attention mask tensor to add to scaled input.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the output tensor. If not provided, inherits from input tensor.
is_causal_mask (bool, optional) – Whether the mask is a causal mask. Defaults to False.
compute_kernel_config (DeviceComputeKernelConfig, optional) – Compute kernel configuration for the operation.
numeric_stable (bool, optional) – Whether to use numerically stable softmax computation. Defaults to True.
- Returns:
-
ttnn.Tensor – Output tensor with the fused scale-mask-softmax operation applied.
Note
The tensors support the following data types and layouts:
Input Tensor Dtypes
Layouts
BFLOAT16, FLOAT32, BFLOAT8_B
TILE
Mask Tensor (optional) Dtypes
Layouts
BFLOAT16, BFLOAT8_B
TILE, ROW_MAJOR
The output tensor will be in TILE layout and have the same dtype as the
input_tensor- Limitations:
-
All tensors must be on-device.
For ROW_MAJOR masks: intermediate dimensions (except last two) must be 1; last dimension must equal TILE_WIDTH; width must align to input tensor’s tile width.
Example
# Setup input tensor and mask compute_grid_size = device.compute_with_storage_grid_size() fuse_head = 2 batch = compute_grid_size.x num_cores_r = compute_grid_size.y input_shape = (batch, num_cores_r, fuse_head * 384, 768) attention_mask_t = ttnn.rand((batch, 1, 1, 768), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) # Apply scale mask softmax tt_output = ttnn.scale_mask_softmax( input_tensor=input_tensor, scale=1.0, mask=attention_mask_t, ) logger.info(f"Scale Mask Softmax result: {tt_output}")