Adding New TT-NN Operation
Note
This document is meant for contributors to TT-NN.
Not all operations may be functional on all Tenstorrent hardware (Grayskull, Wormhole, or others).
FAQ
What is a TT-NN operation?
A TT-NN operation is a function that takes in one or more input tensors and produces one or more output tensors. It is implemented in C++ and can be called from Python.
What steps are needed to add TT-NN operation in C++?
There are 2 options for writing a new operation. Option
ais to write a device operation and optionbis to write an operation that calls other operations a. Implement device operation in C++. Device operation is a struct that satisfies DeviceOperationConcept and specifies how to create output tensors and a program to run on the device. b. Implement an operation in C++ that calls other operations. This type of operation simply defines aninvoke()method that calls other operations.Register the struct using ttnn::register_operation.
What steps are needed to add TT-NN operation in Python?
Take an existing registered C++ operation and add a Python binding for it using ttnn::bind_registered_operation. The operation will be auto-registered in python. If the operation is called ttnn::add in C++, then the python binding will be ttnn.add.
(Optional) Attach golden function to the operation using ttnn.attach_golden_function. This is useful for debugging and testing.
Example of Adding a new Device Operation
Let’s implement ttnn.example (It will just copy the input tensor to the output tensor on the device)
C++ Implementation
Step 1: Implement device operation
In order to add a new device operation, follow the directory structure shown below:
ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<operation_name>_device_operation.hpp ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<operation_name>_device_operation.cpp ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<program_factory_0>_program_factory.cpp
Note
Add as many program factories as needed. But the minimum requirement is one program factory.
A concrete example of a device operation can be found in ttnn/cpp/ttnn/operations/examples/example/device
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#pragma once
6
7#include <optional>
8#include <variant>
9
10#include "ttnn/tensor/tensor.hpp"
11#include "ttnn/core.hpp"
12#include "ttnn/device_operation.hpp"
13#include "ttnn/types.hpp"
14#include "ttnn/decorators.hpp"
15
16namespace ttnn::operations::examples {
17
18struct ExampleDeviceOperation {
19 // Define the operation attributes. This is it to store all variables needed by operations that aren't tensors
20 struct operation_attributes_t {
21 bool attribute;
22 int some_other_attribute;
23 };
24
25 // Define the tensor arguments. This is it to store all tensors passed in and/or out of the operation
26 // Tensor arguments don't need to be just input tensors, they can be output tensors, input/output tensors, optional
27 // tensors, etc.
28 struct tensor_args_t {
29 // This example will use a tensor that can only be used as an input
30 const Tensor& input_tensor;
31
32 // However, the following examples show what else can be done with tensor_args_t
33
34 // An example of the tensor that can be used for input/output or just for pre-allocated output
35 // Tensor& io_tensor;
36
37 // An example of an optional tensor
38 // std::optional<Tensor> optional_output_tensor;
39
40 // An example of a vector of tensors
41 // std::vector<Tensor> vector_of_tensors;
42
43 // An example of a tuple of tensors
44 // std::tuple<Tensor, ...> tuple_of_tensors;
45
46 // An example of a vector of optional tensors
47 // std::vector<std::optional<Tensor>> vector_of_optional_tensors;
48
49 // An example of a tuple of tensors
50 // std::tuple<std::vector<std::optional<Tensor>>, std::optional<Tensor>> some_crazy_tuple_of_tensors;
51 };
52
53 // Define the return types for the spec(s) of the operation
54 // Can be a single ttnn::TensorSpec, std::optional<ttnn::TensorSpec>, std::vector<ttnn::TensorSpec>,
55 // std::tuple<ttnn::TensorSpec> etc.
56 using spec_return_value_t = ttnn::TensorSpec;
57
58 // Define the return types for the tensor(s) of the operation
59 // Can be a single Tensor, std::optional<Tensor, ...>, std::vector<Tensor>, std::tuple<Tensor, ...> etc.
60 using tensor_return_value_t = Tensor;
61
62 // Note spec_return_value_t and tensor_return_value_t should follow the same pattern
63 // i.e. if spec_return_value_t is a std::vector<std::optional<ttnn::TensorSpec>> then tensor_return_value_t should
64 // be std::vector<std::optional<Tensor>>
65
66 struct SingleCore {
67 // Shared variables are the variables that are shared between the create and override_runtime_arguments methods
68 struct shared_variables_t {
69 tt::tt_metal::KernelHandle unary_reader_kernel_id;
70 tt::tt_metal::KernelHandle unary_writer_kernel_id;
71 };
72 using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
73
74 static cached_program_t create(
75 const operation_attributes_t& operation_attributes,
76 const tensor_args_t& tensor_args,
77 tensor_return_value_t& tensor_return_value);
78
79 static void override_runtime_arguments(
80 cached_program_t& cached_program,
81 const operation_attributes_t& operation_attributes,
82 const tensor_args_t& tensor_args,
83 tensor_return_value_t& tensor_return_value);
84 };
85
86 struct MultiCore {
87 // Shared variables are the variables that are shared between the create and override_runtime_arguments methods
88 struct shared_variables_t {
89 tt::tt_metal::KernelHandle unary_reader_kernel_id;
90 tt::tt_metal::KernelHandle unary_writer_kernel_id;
91 std::size_t num_cores;
92 std::size_t num_cores_y;
93 };
94 using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
95
96 static cached_program_t create(
97 const operation_attributes_t& operation_attributes,
98 const tensor_args_t& tensor_args,
99 tensor_return_value_t& tensor_return_value);
100
101 static void override_runtime_arguments(
102 cached_program_t& cached_program,
103 const operation_attributes_t& operation_attributes,
104 const tensor_args_t& tensor_args,
105 tensor_return_value_t& tensor_return_value);
106 };
107
108 using program_factory_t = std::variant<SingleCore, MultiCore>;
109
110 // Mandatory methods
111
112 // Select the program factory based on the operation attributes and tensor args
113 static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
114
115 // Validate the operation when it creates a program. Usually will have more checks
116 static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
117
118 // Validate the operation when it reuses a program. Usually will have less checks
119 static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
120
121 // Compute the output specs based on the operation attributes and tensor args
122 static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
123
124 // Create the output tensors based on the operation attributes and tensor args
125 static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
126};
127
128} // namespace ttnn::operations::examples
129
130namespace ttnn::prim {
131ttnn::operations::examples::ExampleDeviceOperation::tensor_return_value_t example(const Tensor& input_tensor);
132} // namespace ttnn::prim
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#include "example_device_operation.hpp"
6#include "ttnn/device_operation.hpp"
7
8namespace ttnn::operations::examples {
9
10ExampleDeviceOperation::program_factory_t ExampleDeviceOperation::select_program_factory(
11 const operation_attributes_t& /*operation_attributes*/, const tensor_args_t& /*tensor_args*/) {
12 bool some_condition_based_on_operation_attributes_and_or_tensor_args = true;
13 if (some_condition_based_on_operation_attributes_and_or_tensor_args) {
14 return SingleCore{};
15 }
16 return MultiCore{};
17}
18
19void ExampleDeviceOperation::validate_on_program_cache_miss(
20 const operation_attributes_t& /*attributes*/, const tensor_args_t& /*tensor_args*/) {}
21
22void ExampleDeviceOperation::validate_on_program_cache_hit(
23 const operation_attributes_t& /*attributes*/, const tensor_args_t& /*tensor_args*/) {}
24
25ExampleDeviceOperation::spec_return_value_t ExampleDeviceOperation::compute_output_specs(
26 const operation_attributes_t&, const tensor_args_t& tensor_args) {
27 const auto& input_tensor = tensor_args.input_tensor;
28 return TensorSpec(
29 input_tensor.logical_shape(),
30 tt::tt_metal::TensorLayout(
31 input_tensor.dtype(), tt::tt_metal::PageConfig(input_tensor.layout()), MemoryConfig{}));
32}
33
34ExampleDeviceOperation::tensor_return_value_t ExampleDeviceOperation::create_output_tensors(
35 const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
36 auto output_spec = compute_output_specs(operation_attributes, tensor_args);
37 return create_device_tensor(output_spec, tensor_args.input_tensor.device());
38}
39
40} // namespace ttnn::operations::examples
41
42namespace ttnn::prim {
43ttnn::operations::examples::ExampleDeviceOperation::tensor_return_value_t example(const Tensor& input_tensor) {
44 using OperationType = ttnn::operations::examples::ExampleDeviceOperation;
45 auto operation_attributes = OperationType::operation_attributes_t{true, 42};
46 auto tensor_args = OperationType::tensor_args_t{input_tensor};
47
48 return ttnn::device_operation::launch<OperationType>(operation_attributes, tensor_args);
49}
50} // namespace ttnn::prim
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#include "example_device_operation.hpp"
6#include <tt-metalium/work_split.hpp>
7#include <tt-metalium/tensor_accessor_args.hpp>
8
9namespace ttnn::operations::examples {
10ExampleDeviceOperation::SingleCore::cached_program_t ExampleDeviceOperation::SingleCore::create(
11 const operation_attributes_t& /*operation_attributes*/,
12 const tensor_args_t& tensor_args,
13 tensor_return_value_t& tensor_return_value) {
14 using namespace tt;
15 using namespace tt::tt_metal;
16
17 const auto& input_tensor = tensor_args.input_tensor;
18 auto& output_tensor = tensor_return_value;
19
20 auto* src_buffer = input_tensor.buffer();
21 auto* dst_buffer = output_tensor.buffer();
22
23 tt::tt_metal::Program program{};
24
25 tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype());
26 uint32_t single_tile_size = tt::tile_size(cb_data_format);
27 tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.dtype());
28 uint32_t single_tile_size_output = tt::tile_size(cb_data_format_output);
29
30 uint32_t num_tiles = input_tensor.physical_volume() / tt::constants::TILE_HW;
31
32 CoreCoord compute_with_storage_grid_size = {1, 1};
33 uint32_t num_cores_y = compute_with_storage_grid_size.y;
34 auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
35 tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles);
36
37 uint32_t src0_cb_index = tt::CBIndex::c_0;
38 uint32_t num_input_tiles = 2;
39 tt::tt_metal::CircularBufferConfig cb_src0_config =
40 tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
41 .set_page_size(src0_cb_index, single_tile_size);
42 tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);
43
44 uint32_t output_cb_index = tt::CBIndex::c_2;
45 uint32_t num_output_tiles = 2;
46 tt::tt_metal::CircularBufferConfig cb_output_config =
47 tt::tt_metal::CircularBufferConfig(
48 num_output_tiles * single_tile_size_output, {{output_cb_index, cb_data_format_output}})
49 .set_page_size(output_cb_index, single_tile_size_output);
50 tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
51
52 std::vector<uint32_t> reader_compile_time_args;
53 tt::tt_metal::TensorAccessorArgs(*src_buffer).append_to(reader_compile_time_args);
54 std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index};
55 tt::tt_metal::TensorAccessorArgs(*dst_buffer).append_to(writer_compile_time_args);
56
57 tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
58 program,
59 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
60 all_cores,
61 tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
62
63 tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
64 program,
65 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp",
66 all_cores,
67 tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));
68
69 std::vector<uint32_t> compute_kernel_args_group_1 = {
70 num_tiles_per_core_group_1, // per_core_block_cnt
71 1 // per_core_block_size
72 };
73
74 bool math_approx_mode = false;
75 tt::tt_metal::CreateKernel(
76 program,
77 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
78 core_group_1,
79 tt::tt_metal::ComputeConfig{
80 .math_fidelity = MathFidelity::HiFi4,
81 .math_approx_mode = math_approx_mode,
82 .compile_args = compute_kernel_args_group_1});
83
84 if (!core_group_2.ranges().empty()) {
85 std::vector<uint32_t> compute_kernel_args_group_2 = {
86 num_tiles_per_core_group_2, // per_core_block_cnt
87 1 // per_core_block_size
88 };
89
90 tt::tt_metal::CreateKernel(
91 program,
92 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
93 core_group_2,
94 tt::tt_metal::ComputeConfig{
95 .math_fidelity = MathFidelity::HiFi4,
96 .math_approx_mode = math_approx_mode,
97 .compile_args = compute_kernel_args_group_2});
98 }
99
100 for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
101 CoreCoord core = {i / num_cores_y, i % num_cores_y};
102 uint32_t num_tiles_per_core = 0;
103 if (core_group_1.contains(core)) {
104 num_tiles_per_core = num_tiles_per_core_group_1;
105 } else if (core_group_2.contains(core)) {
106 num_tiles_per_core = num_tiles_per_core_group_2;
107 } else {
108 TT_ASSERT(false, "Core not in specified core ranges");
109 }
110
111 tt::tt_metal::SetRuntimeArgs(
112 program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});
113
114 tt::tt_metal::SetRuntimeArgs(
115 program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles_per_core, num_tiles_written});
116 num_tiles_written += num_tiles_per_core;
117 }
118
119 return {
120 std::move(program),
121 {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
122}
123
124void ExampleDeviceOperation::SingleCore::override_runtime_arguments(
125 cached_program_t& cached_program,
126 const operation_attributes_t& /*operation_attributes*/,
127 const tensor_args_t& tensor_args,
128 tensor_return_value_t& tensor_return_value) {
129 auto& program = cached_program.program;
130 auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id;
131 auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;
132
133 const auto& input_tensor = tensor_args.input_tensor;
134 auto& output_tensor = tensor_return_value;
135
136 auto* src_buffer = input_tensor.buffer();
137 auto* dst_buffer = output_tensor.buffer();
138
139 {
140 auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0});
141 runtime_args[0] = src_buffer->address();
142 }
143
144 {
145 auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0});
146 runtime_args[0] = dst_buffer->address();
147 }
148}
149
150} // namespace ttnn::operations::examples
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#include "example_device_operation.hpp"
6#include <tt-metalium/work_split.hpp>
7#include <tt-metalium/tensor_accessor_args.hpp>
8
9namespace ttnn::operations::examples {
10ExampleDeviceOperation::MultiCore::cached_program_t ExampleDeviceOperation::MultiCore::create(
11 const operation_attributes_t& /*operation_attributes*/,
12 const tensor_args_t& tensor_args,
13 tensor_return_value_t& tensor_return_value) {
14 using namespace tt;
15 using namespace tt::tt_metal;
16
17 const auto& input_tensor = tensor_args.input_tensor;
18 auto& output_tensor = tensor_return_value;
19
20 auto* src_buffer = input_tensor.buffer();
21 auto* dst_buffer = output_tensor.buffer();
22
23 tt::tt_metal::Program program{};
24
25 tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype());
26 uint32_t single_tile_size = tt::tile_size(cb_data_format);
27 tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.dtype());
28 uint32_t single_tile_size_output = tt::tile_size(cb_data_format_output);
29
30 uint32_t num_tiles = input_tensor.physical_volume() / tt::constants::TILE_HW;
31
32 tt::tt_metal::IDevice* device = input_tensor.device();
33
34 auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
35 uint32_t num_cores_y = compute_with_storage_grid_size.y;
36 auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
37 tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles);
38
39 uint32_t src0_cb_index = tt::CBIndex::c_0;
40 uint32_t num_input_tiles = 2;
41 tt::tt_metal::CircularBufferConfig cb_src0_config =
42 tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
43 .set_page_size(src0_cb_index, single_tile_size);
44 tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);
45
46 uint32_t output_cb_index = tt::CBIndex::c_2;
47 uint32_t num_output_tiles = 2;
48 tt::tt_metal::CircularBufferConfig cb_output_config =
49 tt::tt_metal::CircularBufferConfig(
50 num_output_tiles * single_tile_size_output, {{output_cb_index, cb_data_format_output}})
51 .set_page_size(output_cb_index, single_tile_size_output);
52 tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
53
54 std::vector<uint32_t> reader_compile_time_args;
55 tt::tt_metal::TensorAccessorArgs(*src_buffer).append_to(reader_compile_time_args);
56 std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index};
57 tt::tt_metal::TensorAccessorArgs(*dst_buffer).append_to(writer_compile_time_args);
58
59 tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
60 program,
61 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
62 all_cores,
63 tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
64
65 tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
66 program,
67 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp",
68 all_cores,
69 tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));
70
71 std::vector<uint32_t> compute_kernel_args_group_1 = {
72 num_tiles_per_core_group_1, // per_core_block_cnt
73 1 // per_core_block_size
74 };
75
76 bool math_approx_mode = false;
77 tt::tt_metal::CreateKernel(
78 program,
79 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
80 core_group_1,
81 tt::tt_metal::ComputeConfig{
82 .math_fidelity = MathFidelity::HiFi4,
83 .math_approx_mode = math_approx_mode,
84 .compile_args = compute_kernel_args_group_1});
85
86 if (!core_group_2.ranges().empty()) {
87 std::vector<uint32_t> compute_kernel_args_group_2 = {
88 num_tiles_per_core_group_2, // per_core_block_cnt
89 1 // per_core_block_size
90 };
91
92 tt::tt_metal::CreateKernel(
93 program,
94 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
95 core_group_2,
96 tt::tt_metal::ComputeConfig{
97 .math_fidelity = MathFidelity::HiFi4,
98 .math_approx_mode = math_approx_mode,
99 .compile_args = compute_kernel_args_group_2});
100 }
101
102 for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
103 CoreCoord core = {i / num_cores_y, i % num_cores_y};
104 uint32_t num_tiles_per_core = 0;
105 if (core_group_1.contains(core)) {
106 num_tiles_per_core = num_tiles_per_core_group_1;
107 } else if (core_group_2.contains(core)) {
108 num_tiles_per_core = num_tiles_per_core_group_2;
109 } else {
110 TT_ASSERT(false, "Core not in specified core ranges");
111 }
112
113 tt::tt_metal::SetRuntimeArgs(
114 program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});
115
116 tt::tt_metal::SetRuntimeArgs(
117 program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles_per_core, num_tiles_written});
118 num_tiles_written += num_tiles_per_core;
119 }
120
121 return {
122 std::move(program),
123 {.unary_reader_kernel_id = unary_reader_kernel_id,
124 .unary_writer_kernel_id = unary_writer_kernel_id,
125 .num_cores = num_cores,
126 .num_cores_y = num_cores_y}};
127}
128
129void ExampleDeviceOperation::MultiCore::override_runtime_arguments(
130 cached_program_t& cached_program,
131 const operation_attributes_t& /*operation_attributes*/,
132 const tensor_args_t& tensor_args,
133 tensor_return_value_t& tensor_return_value) {
134 auto& program = cached_program.program;
135 auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id;
136 auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;
137 auto& num_cores = cached_program.shared_variables.num_cores;
138 auto& num_cores_y = cached_program.shared_variables.num_cores_y;
139
140 const auto& input_tensor = tensor_args.input_tensor;
141 auto& output_tensor = tensor_return_value;
142
143 auto* src_buffer = input_tensor.buffer();
144 auto* dst_buffer = output_tensor.buffer();
145
146 for (uint32_t i = 0; i < num_cores; i++) {
147 CoreCoord core = {i / num_cores_y, i % num_cores_y};
148
149 {
150 auto& runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
151 runtime_args[0] = src_buffer->address();
152 }
153
154 {
155 auto& runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core);
156 runtime_args[0] = dst_buffer->address();
157 }
158 }
159}
160
161} // namespace ttnn::operations::examples
Step 2: Implement the operation in C++
In order to add a new operation, add the following file:
ttnn/cpp/ttnn/operations/<category>/<operation_name>/<operation_name>.hpp
A concrete example:
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#pragma once
6
7#include "device/example_device_operation.hpp"
8
9namespace ttnn::operations::examples {
10
11// A composite operation is an operation that calls multiple operations in sequence
12// It is written using invoke and can be used to call multiple primitive and/or composite operations
13struct CompositeExampleOperation {
14 // The user will be able to call this method as `Tensor output = ttnn::composite_example(input_tensor)` after the op
15 // is registered
16 static Tensor invoke(const Tensor& input_tensor) {
17 auto copy = prim::example(input_tensor);
18 auto another_copy = prim::example(copy);
19 return another_copy;
20 }
21};
22
23} // namespace ttnn::operations::examples
24
25namespace ttnn {
26constexpr auto composite_example =
27 ttnn::register_operation<"ttnn::composite_example", operations::examples::CompositeExampleOperation>();
28} // namespace ttnn
Python Implementation
Step 1: Add Python binding
In order to add a python binding for the operation, follow the directory structure shown below:
ttnn/python/ttnn/operations/<category>/<operation_name>/<operation_name>_nanobind.hpp ttnn/python/ttnn/operations/<category>/<category>_nanobind.hpp
A concrete example:
1// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#pragma once
6
7#include "ttnn-nanobind/nanobind_fwd.hpp"
8
9namespace ttnn::operations::examples {
10namespace nb = nanobind;
11void bind_example_operation(nb::module_& mod);
12} // namespace ttnn::operations::examples
1// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#pragma once
6
7#include "ttnn-nanobind/nanobind_fwd.hpp"
8
9namespace ttnn::operations::examples {
10
11namespace nb = nanobind;
12void py_module(nb::module_& mod);
13
14} // namespace ttnn::operations::examples
Finally, call the module defined in examples/example/example_nanobind.hpp wherever you want it to be added.
Step 2: (Optional) Add golden function for the operation in Python
A golden function can be added to an operation in order to compare its output with an equivalent torch implementation
Add the following code in a python file:
import ttnn
# For the golden function, use the same signature as the operation
# Keep in mind that all `ttnn.Tensor`s are converted to `torch.Tensor`s
# And arguments not needed by torch can be ignored using `*args` and `**kwargs`
def golden_function(input_tensor: "torch.Tensor", *args, **kwargs):
output_tensor: "torch.Tensor" = ...
return output_tensor
# TT-NN Tensors are converted to torch tensors before calling the golden function automatically
# And the outputs are converted back to TT-NN Tensors
# But in some cases you may need to preprocess the inputs and postprocess the outputs manually
# In order to preprocess the inputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def preprocess_golden_function_inputs(args, kwargs):
# i.e.
ttnn_input_tensor = args[0]
return ttnn.to_torch(ttnn_input_tensor)
# In order to postprocess the outputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def postprocess_golden_function_outputs(args, kwargs, output):
# i.e.
ttnn_input_tensor = args[0]
torch_output_tensor = outputs[0]
return ttnn.from_torch(torch_output_tensor, dtype=ttnn_input_tensor.dtype, device=ttnn_input_tensor.device)
ttnn.attach_golden_function(
ttnn.example,
golden_function=golden_function,
preprocess_golden_function_inputs=preprocess_golden_function_inputs, # Optional
postprocess_golden_function_outputs=postprocess_golden_function_outputs # Optional
)
Note
ttnn.example is the name of the operation in Python because the operation was registered as ttnn::example in C++.
Step 3: (Optional) Add example usage to docs
It is good practice to include an example demonstrating how to use the new function.
The simplest method is to add an Example section directly in the documentation passed to the bind_registered_operation function. However, this approach makes it difficult to keep the example up to date and prevents the snippet from being tested.
A better approach is to place the example code in a test file and have it included automatically during the documentation build process.
In the file examples_mapping.py, each function is mapped to an example usage snippet that will appear in its documentation.
Add the new operation to the FUNCTION_TO_EXAMPLES_MAPPING_DICT dictionary, as shown below:
FUNCTION_TO_EXAMPLES_MAPPING_DICT = {
...
"ttnn.example": example.test_example,
...
}
Place the example usage function in a new file named test_example_examples.py (or an existing file, if appropriate).
Make sure the file is imported at the top of examples_mapping.py:
# ...
from . import test_data_movement_examples as data_movement
from . import test_core_examples as core
# Import the new file
from . import test_example_examples as example
# ...
Implement the example as a standard ttnn pytest:
def test_example(device):
# Create tensor
tensor = ttnn.rand((2, 3), ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
# Call the new operation
output_tensor = ttnn.example(tensor)
This ensures that all example code snippets are executed and validated in the TT-NN CI pipeline.