ttnn.scale_causal_mask_hw_dims_softmax_in_place
- ttnn.scale_causal_mask_hw_dims_softmax_in_place = Operation(python_fully_qualified_name='ttnn.scale_causal_mask_hw_dims_softmax_in_place', function=<ttnn._ttnn.operations.normalization.scale_causal_mask_hw_dims_softmax_in_place_t object>, preprocess_golden_function_inputs=<function default_preprocess_golden_function_inputs>, golden_function=<function _golden_function>, postprocess_golden_function_outputs=<function default_postprocess_golden_function_outputs>, is_cpp_operation=True, is_experimental=False)
-
ttnn.scale_causal_mask_hw_dims_softmax_in_place(input_tensor: ttnn.Tensor, scale: Optional[float] = None, mask: Optional[ttnn.Tensor] = None, program_config: Optional[ttnn.SoftmaxProgramConfig] = None, compute_kernel_config: Optional[ttnn.DeviceComputeKernelConfig] = None, numeric_stable: bool = False) -> ttnn.Tensor
Specialized in-place operation for causal masked softmax with height-width dimension constraints.
This is an optimized version of scale_mask_softmax_in_place specifically designed for transformer attention patterns where the causal mask only affects the height and width dimensions. This operation provides better performance than general
ttnn.scale_mask_softmax_in_place()
for these specific constraints.The operation performs: 1. Scales the input:
input_tensor *= scale
(if scale is provided) 2. Applies causal mask:input_tensor += mask
(with broadcasting from [1, 1, H, W]) 3. Computes softmax:input_tensor = softmax(input_tensor)
- Parameters:
-
input_tensor (ttnn.Tensor) – The input tensor to process. Must be sharded for optimal performance.
scale (float, optional) – Scaling factor to multiply with input tensor (typically 1/√d_k for attention).
mask (ttnn.Tensor, optional) – Causal attention mask tensor with shape [1, 1, H, W].
- Keyword Arguments:
-
program_config (SoftmaxProgramConfig, optional) – Program configuration for the operation. Defaults to SoftmaxDefaultProgramConfig().
compute_kernel_config (DeviceComputeKernelConfig, optional) – Compute kernel configuration for the operation.
numeric_stable (bool, optional) – Whether to use numerically stable softmax computation. Defaults to False.
- Returns:
-
ttnn.Tensor – The same tensor as input with the specialized causal scale-mask-softmax operation applied in-place.
Note
The tensors support the following data types and layouts:
Input Tensor (Sharded) Dtypes - Layouts
BFLOAT16, FLOAT32, BFLOAT8_B - TILE
Mask Tensor [1, 1, H, W] Dtypes - Layouts
BFLOAT16, BFLOAT8_B - TILE (interleaved)
- Limitations:
-
This is an experimental/specialized feature optimized for specific transformer attention patterns.
Inputs must be on the device.
Input tensor must be sharded for optimal performance.
Attention mask must be interleaved and have shape [1, 1, H, W] (i.e. hw_dims_only)
The mask is treated as a causal mask by design
Scale parameter is typically provided for attention scaling
Example
compute_grid_size = device.compute_with_storage_grid_size() batch = compute_grid_size.x num_cores_r = compute_grid_size.y input_shape = (batch, num_cores_r, 384, 768) attention_mask_t = ttnn.rand((1, 1, 384, 768), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) input_tiled = ttnn.rand(input_shape, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) # We must shard the input tensor in ROW_MAJOR orientation grid_coord = ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1) shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) shard_shape = [384, 768] shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR) sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) input_sharded = ttnn.to_memory_config(input_tiled, sharded_mem_config) # We must also use the sharded softmax program config program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=compute_grid_size, subblock_w=8, block_h=12, block_w=24, ) tt_output_sharded = ttnn.scale_causal_mask_hw_dims_softmax_in_place( input_tensor=input_sharded, scale=1.0, mask=attention_mask_t, program_config=program_config, )