Basic Convolution

This tutorial details convolution layer operation when building neural models.

We will create an example file, ttnn_basic_conv.py

Import Libraries

[ ]:
import torch
import ttnn
from loguru import logger

Set Manual Seed

Setting a manual seed ensures that results are reproducible by initializing random number generators to a fixed state.

[ ]:
torch.manual_seed(0)

Open the Device

Create the device that will run the program, with custom L1 memory config. The following parameter we use here, l1_small_size, allocates on-chip L1 memory for sliding-window operations like convolutions, kernels, and memory. 8kB is enough for simple CNNs, 32kB or more for more complex models.

[ ]:
device = ttnn.open_device(device_id=0, l1_small_size=8192)

Create Forward Method

This function executes convolution operations on input tensors using initialized layer parameters. Convolutions require the following configuration parameter: ttnn.Conv2dConfig.

[ ]:
def forward(
    input_tensor: ttnn.Tensor,
    weight_tensor: ttnn.Tensor,
    bias_tensor: ttnn.Tensor,
    out_channels: int,
    kernel_size: tuple,
    device: ttnn.Device,
) -> ttnn.Tensor:
    # Permute input from PyTorch BCHW (batch, channel, height, width)
    # to NHWC (batch, height, width, channel) which TTNN expects
    permuted_input = ttnn.permute(input_tensor, (0, 2, 3, 1))

    # Get shape after permutation
    B, H, W, C = permuted_input.shape

    # Reshape input to a flat image of shape (1, 1, B*H*W, C)
    # This flattens the spatial dimensions and prepares it for TTNN conv2d
    reshaped_input = ttnn.reshape(permuted_input, (1, 1, B * H * W, C))

    # Set up convolution configuration for TTNN conv2d
    conv_config = ttnn.Conv2dConfig(weights_dtype=weight_tensor.dtype)

    # Perform 2D convolution using TTNN
    out = ttnn.conv2d(
        input_tensor=reshaped_input,
        weight_tensor=weight_tensor,
        bias_tensor=bias_tensor,
        in_channels=C,
        out_channels=out_channels,
        device=device,
        kernel_size=kernel_size,
        stride=(1, 1),
        padding=(1, 1),
        batch_size=1,
        input_height=1,
        input_width=B * H * W,
        conv_config=conv_config,
        groups=0,  # No grouped convolution
    )

    # Optionally convert back to torch tensor: out_torch = ttnn.to_torch(out)
    return out

Set Input and Convolution Parameters

[ ]:
batch = 1
in_channels = 3
out_channels = 4
height = width = 2  # Small dimensions to avoid device memory issues
kernel_size = (3, 3)

Create Tensors

To create the input tensor, weight tensor, and bias tensor for the convolution operation, see the Tensors Creation Guide.

[ ]:
# Create random input tensor in BCHW format
x = ttnn.rand((batch, in_channels, height, width), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)

# Random weight tensor for convolution: (out_channels, in_channels, kH, kW)
w = ttnn.rand((out_channels, in_channels, *kernel_size), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)

# Bias tensor, broadcastable to the output shape
b = ttnn.zeros((1, 1, 1, out_channels), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)

Run the Convolution Operation

[ ]:
# Run forward conv pass and print output shape
out_torch = forward(x, w, b, out_channels, kernel_size, device)
logger.info(f"✅ Success! Conv2D output shape: {out_torch.shape}")

Close the Device

[ ]:
ttnn.close_device(device)

Full Example and Output

Lets put everything together in a complete example that can be run directly.

ttnn_basic_conv.py

Running this script will generate output the as shown below:

$ python3 $TT_METAL_HOME/ttnn/tutorials/basic_python/ttnn_basic_conv.py
2025-07-07 13:02:09.649 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:02:09.651 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:02:09.658 | info     |          Device | Opening user mode device driver (tt_cluster.cpp:190)
2025-07-07 13:02:09.658 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:02:09.659 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:02:09.666 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:02:09.667 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:02:09.673 | info     |   SiliconDriver | Harvesting mask for chip 0 is 0x100 (NOC0: 0x100, simulated harvesting mask: 0x0). (cluster.cpp:282)
2025-07-07 13:02:09.772 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:02:09.817 | info     |   SiliconDriver | Opening local chip ids/pci ids: {0}/[7] and remote chip ids {} (cluster.cpp:147)
2025-07-07 13:02:09.828 | info     |   SiliconDriver | Software version 6.0.0, Ethernet FW version 6.14.0 (Device 0) (cluster.cpp:1039)
2025-07-07 13:02:09.915 | info     |           Metal | AI CLK for device 0 is:   1000 MHz (metal_context.cpp:128)
2025-07-07 13:02:10.487 | info     |           Metal | Initializing device 0. Program cache is enabled (device.cpp:428)
2025-07-07 13:02:10.489 | warning  |           Metal | Unable to bind worker thread to CPU Core. May see performance degradation. Error Code: 22 (hardware_command_queue.cpp:74)
2025-07-07 13:02:13.921 | warning  |              Op | conv2d: Device weights not properly prepared, pulling back to host and trying to reprocess. (conv2d.cpp:563)
2025-07-07 13:02:13.922 | warning  |              Op | conv2d: Device bias not properly prepared, pulling back to host and reprocessing. (conv2d.cpp:582)
2025-07-07 13:02:15.390 | INFO     | __main__:main:78 - ✅ Success! Conv2D output shape: Shape([1, 1, 4, 4])
2025-07-07 13:02:15.390 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)
2025-07-07 13:02:15.391 | info     |           Metal | Closing mesh device 0 (mesh_device.cpp:488)
2025-07-07 13:02:15.391 | info     |           Metal | Closing device 0 (device.cpp:468)
2025-07-07 13:02:15.391 | info     |           Metal | Disabling and clearing program cache on device 0 (device.cpp:783)
2025-07-07 13:02:15.392 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)