Using TT-NN

Note

These basic snippets currently work on Grayskull only. We are working on updating the API for other architectures, like Wormhole.

Note

If you are using a wheel or a Docker Release Image, you will need to install Pytorch for these examples to work. `sh pip install torch `

Basic Examples

1. Converting from and to torch tensor

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import torch
import ttnn

torch_input_tensor = torch.zeros(2, 4, dtype=torch.float32)
tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16)
torch_output_tensor = ttnn.to_torch(tensor)

2. Running an operation on the device

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

torch_input_tensor_a = torch.rand(4, 7, dtype=torch.float32)
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.exp(input_tensor_a)
torch_output_tensor = ttnn.to_torch(output_tensor)

torch_input_tensor_b = torch.rand(7, 1, dtype=torch.float32)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
matmul_output_tensor = input_tensor_a @ input_tensor_b
torch_matmul_output_tensor = ttnn.to_torch(matmul_output_tensor)

print(torch_matmul_output_tensor)

ttnn.close_device(device)

3. Using __getitem__ to slice the tensor

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


# Note that this not a view, unlike torch tensor

import torch
import ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

torch_input_tensor = torch.rand(3, 96, 128, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = input_tensor[:1, 32:64, 32:64]  # this particular slice will run on the device
torch_output_tensor = ttnn.to_torch(output_tensor)

ttnn.close_device(device)

4. Enabling program cache

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import torch
import ttnn
import time

device_id = 0
device = ttnn.open_device(device_id=device_id)

device.enable_program_cache()

torch_input_tensor = torch.rand(2, 4, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

# Running the first time will compile the program and cache it
start_time = time.time()
output_tensor = ttnn.exp(input_tensor)
torch_output_tensor = ttnn.to_torch(output_tensor)
end_time = time.time()
duration = end_time - start_time
print(f"duration of the first run: {duration}")
# stdout: duration of the first run: 0.6391518115997314

# Running the subsequent time will use the cached program
start_time = time.time()
output_tensor = ttnn.exp(input_tensor)
torch_output_tensor = ttnn.to_torch(output_tensor)
end_time = time.time()
duration = end_time - start_time
print(f"duration of the second run: {duration}")
# stdout: duration of the subsequent run: 0.0007393360137939453

ttnn.close_device(device)

5. Debugging intermediate tensors

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import torch
import ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

torch_input_tensor = torch.rand(32, 32, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
with ttnn.manage_config("enable_comparison_mode", True):
    with ttnn.manage_config(
        "comparison_mode_pcc", 0.9998
    ):  # This is optional in case default value of 0.9999 is too high
        output_tensor = ttnn.exp(input_tensor)
torch_output_tensor = ttnn.to_torch(output_tensor)

ttnn.close_device(device)

6. Tracing the graph of operations

Note

This basic snippet is under construction, and may not work on all hardware architectures.

import torch
import ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

with ttnn.tracer.trace():
    torch_input_tensor = torch.rand(32, 32, dtype=torch.float32)
    input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    output_tensor = ttnn.exp(input_tensor)
    torch_output_tensor = ttnn.to_torch(output_tensor)
ttnn.tracer.visualize(torch_output_tensor, file_name="exp_trace.svg")

ttnn.close_device(device)

7. Using tt_lib operation in TT-NN

tt_lib operations are missing some of the features of TT-NN operations such as graph tracing and in order to support these features, TT-NN provides a different to call tt_lib operations that enabled the missing features.

tt_lib operations are missing some of the features of TT-NN operations such as graph tracing and in order to support these features, TT-NN provides a different to call tt_lib operations that enabled the missing features.

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import torch
import ttnn


device_id = 0
device = ttnn.open_device(device_id=device_id)

torch_input_tensor = torch.rand(1, 1, 2, 4, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.exp(input_tensor)  # exp migrated to ttnn
torch_output_tensor = ttnn.to_torch(output_tensor)

ttnn.close_device(device)

8. Enabling Logging

# To print currently executing TT-NN operations
export TTNN_CONFIG_OVERRIDES='{"enable_fast_runtime_mode": false, "enable_logging": true}'

# To print the currently executing TT-NN and tt_lib operation and its input tensors to stdout
export TT_LOGGER_TYPES=Op
export TT_LOGGER_LEVEL=Debug

Logging is not a substitute for profiling. Please refer to Profiling TT-NN Operations for instructions on how to profile operations.

Note

The logging is only available when compiling with CONFIG=assert or CONFIG=debug.

9. Supported Python Operators

Note

This basic snippet is under construction, and may not work on all hardware architectures.

import ttnn

input_tensor_a: ttnn.Tensor = ttnn.from_torch(torch.rand(2, 4), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b: ttnn.Tensor = ttnn.from_torch(torch.rand(2, 4), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

# Add (supports broadcasting)
input_tensor_a + input_tensor_b

# Subtract (supports broadcasting)
input_tensor_a - input_tensor_b

# Multiply (supports broadcasting)
input_tensor_a - input_tensor_b

# Matrix Multiply
input_tensor_a @ input_tensor_b

# Equals
input_tensor_a == input_tensor_b

# Not equals
input_tensor_a != input_tensor_b

# Greater than
input_tensor_a > input_tensor_b

# Greater than or equals
input_tensor_a >= input_tensor_b

# Less than
input_tensor_a < input_tensor_b

# Less than or equals
input_tensor_a <= input_tensor_b

10. Changing the string representation of the tensor

Note

This basic snippet is under construction, and may not work on all hardware architectures.

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import ttnn

# Profile can be set to "empty", "short" or "full"

ttnn.set_printoptions(profile="full")

11. Visualize using Web Browser

Note

This basic snippet is under construction, and may not work on all hardware architectures.

Set the following environment variables as needed

# enable_fast_runtime_mode - This has to be disabled to enable logging
# enable_logging - Synchronize main thread after every operation and log the operation
# report_name (optional) - Name of the report used by the visualizer. If not provided, then no data will be dumped to disk
# enable_detailed_buffer_report (if report_name is set) - Enable to visualize the detailed buffer report after every operation
# enable_graph_report (if report_name is set) - Enable to visualize the graph after every operation
# enable_detailed_tensor_report (if report_name is set) - Enable to visualize the values of input and output tensors of every operation
# enable_comparison_mode (if report_name is set) - Enable to test the output of operations against their golden implementaiton


 # If running a pytest that is located inside of tests/ttnn, use this config (unless you want to override "report_name" manually)
export TTNN_CONFIG_OVERRIDES='{
    "enable_fast_runtime_mode": false,
    "enable_logging": true,
    "enable_graph_report": false,
    "enable_detailed_buffer_report": false,
    "enable_detailed_tensor_report": false,
    "enable_comparison_mode": false
}'

# Otherwise, use this config and make sure to set "report_name"
export TTNN_CONFIG_OVERRIDES='{
    "enable_fast_runtime_mode": false,
    "enable_logging": true,
    "report_name": "<name of the run in the visualizer>",
    "enable_graph_report": false,
    "enable_detailed_buffer_report": false,
    "enable_detailed_tensor_report": false,
    "enable_comparison_mode": false
}'

# Additionally, a json file can be used to override the config values
export TTNN_CONFIG_PATH=<path to the file>

Run the code. i.e.:

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import torch
import ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

torch_input_tensor_a = torch.rand(2048, 2048, dtype=torch.float32)
torch_input_tensor_b = torch.rand(2048, 2048, dtype=torch.float32)
input_tensor_a = ttnn.from_torch(
    torch_input_tensor_a,
    dtype=ttnn.bfloat16,
    layout=ttnn.TILE_LAYOUT,
    device=device,
    memory_config=ttnn.L1_MEMORY_CONFIG,
)
input_tensor_b = ttnn.from_torch(
    torch_input_tensor_b,
    dtype=ttnn.bfloat16,
    layout=ttnn.TILE_LAYOUT,
    device=device,
    memory_config=ttnn.L1_MEMORY_CONFIG,
)

output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(input_tensor_a)
ttnn.deallocate(input_tensor_b)

torch_output_tensor = ttnn.to_torch(output_tensor)
ttnn.deallocate(output_tensor)

ttnn.close_device(device)

Open the visualizer by running the following command:

python ttnn/visualizer/app.py

12. Register pre- and/or post-operation hooks

Note

This basic snippet is under construction, and may not work on all hardware architectures.

import torch
import ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

torch_input_tensor = torch.rand((1, 32, 64), dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

def pre_hook_to_print_args_and_kwargs(operation, args, kwargs):
    print(f"Pre-hook called for {operation}. Args: {args}, kwargs: {kwargs}")

def post_hook_to_print_output(operation, args, kwargs, output):
    print(f"Post-hook called for {operation}. Output: {output}")

with ttnn.register_pre_operation_hook(pre_hook_to_print_args_and_kwargs), ttnn.register_post_operation_hook(post_hook_to_print_output):
    ttnn.exp(input_tensor) * 2 + 1

ttnn.close_device(device)

13. Query all operations

import ttnn
ttnn.query_registered_operations()

14. Falling back to torch


# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import torch
import ttnn

torch_input_tensor = torch.zeros(2, 4, dtype=torch.float32)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16)

# Recommended approach
tensor = ttnn.to_torch(input_tensor)
tensor = torch.nn.functional.silu(tensor)
output_tensor = ttnn.from_torch(tensor, dtype=ttnn.bfloat16)

# Alternative approach that only works with some operations
output_tensor = ttnn.get_fallback_function(ttnn.silu)(input_tensor)

15. Capturing graph of C++ functions, buffer allocations, etc

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn


torch.manual_seed(0)

size = 128
scalar = 3

with ttnn.manage_device(0) as device:
    torch_input_tensor = torch.rand((size,), dtype=torch.bfloat16)

    ttnn.graph.begin_graph_capture(ttnn.graph.RunMode.NO_DISPATCH)
    input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
    output_tensor = input_tensor + scalar
    output_tensor = ttnn.to_torch(output_tensor, torch_rank=1)
    captured_graph = ttnn.graph.end_graph_capture()

    ttnn.graph.pretty_print(captured_graph)

    ttnn.graph.visualize(captured_graph, file_name="graph.svg")