Adding New TT-NN Operation
Note
This document is meant for contributors to TT-NN.
Not all operations may be functional on all Tenstorrent hardware (Grayskull, Wormhole, or others).
FAQ
What is a TT-NN operation?
A TT-NN operation is a function that takes in one or more input tensors and produces one or more output tensors. It is implemented in C++ and can be called from Python.
What steps are needed to add TT-NN operation in C++?
There are 2 options for writing a new operation. Optiona
a
is to write a device operation and optionb
is to write a composite operation a. Implement device operation in C++. Device operation is a struct that specifies how to create output tensors and a program to run on the device. b. Implement a composite operation in C++. Composite operation simply definesoperator()
method that calls other operations.Register the struct using ttnn::register_operation.
What steps are needed to add TT-NN operation in Python?
Take an existing registered C++ operation and add a Python binding for it using ttnn::bind_registered_operation. The operation will be auto-registered in python. If the operation is called ttnn::add in C++, then the python binding will be ttnn.add.
(Optional) Attach golden function to the operation using ttnn.attach_golden_function. This is useful for debugging and testing.
Example of Adding a new Device Operation
Let’s implement ttnn.example (It will just copy the input tensor to the output tensor on the device)
C++ Implementation
Step 1: Implement device operation
In order to add a new device operation, follow the directory structure shown below:
ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<operation_name>_device_operation.hpp ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<operation_name>_device_operation.cpp ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<program_factory_0>_program_factory.cpp
Note
Add as many program factories as needed. But the minimum requirement is one program factory.
A concrete example of a device operation can be found in ttnn/cpp/ttnn/operations/examples/example/device
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#pragma once
6
7#include <optional>
8#include <variant>
9
10#include "ttnn/tensor/tensor.hpp"
11#include "ttnn/core.hpp"
12#include "ttnn/device_operation.hpp"
13#include "ttnn/types.hpp"
14#include "ttnn/decorators.hpp"
15
16namespace ttnn::operations::examples {
17
18struct ExampleDeviceOperation {
19 // Define the operation attributes. This is it to store all variables needed by operations that aren't tensors
20 struct operation_attributes_t {
21 bool attribute;
22 int some_other_attribute;
23 };
24
25 // Define the tensor arguments. This is it to store all tensors passed in and/or out of the operation
26 // Tensor arguments don't need to be just input tensors, they can be output tensors, input/output tensors, optional
27 // tensors, etc.
28 struct tensor_args_t {
29 // This example will use a tensor that can only be used as an input
30 const Tensor& input_tensor;
31
32 // However, the following examples show what else can be done with tensor_args_t
33
34 // An example of the tensor that can be used for input/output or just for pre-allocated output
35 // Tensor& io_tensor;
36
37 // An example of an optional tensor
38 // std::optional<Tensor> optional_output_tensor;
39
40 // An example of a vector of tensors
41 // std::vector<Tensor> vector_of_tensors;
42
43 // An example of a tuple of tensors
44 // std::tuple<Tensor, ...> tuple_of_tensors;
45
46 // An example of a vector of optional tensors
47 // std::vector<std::optional<Tensor>> vector_of_optional_tensors;
48
49 // An example of a tuple of tensors
50 // std::tuple<std::vector<std::optional<Tensor>>, std::optional<Tensor>> some_crazy_tuple_of_tensors;
51 };
52
53 // Define the return types for the spec(s) of the operation
54 // Can be a single ttnn::TensorSpec, std::optional<ttnn::TensorSpec>, std::vector<ttnn::TensorSpec>,
55 // std::tuple<ttnn::TensorSpec> etc.
56 using spec_return_value_t = ttnn::TensorSpec;
57
58 // Define the return types for the tensor(s) of the operation
59 // Can be a single Tensor, std::optional<Tensor, ...>, std::vector<Tensor>, std::tuple<Tensor, ...> etc.
60 using tensor_return_value_t = Tensor;
61
62 // Note spec_return_value_t and tensor_return_value_t should follow the same pattern
63 // i.e. if spec_return_value_t is a std::vector<std::optional<ttnn::TensorSpec>> then tensor_return_value_t should
64 // be std::vector<std::optional<Tensor>>
65
66 struct SingleCore {
67 // Shared variables are the variables that are shared between the create and override_runtime_arguments methods
68 struct shared_variables_t {
69 tt::tt_metal::KernelHandle unary_reader_kernel_id;
70 tt::tt_metal::KernelHandle unary_writer_kernel_id;
71 };
72 using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
73
74 static cached_program_t create(
75 const operation_attributes_t& operation_attributes,
76 const tensor_args_t& tensor_args,
77 tensor_return_value_t& tensor_return_value);
78
79 static void override_runtime_arguments(
80 cached_program_t& cached_program,
81 const operation_attributes_t& operation_attributes,
82 const tensor_args_t& tensor_args,
83 tensor_return_value_t& tensor_return_value);
84 };
85
86 struct MultiCore {
87 // Shared variables are the variables that are shared between the create and override_runtime_arguments methods
88 struct shared_variables_t {
89 tt::tt_metal::KernelHandle unary_reader_kernel_id;
90 tt::tt_metal::KernelHandle unary_writer_kernel_id;
91 std::size_t num_cores;
92 std::size_t num_cores_y;
93 };
94 using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
95
96 static cached_program_t create(
97 const operation_attributes_t& operation_attributes,
98 const tensor_args_t& tensor_args,
99 tensor_return_value_t& tensor_return_value);
100
101 static void override_runtime_arguments(
102 cached_program_t& cached_program,
103 const operation_attributes_t& operation_attributes,
104 const tensor_args_t& tensor_args,
105 tensor_return_value_t& tensor_return_value);
106 };
107
108 using program_factory_t = std::variant<SingleCore, MultiCore>;
109
110 // Mandatory methods
111
112 // Select the program factory based on the operation attributes and tensor args
113 static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
114
115 // Validate the operation when it creates a program. Usually will have more checks
116 static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
117
118 // Validate the operation when it reuses a program. Usually will have less checks
119 static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
120
121 // Compute the output specs based on the operation attributes and tensor args
122 static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
123
124 // Create the output tensors based on the operation attributes and tensor args
125 static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
126
127 // API call to map user arguments to operation attributes and tensor args.
128 // This is the only method that is called by the user
129 // The user will be able to call the operation using `tensor_return_value_t output =
130 // ttnn::prim::example(input_tensor)` after the op is registered Keep in mind that the the overload with `queue_id`
131 // argument will be added automatically for primitive operations So, the user can also call this operation using
132 // `tensor_return_value_t output = ttnn::prim::example(queue_id, input_tensor)`
133 static std::tuple<operation_attributes_t, tensor_args_t> invoke(const Tensor& input_tensor);
134
135 // Optional methods
136
137 // In case the operation need a custom hash function, the following method can be implemented
138 /* static tt::stl::hash::hash_t compute_program_hash(
139 const operation_attributes_t&, const tensor_args_t&);
140 */
141
142 // In case the operation needs a custom create_op_performance_model, this method can be implemented
143 /*
144 static tt::tt_metal::tt::tt_metal::operation::OpPerformanceModel create_op_performance_model(
145 const operation_attributes_t&,
146 const tensor_args_t&,
147 tensor_return_value_t&);
148 */
149};
150
151} // namespace ttnn::operations::examples
152
153// Register the operation with the ttnn::register_operation API to make it available to the user as ttnn::prim::example
154namespace ttnn::prim {
155constexpr auto example =
156 ttnn::register_operation<"ttnn::prim::example", ttnn::operations::examples::ExampleDeviceOperation>();
157} // namespace ttnn::prim
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#include "example_device_operation.hpp"
6
7namespace ttnn::operations::examples {
8
9ExampleDeviceOperation::program_factory_t ExampleDeviceOperation::select_program_factory(
10 const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
11 bool some_condition_based_on_operation_attributes_and_or_tensor_args = true;
12 if (some_condition_based_on_operation_attributes_and_or_tensor_args) {
13 return SingleCore{};
14 }
15 return MultiCore{};
16}
17
18void ExampleDeviceOperation::validate_on_program_cache_miss(
19 const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {}
20
21void ExampleDeviceOperation::validate_on_program_cache_hit(
22 const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {}
23
24ExampleDeviceOperation::spec_return_value_t ExampleDeviceOperation::compute_output_specs(
25 const operation_attributes_t&, const tensor_args_t& tensor_args) {
26 const auto& input_tensor = tensor_args.input_tensor;
27 return TensorSpec(
28 input_tensor.logical_shape(),
29 tt::tt_metal::TensorLayout(
30 input_tensor.dtype(), tt::tt_metal::PageConfig(input_tensor.layout()), MemoryConfig{}));
31}
32
33ExampleDeviceOperation::tensor_return_value_t ExampleDeviceOperation::create_output_tensors(
34 const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
35 auto output_spec = compute_output_specs(operation_attributes, tensor_args);
36 return create_device_tensor(output_spec, tensor_args.input_tensor.device());
37}
38
39std::tuple<ExampleDeviceOperation::operation_attributes_t, ExampleDeviceOperation::tensor_args_t>
40ExampleDeviceOperation::invoke(const Tensor& input_tensor) {
41 return {operation_attributes_t{true, 42}, tensor_args_t{input_tensor}};
42}
43
44} // namespace ttnn::operations::examples
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#include "example_device_operation.hpp"
6#include <tt-metalium/work_split.hpp>
7
8namespace ttnn::operations::examples {
9ExampleDeviceOperation::SingleCore::cached_program_t ExampleDeviceOperation::SingleCore::create(
10 const operation_attributes_t& operation_attributes,
11 const tensor_args_t& tensor_args,
12 tensor_return_value_t& tensor_return_value) {
13 using namespace tt;
14 using namespace tt::tt_metal;
15
16 const auto& input_tensor = tensor_args.input_tensor;
17 auto& output_tensor = tensor_return_value;
18
19 auto src_buffer = input_tensor.buffer();
20 auto dst_buffer = output_tensor.buffer();
21
22 tt::tt_metal::Program program{};
23
24 tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype());
25 uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format);
26 tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.dtype());
27 uint32_t single_tile_size_output = tt::tt_metal::detail::TileSize(cb_data_format_output);
28
29 uint32_t num_tiles = input_tensor.physical_volume() / tt::constants::TILE_HW;
30
31 tt::tt_metal::IDevice* device = input_tensor.device();
32
33 CoreCoord compute_with_storage_grid_size = {1, 1};
34 uint32_t num_cores_x = compute_with_storage_grid_size.x;
35 uint32_t num_cores_y = compute_with_storage_grid_size.y;
36 auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
37 tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles);
38
39 uint32_t src0_cb_index = tt::CBIndex::c_0;
40 uint32_t num_input_tiles = 2;
41 tt::tt_metal::CircularBufferConfig cb_src0_config =
42 tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
43 .set_page_size(src0_cb_index, single_tile_size);
44 auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);
45
46 uint32_t output_cb_index = tt::CBIndex::c_2;
47 uint32_t num_output_tiles = 2;
48 tt::tt_metal::CircularBufferConfig cb_output_config =
49 tt::tt_metal::CircularBufferConfig(
50 num_output_tiles * single_tile_size_output, {{output_cb_index, cb_data_format_output}})
51 .set_page_size(output_cb_index, single_tile_size_output);
52 auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
53
54 bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
55 std::vector<uint32_t> reader_compile_time_args = {(uint32_t)src_is_dram};
56 bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
57 std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram};
58
59 tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
60 program,
61 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
62 all_cores,
63 tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
64
65 tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
66 program,
67 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp",
68 all_cores,
69 tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));
70
71 std::vector<uint32_t> compute_kernel_args_group_1 = {
72 num_tiles_per_core_group_1, // per_core_block_cnt
73 1 // per_core_block_size
74 };
75
76 bool math_approx_mode = false;
77 auto eltwise_unary_kernel_group_1_id = tt::tt_metal::CreateKernel(
78 program,
79 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
80 core_group_1,
81 tt::tt_metal::ComputeConfig{
82 .math_fidelity = MathFidelity::HiFi4,
83 .math_approx_mode = math_approx_mode,
84 .compile_args = compute_kernel_args_group_1});
85
86 if (!core_group_2.ranges().empty()) {
87 std::vector<uint32_t> compute_kernel_args_group_2 = {
88 num_tiles_per_core_group_2, // per_core_block_cnt
89 1 // per_core_block_size
90 };
91
92 auto eltwise_unary_kernel_group_2_id = tt::tt_metal::CreateKernel(
93 program,
94 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
95 core_group_2,
96 tt::tt_metal::ComputeConfig{
97 .math_fidelity = MathFidelity::HiFi4,
98 .math_approx_mode = math_approx_mode,
99 .compile_args = compute_kernel_args_group_2});
100 }
101
102 for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
103 CoreCoord core = {i / num_cores_y, i % num_cores_y};
104 uint32_t num_tiles_per_core = 0;
105 if (core_group_1.contains(core)) {
106 num_tiles_per_core = num_tiles_per_core_group_1;
107 } else if (core_group_2.contains(core)) {
108 num_tiles_per_core = num_tiles_per_core_group_2;
109 } else {
110 TT_ASSERT(false, "Core not in specified core ranges");
111 }
112
113 tt::tt_metal::SetRuntimeArgs(
114 program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});
115
116 tt::tt_metal::SetRuntimeArgs(
117 program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles_per_core, num_tiles_written});
118 num_tiles_written += num_tiles_per_core;
119 }
120
121 return {
122 std::move(program),
123 {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
124}
125
126void ExampleDeviceOperation::SingleCore::override_runtime_arguments(
127 cached_program_t& cached_program,
128 const operation_attributes_t& operation_attributes,
129 const tensor_args_t& tensor_args,
130 tensor_return_value_t& tensor_return_value) {
131 auto& program = cached_program.program;
132 auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id;
133 auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;
134
135 const auto& input_tensor = tensor_args.input_tensor;
136 auto& output_tensor = tensor_return_value;
137
138 auto src_buffer = input_tensor.buffer();
139 auto dst_buffer = output_tensor.buffer();
140
141 {
142 auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0});
143 runtime_args[0] = src_buffer->address();
144 }
145
146 {
147 auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0});
148 runtime_args[0] = dst_buffer->address();
149 }
150}
151
152} // namespace ttnn::operations::examples
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#include "example_device_operation.hpp"
6#include <tt-metalium/work_split.hpp>
7
8namespace ttnn::operations::examples {
9ExampleDeviceOperation::MultiCore::cached_program_t ExampleDeviceOperation::MultiCore::create(
10 const operation_attributes_t& operation_attributes,
11 const tensor_args_t& tensor_args,
12 tensor_return_value_t& tensor_return_value) {
13 using namespace tt;
14 using namespace tt::tt_metal;
15
16 const auto& input_tensor = tensor_args.input_tensor;
17 auto& output_tensor = tensor_return_value;
18
19 auto src_buffer = input_tensor.buffer();
20 auto dst_buffer = output_tensor.buffer();
21
22 tt::tt_metal::Program program{};
23
24 tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype());
25 uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format);
26 tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.dtype());
27 uint32_t single_tile_size_output = tt::tt_metal::detail::TileSize(cb_data_format_output);
28
29 uint32_t num_tiles = input_tensor.physical_volume() / tt::constants::TILE_HW;
30
31 tt::tt_metal::IDevice* device = input_tensor.device();
32
33 auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
34 uint32_t num_cores_x = compute_with_storage_grid_size.x;
35 uint32_t num_cores_y = compute_with_storage_grid_size.y;
36 auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
37 tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles);
38
39 uint32_t src0_cb_index = tt::CBIndex::c_0;
40 uint32_t num_input_tiles = 2;
41 tt::tt_metal::CircularBufferConfig cb_src0_config =
42 tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
43 .set_page_size(src0_cb_index, single_tile_size);
44 auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);
45
46 uint32_t output_cb_index = tt::CBIndex::c_2;
47 uint32_t num_output_tiles = 2;
48 tt::tt_metal::CircularBufferConfig cb_output_config =
49 tt::tt_metal::CircularBufferConfig(
50 num_output_tiles * single_tile_size_output, {{output_cb_index, cb_data_format_output}})
51 .set_page_size(output_cb_index, single_tile_size_output);
52 auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
53
54 bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
55 std::vector<uint32_t> reader_compile_time_args = {(uint32_t)src_is_dram};
56 bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
57 std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram};
58
59 tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
60 program,
61 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
62 all_cores,
63 tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
64
65 tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
66 program,
67 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp",
68 all_cores,
69 tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));
70
71 std::vector<uint32_t> compute_kernel_args_group_1 = {
72 num_tiles_per_core_group_1, // per_core_block_cnt
73 1 // per_core_block_size
74 };
75
76 bool math_approx_mode = false;
77 auto eltwise_unary_kernel_group_1_id = tt::tt_metal::CreateKernel(
78 program,
79 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
80 core_group_1,
81 tt::tt_metal::ComputeConfig{
82 .math_fidelity = MathFidelity::HiFi4,
83 .math_approx_mode = math_approx_mode,
84 .compile_args = compute_kernel_args_group_1});
85
86 if (!core_group_2.ranges().empty()) {
87 std::vector<uint32_t> compute_kernel_args_group_2 = {
88 num_tiles_per_core_group_2, // per_core_block_cnt
89 1 // per_core_block_size
90 };
91
92 auto eltwise_unary_kernel_group_2_id = tt::tt_metal::CreateKernel(
93 program,
94 "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
95 core_group_2,
96 tt::tt_metal::ComputeConfig{
97 .math_fidelity = MathFidelity::HiFi4,
98 .math_approx_mode = math_approx_mode,
99 .compile_args = compute_kernel_args_group_2});
100 }
101
102 for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
103 CoreCoord core = {i / num_cores_y, i % num_cores_y};
104 uint32_t num_tiles_per_core = 0;
105 if (core_group_1.contains(core)) {
106 num_tiles_per_core = num_tiles_per_core_group_1;
107 } else if (core_group_2.contains(core)) {
108 num_tiles_per_core = num_tiles_per_core_group_2;
109 } else {
110 TT_ASSERT(false, "Core not in specified core ranges");
111 }
112
113 tt::tt_metal::SetRuntimeArgs(
114 program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});
115
116 tt::tt_metal::SetRuntimeArgs(
117 program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles_per_core, num_tiles_written});
118 num_tiles_written += num_tiles_per_core;
119 }
120
121 return {
122 std::move(program),
123 {.unary_reader_kernel_id = unary_reader_kernel_id,
124 .unary_writer_kernel_id = unary_writer_kernel_id,
125 .num_cores = num_cores,
126 .num_cores_y = num_cores_y}};
127}
128
129void ExampleDeviceOperation::MultiCore::override_runtime_arguments(
130 cached_program_t& cached_program,
131 const operation_attributes_t& operation_attributes,
132 const tensor_args_t& tensor_args,
133 tensor_return_value_t& tensor_return_value) {
134 auto& program = cached_program.program;
135 auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id;
136 auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;
137 auto& num_cores = cached_program.shared_variables.num_cores;
138 auto& num_cores_y = cached_program.shared_variables.num_cores_y;
139
140 const auto& input_tensor = tensor_args.input_tensor;
141 auto& output_tensor = tensor_return_value;
142
143 auto src_buffer = input_tensor.buffer();
144 auto dst_buffer = output_tensor.buffer();
145
146 for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
147 CoreCoord core = {i / num_cores_y, i % num_cores_y};
148
149 {
150 auto& runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
151 runtime_args[0] = src_buffer->address();
152 }
153
154 {
155 auto& runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core);
156 runtime_args[0] = dst_buffer->address();
157 }
158 }
159}
160
161} // namespace ttnn::operations::examples
Step 2: Implement the operation in C++
In order to add a new operation, add the following file:
ttnn/cpp/ttnn/operations/<category>/<operation_name>/<operation_name>.hpp
A concrete example:
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#pragma once
6
7#include "device/example_device_operation.hpp"
8
9namespace ttnn::operations::examples {
10
11// A composite operation is an operation that calls multiple operations in sequence
12// It is written using invoke and can be used to call multiple primitive and/or composite operations
13struct CompositeExampleOperation {
14 // The user will be able to call this method as `Tensor output = ttnn::composite_example(input_tensor)` after the op
15 // is registered
16 static Tensor invoke(const Tensor& input_tensor) {
17 auto copy = prim::example(input_tensor);
18 auto another_copy = prim::example(copy);
19 return another_copy;
20 }
21};
22
23} // namespace ttnn::operations::examples
24
25namespace ttnn {
26constexpr auto composite_example =
27 ttnn::register_operation<"ttnn::composite_example", operations::examples::CompositeExampleOperation>();
28} // namespace ttnn
Python Implementation
Step 1: Add Python binding
In order to add a python binding for the operation, follow the directory structure shown below:
ttnn/python/ttnn/operations/<category>/<operation_name>/<operation_name>_pybind.hpp ttnn/python/ttnn/operations/<category>/<category>_pybind.hpp
A concrete example:
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#pragma once
6
7#include "ttnn-pybind/pybind_fwd.hpp"
8
9namespace ttnn::operations::examples {
10namespace py = pybind11;
11
12void bind_example_operation(py::module& module);
13
14} // namespace ttnn::operations::examples
15
16/*
17void bind_example_operation(py::module& module) {
18 bind_registered_operation(
19 module,
20 ttnn::prim::example,
21 R"doc(example(input_tensor: ttnn.Tensor) -> ttnn.Tensor)doc",
22
23 // Add pybind overloads for the C++ APIs that should be exposed to python
24 // There should be no logic here, just a call to `self` with the correct arguments
25 // The overload with `queue_id` argument will be added automatically for primitive operations
26 // This specific function can be called from python as `ttnn.prim.example(input_tensor)` or
27 // `ttnn.prim.example(input_tensor, queue_id=queue_id)`
28 ttnn::pybind_overload_t{
29 [](const decltype(ttnn::prim::example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
30 return self(input_tensor);
31 },
32 py::arg("input_tensor")});
33
34 bind_registered_operation(
35 module,
36 ttnn::composite_example,
37 R"doc(composite_example(input_tensor: ttnn.Tensor) -> ttnn.Tensor)doc",
38
39 // Add pybind overloads for the C++ APIs that should be exposed to python
40 // There should be no logic here, just a call to `self` with the correct arguments
41 ttnn::pybind_overload_t{
42 [](const decltype(ttnn::composite_example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
43 return self(input_tensor);
44 },
45 py::arg("input_tensor")});
46}
47*/
1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5#pragma once
6
7#include "ttnn-pybind/pybind_fwd.hpp"
8
9namespace ttnn::operations::examples {
10namespace py = pybind11;
11
12void py_module(py::module& module);
13
14} // namespace ttnn::operations::examples
Finally, call the module defined in examples/example/example_pybind.hpp wherever you want it to be added.
Step 2: (Optional) Add golden function for the operation in Python
A golden function can be added to an operation in order to compare its output with an equivalent torch implementation
Add the following code in a python file:
import ttnn
# For the golden function, use the same signature as the operation
# Keep in mind that all `ttnn.Tensor`s are converted to `torch.Tensor`s
# And arguments not needed by torch can be ignored using `*args` and `**kwargs`
def golden_function(input_tensor: "torch.Tensor", *args, **kwargs):
output_tensor: "torch.Tensor" = ...
return output_tensor
# TT-NN Tensors are converted to torch tensors before calling the golden function automatically
# And the outputs are converted back to TT-NN Tensors
# But in some cases you may need to preprocess the inputs and postprocess the outputs manually
# In order to preprocess the inputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def preprocess_golden_function_inputs(args, kwargs):
# i.e.
ttnn_input_tensor = args[0]
return ttnn.to_torch(ttnn_input_tensor)
# In order to postprocess the outputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def postprocess_golden_function_outputs(args, kwargs, output):
# i.e.
ttnn_input_tensor = args[0]
torch_output_tensor = outputs[0]
return ttnn.from_torch(torch_output_tensor, dtype=ttnn_input_tensor.dtype, device=ttnn_input_tensor.device)
ttnn.attach_golden_function(
ttnn.example,
golden_function=golden_function,
preprocess_golden_function_inputs=preprocess_golden_function_inputs, # Optional
postprocess_golden_function_outputs=postprocess_golden_function_outputs # Optional
)
Note
ttnn.example is the name of the operation in Python because the operation was registered as ttnn::example in C++.