Adding New TT-NN Operation

Note

This document is meant for contributors to TT-NN.

Not all operations may be functional on all Tenstorrent hardware (Grayskull, Wormhole, or others).

FAQ

What is a TT-NN operation?

A TT-NN operation is a function that takes in one or more input tensors and produces one or more output tensors. It is implemented in C++ and can be called from Python.

What steps are needed to add TT-NN operation in C++?

  1. There are 2 options for writing a new operation. Optiona a is to write a device operation and option b is to write a composite operation a. Implement device operation in C++. Device operation is a struct that specifies how to create output tensors and a program to run on the device. b. Implement a composite operation in C++. Composite operation simply defines operator() method that calls other operations.

  2. Register the struct using ttnn::register_operation.

What steps are needed to add TT-NN operation in Python?

  1. Take an existing registered C++ operation and add a Python binding for it using ttnn::bind_registered_operation. The operation will be auto-registered in python. If the operation is called ttnn::add in C++, then the python binding will be ttnn.add.

  2. (Optional) Attach golden function to the operation using ttnn.attach_golden_function. This is useful for debugging and testing.

Example of Adding a new Device Operation

Let’s implement ttnn.example (It will just copy the input tensor to the output tensor on the device)

C++ Implementation

Step 1: Implement device operation

In order to add a new device operation, follow the directory structure shown below:

ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<operation_name>_device_operation.hpp ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<operation_name>_device_operation.cpp ttnn/cpp/ttnn/operations/<category>/<operation_name>/device/<program_factory_0>_program_factory.cpp

Note

Add as many program factories as needed. But the minimum requirement is one program factory.

A concrete example of a device operation can be found in ttnn/cpp/ttnn/operations/examples/example/device

ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.hpp
  1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5#pragma once
  6
  7#include <optional>
  8#include <variant>
  9
 10#include "ttnn/tensor/tensor.hpp"
 11#include "ttnn/core.hpp"
 12#include "ttnn/device_operation.hpp"
 13#include "ttnn/types.hpp"
 14#include "ttnn/decorators.hpp"
 15
 16namespace ttnn::operations::examples {
 17
 18struct ExampleDeviceOperation {
 19    // Define the operation attributes. This is it to store all variables needed by operations that aren't tensors
 20    struct operation_attributes_t {
 21        bool attribute;
 22        int some_other_attribute;
 23    };
 24
 25    // Define the tensor arguments. This is it to store all tensors passed in and/or out of the operation
 26    // Tensor arguments don't need to be just input tensors, they can be output tensors, input/output tensors, optional
 27    // tensors, etc.
 28    struct tensor_args_t {
 29        // This example will use a tensor that can only be used as an input
 30        const Tensor& input_tensor;
 31
 32        // However, the following examples show what else can be done with tensor_args_t
 33
 34        // An example of the tensor that can be used for input/output or just for pre-allocated output
 35        // Tensor& io_tensor;
 36
 37        // An example of an optional tensor
 38        // std::optional<Tensor> optional_output_tensor;
 39
 40        // An example of a vector of tensors
 41        // std::vector<Tensor> vector_of_tensors;
 42
 43        // An example of a tuple of tensors
 44        // std::tuple<Tensor, ...> tuple_of_tensors;
 45
 46        // An example of a vector of optional tensors
 47        // std::vector<std::optional<Tensor>> vector_of_optional_tensors;
 48
 49        // An example of a tuple of tensors
 50        // std::tuple<std::vector<std::optional<Tensor>>, std::optional<Tensor>> some_crazy_tuple_of_tensors;
 51    };
 52
 53    // Define the return types for the spec(s) of the operation
 54    // Can be a single ttnn::TensorSpec, std::optional<ttnn::TensorSpec>, std::vector<ttnn::TensorSpec>,
 55    // std::tuple<ttnn::TensorSpec> etc.
 56    using spec_return_value_t = ttnn::TensorSpec;
 57
 58    // Define the return types for the tensor(s) of the operation
 59    // Can be a single Tensor, std::optional<Tensor, ...>, std::vector<Tensor>, std::tuple<Tensor, ...> etc.
 60    using tensor_return_value_t = Tensor;
 61
 62    // Note spec_return_value_t and tensor_return_value_t should follow the same pattern
 63    // i.e. if spec_return_value_t is a std::vector<std::optional<ttnn::TensorSpec>> then tensor_return_value_t should
 64    // be std::vector<std::optional<Tensor>>
 65
 66    struct SingleCore {
 67        // Shared variables are the variables that are shared between the create and override_runtime_arguments methods
 68        struct shared_variables_t {
 69            tt::tt_metal::KernelHandle unary_reader_kernel_id;
 70            tt::tt_metal::KernelHandle unary_writer_kernel_id;
 71        };
 72        using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
 73
 74        static cached_program_t create(
 75            const operation_attributes_t& operation_attributes,
 76            const tensor_args_t& tensor_args,
 77            tensor_return_value_t& tensor_return_value);
 78
 79        static void override_runtime_arguments(
 80            cached_program_t& cached_program,
 81            const operation_attributes_t& operation_attributes,
 82            const tensor_args_t& tensor_args,
 83            tensor_return_value_t& tensor_return_value);
 84    };
 85
 86    struct MultiCore {
 87        // Shared variables are the variables that are shared between the create and override_runtime_arguments methods
 88        struct shared_variables_t {
 89            tt::tt_metal::KernelHandle unary_reader_kernel_id;
 90            tt::tt_metal::KernelHandle unary_writer_kernel_id;
 91            std::size_t num_cores;
 92            std::size_t num_cores_y;
 93        };
 94        using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;
 95
 96        static cached_program_t create(
 97            const operation_attributes_t& operation_attributes,
 98            const tensor_args_t& tensor_args,
 99            tensor_return_value_t& tensor_return_value);
100
101        static void override_runtime_arguments(
102            cached_program_t& cached_program,
103            const operation_attributes_t& operation_attributes,
104            const tensor_args_t& tensor_args,
105            tensor_return_value_t& tensor_return_value);
106    };
107
108    using program_factory_t = std::variant<SingleCore, MultiCore>;
109
110    // Mandatory methods
111
112    // Select the program factory based on the operation attributes and tensor args
113    static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
114
115    // Validate the operation when it creates a program. Usually will have more checks
116    static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
117
118    // Validate the operation when it reuses a program. Usually will have less checks
119    static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
120
121    // Compute the output specs based on the operation attributes and tensor args
122    static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
123
124    // Create the output tensors based on the operation attributes and tensor args
125    static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
126
127    // API call to map user arguments to operation attributes and tensor args.
128    // This is the only method that is called by the user
129    // The user will be able to call the operation using `tensor_return_value_t output =
130    // ttnn::prim::example(input_tensor)` after the op is registered Keep in mind that the the overload with `queue_id`
131    // argument will be added automatically for primitive operations So, the user can also call this operation using
132    // `tensor_return_value_t output = ttnn::prim::example(queue_id, input_tensor)`
133    static std::tuple<operation_attributes_t, tensor_args_t> invoke(const Tensor& input_tensor);
134
135    // Optional methods
136
137    // In case the operation need a custom hash function, the following method can be implemented
138    /* static tt::stl::hash::hash_t compute_program_hash(
139        const operation_attributes_t&, const tensor_args_t&);
140    */
141
142    // In case the operation needs a custom create_op_performance_model, this method can be implemented
143    /*
144    static tt::tt_metal::tt::tt_metal::operation::OpPerformanceModel create_op_performance_model(
145        const operation_attributes_t&,
146        const tensor_args_t&,
147        tensor_return_value_t&);
148    */
149};
150
151}  // namespace ttnn::operations::examples
152
153// Register the operation with the ttnn::register_operation API to make it available to the user as ttnn::prim::example
154namespace ttnn::prim {
155constexpr auto example =
156    ttnn::register_operation<"ttnn::prim::example", ttnn::operations::examples::ExampleDeviceOperation>();
157}  // namespace ttnn::prim
ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp
 1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
 2//
 3// SPDX-License-Identifier: Apache-2.0
 4
 5#include "example_device_operation.hpp"
 6
 7namespace ttnn::operations::examples {
 8
 9ExampleDeviceOperation::program_factory_t ExampleDeviceOperation::select_program_factory(
10    const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
11    bool some_condition_based_on_operation_attributes_and_or_tensor_args = true;
12    if (some_condition_based_on_operation_attributes_and_or_tensor_args) {
13        return SingleCore{};
14    }
15    return MultiCore{};
16}
17
18void ExampleDeviceOperation::validate_on_program_cache_miss(
19    const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {}
20
21void ExampleDeviceOperation::validate_on_program_cache_hit(
22    const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {}
23
24ExampleDeviceOperation::spec_return_value_t ExampleDeviceOperation::compute_output_specs(
25    const operation_attributes_t&, const tensor_args_t& tensor_args) {
26    const auto& input_tensor = tensor_args.input_tensor;
27    return TensorSpec(
28        input_tensor.logical_shape(),
29        tt::tt_metal::TensorLayout(
30            input_tensor.dtype(), tt::tt_metal::PageConfig(input_tensor.layout()), MemoryConfig{}));
31}
32
33ExampleDeviceOperation::tensor_return_value_t ExampleDeviceOperation::create_output_tensors(
34    const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
35    auto output_spec = compute_output_specs(operation_attributes, tensor_args);
36    return create_device_tensor(output_spec, tensor_args.input_tensor.device());
37}
38
39std::tuple<ExampleDeviceOperation::operation_attributes_t, ExampleDeviceOperation::tensor_args_t>
40ExampleDeviceOperation::invoke(const Tensor& input_tensor) {
41    return {operation_attributes_t{true, 42}, tensor_args_t{input_tensor}};
42}
43
44}  // namespace ttnn::operations::examples
ttnn/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp
  1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5#include "example_device_operation.hpp"
  6#include <tt-metalium/work_split.hpp>
  7
  8namespace ttnn::operations::examples {
  9ExampleDeviceOperation::SingleCore::cached_program_t ExampleDeviceOperation::SingleCore::create(
 10    const operation_attributes_t& operation_attributes,
 11    const tensor_args_t& tensor_args,
 12    tensor_return_value_t& tensor_return_value) {
 13    using namespace tt;
 14    using namespace tt::tt_metal;
 15
 16    const auto& input_tensor = tensor_args.input_tensor;
 17    auto& output_tensor = tensor_return_value;
 18
 19    auto src_buffer = input_tensor.buffer();
 20    auto dst_buffer = output_tensor.buffer();
 21
 22    tt::tt_metal::Program program{};
 23
 24    tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype());
 25    uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format);
 26    tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.dtype());
 27    uint32_t single_tile_size_output = tt::tt_metal::detail::TileSize(cb_data_format_output);
 28
 29    uint32_t num_tiles = input_tensor.physical_volume() / tt::constants::TILE_HW;
 30
 31    tt::tt_metal::IDevice* device = input_tensor.device();
 32
 33    CoreCoord compute_with_storage_grid_size = {1, 1};
 34    uint32_t num_cores_x = compute_with_storage_grid_size.x;
 35    uint32_t num_cores_y = compute_with_storage_grid_size.y;
 36    auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
 37        tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles);
 38
 39    uint32_t src0_cb_index = tt::CBIndex::c_0;
 40    uint32_t num_input_tiles = 2;
 41    tt::tt_metal::CircularBufferConfig cb_src0_config =
 42        tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
 43            .set_page_size(src0_cb_index, single_tile_size);
 44    auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);
 45
 46    uint32_t output_cb_index = tt::CBIndex::c_2;
 47    uint32_t num_output_tiles = 2;
 48    tt::tt_metal::CircularBufferConfig cb_output_config =
 49        tt::tt_metal::CircularBufferConfig(
 50            num_output_tiles * single_tile_size_output, {{output_cb_index, cb_data_format_output}})
 51            .set_page_size(output_cb_index, single_tile_size_output);
 52    auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
 53
 54    bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
 55    std::vector<uint32_t> reader_compile_time_args = {(uint32_t)src_is_dram};
 56    bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
 57    std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram};
 58
 59    tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
 60        program,
 61        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
 62        all_cores,
 63        tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
 64
 65    tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
 66        program,
 67        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp",
 68        all_cores,
 69        tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));
 70
 71    std::vector<uint32_t> compute_kernel_args_group_1 = {
 72        num_tiles_per_core_group_1,  // per_core_block_cnt
 73        1                            // per_core_block_size
 74    };
 75
 76    bool math_approx_mode = false;
 77    auto eltwise_unary_kernel_group_1_id = tt::tt_metal::CreateKernel(
 78        program,
 79        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
 80        core_group_1,
 81        tt::tt_metal::ComputeConfig{
 82            .math_fidelity = MathFidelity::HiFi4,
 83            .math_approx_mode = math_approx_mode,
 84            .compile_args = compute_kernel_args_group_1});
 85
 86    if (!core_group_2.ranges().empty()) {
 87        std::vector<uint32_t> compute_kernel_args_group_2 = {
 88            num_tiles_per_core_group_2,  // per_core_block_cnt
 89            1                            // per_core_block_size
 90        };
 91
 92        auto eltwise_unary_kernel_group_2_id = tt::tt_metal::CreateKernel(
 93            program,
 94            "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
 95            core_group_2,
 96            tt::tt_metal::ComputeConfig{
 97                .math_fidelity = MathFidelity::HiFi4,
 98                .math_approx_mode = math_approx_mode,
 99                .compile_args = compute_kernel_args_group_2});
100    }
101
102    for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
103        CoreCoord core = {i / num_cores_y, i % num_cores_y};
104        uint32_t num_tiles_per_core = 0;
105        if (core_group_1.contains(core)) {
106            num_tiles_per_core = num_tiles_per_core_group_1;
107        } else if (core_group_2.contains(core)) {
108            num_tiles_per_core = num_tiles_per_core_group_2;
109        } else {
110            TT_ASSERT(false, "Core not in specified core ranges");
111        }
112
113        tt::tt_metal::SetRuntimeArgs(
114            program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});
115
116        tt::tt_metal::SetRuntimeArgs(
117            program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles_per_core, num_tiles_written});
118        num_tiles_written += num_tiles_per_core;
119    }
120
121    return {
122        std::move(program),
123        {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}};
124}
125
126void ExampleDeviceOperation::SingleCore::override_runtime_arguments(
127    cached_program_t& cached_program,
128    const operation_attributes_t& operation_attributes,
129    const tensor_args_t& tensor_args,
130    tensor_return_value_t& tensor_return_value) {
131    auto& program = cached_program.program;
132    auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id;
133    auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;
134
135    const auto& input_tensor = tensor_args.input_tensor;
136    auto& output_tensor = tensor_return_value;
137
138    auto src_buffer = input_tensor.buffer();
139    auto dst_buffer = output_tensor.buffer();
140
141    {
142        auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0});
143        runtime_args[0] = src_buffer->address();
144    }
145
146    {
147        auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0});
148        runtime_args[0] = dst_buffer->address();
149    }
150}
151
152}  // namespace ttnn::operations::examples
ttnn/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp
  1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5#include "example_device_operation.hpp"
  6#include <tt-metalium/work_split.hpp>
  7
  8namespace ttnn::operations::examples {
  9ExampleDeviceOperation::MultiCore::cached_program_t ExampleDeviceOperation::MultiCore::create(
 10    const operation_attributes_t& operation_attributes,
 11    const tensor_args_t& tensor_args,
 12    tensor_return_value_t& tensor_return_value) {
 13    using namespace tt;
 14    using namespace tt::tt_metal;
 15
 16    const auto& input_tensor = tensor_args.input_tensor;
 17    auto& output_tensor = tensor_return_value;
 18
 19    auto src_buffer = input_tensor.buffer();
 20    auto dst_buffer = output_tensor.buffer();
 21
 22    tt::tt_metal::Program program{};
 23
 24    tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype());
 25    uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format);
 26    tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.dtype());
 27    uint32_t single_tile_size_output = tt::tt_metal::detail::TileSize(cb_data_format_output);
 28
 29    uint32_t num_tiles = input_tensor.physical_volume() / tt::constants::TILE_HW;
 30
 31    tt::tt_metal::IDevice* device = input_tensor.device();
 32
 33    auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
 34    uint32_t num_cores_x = compute_with_storage_grid_size.x;
 35    uint32_t num_cores_y = compute_with_storage_grid_size.y;
 36    auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
 37        tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles);
 38
 39    uint32_t src0_cb_index = tt::CBIndex::c_0;
 40    uint32_t num_input_tiles = 2;
 41    tt::tt_metal::CircularBufferConfig cb_src0_config =
 42        tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
 43            .set_page_size(src0_cb_index, single_tile_size);
 44    auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);
 45
 46    uint32_t output_cb_index = tt::CBIndex::c_2;
 47    uint32_t num_output_tiles = 2;
 48    tt::tt_metal::CircularBufferConfig cb_output_config =
 49        tt::tt_metal::CircularBufferConfig(
 50            num_output_tiles * single_tile_size_output, {{output_cb_index, cb_data_format_output}})
 51            .set_page_size(output_cb_index, single_tile_size_output);
 52    auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
 53
 54    bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
 55    std::vector<uint32_t> reader_compile_time_args = {(uint32_t)src_is_dram};
 56    bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM;
 57    std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram};
 58
 59    tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
 60        program,
 61        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
 62        all_cores,
 63        tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
 64
 65    tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel(
 66        program,
 67        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp",
 68        all_cores,
 69        tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args));
 70
 71    std::vector<uint32_t> compute_kernel_args_group_1 = {
 72        num_tiles_per_core_group_1,  // per_core_block_cnt
 73        1                            // per_core_block_size
 74    };
 75
 76    bool math_approx_mode = false;
 77    auto eltwise_unary_kernel_group_1_id = tt::tt_metal::CreateKernel(
 78        program,
 79        "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
 80        core_group_1,
 81        tt::tt_metal::ComputeConfig{
 82            .math_fidelity = MathFidelity::HiFi4,
 83            .math_approx_mode = math_approx_mode,
 84            .compile_args = compute_kernel_args_group_1});
 85
 86    if (!core_group_2.ranges().empty()) {
 87        std::vector<uint32_t> compute_kernel_args_group_2 = {
 88            num_tiles_per_core_group_2,  // per_core_block_cnt
 89            1                            // per_core_block_size
 90        };
 91
 92        auto eltwise_unary_kernel_group_2_id = tt::tt_metal::CreateKernel(
 93            program,
 94            "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/compute/eltwise_sfpu.cpp",
 95            core_group_2,
 96            tt::tt_metal::ComputeConfig{
 97                .math_fidelity = MathFidelity::HiFi4,
 98                .math_approx_mode = math_approx_mode,
 99                .compile_args = compute_kernel_args_group_2});
100    }
101
102    for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
103        CoreCoord core = {i / num_cores_y, i % num_cores_y};
104        uint32_t num_tiles_per_core = 0;
105        if (core_group_1.contains(core)) {
106            num_tiles_per_core = num_tiles_per_core_group_1;
107        } else if (core_group_2.contains(core)) {
108            num_tiles_per_core = num_tiles_per_core_group_2;
109        } else {
110            TT_ASSERT(false, "Core not in specified core ranges");
111        }
112
113        tt::tt_metal::SetRuntimeArgs(
114            program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});
115
116        tt::tt_metal::SetRuntimeArgs(
117            program, unary_writer_kernel_id, core, {dst_buffer->address(), num_tiles_per_core, num_tiles_written});
118        num_tiles_written += num_tiles_per_core;
119    }
120
121    return {
122        std::move(program),
123        {.unary_reader_kernel_id = unary_reader_kernel_id,
124         .unary_writer_kernel_id = unary_writer_kernel_id,
125         .num_cores = num_cores,
126         .num_cores_y = num_cores_y}};
127}
128
129void ExampleDeviceOperation::MultiCore::override_runtime_arguments(
130    cached_program_t& cached_program,
131    const operation_attributes_t& operation_attributes,
132    const tensor_args_t& tensor_args,
133    tensor_return_value_t& tensor_return_value) {
134    auto& program = cached_program.program;
135    auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id;
136    auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;
137    auto& num_cores = cached_program.shared_variables.num_cores;
138    auto& num_cores_y = cached_program.shared_variables.num_cores_y;
139
140    const auto& input_tensor = tensor_args.input_tensor;
141    auto& output_tensor = tensor_return_value;
142
143    auto src_buffer = input_tensor.buffer();
144    auto dst_buffer = output_tensor.buffer();
145
146    for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) {
147        CoreCoord core = {i / num_cores_y, i % num_cores_y};
148
149        {
150            auto& runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
151            runtime_args[0] = src_buffer->address();
152        }
153
154        {
155            auto& runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core);
156            runtime_args[0] = dst_buffer->address();
157        }
158    }
159}
160
161}  // namespace ttnn::operations::examples

Step 2: Implement the operation in C++

In order to add a new operation, add the following file:

ttnn/cpp/ttnn/operations/<category>/<operation_name>/<operation_name>.hpp

A concrete example:

ttnn/cpp/ttnn/operations/examples/example/example.hpp
 1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
 2//
 3// SPDX-License-Identifier: Apache-2.0
 4
 5#pragma once
 6
 7#include "device/example_device_operation.hpp"
 8
 9namespace ttnn::operations::examples {
10
11// A composite operation is an operation that calls multiple operations in sequence
12// It is written using invoke and can be used to call multiple primitive and/or composite operations
13struct CompositeExampleOperation {
14    // The user will be able to call this method as `Tensor output = ttnn::composite_example(input_tensor)` after the op
15    // is registered
16    static Tensor invoke(const Tensor& input_tensor) {
17        auto copy = prim::example(input_tensor);
18        auto another_copy = prim::example(copy);
19        return another_copy;
20    }
21};
22
23}  // namespace ttnn::operations::examples
24
25namespace ttnn {
26constexpr auto composite_example =
27    ttnn::register_operation<"ttnn::composite_example", operations::examples::CompositeExampleOperation>();
28}  // namespace ttnn

Python Implementation

Step 1: Add Python binding

In order to add a python binding for the operation, follow the directory structure shown below:

ttnn/python/ttnn/operations/<category>/<operation_name>/<operation_name>_pybind.hpp ttnn/python/ttnn/operations/<category>/<category>_pybind.hpp

A concrete example:

ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp
 1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
 2//
 3// SPDX-License-Identifier: Apache-2.0
 4
 5#pragma once
 6
 7#include "ttnn-pybind/pybind_fwd.hpp"
 8
 9namespace ttnn::operations::examples {
10namespace py = pybind11;
11
12void bind_example_operation(py::module& module);
13
14}  // namespace ttnn::operations::examples
15
16/*
17void bind_example_operation(py::module& module) {
18    bind_registered_operation(
19        module,
20        ttnn::prim::example,
21        R"doc(example(input_tensor: ttnn.Tensor) -> ttnn.Tensor)doc",
22
23        // Add pybind overloads for the C++ APIs that should be exposed to python
24        // There should be no logic here, just a call to `self` with the correct arguments
25        // The overload with `queue_id` argument will be added automatically for primitive operations
26        // This specific function can be called from python as `ttnn.prim.example(input_tensor)` or
27        // `ttnn.prim.example(input_tensor, queue_id=queue_id)`
28        ttnn::pybind_overload_t{
29            [](const decltype(ttnn::prim::example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
30                return self(input_tensor);
31            },
32            py::arg("input_tensor")});
33
34    bind_registered_operation(
35        module,
36        ttnn::composite_example,
37        R"doc(composite_example(input_tensor: ttnn.Tensor) -> ttnn.Tensor)doc",
38
39        // Add pybind overloads for the C++ APIs that should be exposed to python
40        // There should be no logic here, just a call to `self` with the correct arguments
41        ttnn::pybind_overload_t{
42            [](const decltype(ttnn::composite_example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
43                return self(input_tensor);
44            },
45            py::arg("input_tensor")});
46}
47*/
ttnn/cpp/ttnn/operations/examples/examples_pybind.hpp
 1// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
 2//
 3// SPDX-License-Identifier: Apache-2.0
 4
 5#pragma once
 6
 7#include "ttnn-pybind/pybind_fwd.hpp"
 8
 9namespace ttnn::operations::examples {
10namespace py = pybind11;
11
12void py_module(py::module& module);
13
14}  // namespace ttnn::operations::examples

Finally, call the module defined in examples/example/example_pybind.hpp wherever you want it to be added.

Step 2: (Optional) Add golden function for the operation in Python

A golden function can be added to an operation in order to compare its output with an equivalent torch implementation

Add the following code in a python file:

import ttnn

# For the golden function, use the same signature as the operation
# Keep in mind that all `ttnn.Tensor`s are converted to `torch.Tensor`s
# And arguments not needed by torch can be ignored using `*args` and `**kwargs`
def golden_function(input_tensor: "torch.Tensor", *args, **kwargs):
    output_tensor:  "torch.Tensor" = ...
    return output_tensor

# TT-NN Tensors are converted to torch tensors before calling the golden function automatically
# And the outputs are converted back to TT-NN Tensors
# But in some cases you may need to preprocess the inputs and postprocess the outputs manually

# In order to preprocess the inputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def preprocess_golden_function_inputs(args, kwargs):
    # i.e.
    ttnn_input_tensor = args[0]
    return ttnn.to_torch(ttnn_input_tensor)

# In order to postprocess the outputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def postprocess_golden_function_outputs(args, kwargs, output):
    # i.e.
    ttnn_input_tensor = args[0]
    torch_output_tensor = outputs[0]
    return ttnn.from_torch(torch_output_tensor, dtype=ttnn_input_tensor.dtype, device=ttnn_input_tensor.device)

ttnn.attach_golden_function(
    ttnn.example,
    golden_function=golden_function,
    preprocess_golden_function_inputs=preprocess_golden_function_inputs, # Optional
    postprocess_golden_function_outputs=postprocess_golden_function_outputs # Optional
)

Note

ttnn.example is the name of the operation in Python because the operation was registered as ttnn::example in C++.