Tensor

Overview

The TT Tensor library provides support for creation and manipulation of TT Tensors.

This library is used by TT-Dispatch to represent tensors that can be sent to and received from TT-Metal platform. Operations in ttDNN library also utilize this library: operation take TT Tensors as inputs and return TT Tensors as outputs.

This library only supports tensors of rank 4.

TT Tensor library provides support for different memory layouts of data stored within tensor.

ROW_MAJOR layout will store values in memory row by row, starting from last dimension of tensor. For a tensor of shape [W, Z, Y, X] to be stored in ROW_MAJOR order on TT Accelerator device, X must be divisible by 2. A tensor in ROW_MAJOR order with X not divisible by 2 can exist on host machine, but can’t be sent TT Accelerator device. So you can’t provide a TT Accelerator device to TT Tensor construct for this type of tensor nor can you use ttnn.Tensor.to() to send this type of tensor to TT Accelerator device.

TILE layout will store values in memory tile by tile, starting from the last two dimensions of the tensor. A tile is a (32, 32) shaped subsection of tensor. Tiles are stored in memory in row major order, and then values inside tiles are stored in row major order. A TT Tensor of shape [W, Z, Y, X] can have TILE layout only if both X and Y are divisible by 2.

#Tensor of shape (2, 64, 64)

#batch=0
[    0,    1,    2, ...,   63,
    64,   65,   66, ...,  127,
    ...
  3968, 3969, 3970, ..., 4031,
  4032, 4033, 4034, ..., 4095 ]

#batch=1
[ 4096, 4097, 4098, ..., 4159,
  4160, 4161, 6462, ..., 4223,
    ...
  8064, 8065, 8066, ..., 8127,
  8128, 8129, 8130, ..., 8191 ]


#Stored in ROW_MAJOR layout
[0, 1, 2, ..., 63, 64, ..., 4095, 4096, 4097, 4098, ..., 4159, 4160, ..., 8191]

#Stored in TILE layout
[  0,    1, ...,   31,   64,   65, ...,   95, ..., 1984, 1985, ..., 2015, # first tile of batch=0
  32,   33, ...,   63,   96,   97, ...,  127, ..., 2016, 2017, ..., 2047, # second tile of batch=0
...
2080, 2081, ..., 2111, 2144, 2145, ..., 2175, ..., 4064, 4065, ..., 4095, # fourth (last) tile of batch=0

4096, ..., 6111,                                                           # first tile of batch=1
...
6176, ..., 8191 ]                                                          # fourth (last) tile of batch=0

Tensor Storage

Tensor class has 3 types of storages: OwnedStorage, BorrowedStorage and DeviceStorage. And it has a constructor for each type.

OwnedStorage is used to store the data in host DRAM. Every data type is stored in the vector corresponding to that data type. And the vector itself is stored in the shared pointer. That is done so that if the Tensor object is copied, the underlying storage is simply reference counted and not copied as well.

BorrowedStorage is used to borrow buffers from torch, numpy, etc

DeviceStorage is used to store the data in device DRAM or device L1. It also uses a shared pointer to store the underlying buffer. And the reason is also to allow for copying Tensor objects without having to copy the underlying storage.

Tensor API

class ttnn.Tensor

Class constructor supports tensors of rank 4. The constructor takes following arguments:

Argument

Description

Data type

Valid range

Required

data

Data to store in TT tensor

List[float/int]

Yes

shape

Shape of TT tensor

List[int[4]]

Yes

data_type

Data type of numbers in TT tensor

ttnn.DataType

ttnn.DataType.BFLOAT16

ttnn.DataType.FLOAT32

ttnn.DataType.UINT32

ttnn.DataType.BFLOAT8_B

ttnn.DataType.BFLOAT4_B

Yes

layout

Layout of tensor data in memory

ttnn.Layout

ttnn.Layout.ROW_MAJOR

ttnn.Layout.TILE

Yes

device

Device on which tensor will be created

ttnn.Device

Host or TT accelerator device

No

mem_config

Layout of tensor in TT Accelerator device memory banks

ttnn.MemoryConfig

No

__init__(*args, **kwargs)

Overloaded function.

  1. __init__(self: ttnn._ttnn.tensor.Tensor, arg0: ttnn._ttnn.tensor.Tensor) -> None

  2. __init__(self: ttnn._ttnn.tensor.Tensor, data: list[float], shape: Annotated[list[int], FixedSize(4)], data_type: ttnn._ttnn.tensor.DataType, layout: ttnn._ttnn.tensor.Layout, tile: Optional[ttnn._ttnn.tensor.Tile] = None) -> None

    Argument

    Name

    arg0

    data

    arg1

    shape

    arg2

    data_type

    arg3

    layout

    Example of creating a TT Tensor on host:

    py_tensor = torch.randn((1, 1, 32, 32))
    ttnn.Tensor(
        py_tensor.reshape(-1).tolist(),
        py_tensor.size(),
        ttnn.DataType.BFLOAT16,
        ttnn.Layout.ROW_MAJOR,
    )
    
  3. __init__(self: ttnn._ttnn.tensor.Tensor, data: list[float], shape: Annotated[list[int], FixedSize(4)], data_type: ttnn._ttnn.tensor.DataType, layout: ttnn._ttnn.tensor.Layout, device: ttnn._ttnn.device.IDevice = None, tile: Optional[ttnn._ttnn.tensor.Tile] = None) -> None

    Argument

    Name

    arg0

    data

    arg1

    shape

    arg2

    data_type

    arg3

    layout

    arg3

    device

    Only BFLOAT16 (in ROW_MAJOR or TILE layout) and BFLOAT8_B, BFLOAT4_B (in TILE layout) are supported on device.

    Note that TT Tensor in ROW_MAJOR layout on TT Accelerator device must have size of last dimension divisble by 2.

    Example of creating a TT Tensor on TT accelerator device:

    py_tensor = torch.randn((1, 1, 32, 32))
    tt_device = ttnn.CreateDevice(0)
    // ...
    ttnn.Tensor(
        py_tensor.reshape(-1).tolist(),
        py_tensor.size(),
        ttnn.DataType.BFLOAT16,
        ttnn.Layout.ROW_MAJOR,
        tt_device
    )
    
  4. __init__(self: ttnn._ttnn.tensor.Tensor, data: list[float], shape: Annotated[list[int], FixedSize(4)], data_type: ttnn._ttnn.tensor.DataType, layout: ttnn._ttnn.tensor.Layout, device: ttnn._ttnn.device.IDevice = None, memory_config: ttnn._ttnn.tensor.MemoryConfig, tile: Optional[ttnn._ttnn.tensor.Tile] = None) -> None

    Argument

    Name

    arg0

    data

    arg1

    shape

    arg2

    data_type

    arg3

    layout

    arg4

    device

    arg5

    mem_config

    Only BFLOAT16 (in ROW_MAJOR or TILE layout) and BFLOAT8_B, BFLOAT4_B (in TILE layout) are supported on device.

    Note that TT Tensor in ROW_MAJOR layout on TT Accelerator device must have size of last dimension divisble by 2.

    Example of creating a TT Tensor on TT accelerator device with specified mem_config:

    py_tensor = torch.randn((1, 1, 32, 32))
    tt_device = ttnn.CreateDevice(0)
    mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.SINGLE_BANK)
    // ...
    ttnn.Tensor(
        py_tensor.reshape(-1).tolist(),
        py_tensor.size(),
        ttnn.DataType.BFLOAT16,
        ttnn.Layout.ROW_MAJOR,
        tt_device,
        mem_config
    )
    
  5. __init__(self: ttnn._ttnn.tensor.Tensor, tensor: object, data_type: Optional[ttnn._ttnn.tensor.DataType] = None, strategy: dict[str, str] = {}, tile: Optional[ttnn._ttnn.tensor.Tile] = None) -> None

    Argument

    Description

    tensor

    Pytorch or Numpy Tensor

    data_type

    TT Tensor data type

    tile

    TT Tile Spec

    Example of creating a TT Tensor that uses torch.Tensor’s storage as its own storage:

    py_tensor = torch.randn((1, 1, 32, 32))
    ttnn.Tensor(py_tensor)
    
  6. __init__(self: ttnn._ttnn.tensor.Tensor, tensor: object, data_type: Optional[ttnn._ttnn.tensor.DataType] = None, device: ttnn._ttnn.device.IDevice = None, layout: ttnn._ttnn.tensor.Layout = <Layout.ROW_MAJOR: 0>, mem_config: ttnn._ttnn.tensor.MemoryConfig = MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt), tile: Optional[ttnn._ttnn.tensor.Tile] = None) -> None

    Argument

    Description

    tensor

    Pytorch or Numpy Tensor

    data_type

    TT Tensor data type

    device

    TT device ptr

    layout

    TT layout

    mem_config

    TT memory_config

    tile

    TT Tile Spec

    Example of creating a TT Tensor from numpy tensor:

    device = ttnn.open_device(device_id=0)
    py_tensor = np.zeros((1, 1, 32, 32))
    ttnn.Tensor(py_tensor, ttnn.bfloat16, device, ttnn.TILE_LAYOUT)
    
buffer(self: ttnn._ttnn.tensor.Tensor) ttnn._ttnn.tensor.owned_buffer_for_uint8_t | ttnn._ttnn.tensor.owned_buffer_for_uint16_t | ttnn._ttnn.tensor.owned_buffer_for_int32_t | ttnn._ttnn.tensor.owned_buffer_for_uint32_t | ttnn._ttnn.tensor.owned_buffer_for_float32_t | ttnn._ttnn.tensor.owned_buffer_for_bfloat16_t | ttnn._ttnn.tensor.borrowed_buffer_for_uint8_t | ttnn._ttnn.tensor.borrowed_buffer_for_uint16_t | ttnn._ttnn.tensor.borrowed_buffer_for_int32_t | ttnn._ttnn.tensor.borrowed_buffer_for_uint32_t | ttnn._ttnn.tensor.borrowed_buffer_for_float32_t | ttnn._ttnn.tensor.borrowed_buffer_for_bfloat16_t

Get the underlying buffer.

The tensor must be on the cpu when calling this function.

buffer = tt_tensor.cpu().buffer() # move TT Tensor to host and get the buffer
device(self: ttnn._ttnn.tensor.Tensor) ttnn._ttnn.device.IDevice

Get the device of the tensor.

device = tt_tensor.device()
get_dtype(self: ttnn._ttnn.tensor.Tensor) ttnn._ttnn.tensor.DataType

Get dtype of TT Tensor.

dtype = tt_tensor.get_dtype()
get_layout(self: ttnn._ttnn.tensor.Tensor) ttnn._ttnn.tensor.Layout

Get memory layout of TT Tensor.

layout = tt_tensor.get_layout()
pad(self: ttnn._ttnn.tensor.Tensor, arg0: Annotated[list[int], FixedSize(4)], arg1: Annotated[list[int], FixedSize(4)], arg2: float) ttnn._ttnn.tensor.Tensor

Pad TT Tensor with given pad value arg2.

The input tensor must be on host and in ROW_MAJOR layout.

Returns an output tensor that contains the input tensor at the given input tensor start indices arg1 and the padded value everywhere else.

Argument

Description

Data type

Valid range

Required

arg0

Shape of output tensor

List[int[4]]

Yes

arg1

Start indices to place input tensor in output tensor

List[int[4]]

Values along each dim must be

<= (output_tensor_shape[i] - input_tensor_shape[i])

Yes

arg2

Value to pad input tensor

float

Yes

input_tensor_shape = [1, 1, 3, 3]
output_tensor_shape = [1, 2, 5, 5]
input_tensor_start = [0, 1, 1, 1]
pad_value = 0

inp = torch.Tensor(
    [ 1, 2, 3,
    4, 5, 6,
    7, 8, 9 ]
)
tt_tensor = ttnn.Tensor(
    inp.tolist(),
    input_tensor_shape,
    ttnn.DataType.BFLOAT16,
    ttnn.Layout.ROW_MAJOR,
)
tt_tensor_padded = tt_tensor.pad(output_tensor_shape, input_tensor_start, pad_value)

print("Input tensor:")
print(tt_tensor)
print("\nPadded tensor:")
print(tt_tensor_padded)

Example output:

Input tensor:
[ [[[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]]] dtype=bfloat16 ]

Padded tensor:
[ [[[0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0]],

    [[0, 0, 0, 0, 0],
    [0, 1, 2, 3, 0],
    [0, 4, 5, 6, 0],
    [0, 7, 8, 9, 0],
    [0, 0, 0, 0, 0]]] dtype=bfloat16 ]
pad_to_tile(self: ttnn._ttnn.tensor.Tensor, arg0: float) ttnn._ttnn.tensor.Tensor

Pads TT Tensor with given pad value arg0.

The input tensor must be on host and in ROW_MAJOR layout.

Returns an output tensor that contains the input tensor padded with the padded value in the last two dims to multiples of 32.

Padding will be added to the right and bottom of the tensor.

Argument

Description

Data type

Valid range

Required

arg0

Value to pad input tensor

float

Yes

input_tensor_shape = [1, 1, 3, 3]
pad_value = 0

inp = torch.Tensor(
    [ 1, 2, 3,
    4, 5, 6,
    7, 8, 9 ]
)
tt_tensor = ttnn.Tensor(
    inp.tolist(),
    input_tensor_shape,
    ttnn.DataType.BFLOAT16,
    ttnn.Layout.ROW_MAJOR,
)
tt_tensor_padded = tt_tensor.pad_to_tile(pad_value)

print("Input tensor:")
print(tt_tensor)
print("\nPadded tensor:")
print(tt_tensor_padded)

Example output:

Input tensor:
[ [[[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]]] dtype=bfloat16 ]

Padded tensor:
[ [[[1, 2, 3, 0, ..., 0],
    [4, 5, 6, 0, ..., 0],
    [7, 8, 9, 0, ..., 0],
    [0, 0, 0, 0, ..., 0],
    ...,
    [0, 0, 0, 0, ..., 0]]] dtype=bfloat16 ]
storage_type(self: ttnn._ttnn.tensor.Tensor) ttnn._ttnn.tensor.StorageType

Check if the tensor is on host

storage_type = tt_tensor.storage_type()
to(*args, **kwargs)

Overloaded function.

  1. to(self: ttnn._ttnn.tensor.Tensor, device: ttnn._ttnn.device.IDevice, mem_config: ttnn._ttnn.tensor.MemoryConfig = MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt), cq_id: int = 0) -> ttnn._ttnn.tensor.Tensor

    Move TT Tensor from host device to TT accelerator device.

    Only BFLOAT16 (in ROW_MAJOR or TILE layout) and BFLOAT8_B, BFLOAT4_B (in TILE layout) are supported on device.

    If arg1 is not supplied, default MemoryConfig with interleaved set to True.

    Argument

    Description

    Data type

    Valid range

    Required

    arg0

    Device to which tensor will be moved

    ttnn.Device

    TT accelerator device

    Yes

    arg1

    MemoryConfig of tensor of TT accelerator device

    ttnn.MemoryConfig

    No

    arg2

    CQ ID of TT accelerator device to use

    uint8_t

    No

    tt_tensor = tt_tensor.to(tt_device)
    
  2. to(self: ttnn._ttnn.tensor.Tensor, mesh_device: ttnn._ttnn.multi_device.MeshDevice, mem_config: ttnn._ttnn.tensor.MemoryConfig = MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt), cq_id: int = 0) -> ttnn._ttnn.tensor.Tensor

    Move TT Tensor from host device to TT accelerator device.

    Only BFLOAT16 (in ROW_MAJOR or TILE layout) and BFLOAT8_B, BFLOAT4_B (in TILE layout) are supported on device.

    If arg1 is not supplied, default MemoryConfig with interleaved set to True.

    Argument

    Description

    Data type

    Valid range

    Required

    arg0

    MeshDevice to which tensor will be moved

    ttnn.MeshDevice

    TT accelerator device

    Yes

    arg1

    MemoryConfig of tensor of TT accelerator device

    ttnn.MemoryConfig

    No

    arg2

    CQ ID of TT accelerator device to use

    uint8_t

    No

    tt_tensor = tt_tensor.to(tt_device)
    
  3. to(self: ttnn._ttnn.tensor.Tensor, target_layout: ttnn._ttnn.tensor.Layout, worker: ttnn._ttnn.device.IDevice = None) -> ttnn._ttnn.tensor.Tensor

    Convert TT Tensor to provided memory layout. Available layouts conversions are:

    • ROW_MAJOR to TILE

    • TILE to ROW_MAJOR

    Argument

    Description

    Data type

    Valid range

    Required

    arg0

    Target memory layout

    ttnn.Layout

    ROW_MAJOR, TILE

    Yes

    arg1

    Worker thread performing layout conversion (optional)

    ttnn.Device

    Thread tied to TT accelerator device

    No

    tt_tensor = tt_tensor.to(ttnn.Layout.TILE, worker)
    
  4. to(self: ttnn._ttnn.tensor.Tensor, target_layout: ttnn._ttnn.tensor.Layout, mesh_device: ttnn._ttnn.multi_device.MeshDevice = None) -> ttnn._ttnn.tensor.Tensor

    Convert TT Tensor to provided memory layout. Available layouts conversions are:

    • ROW_MAJOR to TILE

    • TILE to ROW_MAJOR

    Argument

    Description

    Data type

    Valid range

    Required

    arg0

    Target memory layout

    ttnn.Layout

    ROW_MAJOR, TILE

    Yes

    arg1

    Worker thread performing layout conversion (optional)

    ttnn.Device

    Thread tied to TT accelerator device

    No

    tt_tensor = tt_tensor.to(ttnn.Layout.TILE, mesh_device)
    
unpad(self: ttnn._ttnn.tensor.Tensor, arg0: Annotated[list[int], FixedSize(4)], arg1: Annotated[list[int], FixedSize(4)]) ttnn._ttnn.tensor.Tensor

Unpad this TT Tensor.

This tensor must be on host and in ROW_MAJOR layout.

Returns an output tensor from output tensor start indices arg0 to output tensor end indices arg1 (inclusive) of the input tensor.

Argument

Description

Data type

Valid range

Required

arg0

Start indices of input tensor

List[int[4]]

Values along each dim must be

< input_tensor_shape[i] and <= output_tensor_end[i]

Yes

arg1

End indices of input tensor in output tensor

List[int[4]]

Values along each dim must be

< input_tensor_shape[i]

Yes

input_tensor_shape = [1, 1, 5, 5]
output_tensor_start = [0, 0, 1, 1]
output_tensor_end = [0, 0, 3, 3]

inp = torch.Tensor(
    [ 0, 0, 0, 0, 0,
    0, 1, 2, 3, 0,
    0, 4, 5, 6, 0,
    0, 7, 8, 9, 0,
    0, 0, 0, 0, 0 ]
)
tt_tensor = ttnn.Tensor(
    inp.tolist(),
    input_tensor_shape,
    ttnn.DataType.BFLOAT16,
    ttnn.Layout.ROW_MAJOR,
)
tt_tensor_unpadded = tt_tensor.unpad(output_tensor_start, output_tensor_end)

print("Input tensor:")
print(tt_tensor)
print("\nUnpadded tensor:")
print(tt_tensor_unpadded)

Example output:

Input tensor:
[ [[[0, 0, 0, 0, 0],
    [0, 1, 2, 3, 0],
    [0, 4, 5, 6, 0],
    [0, 7, 8, 9, 0],
    [0, 0, 0, 0, 0]]] dtype=bfloat16 ]

Unpadded tensor:
[ [[[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]]] dtype=bfloat16 ]
unpad_from_tile(self: ttnn._ttnn.tensor.Tensor, arg0: list[int]) ttnn._ttnn.tensor.Tensor

Unpads TT Tensor from given input tensor arg0.

The input tensor must be on host and in ROW_MAJOR layout.

This function expects the real data to aligned on the top left of the tensor.

Returns an output tensor with padding removed from the right and bottom of the input tensor.

Argument

Description

Data type

Valid range

Required

arg0

Shape of output tensor

List[int[4]]

All dims must match the input tensor dims apart from the last two dims.

Last two dims have the following restrictions:

input_tensor_shape[i] must be a multiple of 32

input_tensor_shape[i] - 32 < output_tensor_shape[i] <= input_tensor_shape[i]

Yes

input_tensor_shape = [1, 1, 32, 32]
output_tensor_shape = [1, 1, 3, 3]

inp = torch.arange(start=1.0, end=10.0).reshape(1, 1, 3, 3)
inp = torch.nn.functional.pad(inp, [0, input_tensor_shape[3] - inp.shape[3], 0, input_tensor_shape[2] - inp.shape[2]]).reshape(-1)
tt_tensor = ttnn.Tensor(
    inp.tolist(),
    input_tensor_shape,
    ttnn.DataType.BFLOAT16,
    ttnn.Layout.ROW_MAJOR,
)
tt_tensor_unpadded = tt_tensor.unpad_from_tile(output_tensor_shape)

print("Input tensor:")
print(tt_tensor)
print("\nUnpadded tensor:")
print(tt_tensor_unpadded)

Example output:

Input tensor:
[ [[[1, 2, 3, 0, ..., 0],
    [4, 5, 6, 0, ..., 0],
    [7, 8, 9, 0, ..., 0],
    [0, 0, 0, 0, ..., 0],
    ...,
    [0, 0, 0, 0, ..., 0]]] dtype=bfloat16 ]

Unpadded tensor:
[ [[[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]]] dtype=bfloat16 ]

MemoryConfig

class ttnn.MemoryConfig

Class defining memory configuration for storing tensor data on TT Accelerator device. There are eight DRAM memory banks on TT Accelerator device, indexed as 0, 1, 2, …, 7.

__init__(self: ttnn._ttnn.tensor.MemoryConfig, memory_layout: ttnn._ttnn.tensor.TensorMemoryLayout = <TensorMemoryLayout.INTERLEAVED: 0>, buffer_type: ttnn._ttnn.tensor.BufferType = <BufferType.DRAM: 0>, shard_spec: Optional[ttnn._ttnn.tensor.ShardSpec] = None) None

Create MemoryConfig class. If interleaved is set to True, tensor data will be interleaved across multiple DRAM banks on TT Accelerator device. Otherwise, tensor data will be stored in a DRAM bank selected by dram_channel (valid values are 0, 1, …, 7).

Example of creating MemoryConfig specifying that tensor data should be stored in DRAM bank 3.

mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.SINGLE_BANK)

Examples of converting between PyTorch Tensor and TT Tensor

Remember that TT Tensors must have rank 4.

Converting a PyTorch Tensor to a TT Tensor

This example shows how to create a TT Tensor tt_tensor from a PyTorch tensor. The created tensor will be in ROW_MAJOR layout and stored on TT accelerator device.

py_tensor = torch.randn((1, 1, 32, 32))
tt_tensor = ttnn.Tensor(py_tensor, ttnn.bfloat16).to(tt_device)

Converting a TT Tensor to a PyTorch Tensor

This example shows how to move a TT Tensor output from device to host and how to convert a TT Tensor to PyTorch tensor.

# move TT Tensor output from TT accelerator device to host
tt_output = tt_output.cpu()

# create PyTorch tensor from TT Tensor using to_torch() member function
py_output = tt_output.to_torch()