ttnn.layer_norm_pre_all_gather
- ttnn.layer_norm_pre_all_gather() ttnn.Tensor
-
This operation is used in conjunction with
ttnn.layer_norm_post_all_gather()to compute layer norm on a distributed setup, where layer norm is defined as:\[\text{layer_norm}(x, \gamma, \beta, \epsilon) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta\]- Where:
-
\(\mu\) is the mean of the input tensor. This is computed over the last dimension of the input tensor (W).
\(\sigma^2\) is the variance of the input tensor. This is computed over the last dimension of the input tensor (W) and is biased.
\(\gamma\) and \(\beta\) are the learnable scale and shift parameters, respectively
\(\epsilon\) is a small constant
See Layer Normalization for more details.
This operation computes \(\sum_{}^{}x\) and \(\sum_{}^{}x^2\) over the last dimension. Its output should be combined across devices with
ttnn.all_gather(), then followed byttnn.layer_norm_post_all_gather()to compute layer norm.- Parameters:
-
input_tensor (ttnn.Tensor) – the input tensor.
- Keyword Arguments:
-
dtype (ttnn.DataType, optional) – the data type of the output tensor. Defaults to BFLOAT16.
residual_input_tensor (ttnn.Tensor, optional) – the residual input tensor. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – the compute kernel configuration. Defaults to None.
program_config (ttnn.ProgramConfig, optional) – the program configuration. Defaults to None.
memory_config (ttnn.MemoryConfig, optional) – the memory configuration. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.
Note
Supported data types and layouts by tensor:
input_tensor dtype
layout
BFLOAT16, FLOAT32, BFLOAT8_B
TILE
residual_input_tensor dtype
layout
BFLOAT16, FLOAT32, BFLOAT8_B
TILE
Output stats tensor will be in TILE layout and have dtype of BFLOAT16.
- Limitations:
-
Input tensors must be on-device and rank 4.
Unsharded runs:
input_tensormust be interleaved.Sharded runs: inputs cannot be height-sharded, padded height must equal TILE_HEIGHT (32).
When using
residual_input_tensorwith sharding, it must match theinput_tensorpadded shape and sharding.
Example
# Create input tensor input_tensor = ttnn.rand([1, 1, 32, 32], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.TILE_LAYOUT, device=device) # Apply pre-all-gather layer normalization stats = ttnn.layer_norm_pre_all_gather(input_tensor) logger.info(f"Layer Norm Pre All Gather result: {stats}") # On a distributed setup, all gather would go here to collect the stats from all devices # See documentation for ttnn.all_gather for example usage of all_gather # Now apply the post-all-gather layer normalization output = ttnn.layer_norm_post_all_gather(input_tensor, stats) logger.info(f"Layer Norm Post All Gather result: {output}") # For reference, this two-step process is equivalent to the following # output = ttnn.layer_norm(input_tensor)