Tensor and Add Operation

ttnn.Tensor is the central type of ttnn.

It is similar to torch.Tensor in the sense that it represents multi-dimensional matrix containing elements of a single data type.

The are a few key differences:

  • ttnn.Tensor can be stored in the SRAM or DRAM of TensTorrent devices

  • ttnn.Tensor doesn’t have a concept of the strides, however it has a concept of row-major and tile layout

  • ttnn.Tensor has support for data types not supported by torch such as bfp8 for example

  • ttnn.Tensor’s shape stores the padding added to the tensor due to TILE_LAYOUT

Creating a tensor

The recommended way to create a tensor is by using torch create function and then simply calling ttnn.from_torch. So, let’s import both torch and ttnn

[1]:
import torch
import ttnn
2024-07-11 18:12:48.818 | DEBUG    | ttnn:<module>:136 - Initial ttnn.CONFIG:
{'cache_path': PosixPath('/home/ubuntu/.cache/ttnn'),
 'comparison_mode_pcc': 0.9999,
 'enable_comparison_mode': False,
 'enable_detailed_buffer_report': False,
 'enable_detailed_tensor_report': False,
 'enable_fast_runtime_mode': True,
 'enable_graph_report': False,
 'enable_logging': False,
 'enable_model_cache': False,
 'model_cache_path': PosixPath('/home/ubuntu/.cache/ttnn/models'),
 'report_name': None,
 'root_report_path': PosixPath('generated/ttnn/reports'),
 'throw_exception_on_fallback': False,
 'tmp_dir': PosixPath('/tmp/ttnn')}
2024-07-11 18:12:48.905 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.logical_xor be migrated to C++?
2024-07-11 18:12:48.906 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.xlogy be migrated to C++?
2024-07-11 18:12:48.906 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.maximum be migrated to C++?
2024-07-11 18:12:48.907 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.minimum be migrated to C++?
2024-07-11 18:12:48.908 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.atan2 be migrated to C++?
2024-07-11 18:12:48.909 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.hypot be migrated to C++?
2024-07-11 18:12:48.910 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.nextafter be migrated to C++?
2024-07-11 18:12:48.911 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.polyval be migrated to C++?
2024-07-11 18:12:48.911 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.isclose be migrated to C++?
2024-07-11 18:12:48.914 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.all_gather be migrated to C++?
2024-07-11 18:12:48.915 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.pearson_correlation_coefficient be migrated to C++?
2024-07-11 18:12:48.919 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.conv2d be migrated to C++?
2024-07-11 18:12:48.920 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.reshape be migrated to C++?
2024-07-11 18:12:48.921 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.unsqueeze_to_4D be migrated to C++?
2024-07-11 18:12:48.922 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.squeeze be migrated to C++?
2024-07-11 18:12:48.923 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.from_torch be migrated to C++?
2024-07-11 18:12:48.923 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.to_torch be migrated to C++?
2024-07-11 18:12:48.924 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.to_device be migrated to C++?
2024-07-11 18:12:48.925 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.from_device be migrated to C++?
2024-07-11 18:12:48.926 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.allocate_tensor_on_device be migrated to C++?
2024-07-11 18:12:48.926 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.copy_host_to_device_tensor be migrated to C++?
2024-07-11 18:12:48.927 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.deallocate be migrated to C++?
2024-07-11 18:12:48.928 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.clone be migrated to C++?
2024-07-11 18:12:48.929 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.reallocate be migrated to C++?
2024-07-11 18:12:48.929 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.load_tensor be migrated to C++?
2024-07-11 18:12:48.930 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.dump_tensor be migrated to C++?
2024-07-11 18:12:48.931 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.as_tensor be migrated to C++?
2024-07-11 18:12:48.934 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.arange be migrated to C++?
2024-07-11 18:12:48.935 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.mse_loss be migrated to C++?
2024-07-11 18:12:48.936 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.l1_loss be migrated to C++?
2024-07-11 18:12:48.937 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.matmul be migrated to C++?
2024-07-11 18:12:48.938 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.linear be migrated to C++?
2024-07-11 18:12:48.941 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.mac be migrated to C++?
2024-07-11 18:12:48.942 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.addcmul be migrated to C++?
2024-07-11 18:12:48.942 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.addcdiv be migrated to C++?
2024-07-11 18:12:48.943 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.lerp be migrated to C++?
2024-07-11 18:12:48.948 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.logit be migrated to C++?
2024-07-11 18:12:48.949 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.polygamma be migrated to C++?
2024-07-11 18:12:48.950 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.hardshrink be migrated to C++?
2024-07-11 18:12:48.950 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.celu be migrated to C++?
2024-07-11 18:12:48.951 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.softshrink be migrated to C++?
2024-07-11 18:12:48.952 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.clip be migrated to C++?
2024-07-11 18:12:48.952 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.threshold be migrated to C++?
2024-07-11 18:12:48.953 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.glu be migrated to C++?
2024-07-11 18:12:48.954 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.reglu be migrated to C++?
2024-07-11 18:12:48.955 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.swiglu be migrated to C++?
2024-07-11 18:12:48.955 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.geglu be migrated to C++?
2024-07-11 18:12:48.958 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.matmul be migrated to C++?
2024-07-11 18:12:48.959 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.linear be migrated to C++?
2024-07-11 18:12:48.960 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.conv2d be migrated to C++?
[2]:
import os

And now let’s create a torch Tensor and convert it to ttnn Tensor

[3]:
torch_tensor = torch.rand(3, 4)
ttnn_tensor = ttnn.from_torch(torch_tensor)

print(f"shape: {ttnn_tensor.shape}")
print(f"layout: {ttnn_tensor.layout}")
print(f"dtype: {ttnn_tensor.dtype}")
shape: ttnn.Shape([3, 4])
layout: Layout.ROW_MAJOR
dtype: DataType.FLOAT32

As expected we get a tensor of shape [3, 4] in row-major layout with a data type of float32.

Host Storage: Borrowed vs Owned

In this particular case, ttnn Tensor will borrow the data of the torch Tensor because ttnn Tensor is in row-major layout, torch tensor is contiguous and their data type matches.

Let’s print the current ttnn tensor, set element of torch tensor to 1234 and print the ttnn Tensor again to see borrowed storage in action

[4]:
print(f"Original values:\n{ttnn_tensor}")
torch_tensor[:] = 1234
print(f"New values are all going to be 1234:\n{ttnn_tensor}")
Original values:
ttnn.Tensor([[ 0.98300,  0.11301,  ...,  0.37592,  0.64318],
             [ 0.53437,  0.59434,  ...,  0.69190,  0.04268],
             [ 0.33346,  0.20231,  ...,  0.15127,  0.58303]], shape=Shape([3, 4]), dtype=DataType::FLOAT32, layout=Layout::ROW_MAJOR)
New values are all going to be 1234:
ttnn.Tensor([[1234.00000, 1234.00000,  ..., 1234.00000, 1234.00000],
             [1234.00000, 1234.00000,  ..., 1234.00000, 1234.00000],
             [1234.00000, 1234.00000,  ..., 1234.00000, 1234.00000]], shape=Shape([3, 4]), dtype=DataType::FLOAT32, layout=Layout::ROW_MAJOR)

We try our best to use borrowed storage but if the torch data type is not supported in ttnn, then we don’t have a choice but to automatically pick a different data type and copy data

[5]:
torch_tensor = torch.rand(3, 4).to(torch.float16)
ttnn_tensor = ttnn.from_torch(torch_tensor)
print("torch_tensor.dtype:", torch_tensor.dtype)
print("ttnn_tensor.dtype:", ttnn_tensor.dtype)

print(f"Original values:\n{ttnn_tensor}")
torch_tensor[0, 0] = 1234
#print(f"Original values again because the tensor doesn't use borrowed storage:\n{ttnn_tensor}")
torch_tensor.dtype: torch.float16
ttnn_tensor.dtype: DataType.BFLOAT16
Original values:
ttnn.Tensor([[ 0.80078,  0.69531,  ...,  0.71484,  0.33398],
             [ 0.60156,  0.36523,  ...,  0.73047,  0.90625],
             [ 0.59766,  0.83203,  ...,  0.61719,  0.53516]], shape=Shape([3, 4]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)

Data Type

The data type of the ttnn tensor can be controlled explicitly when conversion from torch.

[6]:
torch_tensor = torch.rand(3, 4).to(torch.float32)
ttnn_tensor = ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16)
print(f"torch_tensor.dtype: {torch_tensor.dtype}")
print(f"ttnn_tensor.dtype: {ttnn_tensor.dtype}")
torch_tensor.dtype: torch.float32
ttnn_tensor.dtype: DataType.BFLOAT16

Layout

TensTorrent hardware is most efficiently utilized when running tensors using tile layout. The current tile size is hard-coded to [32, 32]. It was determined to be the optimal size for a tile given the compute, memory and data transfer constraints.

ttnn provides easy and intuitive way to convert from row-major layout to tile layout and back.

[7]:
torch_tensor = torch.rand(3, 4).to(torch.float16)
ttnn_tensor = ttnn.from_torch(torch_tensor)
print(f"Tensor in row-major layout:\nShape {ttnn_tensor.shape}\nLayout: {ttnn_tensor.layout}\n{ttnn_tensor}")
ttnn_tensor = ttnn.to_layout(ttnn_tensor, ttnn.TILE_LAYOUT)
print(f"Tensor in tile layout:\nShape {ttnn_tensor.shape}\nLayout: {ttnn_tensor.layout}\n{ttnn_tensor}")
ttnn_tensor = ttnn.to_layout(ttnn_tensor, ttnn.ROW_MAJOR_LAYOUT)
print(f"Tensor back in row-major layout:\nShape {ttnn_tensor.shape}\nLayout: {ttnn_tensor.layout}\n{ttnn_tensor}")
Tensor in row-major layout:
Shape ttnn.Shape([3, 4])
Layout: Layout.ROW_MAJOR
ttnn.Tensor([[ 0.21680,  0.24316,  ...,  0.19336,  0.40625],
             [ 0.81641,  0.50781,  ...,  0.09961,  0.54688],
             [ 0.70703,  0.93359,  ...,  0.06787,  0.75781]], shape=Shape([3, 4]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)
Tensor in tile layout:
Shape ttnn.Shape([3[32], 4[32]])
Layout: Layout.TILE
ttnn.Tensor([[ 0.21680,  0.24316,  ...,  0.00000,  0.00000],
             [ 0.70703,  0.93359,  ...,  0.00000,  0.00000],
             ...,
             [ 0.00000,  0.00000,  ...,  0.00000,  0.00000],
             [ 0.00000,  0.00000,  ...,  0.00000,  0.00000]], shape=Shape([3[32], 4[32]]), dtype=DataType::BFLOAT16, layout=Layout::TILE)
Tensor back in row-major layout:
Shape ttnn.Shape([3, 4])
Layout: Layout.ROW_MAJOR
ttnn.Tensor([[ 0.21680,  0.24316,  ...,  0.19336,  0.40625],
             [ 0.81641,  0.50781,  ...,  0.09961,  0.54688],
             [ 0.70703,  0.93359,  ...,  0.06787,  0.75781]], shape=Shape([3, 4]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)

Note that padding is automatically inserted to put the tensor into tile layout and it automatically removed after the tensor is converted back to row-major layout

The conversion to tile layout can be done when caling ttnn.from_torch

[8]:
torch_tensor = torch.rand(3, 4).to(torch.float16)
ttnn_tensor = ttnn.from_torch(torch_tensor)
print(f"Tensor in row-major layout:\nShape {ttnn_tensor.shape}; Layout: {ttnn_tensor.layout}")
Tensor in row-major layout:
Shape ttnn.Shape([3, 4]); Layout: Layout.ROW_MAJOR

Note that ttnn.to_torch will always convert to row-major layout

Device storage

Finally, in order to actually utilize the tensor, we need to put it on the device. So, that we can run ttnn operations on it

Open the device

Use ttnn.open to get a handle to the device

[9]:
device_id = 0
device = ttnn.open_device(device_id=device_id)
                 Device | INFO     | Opening user mode device driver

2024-07-11 18:12:49.027 | INFO     | SiliconDriver   - Detected 1 PCI device : {0}
2024-07-11 18:12:49.040 | WARNING  | SiliconDriver   - init_detect_tt_device_numanodes(): Could not determine NumaNodeSet for TT device (physical_device_id: 0 pci_bus_id: 0000:07:00.0)
2024-07-11 18:12:49.040 | WARNING  | SiliconDriver   - Could not find NumaNodeSet for TT Device (physical_device_id: 0 pci_bus_id: 0000:07:00.0)
2024-07-11 18:12:49.041 | WARNING  | SiliconDriver   - bind_area_memory_nodeset(): Unable to determine TT Device to NumaNode mapping for physical_device_id: 0. Skipping membind.
---- ttSiliconDevice::init_hugepage: bind_area_to_memory_nodeset() failed (physical_device_id: 0 ch: 0). Hugepage allocation is not on NumaNode matching TT Device. Side-Effect is decreased Device->Host perf (Issue #893).
2024-07-11 18:12:49.082 | INFO     | SiliconDriver   - Software version 6.0.0, Ethernet FW version 6.9.0 (Device 0)
                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   800 MHz

Initialize tensors a and b with random values using torch

To create a tensor that can be used by a ttnn operation: 1. Create a tensor using torch 2. Use ttnn.from_torch to convert the tensor from torch.Tensor to ttnn.Tensor, change the layout to ttnn.TILE_LAYOUT and put the tensor on the device

[10]:
torch.manual_seed(0)

torch_input_tensor_a = torch.rand((32, 32), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((32, 32), dtype=torch.bfloat16)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

Add tensor a and b

ttnn supports operator overloading, therefore operator + can be used instead of torch.add

[11]:
output_tensor = input_tensor_a + input_tensor_b

Inspect the output tensor of the add in ttnn

As can be seen the tensor of the same shape, layout and dtype is produced

[12]:
print(f"shape: {output_tensor.shape}")
print(f"dtype: {output_tensor.dtype}")
print(f"layout: {output_tensor.layout}")
shape: ttnn.Shape([32, 32])
dtype: DataType.BFLOAT16
layout: Layout.TILE

In general we expect layout and dtype to stay the same when running most operations unless explicit arguments to modify them are passed in. However, there are obvious exceptions like an embedding operation that takes in ttnn.uint32 and produces ttnn.bfloat16

Convert to torch and inspect the attributes of the torch tensor

When converting the tensor to torch, ttnn.to_torch will move the tensor from the device, convert to tile layout and figure out the best data type to use on the torch side

[13]:
output_tensor = ttnn.to_torch(output_tensor)
print(f"shape: {output_tensor.shape}")
print(f"dtype: {output_tensor.dtype}")
shape: torch.Size([32, 32])
dtype: torch.bfloat16

Close the device

Close the handle the device. This is a very important step as the device can hang currently if not closed properly

[14]:
ttnn.close_device(device)
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0