ttnn.layer_norm_post_all_gather

ttnn.layer_norm_post_all_gather() ttnn.Tensor

This operation is used in conjunction with ttnn.layer_norm_pre_all_gather() to compute layer norm on a distributed setup, where layer norm is defined as:

\[\text{layer_norm}(x, \gamma, \beta, \epsilon) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta\]
Where:
  • \(\mu\) is the mean of the input tensor. This is computed over the last dimension of the input tensor (W).

  • \(\sigma^2\) is the variance of the input tensor. This is computed over the last dimension of the input tensor (W) and is biased.

  • \(\gamma\) and \(\beta\) are the learnable scale and shift parameters, respectively

  • \(\epsilon\) is a small constant

See Layer Normalization for more details.

Performs the second part of a distributed layernorm operation, using the gathered statistics to compute the mean and variance, and finally normalizing the input. The input stats tensor should be computed by first using ttnn.layer_norm_pre_all_gather() and then using ttnn.all_gather() to gather the statistics across all devices.

Parameters:
  • input_tensor (ttnn.Tensor) – the input tensor.

  • stats (ttnn.Tensor) – the stats tensor.

Keyword Arguments:
  • epsilon (float, optional) – the epsilon value. Defaults to 1e-12.

  • weight (ttnn.Tensor, optional) – the weight tensor. Defaults to None.

  • bias (ttnn.Tensor, optional) – the bias tensor. Defaults to None.

  • memory_config (ttnn.MemoryConfig, optional) – the memory configuration. Defaults to None.

  • compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – the compute kernel configuration. Defaults to None.

  • program_config (ttnn.ProgramConfig, optional) – the program configuration. Defaults to None.

  • dtype (ttnn.DataType, optional) – the data type of the output tensor. Defaults to None.

Returns:

ttnn.Tensor – the output tensor.

Note

Supported data types and layouts:

input_tensor

dtype

layout

BFLOAT16, BFLOAT8_B

TILE

stats

dtype

layout

BFLOAT16

TILE

weight (gamma) and bias (beta)

dtype

layout

BFLOAT16

ROW_MAJOR

Output tensor will be in TILE layout and have the same dtype as the input_tensor

Limitations:
  • Input tensors must be on-device and rank 4.

  • The last padded dim of stats must be a multiple of TILE_WIDTH.

  • The first three padded dims of stats must match input_tensor.

  • If weight (gamma) is provided, bias (beta) must also be provided with matching layouts with their last padded dim matching TILE_WIDTH.

  • Sharded runs: inputs cannot be height-sharded, padded height must equal TILE_HEIGHT (32), and stats must be sharded with num_cores=1 and expected tile columns per device.

Example

# Create input tensor
input_tensor = ttnn.rand([1, 1, 32, 32], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.TILE_LAYOUT, device=device)

# Apply pre-all-gather layer normalization
stats = ttnn.layer_norm_pre_all_gather(input_tensor)
logger.info(f"Layer Norm Pre All Gather result: {stats}")

# On a distributed setup, all gather would go here to collect the stats from all devices
# See documentation for ttnn.all_gather for example usage of all_gather

# Now apply the post-all-gather layer normalization
output = ttnn.layer_norm_post_all_gather(input_tensor, stats)
logger.info(f"Layer Norm Post All Gather result: {output}")

# For reference, this two-step process is equivalent to the following
# output = ttnn.layer_norm(input_tensor)