ttnn.scale_causal_mask_hw_dims_softmax_in_place

ttnn.scale_causal_mask_hw_dims_softmax_in_place = Operation(python_fully_qualified_name='ttnn.scale_causal_mask_hw_dims_softmax_in_place', function=<ttnn._ttnn.operations.normalization.scale_causal_mask_hw_dims_softmax_in_place_t object>, preprocess_golden_function_inputs=<function default_preprocess_golden_function_inputs>, golden_function=<function _golden_function>, postprocess_golden_function_outputs=<function default_postprocess_golden_function_outputs>, is_cpp_operation=True, is_experimental=False)

Specialized in-place operation for causal masked softmax with height-width dimension constraints.

This is an optimized version of scale_mask_softmax_in_place specifically designed for transformer attention patterns where the causal mask only affects the height and width dimensions. This operation provides better performance for specific use cases with the following constraints:

Requirements: * Input tensor should be sharded for optimal performance * Attention mask must be interleaved and have shape [1, 1, H, W] (hw_dims_only) * The mask is treated as a causal mask by design * Scale parameter is typically provided for attention scaling

The operation performs: 1. Scales the input: input_tensor *= scale (if scale is provided) 2. Applies causal mask: input_tensor += mask (with broadcasting from [1, 1, H, W]) 3. Computes softmax: input_tensor = softmax(input_tensor)

Parameters:
  • input_tensor (ttnn.Tensor) – The input tensor to process. Must be sharded for optimal performance.

  • scale (float, optional) – Scaling factor to multiply with input tensor (typically 1/√d_k for attention).

  • mask (ttnn.Tensor, optional) – Causal attention mask tensor with shape [1, 1, H, W].

Keyword Arguments:
  • program_config (SoftmaxProgramConfig, optional) – Program configuration for the operation. Defaults to SoftmaxDefaultProgramConfig().

  • compute_kernel_config (DeviceComputeKernelConfig, optional) – Compute kernel configuration for the operation.

  • numeric_stable (bool, optional) – Whether to use numerically stable softmax computation. Defaults to False.

Returns:

ttnn.Tensor – The same tensor as input with the specialized causal scale-mask-softmax operation applied in-place.

Supported dtypes and layouts:

Input Tensor (Sharded)

Dtypes

Layouts

BFLOAT16, FLOAT32, BFLOAT8_B

TILE

Mask Tensor [1, 1, H, W]

Dtypes

Layouts

BFLOAT16, BFLOAT8_B

TILE (interleaved)

Note

  • This is an experimental/specialized feature optimized for specific transformer attention patterns.

  • Input tensor must be sharded for optimal performance.

  • Mask shape is constrained to [1, 1, H, W] format.

  • Provides better performance than general scale_mask_softmax_in_place for these specific constraints.

Example

compute_grid_size = device.compute_with_storage_grid_size()
batch = compute_grid_size.x
num_cores_r = compute_grid_size.y

input_shape = (batch, num_cores_r, 384, 768)
attention_mask_t = ttnn.rand((1, 1, 384, 768), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)

input_tiled = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)

# We must shard the input tensor in ROW_MAJOR orientation
grid_coord = ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1)
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
shard_shape = [384, 768]
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec)

input_sharded = ttnn.to_memory_config(input_tiled, sharded_mem_config)

# We must also use the sharded softmax program config
program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig(
    compute_with_storage_grid_size=compute_grid_size,
    subblock_w=8,
    block_h=12,
    block_w=24,
)

tt_output_sharded = ttnn.scale_causal_mask_hw_dims_softmax_in_place(
    input_tensor=input_sharded,
    scale=1.0,
    mask=attention_mask_t,
    program_config=program_config,
)