ttnn.rms_norm_pre_all_gather

ttnn.rms_norm_pre_all_gather() ttnn.Tensor

This operation is used in conjunction with ttnn.rms_norm_post_all_gather() to compute RMS norm on a distributed setup, where RMS norm is defined as:

\[\text{RMS_norm}(x, \gamma, \beta, \epsilon) = \frac{x}{\sqrt{\epsilon+\frac{1}{N}\sum_{i=1}^{N}x^{2}}} \cdot \gamma + \beta\]
Where:
  • \(\gamma\) and \(\beta\) are optional scale and shift parameters

  • \(\epsilon\) is a small constant

See Root Mean Square Layer Normalization for more details.

This operation computes \(\sum_{}^{}x\) and \(\sum_{}^{}x^2\) over the last dimension. Its output should be combined across devices with ttnn.all_gather(), then followed by ttnn.rms_norm_post_all_gather() to compute the RMS norm.

Parameters:

input_tensor (ttnn.Tensor) – the input tensor.

Keyword Arguments:
  • dtype (ttnn.DataType, optional) – the data type of the output tensor. Defaults to BFLOAT16.

  • residual_input_tensor (ttnn.Tensor, optional) – the residual input tensor. Defaults to None.

  • compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – the compute kernel configuration. Defaults to None.

  • program_config (ttnn.ProgramConfig, optional) – the program configuration. Defaults to None.

  • memory_config (ttnn.MemoryConfig, optional) – the memory configuration. Defaults to None.

  • use_2d_core_grid (bool, optional) – the 2D core grid. Defaults to None.

Returns:

ttnn.Tensor – the output tensor.

Note

Supported data types and layouts by tensor:

input_tensor

dtype

layout

BFLOAT16, FLOAT32, BFLOAT8_B

TILE

residual_input_tensor

dtype

layout

BFLOAT16, FLOAT32, BFLOAT8_B

TILE

Output stats tensor will in TILE layout and have dtype of BFLOAT16.

Limitations:
  • All tensors must be on-device.

  • Unsharded inputs must be interleaved

  • Sharded inputs cannot be height-sharded, padded height must equal TILE_HEIGHT (32). If residual_input_tensor is provided, it must match input’s padded shape and sharding.

Example

# Create input tensor
input_tensor = ttnn.rand([1, 1, 32, 32], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.TILE_LAYOUT, device=device)
weight = ttnn.rand([32], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)

# Apply pre-all-gather RMS normalization
stats = ttnn.rms_norm_pre_all_gather(input_tensor)
logger.info(f"RMS Norm Pre All Gather result: {stats}")

# On a distributed setup, an all gather would go here to collect the stats from all the devices
# See documentation for ttnn.all_gather for example usage of all_gather

# Now apply the post-all-gather RMS normalization
output = ttnn.rms_norm_post_all_gather(input_tensor, stats, weight=weight)
logger.info(f"RMS Norm Post All Gather result: {output}")

# For reference, this two-step process is equivalent to the following
# output = ttnn.rms_norm(input_tensor, weight=weight)