ttnn.scale_mask_softmax_in_place
- ttnn.scale_mask_softmax_in_place(input_tensor: ttnn.Tensor, scale: float | None, mask: ttnn.Tensor | None, *, program_config: SoftmaxProgramConfig = SoftmaxDefaultProgramConfig(, is_causal_mask: bool = False, compute_kernel_config: DeviceComputeKernelConfig | None, numeric_stable: bool = False) ttnn.Tensor
-
Computes a fused scale-mask-softmax operation along the last dimension in-place.
- This operation modifies the input tensor directly and performs the following sequence:
-
Optionally scales the input:
input_tensor *= scale(if scale is provided)Optionally applies mask:
input_tensor += mask(if mask is provided, with broadcasting)Computes softmax:
input_tensor = softmax(input_tensor)
This in-place fused operation is commonly used in attention mechanisms and is memory-efficient as it reuses the input tensor for output, avoiding additional memory allocation.
- Parameters:
-
input_tensor (ttnn.Tensor) – The input tensor to process. This tensor is modified in-place.
scale (float, optional) – Scaling factor to multiply with input tensor.
mask (ttnn.Tensor, optional) – Attention mask tensor to add to scaled input.
- Keyword Arguments:
-
program_config (SoftmaxProgramConfig, optional) – Program configuration for the operation. Defaults to SoftmaxDefaultProgramConfig().
is_causal_mask (bool, optional) – Whether the mask is a causal mask. Defaults to False.
compute_kernel_config (DeviceComputeKernelConfig, optional) – Compute kernel configuration for the operation.
numeric_stable (bool, optional) – Whether to use numerically stable softmax computation. Defaults to False.
- Returns:
-
ttnn.Tensor – The same tensor as input with the fused scale-mask-softmax operation applied in-place.
Note
The tensors support the following data types and layouts:
Input Tensor Dtypes
Layouts
BFLOAT16, FLOAT32, BFLOAT8_B
TILE
Mask Tensor (optional) Dtypes
Layouts
Ranks
BFLOAT16, BFLOAT8_B
TILE, ROW_MAJOR
2, 3, 4
The output tensor will be in TILE layout and have the same dtype as the
input_tensor- Limitations:
-
All tensors must be on-device.
For unsharded ROW_MAJOR masks: intermediate dimensions (except last two) must be 1; last dimension must equal TILE_WIDTH; width must align to input tensor.
For sharded inputs: mask must be TILE layout with identical padded shape to input.
Internal block size constraints may restrict in-place operation for very large width tensors.
Example
# Setup input tensor and mask input_shape = (1, 1, 32, 32) attention_mask_t = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) # Apply in-place scale mask softmax tt_output = ttnn.scale_mask_softmax_in_place( input_tensor=input_tensor, scale=1.0, mask=attention_mask_t, ) logger.info(f"Scale Mask Softmax In Place result: {tt_output}") compute_grid_size = device.compute_with_storage_grid_size() fuse_head = 2 batch = compute_grid_size.x num_cores_r = compute_grid_size.y input_shape = (batch, num_cores_r, fuse_head * 384, 768) attention_mask_t = ttnn.rand((batch, 1, 384, 768), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) # Shard the input tensor grid_coord = ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1) shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) shard_shape = [fuse_head * 384, 768] shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR) sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) input_sharded = ttnn.to_memory_config(input_tensor, sharded_mem_config) # Create sharded program config program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=compute_grid_size, subblock_w=8, block_h=12 * fuse_head, block_w=24, ) tt_output = ttnn.scale_mask_softmax_in_place( input_tensor=input_sharded, scale=1.0, mask=attention_mask_t, program_config=program_config, ) logger.info(f"Scale Mask Softmax In Place result: {tt_output}")