ttnn.group_norm

ttnn.group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, input_mask: ttnn.Tensor = None, weight: ttnn.Tensor = None, bias: ttnn.Tensor = None, memory_config: ttnn.MemoryConfig = None, dtype: ttnn.DataType = None, core_grid: CoreGrid = None, inplace: bool = True, output_layout: ttnn.Layout = None, num_out_blocks: int = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None, negative_mask: ttnn.Tensor = None, use_welford: bool = False, reciprocals: ttnn.Tensor = None) ttnn.Tensor

Computes group_norm over input_tensor. See Group Normalization for more details.

\[\text{group_norm}(x, \gamma, \beta, \epsilon) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta\]
Where:
  • \(\mu\) and \(\sigma^2\) are the mean and variance of the input tensor, respectively

  • \(\gamma\) and \(\beta\) are the learnable scale and shift parameters, respectively

  • \(\epsilon\) is a small constant.

GroupNorm traditionally operates by splitting the input tensor’s channels into groups, and then computing the mean and variance of each group. This implementation is slightly different, in that it forms the groups using the tensor’s last dimension. Concretely, the input tensor is expected to have a shape of [N, 1, H*W, C], where C is the dimension along which the groups are formed.

TTNN provides utility functions to help prepare this op’s inputs for different types of input tensors:
  • When using sharded input tensors, ttnn.determine_expected_group_norm_sharded_config_and_grid_size() can provide the appropriate memory configuration and grid size.

  • When using interleaved (DRAM) input tensors, ttnn.determine_expected_group_norm_dram_grid_size() can provide the appropriate grid size.

  • ttnn.dram_group_norm_params_from_torch() is a convenience function that prepares the weight, bias, and input mask from PyTorch tensors for interleaved inputs.

  • ttnn.get_group_norm_cores_across_channel() returns the number of cores that split the channel axis for a given memory layout, grid, and shard orientation. This value must be passed consistently to ttnn.create_group_norm_input_mask() and ttnn.create_group_norm_weight_bias_rm(). For HEIGHT_SHARDED inputs this is always 1; for BLOCK_SHARDED inputs it is core_grid.x when using ROW_MAJOR shard orientation, or core_grid.y when using COL_MAJOR.

  • ttnn.create_group_norm_input_mask() creates the appropriate input mask for a given tensor dimension and group size.

  • ttnn.create_group_norm_weight_bias_rm() converts the weight and bias tensors into appropriately padded and tiled inputs

See the sharded example in this document for more details on how to properly prepare the op’s inputs using these functions.

Parameters:

input_tensor (ttnn.Tensor) – the input tensor.

Keyword Arguments:
  • num_groups (int) – Number of groups to split the tensor’s channels into.

  • epsilon (float) – Defaults to 1e-12.

  • input_mask (ttnn.Tensor, optional) – Defaults to None. When processing the inputs, the mask is used to only look at the elements of the current group.

  • weight (ttnn.Tensor, optional) – Gamma (scale) parameter for the affine transformation. When omitted, no scaling is applied. Defaults to None.

  • bias (ttnn.Tensor, optional) – Beta (shift) parameter for the affine transformation. When omitted, no shift is applied. Defaults to None.

  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

  • dtype (ttnn.DataType, optional) – Defaults to None.

  • core_grid (CoreGrid, optional) – Must be provided (see limitations). Defaults to None.

  • inplace (bool, optional) – Defaults to True.

  • output_layout (ttnn.Layout, optional) – Defaults to None.

  • num_out_blocks (int, optional) – Allows the output to be processed in multiple smaller chunks, to reduce the amount of L1 required at a time. Should only be used if needed to relieve L1 pressure, as this negatively impacts performance. Defaults to None.

  • compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – Compute kernel configuration for the op. Defaults to None.

  • negative_mask (ttnn.Tensor, optional) – Defaults to None. Can be used only in row-major sharded input/output tensors. Used to reduce the number of CB’s used in the sharded version of the kernel by overlapping the CB’s used for tilized input and output. (The kernel is in fact row major variant, but is internally tilizing RM into tilized inputs).

  • use_welford (bool, optional) – Defaults to False. If True, the Welford’s algorithm is used to compute the mean and variance.

  • reciprocals (ttnn.Tensor, optional) – Defaults to None. FP32 tensor containing pre-computed reciprocal values. Only valid when use_welford is True. Must be sharded to L1 memory in each core.

Returns:

ttnn.Tensor – the output tensor.

Note

The supported input data types and layouts:

input_tensor

dtype

layout

BFLOAT16

TILE, ROW_MAJOR

weight (gamma) and bias (beta)

dtype

layout

BFLOAT16

ROW_MAJOR

input_mask

dtype

layout

BFLOAT16, BFLOAT8_B

TILE

The output will be BFLOAT16, and both the layout and the memory configuration will match the input_tensor.

Memory Support:
  • Interleaved: DRAM and L1

  • Sharded (L1): Height and Block sharded

Limitations:
  • input_tensor is a 4D tensor of shape [N, 1, H*W, C] and is allocated on the device

  • For the input_tensor, H*W must be a multiple of the tile size (32) and C must be a multiple of the tile size and divide evenly into num_groups.

  • For the input_mask, C must match the number of groups, H must match a tile’s height, and W must be a multiple of a tile’s width.

  • core_grid must be provided

  • inplace is not supported for TILE-layout inputs and requires input and output layouts to be identical.

  • When generating inputs (e.g. weight, bias) for block sharded tensors, the number of cores in a column should draw upon core.x rather than core.y.

  • When generating inputs (e.g. weight, bias) for height sharded tensors, the number of cores in a column should be 1 rather than core.y.

  • Width-sharding is not supported

Example

#
# Sharded Input Tensor Example
#
N, C, H, W = 1, 64, 32, 1
num_groups = 2

# Prepare random inputs
torch_input_tensor = torch.rand((N, C, H, W), dtype=torch.bfloat16)
torch_weight = torch.rand((C,), dtype=torch.bfloat16)
torch_bias = torch.rand((C,), dtype=torch.bfloat16)

# Generate random inputs and prepare reference output
torch_output_tensor = torch.nn.functional.group_norm(
    torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias
)

# Permute the torch output to match the TTNN format, so they can be compared
torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

# Prepare TTNN input
# Determine how to shard the input tensor - this example uses height sharding
# For interleaved (non-sharded) tensors, use ttnn.determine_expected_group_norm_dram_grid_size instead to determine the grid size.
sharded_mem_config, grid_size = ttnn.determine_expected_group_norm_sharded_config_and_grid_size(
    device=device,
    num_channels=C,
    num_groups=num_groups,
    input_nhw=N * H * W,
    is_height_sharded=True,
    is_row_major=True,
)

input_tensor = ttnn.from_torch(
    torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C),
    dtype=ttnn.DataType.BFLOAT16,
    layout=ttnn.ROW_MAJOR_LAYOUT,
    device=device,
    memory_config=sharded_mem_config,
)

# To prepare the input tensors, we need to know how many cores split the channel dimension.
# For height sharding (as in this example) this is always 1; for block sharding it depends on the shard orientation.
num_cores_across_channel = ttnn.get_group_norm_cores_across_channel(
    ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED,
    grid_size,
    ttnn.ShardOrientation.ROW_MAJOR,
)

# Create the input mask which helps each group select the correct elements of the input tensor
# In general, it will have dimensions of [1, num_groups, 32, 32*block_wt]

# In this example, C=64 and num_groups=2, so each group is 32 channels (i.e. one tile) wide
# As a result, the input_mask_tensor is a [1, 2, 32, 32] tensor where every value is 1

# If instead num_groups was 4, each group would be 16 channels (i.e. half a tile) wide
# As a result, the input_mask_tensor would be a [1, 4, 32, 32] tensor that selects either the first or second half of the tile
# e.g. The mask at [0][0][:][:] would be a 32x32 tensor where the left half is 1 and the right half is 0
# While [0][1][:][:] would be a 32x32 tensor where the left half is 0 and the right half is 1
input_mask_tensor = ttnn.create_group_norm_input_mask(
    num_channel=C,
    num_groups=num_groups,
    num_cores_across_channel=num_cores_across_channel,
    data_type=ttnn.bfloat8_b,
)
input_mask_tensor = ttnn.to_device(input_mask_tensor, device)

# Prepare gamma and beta for TTNN. Currently these are just 1D tensors of size [C], which isn't compatible with tile based processing
# First they will zero padded if needed (does not apply to this example)
# Then reshaped to be [1, 1, tiles_per_core_total, 32], which in this case will be [1, 1, 2, 32]
gamma = ttnn.create_group_norm_weight_bias_rm(
    input_tensor=torch_weight, num_channels=C, num_cores_x=num_cores_across_channel
)
beta = ttnn.create_group_norm_weight_bias_rm(
    input_tensor=torch_bias, num_channels=C, num_cores_x=num_cores_across_channel
)

gamma_t = ttnn.from_torch(
    gamma,
    dtype=ttnn.DataType.BFLOAT16,
    layout=ttnn.ROW_MAJOR_LAYOUT,
    device=device,
    memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
beta_t = ttnn.from_torch(
    beta,
    dtype=ttnn.DataType.BFLOAT16,
    layout=ttnn.ROW_MAJOR_LAYOUT,
    device=device,
    memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

# Compute the TTNN output and compare with the reference output
output_tensor = ttnn.group_norm(
    input_tensor,
    num_groups=num_groups,
    input_mask=input_mask_tensor,
    weight=gamma_t,
    bias=beta_t,
    memory_config=sharded_mem_config,
    core_grid=grid_size,
)

output_tensor = ttnn.to_torch(output_tensor)
logger.info(f"Group Norm result: {output_tensor}")

#
# Base example with tilized input
#
tile_size = 32
N, C, H, W = 1, 480, 1, 64
grid_size = ttnn.CoreGrid(y=1, x=1)
num_out_blocks = 1

num_groups = 8  # This must be a multiple of grid_size.y (1 in this example)

input_tensor_row_major = ttnn.rand(
    [N, 1, H * W, C], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device
)
input_tensor_tilized = ttnn.tilize_with_zero_padding(input_tensor_row_major, use_multicore=True)

# input mask
width_per_group = C // num_groups  # C must be a multiple of num_groups
max_tiles_group_can_span = 1 + math.ceil((width_per_group - 1) / tile_size)
input_mask_tensor = ttnn.zeros(
    [1, num_groups, tile_size, max_tiles_group_can_span * tile_size],
    dtype=ttnn.DataType.BFLOAT8_B,
    layout=ttnn.TILE_LAYOUT,
    device=device,
)

# gamma/beta
values_per_chunk = (
    C // grid_size.y
)  # 480 / 1 = 480. Note that 480 is a multiple of 32, so no padding up to the next tile is needed.
values_per_chunk_per_tile = values_per_chunk // tile_size  # 480 / 32 = 15

gamma_beta = ttnn.rand(
    [1, 1, values_per_chunk_per_tile, 32], dtype=ttnn.DataType.BFLOAT16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device
)

# groupnorm
output_tensor = ttnn.group_norm(
    input_tensor_tilized,
    num_groups=num_groups,
    input_mask=input_mask_tensor,
    weight=gamma_beta,
    bias=gamma_beta,
    output_layout=ttnn.TILE_LAYOUT,
    core_grid=grid_size,
    inplace=False,
    num_out_blocks=num_out_blocks,
)
logger.info(f"Group Norm result: {output_tensor}")