Adding New TT-NN Operation

Note

This document is meant for contributors to TT-NN.

Not all operations may be functional on all Tenstorrent hardware (Grayskull, Wormhole, or others).

FAQ

What is a TT-NN operation?

A TT-NN operation is a function that takes in one or more input tensors and produces one or more output tensors. It is implemented in C++ and can be called from Python.

What steps are needed to add TT-NN operation in C++?

  1. There are 2 options for writing a new operation. Optiona a is to write a device operation and option b is to write a composite operation a. Implement device operation in C++. Device operation is a struct that specifies how to create output tensors and a program to run on the device. b. Implement a composite operation in C++. Composite operation simply defines operator() method that calls other operations.

  2. Register the struct using ttnn::register_operation.

What steps are needed to add TT-NN operation in Python?

  1. Take an existing registered C++ operation and add a Python binding for it using ttnn::bind_registered_operation. The operation will be auto-registered in python. If the operation is called ttnn::add in C++, then the python binding will be ttnn.add.

  2. (Optional) Attach golden function to the operation using ttnn.attach_golden_function. This is useful for debugging and testing.

Example of Adding a new Device Operation

Let’s implement ttnn.example (It will just copy the input tensor to the output tensor on the device)

C++ Implementation

Step 1: Implement device operation

In order to add a new device operation, follow the directory structure shown below:

ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<operation_name>_device_operation.hpp ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<operation_name>_device_operation.cpp ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<program_factory_0>_program_factory.cpp

Note

Add as many program factories as needed. But the minimum requirement is one program factory.

A concrete example of a device operation can be found in ttnn/cpp/ttnn/operations/examples/example/device

ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp
  1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5#pragma once
  6
  7#include <optional>
  8#include <variant>
  9
 10#include "ttnn/tensor/tensor.hpp"
 11#include "ttnn/core.hpp"
 12#include "ttnn/device_operation.hpp"
 13#include "ttnn/types.hpp"
 14#include "ttnn/decorators.hpp"
 15
 16namespace ttnn::operations::examples {
 17
 18struct ExampleDeviceOperation {
 19    // Define the operation attributes. This is it to store all variables needed by operations that aren't tensors
 20    struct operation_attributes_t {
 21        bool attribute;
 22        int some_other_attribute;
 23    };
 24
 25    // Define the tensor arguments. This is it to store all tensors passed in and/or out of the operation
 26    // Tensor arguments don't need to be just input tensors, they can be output tensors, input/output tensors, optional
 27    // tensors, etc.
 28    struct tensor_args_t {
 29        // This example will use a tensor that can only be used as an input
 30        const Tensor& input_tensor;
 31
 32        // However, the following examples show what else can be done with tensor_args_t
 33
 34        // An example of the tensor that can be used for input/output or just for pre-allocated output
 35        // Tensor& io_tensor;
 36
 37        // An example of an optional tensor
 38        // std::optional<Tensor> optional_output_tensor;
 39
 40        // An example of a vector of tensors
 41        // std::vector<Tensor> vector_of_tensors;
 42
 43        // An example of a tuple of tensors
 44        // std::tuple<Tensor, ...> tuple_of_tensors;
 45
 46        // An example of a vector of optional tensors
 47        // std::vector<std::optional<Tensor>> vector_of_optional_tensors;
 48
 49        // An example of a tuple of tensors
 50        // std::tuple<std::vector<std::optional<Tensor>>, std::optional<Tensor>> some_crazy_tuple_of_tensors;
 51    };
 52
 53    // Define the return types for the spec(s) of the operation
 54    // Can be a single ttnn::TensorSpec, std::optional<ttnn::TensorSpec>, std::vector<ttnn::TensorSpec>,
 55    // std::tuple<ttnn::TensorSpec> etc.
 56    using spec_return_value_t = ttnn::TensorSpec;
 57
 58    // Define the return types for the tensor(s) of the operation
 59    // Can be a single Tensor, std::optional<Tensor, ...>, std::vector<Tensor>, std::tuple<Tensor, ...> etc.
 60    using tensor_return_value_t = Tensor;
 61
 62    // Note spec_return_value_t and tensor_return_value_t should follow the same pattern
 63    // i.e. if spec_return_value_t is a std::vector<std::optional<ttnn::TensorSpec>> then tensor_return_value_t should
 64    // be std::vector<std::optional<Tensor>>
 65
 66    struct SingleCore {
 67        // Shared variables are the variables that are shared between the create and override_runtime_arguments methods
 68        struct shared_variables_t {
 69            tt::tt_metal::KernelHandle unary_reader_kernel_id;
 70            tt::tt_metal::KernelHandle unary_writer_kernel_id;
 71        };
 72        using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
 73
 74        static cached_program_t create(
 75            const operation_attributes_t& operation_attributes,
 76            const tensor_args_t& tensor_args,
 77            tensor_return_value_t& tensor_return_value);
 78
 79        static void override_runtime_arguments(
 80            cached_program_t& cached_program,
 81            const operation_attributes_t& operation_attributes,
 82            const tensor_args_t& tensor_args,
 83            tensor_return_value_t& tensor_return_value);
 84    };
 85
 86    struct MultiCore {
 87        // Shared variables are the variables that are shared between the create and override_runtime_arguments methods
 88        struct shared_variables_t {
 89            tt::tt_metal::KernelHandle unary_reader_kernel_id;
 90            tt::tt_metal::KernelHandle unary_writer_kernel_id;
 91            std::size_t num_cores;
 92            std::size_t num_cores_y;
 93        };
 94        using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
 95
 96        static cached_program_t create(
 97            const operation_attributes_t& operation_attributes,
 98            const tensor_args_t& tensor_args,
 99            tensor_return_value_t& tensor_return_value);
100
101        static void override_runtime_arguments(
102            cached_program_t& cached_program,
103            const operation_attributes_t& operation_attributes,
104            const tensor_args_t& tensor_args,
105            tensor_return_value_t& tensor_return_value);
106    };
107
108    using program_factory_t = std::variant<SingleCore, MultiCore>;
109
110    // Mandatory methods
111
112    // Select the program factory based on the operation attributes and tensor args
113    static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
114
115    // Validate the operation when it creates a program. Usually will have more checks
116    static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
117
118    // Validate the operation when it reuses a program. Usually will have less checks
119    static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
120
121    // Compute the output specs based on the operation attributes and tensor args
122    static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
123
124    // Create the output tensors based on the operation attributes and tensor args
125    static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
126
127    // API call to map user arguments to operation attributes and tensor args.
128    // This is the only method that is called by the user
129    // The user will be able to call the operation using `tensor_return_value_t output =
130    // ttnn::prim::example(input_tensor)` after the op is registered Keep in mind that the the overload with `queue_id`
131    // argument will be added automatically for primitive operations So, the user can also call this operation using
132    // `tensor_return_value_t output = ttnn::prim::example(queue_id, input_tensor)`
133    static std::tuple<operation_attributes_t, tensor_args_t> invoke(const Tensor& input_tensor);
134
135    // Optional methods
136
137    // In case the operation need a custom hash function, the following method can be implemented
138    /* static tt::stl::hash::hash_t compute_program_hash(
139        const operation_attributes_t&, const tensor_args_t&);
140    */
141
142    // In case the operation needs a custom create_op_performance_model, this method can be implemented
143    /*
144    static tt::tt_metal::tt::tt_metal::operation::OpPerformanceModel create_op_performance_model(
145        const operation_attributes_t&,
146        const tensor_args_t&,
147        tensor_return_value_t&);
148    */
149};
150
151}  // namespace ttnn::operations::examples
152
153// Register the operation with the ttnn::register_operation API to make it available to the user as ttnn::prim::example
154namespace ttnn::prim {
155constexpr auto example =
156    ttnn::register_operation<"ttnn::prim::example", ttnn::operations::examples::ExampleDeviceOperation>();
157}  // namespace ttnn::prim
ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp
 1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
 2//
 3// SPDX-License-Identifier: Apache-2.0
 4
 5#include "example_device_operation.hpp"
 6
 7namespace ttnn::operations::examples {
 8
 9ExampleDeviceOperation::program_factory_t ExampleDeviceOperation::select_program_factory(
10    const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
11    bool some_condition_based_on_operation_attributes_and_or_tensor_args = true;
12    if (some_condition_based_on_operation_attributes_and_or_tensor_args) {
13        return SingleCore{};
14    }
15    return MultiCore{};
16}
17
18void ExampleDeviceOperation::validate_on_program_cache_miss(
19    const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {}
20
21void ExampleDeviceOperation::validate_on_program_cache_hit(
22    const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {}
23
24ExampleDeviceOperation::spec_return_value_t ExampleDeviceOperation::compute_output_specs(
25    const operation_attributes_t&, const tensor_args_t& tensor_args) {
26    const auto& input_tensor = tensor_args.input_tensor;
27    return TensorSpec(
28        input_tensor.logical_shape(),
29        tt::tt_metal::TensorLayout(
30            input_tensor.dtype(), tt::tt_metal::PageConfig(input_tensor.layout()), MemoryConfig{}));
31}
32
33ExampleDeviceOperation::tensor_return_value_t ExampleDeviceOperation::create_output_tensors(
34    const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
35    auto output_spec = compute_output_specs(operation_attributes, tensor_args);
36    return create_device_tensor(output_spec, tensor_args.input_tensor.device());
37}
38
39std::tuple<ExampleDeviceOperation::operation_attributes_t, ExampleDeviceOperation::tensor_args_t>
40ExampleDeviceOperation::invoke(const Tensor& input_tensor) {
41    return {operation_attributes_t{true, 42}, tensor_args_t{input_tensor}};
42}
43
44}  // namespace ttnn::operations::examples
ttnn/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp
  1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5#include "example_device_operation.hpp"
  6#include <tt-metalium/work_split.hpp>
  7
  8namespace ttnn::operations::examples {
  9ExampleDeviceOperation::SingleCore::cached_program_t ExampleDeviceOperation::SingleCore::create(
 10    const operation_attributes_t& operation_attributes,
 11    const tensor_args_t& tensor_args,
 12    tensor_return_value_t& tensor_return_value) {
 13    using namespace tt;
 14    using namespace tt::tt_metal;
 15
 16    const auto& input_tensor = tensor_args.input_tensor;
 17    auto& output_tensor = tensor_return_value;
 18
 19    auto src_buffer = input_tensor.buffer();
 20    auto dst_buffer = output_tensor.buffer();
 21
 22    tt::tt_metal::Program program{};
 23
 24    tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype());
 25    uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format);
 26    tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.dtype());
 27    uint32_t single_tile_size_output = tt::tt_metal::detail::TileSize(cb_data_format_output);
 28
 29    uint32_t num_tiles = input_tensor.physical_volume() / tt::constants::TILE_HW;
 30
 31    CoreCoord compute_with_storage_grid_size = {1, 1};
 32    uint32_t num_cores_y = compute_with_storage_grid_size.y;
 33    auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
 34        tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles);
 35
 36    uint32_t src0_cb_index = tt::CBIndex::c_0;
 37    uint32_t num_input_tiles = 2;
 38    tt::tt_metal::CircularBufferConfig cb_src0_config =
 39        tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
 40            .set_page_size(src0_cb_index, single_tile_size);
 41    tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);
 42
 43    uint32_t output_cb_index = tt::CBIndex::c_2;
 44    uint32_t num_output_tiles = 2;
 45    tt::tt_metal::CircularBufferConfig cb_output_config =
 46        tt::tt_metal::CircularBufferConfig(
 47            num_output_tiles * single_tile_size_output, {{output_cb_index, cb_data_format_output}})
 48            .set_page_size(output_cb_index, single_tile_size_output);
 49    tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
 50
 51    bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
 52    std::vector<uint32_t> reader_compile_time_args = {(uint32_t)src_is_dram};
 53    bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
 54    std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram};
 55
 56    tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
 57        program,
 58        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
 59        all_cores,
 60        tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
 61
 62    tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
 63        program,
 64        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp",
 65        all_cores,
 66        tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));
 67
 68    std::vector<uint32_t> compute_kernel_args_group_1 = {
 69        num_tiles_per_core_group_1,  // per_core_block_cnt
 70        1                            // per_core_block_size
 71    };
 72
 73    bool math_approx_mode = false;
 74    tt::tt_metal::CreateKernel(
 75        program,
 76        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
 77        core_group_1,
 78        tt::tt_metal::ComputeConfig{
 79            .math_fidelity = MathFidelity::HiFi4,
 80            .math_approx_mode = math_approx_mode,
 81            .compile_args = compute_kernel_args_group_1});
 82
 83    if (!core_group_2.ranges().empty()) {
 84        std::vector<uint32_t> compute_kernel_args_group_2 = {
 85            num_tiles_per_core_group_2,  // per_core_block_cnt
 86            1                            // per_core_block_size
 87        };
 88
 89        tt::tt_metal::CreateKernel(
 90            program,
 91            "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
 92            core_group_2,
 93            tt::tt_metal::ComputeConfig{
 94                .math_fidelity = MathFidelity::HiFi4,
 95                .math_approx_mode = math_approx_mode,
 96                .compile_args = compute_kernel_args_group_2});
 97    }
 98
 99    for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
100        CoreCoord core = {i / num_cores_y, i % num_cores_y};
101        uint32_t num_tiles_per_core = 0;
102        if (core_group_1.contains(core)) {
103            num_tiles_per_core = num_tiles_per_core_group_1;
104        } else if (core_group_2.contains(core)) {
105            num_tiles_per_core = num_tiles_per_core_group_2;
106        } else {
107            TT_ASSERT(false, "Core not in specified core ranges");
108        }
109
110        tt::tt_metal::SetRuntimeArgs(
111            program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});
112
113        tt::tt_metal::SetRuntimeArgs(
114            program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles_per_core, num_tiles_written});
115        num_tiles_written += num_tiles_per_core;
116    }
117
118    return {
119        std::move(program),
120        {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
121}
122
123void ExampleDeviceOperation::SingleCore::override_runtime_arguments(
124    cached_program_t& cached_program,
125    const operation_attributes_t& operation_attributes,
126    const tensor_args_t& tensor_args,
127    tensor_return_value_t& tensor_return_value) {
128    auto& program = cached_program.program;
129    auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id;
130    auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;
131
132    const auto& input_tensor = tensor_args.input_tensor;
133    auto& output_tensor = tensor_return_value;
134
135    auto src_buffer = input_tensor.buffer();
136    auto dst_buffer = output_tensor.buffer();
137
138    {
139        auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0});
140        runtime_args[0] = src_buffer->address();
141    }
142
143    {
144        auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0});
145        runtime_args[0] = dst_buffer->address();
146    }
147}
148
149}  // namespace ttnn::operations::examples
ttnn/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp
  1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5#include "example_device_operation.hpp"
  6#include <tt-metalium/work_split.hpp>
  7
  8namespace ttnn::operations::examples {
  9ExampleDeviceOperation::MultiCore::cached_program_t ExampleDeviceOperation::MultiCore::create(
 10    const operation_attributes_t& operation_attributes,
 11    const tensor_args_t& tensor_args,
 12    tensor_return_value_t& tensor_return_value) {
 13    using namespace tt;
 14    using namespace tt::tt_metal;
 15
 16    const auto& input_tensor = tensor_args.input_tensor;
 17    auto& output_tensor = tensor_return_value;
 18
 19    auto src_buffer = input_tensor.buffer();
 20    auto dst_buffer = output_tensor.buffer();
 21
 22    tt::tt_metal::Program program{};
 23
 24    tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype());
 25    uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format);
 26    tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.dtype());
 27    uint32_t single_tile_size_output = tt::tt_metal::detail::TileSize(cb_data_format_output);
 28
 29    uint32_t num_tiles = input_tensor.physical_volume() / tt::constants::TILE_HW;
 30
 31    tt::tt_metal::IDevice* device = input_tensor.device();
 32
 33    auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
 34    uint32_t num_cores_y = compute_with_storage_grid_size.y;
 35    auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
 36        tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles);
 37
 38    uint32_t src0_cb_index = tt::CBIndex::c_0;
 39    uint32_t num_input_tiles = 2;
 40    tt::tt_metal::CircularBufferConfig cb_src0_config =
 41        tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
 42            .set_page_size(src0_cb_index, single_tile_size);
 43    tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);
 44
 45    uint32_t output_cb_index = tt::CBIndex::c_2;
 46    uint32_t num_output_tiles = 2;
 47    tt::tt_metal::CircularBufferConfig cb_output_config =
 48        tt::tt_metal::CircularBufferConfig(
 49            num_output_tiles * single_tile_size_output, {{output_cb_index, cb_data_format_output}})
 50            .set_page_size(output_cb_index, single_tile_size_output);
 51    tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
 52
 53    bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
 54    std::vector<uint32_t> reader_compile_time_args = {(uint32_t)src_is_dram};
 55    bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
 56    std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram};
 57
 58    tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
 59        program,
 60        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
 61        all_cores,
 62        tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
 63
 64    tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
 65        program,
 66        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp",
 67        all_cores,
 68        tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));
 69
 70    std::vector<uint32_t> compute_kernel_args_group_1 = {
 71        num_tiles_per_core_group_1,  // per_core_block_cnt
 72        1                            // per_core_block_size
 73    };
 74
 75    bool math_approx_mode = false;
 76    tt::tt_metal::CreateKernel(
 77        program,
 78        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
 79        core_group_1,
 80        tt::tt_metal::ComputeConfig{
 81            .math_fidelity = MathFidelity::HiFi4,
 82            .math_approx_mode = math_approx_mode,
 83            .compile_args = compute_kernel_args_group_1});
 84
 85    if (!core_group_2.ranges().empty()) {
 86        std::vector<uint32_t> compute_kernel_args_group_2 = {
 87            num_tiles_per_core_group_2,  // per_core_block_cnt
 88            1                            // per_core_block_size
 89        };
 90
 91        tt::tt_metal::CreateKernel(
 92            program,
 93            "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
 94            core_group_2,
 95            tt::tt_metal::ComputeConfig{
 96                .math_fidelity = MathFidelity::HiFi4,
 97                .math_approx_mode = math_approx_mode,
 98                .compile_args = compute_kernel_args_group_2});
 99    }
100
101    for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
102        CoreCoord core = {i / num_cores_y, i % num_cores_y};
103        uint32_t num_tiles_per_core = 0;
104        if (core_group_1.contains(core)) {
105            num_tiles_per_core = num_tiles_per_core_group_1;
106        } else if (core_group_2.contains(core)) {
107            num_tiles_per_core = num_tiles_per_core_group_2;
108        } else {
109            TT_ASSERT(false, "Core not in specified core ranges");
110        }
111
112        tt::tt_metal::SetRuntimeArgs(
113            program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});
114
115        tt::tt_metal::SetRuntimeArgs(
116            program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles_per_core, num_tiles_written});
117        num_tiles_written += num_tiles_per_core;
118    }
119
120    return {
121        std::move(program),
122        {.unary_reader_kernel_id = unary_reader_kernel_id,
123         .unary_writer_kernel_id = unary_writer_kernel_id,
124         .num_cores = num_cores,
125         .num_cores_y = num_cores_y}};
126}
127
128void ExampleDeviceOperation::MultiCore::override_runtime_arguments(
129    cached_program_t& cached_program,
130    const operation_attributes_t& operation_attributes,
131    const tensor_args_t& tensor_args,
132    tensor_return_value_t& tensor_return_value) {
133    auto& program = cached_program.program;
134    auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id;
135    auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;
136    auto& num_cores = cached_program.shared_variables.num_cores;
137    auto& num_cores_y = cached_program.shared_variables.num_cores_y;
138
139    const auto& input_tensor = tensor_args.input_tensor;
140    auto& output_tensor = tensor_return_value;
141
142    auto src_buffer = input_tensor.buffer();
143    auto dst_buffer = output_tensor.buffer();
144
145    for (uint32_t i = 0; i < num_cores; i++) {
146        CoreCoord core = {i / num_cores_y, i % num_cores_y};
147
148        {
149            auto& runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
150            runtime_args[0] = src_buffer->address();
151        }
152
153        {
154            auto& runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core);
155            runtime_args[0] = dst_buffer->address();
156        }
157    }
158}
159
160}  // namespace ttnn::operations::examples

Step 2: Implement the operation in C++

In order to add a new operation, add the following file:

ttnn/cpp/ttnn/operations/<category>/<operation_name>/<operation_name>.hpp

A concrete example:

ttnn/cpp/ttnn/operations/examples/example/example.hpp
 1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
 2//
 3// SPDX-License-Identifier: Apache-2.0
 4
 5#pragma once
 6
 7#include "device/example_device_operation.hpp"
 8
 9namespace ttnn::operations::examples {
10
11// A composite operation is an operation that calls multiple operations in sequence
12// It is written using invoke and can be used to call multiple primitive and/or composite operations
13struct CompositeExampleOperation {
14    // The user will be able to call this method as `Tensor output = ttnn::composite_example(input_tensor)` after the op
15    // is registered
16    static Tensor invoke(const Tensor& input_tensor) {
17        auto copy = prim::example(input_tensor);
18        auto another_copy = prim::example(copy);
19        return another_copy;
20    }
21};
22
23}  // namespace ttnn::operations::examples
24
25namespace ttnn {
26constexpr auto composite_example =
27    ttnn::register_operation<"ttnn::composite_example", operations::examples::CompositeExampleOperation>();
28}  // namespace ttnn

Python Implementation

Step 1: Add Python binding

In order to add a python binding for the operation, follow the directory structure shown below:

ttnn/python/ttnn/operations/<category>/<operation_name>/<operation_name>_pybind.hpp ttnn/python/ttnn/operations/<category>/<category>_pybind.hpp

A concrete example:

ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp
 1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
 2//
 3// SPDX-License-Identifier: Apache-2.0
 4
 5#pragma once
 6
 7#include "ttnn-pybind/pybind_fwd.hpp"
 8
 9namespace ttnn::operations::examples {
10namespace py = pybind11;
11
12void bind_example_operation(py::module& module);
13
14}  // namespace ttnn::operations::examples
15
16/*
17void bind_example_operation(py::module& module) {
18    bind_registered_operation(
19        module,
20        ttnn::prim::example,
21        R"doc(example(input_tensor: ttnn.Tensor) -> ttnn.Tensor)doc",
22
23        // Add pybind overloads for the C++ APIs that should be exposed to python
24        // There should be no logic here, just a call to `self` with the correct arguments
25        // The overload with `queue_id` argument will be added automatically for primitive operations
26        // This specific function can be called from python as `ttnn.prim.example(input_tensor)` or
27        // `ttnn.prim.example(input_tensor, queue_id=queue_id)`
28        ttnn::pybind_overload_t{
29            [](const decltype(ttnn::prim::example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
30                return self(input_tensor);
31            },
32            py::arg("input_tensor")});
33
34    bind_registered_operation(
35        module,
36        ttnn::composite_example,
37        R"doc(composite_example(input_tensor: ttnn.Tensor) -> ttnn.Tensor)doc",
38
39        // Add pybind overloads for the C++ APIs that should be exposed to python
40        // There should be no logic here, just a call to `self` with the correct arguments
41        ttnn::pybind_overload_t{
42            [](const decltype(ttnn::composite_example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
43                return self(input_tensor);
44            },
45            py::arg("input_tensor")});
46}
47*/
ttnn/cpp/ttnn/operations/examples/examples_pybind.hpp
 1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
 2//
 3// SPDX-License-Identifier: Apache-2.0
 4
 5#pragma once
 6
 7#include "ttnn-pybind/pybind_fwd.hpp"
 8
 9namespace ttnn::operations::examples {
10namespace py = pybind11;
11
12void py_module(py::module& module);
13
14}  // namespace ttnn::operations::examples

Finally, call the module defined in examples/example/example_pybind.hpp wherever you want it to be added.

Step 2: (Optional) Add golden function for the operation in Python

A golden function can be added to an operation in order to compare its output with an equivalent torch implementation

Add the following code in a python file:

import ttnn

# For the golden function, use the same signature as the operation
# Keep in mind that all `ttnn.Tensor`s are converted to `torch.Tensor`s
# And arguments not needed by torch can be ignored using `*args` and `**kwargs`
def golden_function(input_tensor: "torch.Tensor", *args, **kwargs):
    output_tensor:  "torch.Tensor" = ...
    return output_tensor

# TT-NN Tensors are converted to torch tensors before calling the golden function automatically
# And the outputs are converted back to TT-NN Tensors
# But in some cases you may need to preprocess the inputs and postprocess the outputs manually

# In order to preprocess the inputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def preprocess_golden_function_inputs(args, kwargs):
    # i.e.
    ttnn_input_tensor = args[0]
    return ttnn.to_torch(ttnn_input_tensor)

# In order to postprocess the outputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def postprocess_golden_function_outputs(args, kwargs, output):
    # i.e.
    ttnn_input_tensor = args[0]
    torch_output_tensor = outputs[0]
    return ttnn.from_torch(torch_output_tensor, dtype=ttnn_input_tensor.dtype, device=ttnn_input_tensor.device)

ttnn.attach_golden_function(
    ttnn.example,
    golden_function=golden_function,
    preprocess_golden_function_inputs=preprocess_golden_function_inputs, # Optional
    postprocess_golden_function_outputs=postprocess_golden_function_outputs # Optional
)

Note

ttnn.example is the name of the operation in Python because the operation was registered as ttnn::example in C++.