ttnn.avg_pool2d

ttnn.avg_pool2d(input_tensor_a: ttnn.Tensor, batch_size: int, input_h: int, input_w: int, channels: int, kernel_size: List of [int], stride: List of [int], padding: List of [int], ceil_mode: bool, count_include_pad: bool, divisor_override: int | None, *, memory_config: ttnn.MemoryConfig = None, applied_shard_scheme: ttnn.TensorMemoryLayout = None, deallocate_input: bool = False, reallocate_halo_output: bool = True, dtype: ttnn.DataType = ttnn.bfloat16, output_layout: ttnn.Layout = ttnn.ROW_MAJOR_LAYOUT, compute_kernel_config: DeviceComputeKernelConfig = None) ttnn.Tensor

Applies an average pool convolution to the input tensor. The resulting output Tensor will contain the average value for each channel within a kernel window. The input tensor is expected to be in [NHW, C] format and should be on the device. Height, width and block sharding schemes are supported.

Parameters:
  • input_tensor_a (ttnn.Tensor) – the tensor to be convolved.

  • batch_size (int) – the number of batches (N in a [N, C, H, W] shaped tensor).

  • input_h (int) – the height of the input tensor (H in a [N, C, H, W] shaped tensor).

  • input_w (int) – the width of the input tensor (W in a [N, C, H, W] shaped tensor).

  • channels (int) – the number of channels (C in a [N, C, H, W] shaped tensor).

  • kernel_size (List of [int]) – the (h, w) size of the kernel window.

  • stride (List of [int]) – the (h, w) stride of the kernel window.

  • padding (List of [int]) – the (h, w) padding of the input tensor.

  • ceil_mode (bool) – When True, uses ‘ceiling’ function instead of ‘floor’ function in the formula to compute output shape. Default: False.

  • count_include_pad (bool) – When True, includes zero-padding in the avg calculation. Default: True.

  • divisor_override (int, optional) – If specified, it will be used as a divisor, otherwise size of the pooling region will be used. Default: None. Not currently supported in ttnn.

Keyword Arguments:
  • memory_config (ttnn.MemoryConfig, optional) – the memory configuration for the output tensor. Defaults to None.

  • applied_shard_scheme (ttnn.TensorMemoryLayout, optional) – the sharding scheme to apply to a non-pre-sharded input tensor. Defaults to None, which should be used with pre-sharded input tensors.

  • deallocate_input (bool, optional) – whether to deallocate the input tensor after the operation. Defaults to False.

  • reallocate_halo_output (bool, optional) – whether to reallocate the halo output tensor after the operation, ideally used with deallocate_activation = true. Defaults to True.

  • dtype (ttnn.DataType, optional) – the data format for the output tensor. Defaults to ttnn.bfloat16.

  • output_layout (ttnn.Layout, optional) – the layout for the output tensor. Defaults to ttnn.ROW_MAJOR_LAYOUT.

  • compute_kernel_config (DeviceComputeKernelConfig, optional) – the device compute kernel configuration. Defaults to None.

Returns:

ttnn.Tensor – the average pool convolved output tensor.

Example

device = ttnn.CreateDevice(0, l1_small_size=8192)
# Define input parameters
kernel_h, kernel_w = 2, 2
stride_h, stride_w = 1, 1
pad_h, pad_w = 0, 0
nchw_shape = (4, 256, 40, 40)
dilation_h, dilation_w = 1, 1
in_N, in_C, in_H, in_W = nchw_shape
input_shape = (1, 1, in_N * in_H * in_W, in_C)

# Create a random input tensor
input = torch.randn(nchw_shape, dtype=torch.bfloat16)
input_perm = torch.permute(input, (0, 2, 3, 1))  # this op expects a [N, H, W, C] format
input_reshape = input_perm.reshape(input_shape)  # this op expects [1, 1, NHW, C]
tt_input = ttnn.from_torch(input_reshape, device=device)

# Perform average pooling
tt_output = ttnn.avg_pool2d(
    input_tensor=tt_input,
    batch_size=in_N,
    input_h=in_H,
    input_w=in_W,
    channels=in_C,
    kernel_size=[kernel_h, kernel_w],
    stride=[stride_h, stride_w],
    padding=[pad_h, pad_w],
    dilation=[dilation_h, dilation_w],
    ceil_mode=False,
    count_include_pad=True,
    divisor_override=None,
    memory_config=None,
    applied_shard_scheme=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
    deallocate_input=False,
    reallocate_halo_output=True,
    dtype=ttnn.bfloat16,
    output_layout=ttnn.ROW_MAJOR_LAYOUT,
)
logger.info(f"Output: {tt_output}")
ttnn.close_device(device)