ttnn.scale_causal_mask_hw_dims_softmax_in_place

ttnn.scale_causal_mask_hw_dims_softmax_in_place = Operation(python_fully_qualified_name='ttnn.scale_causal_mask_hw_dims_softmax_in_place', function=<ttnn._ttnn.operations.normalization.scale_causal_mask_hw_dims_softmax_in_place_t object>, preprocess_golden_function_inputs=<function default_preprocess_golden_function_inputs>, golden_function=<function _golden_function>, postprocess_golden_function_outputs=<function default_postprocess_golden_function_outputs>, is_cpp_operation=True, is_experimental=False)

ttnn.scale_causal_mask_hw_dims_softmax_in_place(input_tensor: ttnn.Tensor, scale: Optional[float] = None, mask: Optional[ttnn.Tensor] = None, program_config: Optional[ttnn.SoftmaxProgramConfig] = None, compute_kernel_config: Optional[ttnn.DeviceComputeKernelConfig] = None, numeric_stable: bool = False) -> ttnn.Tensor

Specialized in-place operation for causal masked softmax with height-width dimension constraints.

This is an optimized version of scale_mask_softmax_in_place specifically designed for transformer attention patterns where the causal mask only affects the height and width dimensions. This operation provides better performance than general ttnn.scale_mask_softmax_in_place() for these specific constraints.

The operation performs: 1. Scales the input: input_tensor *= scale (if scale is provided) 2. Applies causal mask: input_tensor += mask (with broadcasting from [1, 1, H, W]) 3. Computes softmax: input_tensor = softmax(input_tensor)

Parameters:
  • input_tensor (ttnn.Tensor) – The input tensor to process. Must be sharded for optimal performance.

  • scale (float, optional) – Scaling factor to multiply with input tensor (typically 1/√d_k for attention).

  • mask (ttnn.Tensor, optional) – Causal attention mask tensor with shape [1, 1, H, W].

Keyword Arguments:
  • program_config (SoftmaxProgramConfig, optional) – Program configuration for the operation. Defaults to SoftmaxDefaultProgramConfig().

  • compute_kernel_config (DeviceComputeKernelConfig, optional) – Compute kernel configuration for the operation.

  • numeric_stable (bool, optional) – Whether to use numerically stable softmax computation. Defaults to False.

Returns:

ttnn.Tensor – The same tensor as input with the specialized causal scale-mask-softmax operation applied in-place.

Note

The tensors support the following data types and layouts:

Input Tensor (Sharded)

Dtypes - Layouts

BFLOAT16, FLOAT32, BFLOAT8_B - TILE

Mask Tensor [1, 1, H, W]

Dtypes - Layouts

BFLOAT16, BFLOAT8_B - TILE (interleaved)

Limitations:
  • This is an experimental/specialized feature optimized for specific transformer attention patterns.

  • Inputs must be on the device.

  • Input tensor must be sharded for optimal performance.

  • Attention mask must be interleaved and have shape [1, 1, H, W] (i.e. hw_dims_only)

  • The mask is treated as a causal mask by design

  • Scale parameter is typically provided for attention scaling

Example

compute_grid_size = device.compute_with_storage_grid_size()
batch = compute_grid_size.x
num_cores_r = compute_grid_size.y

input_shape = (batch, num_cores_r, 384, 768)
attention_mask_t = ttnn.rand((1, 1, 384, 768), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)

input_tiled = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)

# We must shard the input tensor in ROW_MAJOR orientation
grid_coord = ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1)
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
shard_shape = [384, 768]
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec)

input_sharded = ttnn.to_memory_config(input_tiled, sharded_mem_config)

# We must also use the sharded softmax program config
program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig(
    compute_with_storage_grid_size=compute_grid_size,
    subblock_w=8,
    block_h=12,
    block_w=24,
)

tt_output_sharded = ttnn.scale_causal_mask_hw_dims_softmax_in_place(
    input_tensor=input_sharded,
    scale=1.0,
    mask=attention_mask_t,
    program_config=program_config,
)