ttnn.layer_norm_post_all_gather
- ttnn.layer_norm_post_all_gather() ttnn.Tensor
-
This operation is used in conjunction with
ttnn.layer_norm_pre_all_gather()to compute layer norm on a distributed setup, where layer norm is defined as:\[\text{layer_norm}(x, \gamma, \beta, \epsilon) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta\]- Where:
-
\(\mu\) is the mean of the input tensor. This is computed over the last dimension of the input tensor (W).
\(\sigma^2\) is the variance of the input tensor. This is computed over the last dimension of the input tensor (W) and is biased.
\(\gamma\) and \(\beta\) are the learnable scale and shift parameters, respectively
\(\epsilon\) is a small constant
See Layer Normalization for more details.
Performs the second part of a distributed layernorm operation, using the gathered statistics to compute the mean and variance, and finally normalizing the input. The input
statstensor should be computed by first usingttnn.layer_norm_pre_all_gather()and then usingttnn.all_gather()to gather the statistics across all devices.- Parameters:
-
input_tensor (ttnn.Tensor) – the input tensor.
stats (ttnn.Tensor) – the stats tensor.
- Keyword Arguments:
-
epsilon (float, optional) – the epsilon value. Defaults to 1e-12.
weight (ttnn.Tensor, optional) – the weight tensor. Defaults to None.
bias (ttnn.Tensor, optional) – the bias tensor. Defaults to None.
memory_config (ttnn.MemoryConfig, optional) – the memory configuration. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – the compute kernel configuration. Defaults to None.
program_config (ttnn.ProgramConfig, optional) – the program configuration. Defaults to None.
dtype (ttnn.DataType, optional) – the data type of the output tensor. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.
Note
Supported data types and layouts:
input_tensor dtype
layout
BFLOAT16, BFLOAT8_B
TILE
stats dtype
layout
BFLOAT16
TILE
weight (gamma) and bias (beta) dtype
layout
BFLOAT16
ROW_MAJOR
Output tensor will be in TILE layout and have the same dtype as the
input_tensor- Limitations:
-
Input tensors must be on-device and rank 4.
The last padded dim of
statsmust be a multiple of TILE_WIDTH.The first three padded dims of
statsmust matchinput_tensor.If
weight(gamma) is provided,bias(beta) must also be provided with matching layouts with their last padded dim matching TILE_WIDTH.Sharded runs: inputs cannot be height-sharded, padded height must equal TILE_HEIGHT (32), and
statsmust be sharded with num_cores=1 and expected tile columns per device.
Example
# Create input tensor input_tensor = ttnn.rand([1, 1, 32, 32], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.TILE_LAYOUT, device=device) # Apply pre-all-gather layer normalization stats = ttnn.layer_norm_pre_all_gather(input_tensor) logger.info(f"Layer Norm Pre All Gather result: {stats}") # On a distributed setup, all gather would go here to collect the stats from all devices # See documentation for ttnn.all_gather for example usage of all_gather # Now apply the post-all-gather layer normalization output = ttnn.layer_norm_post_all_gather(input_tensor, stats) logger.info(f"Layer Norm Post All Gather result: {output}") # For reference, this two-step process is equivalent to the following # output = ttnn.layer_norm(input_tensor)