ttnn.rms_norm_post_all_gather
- ttnn.rms_norm_post_all_gather() ttnn.Tensor
-
This operation is used in conjunction with
ttnn.rms_norm_pre_all_gather()to compute RMS norm on a distributed setup, where RMS norm is defined as:\[\text{RMS_norm}(x, \gamma, \beta, \epsilon) = \frac{x}{\sqrt{\epsilon+\frac{1}{N}\sum_{i=1}^{N}x^{2}}} \cdot \gamma + \beta\]- Where:
-
\(\gamma\) and \(\beta\) are optional scale and shift parameters
\(\epsilon\) is a small constant
See Root Mean Square Layer Normalization for more details.
Performs the second part of a distributed RMSNorm operation, using the gathered statistics to compute the mean and variance, and finally normalizing the input. The input
statstensor should be computed by first usingttnn.rms_norm_pre_all_gather()and then usingttnn.all_gather()to gather the statistics across all devices.- Parameters:
-
input_tensor (ttnn.Tensor) – the input tensor.
stats (ttnn.Tensor) – the stats tensor.
- Keyword Arguments:
-
epsilon (float, optional) – the epsilon value. Defaults to 1e-12.
weight (ttnn.Tensor, optional) – the weight tensor. Defaults to None.
bias (ttnn.Tensor, optional) – the bias tensor. Defaults to None.
memory_config (ttnn.MemoryConfig, optional) – the memory configuration. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – the compute kernel configuration. Defaults to None.
program_config (ttnn.ProgramConfig, optional) – the program configuration. Defaults to None.
dtype (ttnn.DataType, optional) – the data type of the output tensor. Defaults to None.
use_2d_core_grid (bool, optional) – the 2D core grid. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.
Note
Supported data types and layouts:
input_tensor dtype
layout
BFLOAT16, BFLOAT8_B
TILE
stats dtype
layout
BFLOAT16
TILE
weight (gamma) and bias (beta) dtype
layout
BFLOAT16, FLOAT32
TILE, ROW_MAJOR
Output tensor will be in TILE layout and have the same dtype as the
input_tensor- Limitations:
-
All tensors must be on-device.
The last padded dim of
statsmust be a multiple of TILE_WIDTH, and its first three padded dims must matchinput_tensor.If
weight(gamma) is provided,bias(beta) must also be provided. Gamma and beta must have the same layout. If this is ROW_MAJOR, last padded dim must be TILE_WIDTH.Sharded runs: inputs cannot be height-sharded; padded height must equal TILE_HEIGHT (32). When sharded,
statsmust be sharded across one core.
Example
# Create input tensor input_tensor = ttnn.rand([1, 1, 32, 32], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.TILE_LAYOUT, device=device) weight = ttnn.rand([32], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) # Apply pre-all-gather RMS normalization stats = ttnn.rms_norm_pre_all_gather(input_tensor) logger.info(f"RMS Norm Pre All Gather result: {stats}") # On a distributed setup, an all gather would go here to collect the stats from all the devices # See documentation for ttnn.all_gather for example usage of all_gather # Now apply the post-all-gather RMS normalization output = ttnn.rms_norm_post_all_gather(input_tensor, stats, weight=weight) logger.info(f"RMS Norm Post All Gather result: {output}") # For reference, this two-step process is equivalent to the following # output = ttnn.rms_norm(input_tensor, weight=weight)