Basic Tensor Operations
This simple example demonstrates how to create various tensors and perform basic arithmetic operations using TT-NN, a high-level Python API. These operations include addition, multiplication, and matrix multiplication, and simulating broadcasting a row vector across a tile.
Let’s create the example file, ttnn_basic_operations.py
Import Libraries
[ ]:
import torch
import numpy as np
import ttnn
from loguru import logger
Open the Device
Create a device to run our program.
[ ]:
# Open the Device
device = ttnn.open_device(device_id=0)
Host Tensor Creation
Create a test tensor with different values that can demonstrate various operations. Learn more about Tensors here.
[ ]:
logger.info("\n--- TT-NN Tensor Creation with Tiles (32x32) ---")
host_rand = torch.rand((32, 32), dtype=torch.float32)
Host Tensor Conversion and Creation
Tensix cores operate most efficiently on tiled data, performing parallel computations. The host tensor is converted to a TT-NN tiled tensor or is created natively on the device.
To convert PyTorch host tensors to TT-NN tiled tensors, use the following helper function: to_tt_tile()
. This helper function creates a device tensor based on the host_rand
PyTorch tensor.
[ ]:
# Helper to create a TT-NN tensor from torch with TILE_LAYOUT and bfloat16
def to_tt_tile(torch_tensor):
return ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
tt_t1 = to_tt_tile(host_rand)
Alternatively, we can create and initialize tensors directly on the device using TT-NN’s tensor creation functions. Creating tensors directly on the device is more efficient, it avoids the overhead of transfering data from the host to the device.
[ ]:
tt_t2 = ttnn.full(
shape=(32, 32),
fill_value=1.0,
dtype=ttnn.float32,
layout=ttnn.TILE_LAYOUT,
device=device,
)
tt_t3 = ttnn.zeros(
shape=(32, 32),
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
)
tt_t4 = ttnn.ones(
shape=(32, 32),
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
)
t5 = np.array([[5, 6], [7, 8]], dtype=np.float32).repeat(16, axis=0).repeat(16, axis=1)
tt_t5 = ttnn.Tensor(t5, device=device, layout=ttnn.TILE_LAYOUT)
Tile-Based Arithmetic Operations
Tensors can perform the following arithmetic operations:
[ ]:
logger.info("\n--- TT-NN Tensor Operations on (32x32) Tiles ---")
add_result = ttnn.add(tt_t1, tt_t4)
logger.info(f"Addition:\n{add_result}")
mul_result = ttnn.mul(tt_t1, tt_t5)
logger.info(f"Element-wise Multiplication:\n{mul_result}")
matmul_result = ttnn.matmul(tt_t4, tt_t1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
logger.info(f"Matrix Multiplication:\n{matmul_result}")
Simulated Broadcasting - Row Vector Expansion
Let’s simulate broadcasting a row vector across a tile. Every element of a given column will contain the same value. This is useful for operations that require expanding a smaller tensor to match the dimensions of a larger one.
[ ]:
logger.info("\n--- Simulated Broadcasting (32x32 + Broadcasted Row Vector) ---")
broadcast_vector = torch.tensor(np.arange(0, 32), dtype=torch.float32).repeat(32, 1)
logger.info(f"Broadcast Row Vector:\n{broadcast_vector}")
broadcast_tt = to_tt_tile(broadcast_vector)
broadcast_add_result = ttnn.add(tt_t1, broadcast_tt)
logger.info(f"Broadcast Add Result (TT-NN):\n{ttnn.to_torch(broadcast_add_result)}")
Close the Device
[ ]:
ttnn.close_device(device)
Full Example and Output
Lets put everything together in a complete example that can be run directly.
Running this script will generate the following output:
$ python3 $TT_METAL_HOME/ttnn/tutorials/basic_python/ttnn_basic_operations.py
2025-07-07 13:13:04.850 | info | SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:13:04.852 | info | SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:13:04.859 | info | Device | Opening user mode device driver (tt_cluster.cpp:190)
2025-07-07 13:13:04.859 | info | SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:13:04.860 | info | SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:13:04.866 | info | SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:13:04.867 | info | SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:13:04.873 | info | SiliconDriver | Harvesting mask for chip 0 is 0x100 (NOC0: 0x100, simulated harvesting mask: 0x0). (cluster.cpp:282)
2025-07-07 13:13:04.970 | info | SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:13:05.015 | info | SiliconDriver | Opening local chip ids/pci ids: {0}/[7] and remote chip ids {} (cluster.cpp:147)
2025-07-07 13:13:05.025 | info | SiliconDriver | Software version 6.0.0, Ethernet FW version 6.14.0 (Device 0) (cluster.cpp:1039)
2025-07-07 13:13:05.111 | info | Metal | AI CLK for device 0 is: 1000 MHz (metal_context.cpp:128)
2025-07-07 13:13:05.678 | info | Metal | Initializing device 0. Program cache is enabled (device.cpp:428)
2025-07-07 13:13:05.680 | warning | Metal | Unable to bind worker thread to CPU Core. May see performance degradation. Error Code: 22 (hardware_command_queue.cpp:74)
2025-07-07 13:13:07.537 | INFO | __main__:main:15 -
--- TT-NN Tensor Creation with Tiles (32x32) ---
2025-07-07 13:13:07.564 | INFO | __main__:main:47 -
--- TT-NN Tensor Operations on (32x32) Tiles ---
2025-07-07 13:13:08.072 | INFO | __main__:main:49 - Addition:
ttnn.Tensor([[ 1.82812, 1.04688, ..., 1.32812, 1.00781],
[ 1.39844, 1.03906, ..., 1.14844, 1.24219],
...,
[ 1.65625, 1.32812, ..., 1.31250, 1.21094],
[ 1.21875, 1.33594, ..., 1.37500, 1.62500]], shape=Shape([32, 32]), dtype=DataType::BFLOAT16, layout=Layout::TILE)
2025-07-07 13:13:13.670 | INFO | __main__:main:52 - Element-wise Multiplication:
ttnn.Tensor([[ 4.12500, 0.23438, ..., 1.96875, 0.02600],
[ 1.97656, 0.18164, ..., 0.87891, 1.44531],
...,
[ 4.59375, 2.31250, ..., 2.48438, 1.65625],
[ 1.50781, 2.35938, ..., 2.96875, 4.96875]], shape=Shape([32, 32]), dtype=DataType::BFLOAT16, layout=Layout::TILE)
2025-07-07 13:13:14.229 | INFO | __main__:main:55 - Matrix Multiplication:
ttnn.Tensor([[16.50000, 14.25000, ..., 15.56250, 14.43750],
[16.50000, 14.25000, ..., 15.56250, 14.43750],
...,
[16.50000, 14.25000, ..., 15.56250, 14.43750],
[16.50000, 14.25000, ..., 15.56250, 14.43750]], shape=Shape([32, 32]), dtype=DataType::BFLOAT16, layout=Layout::TILE)
2025-07-07 13:13:14.229 | INFO | __main__:main:57 -
--- Simulated Broadcasting (32x32 + Broadcasted Row Vector) ---
2025-07-07 13:13:14.231 | INFO | __main__:main:59 - Broadcast Row Vector:
tensor([[ 0., 1., 2., ..., 29., 30., 31.],
[ 0., 1., 2., ..., 29., 30., 31.],
[ 0., 1., 2., ..., 29., 30., 31.],
...,
[ 0., 1., 2., ..., 29., 30., 31.],
[ 0., 1., 2., ..., 29., 30., 31.],
[ 0., 1., 2., ..., 29., 30., 31.]])
2025-07-07 13:13:14.233 | INFO | __main__:main:63 - Broadcast Add Result (TT-NN):
tensor([[ 0.8242, 1.0469, 2.2500, ..., 29.0000, 30.3750, 31.0000],
[ 0.3945, 1.0391, 2.5625, ..., 29.1250, 30.1250, 31.2500],
[ 0.2188, 1.8750, 2.4375, ..., 29.7500, 30.8750, 31.6250],
...,
[ 0.7422, 1.1484, 2.9531, ..., 29.1250, 30.5000, 31.1250],
[ 0.6562, 1.3281, 2.0938, ..., 29.3750, 30.3750, 31.2500],
[ 0.2158, 1.3359, 2.8438, ..., 29.2500, 30.3750, 31.6250]],
dtype=torch.bfloat16)
2025-07-07 13:13:14.233 | info | Metal | Closing mesh device 1 (mesh_device.cpp:488)
2025-07-07 13:13:14.234 | info | Metal | Closing mesh device 0 (mesh_device.cpp:488)
2025-07-07 13:13:14.234 | info | Metal | Closing device 0 (device.cpp:468)
2025-07-07 13:13:14.234 | info | Metal | Disabling and clearing program cache on device 0 (device.cpp:783)
2025-07-07 13:13:14.234 | info | Metal | Closing mesh device 1 (mesh_device.cpp:488)