Using ttnn
Note
These basic snippets currently work on Grayskull only. We are working on updating the API for other architectures, like Wormhole.
Basic Examples
1. Converting from and to torch tensor
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import ttnn
torch_input_tensor = torch.zeros(2, 4, dtype=torch.float32)
tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16)
torch_output_tensor = ttnn.to_torch(tensor)
2. Running an operation on the device
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import ttnn
device_id = 0
device = ttnn.open_device(device_id=device_id)
torch_input_tensor_a = torch.rand(4, 7, dtype=torch.float32)
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.exp(input_tensor_a)
torch_output_tensor = ttnn.to_torch(output_tensor)
torch_input_tensor_b = torch.rand(7, 1, dtype=torch.float32)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
matmul_output_tensor = input_tensor_a @ input_tensor_b
torch_matmul_output_tensor = ttnn.to_torch(matmul_output_tensor)
print(torch_matmul_output_tensor)
ttnn.close_device(device)
3. Using __getitem__ to slice the tensor
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
# Note that this not a view, unlike torch tensor
import torch
import ttnn
device_id = 0
device = ttnn.open_device(device_id=device_id)
torch_input_tensor = torch.rand(3, 96, 128, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = input_tensor[:1, 32:64, 32:64] # this particular slice will run on the device
torch_output_tensor = ttnn.to_torch(output_tensor)
ttnn.close_device(device)
4. Enabling program cache
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import ttnn
import time
device_id = 0
device = ttnn.open_device(device_id=device_id)
ttnn.enable_program_cache(device)
torch_input_tensor = torch.rand(2, 4, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# Running the first time will compile the program and cache it
start_time = time.time()
output_tensor = ttnn.exp(input_tensor)
torch_output_tensor = ttnn.to_torch(output_tensor)
end_time = time.time()
duration = end_time - start_time
print(f"duration of the first run: {duration}")
# stdout: duration of the first run: 0.6391518115997314
# Running the subsequent time will use the cached program
start_time = time.time()
output_tensor = ttnn.exp(input_tensor)
torch_output_tensor = ttnn.to_torch(output_tensor)
end_time = time.time()
duration = end_time - start_time
print(f"duration of the second run: {duration}")
# stdout: duration of the subsequent run: 0.0007393360137939453
ttnn.close_device(device)
5. Debugging intermediate tensors
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import ttnn
device_id = 0
device = ttnn.open_device(device_id=device_id)
torch_input_tensor = torch.rand(32, 32, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
with ttnn.manage_config("enable_comparison_mode", True):
with ttnn.manage_config(
"comparison_mode_pcc", 0.9998
): # This is optional in case default value of 0.9999 is too high
output_tensor = ttnn.exp(input_tensor)
torch_output_tensor = ttnn.to_torch(output_tensor)
ttnn.close_device(device)
6. Tracing the graph of operations
Note
This basic snippet is under construction, and may not work on all hardware architectures.
import torch
import ttnn
device_id = 0
device = ttnn.open_device(device_id=device_id)
with ttnn.tracer.trace():
torch_input_tensor = torch.rand(32, 32, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.exp(input_tensor)
torch_output_tensor = ttnn.to_torch(output_tensor)
ttnn.tracer.visualize(torch_output_tensor, file_name="exp_trace.svg")
ttnn.close_device(device)
7. Using tt_lib operation in ttnn
tt_lib operations are missing some of the features of ttnn operations such as graph tracing and in order to support these features, ttnn provides a different to call tt_lib operations that enabled the missing features.
tt_lib operations are missing some of the features of ttnn operations such as graph tracing and in order to support these features, ttnn provides a different to call tt_lib operations that enabled the missing features.
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import ttnn
device_id = 0
device = ttnn.open_device(device_id=device_id)
torch_input_tensor = torch.rand(1, 1, 2, 4, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.exp(input_tensor) # exp migrated to ttnn
torch_output_tensor = ttnn.to_torch(output_tensor)
ttnn.close_device(device)
8. Enabling Logging
# To print currently executing ttnn operations
export TTNN_CONFIG_OVERRIDES='{"enable_fast_runtime_mode": false, "enable_logging": true}'
# To print the currently executing ttnn and tt_lib operation and its input tensors to stdout
export TT_METAL_LOGGER_TYPES=Op
export TT_METAL_LOGGER_LEVEL=Debug
Logging is not a substitute for profiling. Please refer to Profiling ttnn Operations for instructions on how to profile operations.
Note
The logging is only available when compiling with CONFIG=assert or CONFIG=debug.
9. Supported Python Operators
Note
This basic snippet is under construction, and may not work on all hardware architectures.
import ttnn
input_tensor_a: ttnn.Tensor = ttnn.from_torch(torch.rand(2, 4), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b: ttnn.Tensor = ttnn.from_torch(torch.rand(2, 4), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# Add (supports broadcasting)
input_tensor_a + input_tensor_b
# Subtract (supports broadcasting)
input_tensor_a - input_tensor_b
# Multiply (supports broadcasting)
input_tensor_a - input_tensor_b
# Matrix Multiply
input_tensor_a @ input_tensor_b
# Equals
input_tensor_a == input_tensor_b
# Not equals
input_tensor_a != input_tensor_b
# Greater than
input_tensor_a > input_tensor_b
# Greater than or equals
input_tensor_a >= input_tensor_b
# Less than
input_tensor_a < input_tensor_b
# Less than or equals
input_tensor_a <= input_tensor_b
10. Changing the string representation of the tensor
Note
This basic snippet is under construction, and may not work on all hardware architectures.
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import ttnn
# Profile can be set to "empty", "short" or "full"
ttnn.set_printoptions(profile="full")
11. Visualize using Web Browser
Note
This basic snippet is under construction, and may not work on all hardware architectures.
Set the following environment variables as needed
# enable_fast_runtime_mode - This has to be disabled to enable logging
# enable_logging - Synchronize main thread after every operation and log the operation
# report_name (optional) - Name of the report used by the visualizer. If not provided, then no data will be dumped to disk
# enable_detailed_buffer_report (if report_name is set) - Enable to visualize the detailed buffer report after every operation
# enable_graph_report (if report_name is set) - Enable to visualize the graph after every operation
# enable_detailed_tensor_report (if report_name is set) - Enable to visualize the values of input and output tensors of every operation
# enable_comparison_mode (if report_name is set) - Enable to test the output of operations against their golden implementaiton
# If running a pytest that is located inside of tests/ttnn, use this config (unless you want to override "report_name" manually)
export TTNN_CONFIG_OVERRIDES='{
"enable_fast_runtime_mode": false,
"enable_logging": true,
"enable_graph_report": false,
"enable_detailed_buffer_report": false,
"enable_detailed_tensor_report": false,
"enable_comparison_mode": false
}'
# Otherwise, use this config and make sure to set "report_name"
export TTNN_CONFIG_OVERRIDES='{
"enable_fast_runtime_mode": false,
"enable_logging": true,
"report_name": "<name of the run in the visualizer>",
"enable_graph_report": false,
"enable_detailed_buffer_report": false,
"enable_detailed_tensor_report": false,
"enable_comparison_mode": false
}'
# Additionally, a json file can be used to override the config values
export TTNN_CONFIG_PATH=<path to the file>
Run the code. i.e.:
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import ttnn
device_id = 0
device = ttnn.open_device(device_id=device_id)
torch_input_tensor_a = torch.rand(2048, 2048, dtype=torch.float32)
torch_input_tensor_b = torch.rand(2048, 2048, dtype=torch.float32)
input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(input_tensor_a)
ttnn.deallocate(input_tensor_b)
torch_output_tensor = ttnn.to_torch(output_tensor)
ttnn.deallocate(output_tensor)
ttnn.close_device(device)
Open the visualizer by running the following command:
python ttnn/visualizer/app.py
12. Register pre- and/or post-operation hooks
Note
This basic snippet is under construction, and may not work on all hardware architectures.
import torch
import ttnn
device_id = 0
device = ttnn.open_device(device_id=device_id)
torch_input_tensor = torch.rand((1, 32, 64), dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
def pre_hook_to_print_args_and_kwargs(operation, args, kwargs):
print(f"Pre-hook called for {operation}. Args: {args}, kwargs: {kwargs}")
def post_hook_to_print_output(operation, args, kwargs, output):
print(f"Post-hook called for {operation}. Output: {output}")
with ttnn.register_pre_operation_hook(pre_hook_to_print_args_and_kwargs), ttnn.register_post_operation_hook(post_hook_to_print_output):
ttnn.exp(input_tensor) * 2 + 1
ttnn.close_device(device)
13. Query all operations
import ttnn
ttnn.query_registered_operations()
14. Falling back to torch
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import ttnn
torch_input_tensor = torch.zeros(2, 4, dtype=torch.float32)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16)
# Recommended approach
tensor = ttnn.to_torch(input_tensor)
tensor = torch.nn.functional.silu(tensor)
output_tensor = ttnn.from_torch(tensor, dtype=ttnn.bfloat16)
# Alternative approach that only works with some operations
output_tensor = ttnn.get_fallback_function(ttnn.silu)(input_tensor)
15. Capturing graph of C++ functions, buffer allocations, etc
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import ttnn
torch.manual_seed(0)
size = 128
scalar = 3
with ttnn.manage_device(0) as device:
torch_input_tensor = torch.rand((size,), dtype=torch.bfloat16)
ttnn.graph.begin_graph_capture(ttnn.graph.RunMode.NO_DISPATCH)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = input_tensor + scalar
output_tensor = ttnn.to_torch(output_tensor, torch_rank=1)
captured_graph = ttnn.graph.end_graph_capture()
ttnn.graph.pretty_print(captured_graph)
ttnn.graph.visualize(captured_graph, file_name="graph.svg")