Adding an Op

This guide will walk you through the process of adding a new Op end to end in tt-mlir, in this case we will be adding a matmul operation. Note that the matmul op was added as part of the same changeset as this guide, it could be useful to reference the diff alongside this guide to see the changes in full.

This guide will cover the following steps:

1. Define the Op in the TTIR frontend dialect

We will start by defining the Op in the TTIR dialect. The TTIR Ops are defined in a tablegen file located at include/ttmlir/Dialect/TTIR/IR/TTIROps.td.

Tablegen is a domain-specific language for defining ops/types/attributes in MLIR and LLVM, these definitions constitute the dialect's Operation Definition Specification (ODS).

Here is an example of defining matmul in the TTIR dialect:

def TTIR_MatmulOp : TTIR_NamedOp<"matmul"> {
    let summary = "Matrix multiplication operation.";
    let description = [{
      The `matmul` operation computes the matrix multiplication of two tensors.

      This operation performs matrix multiplication between tensors `a` and `b`. It supports optional
      transposition of either input tensor before multiplication. For 2D tensors, this computes the standard
      matrix product. For tensors with more dimensions, it applies batched matrix multiplication.

      Example:
      ```mlir
      // Basic matrix multiplication of 2D tensors
      %a = ... : tensor<3x4xf32>  // Matrix A with shape [3,4]
      %b = ... : tensor<4x5xf32>  // Matrix B with shape [4,5]
      %result = ttir.matmul(%a, %b) :
          (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>

      // Batched matrix multiplication with transposition
      %a = ... : tensor<2x3x4xf32>  // Batch of 2 matrices with shape [3,4]
      %b = ... : tensor<2x5x4xf32>  // Batch of 2 matrices with shape [5,4]
      %result = ttir.matmul(%a, %b) {
          transpose_a = false,  // Don't transpose A
          transpose_b = true    // Transpose B before multiplication
      } : (tensor<2x3x4xf32>, tensor<2x5x4xf32>) -> tensor<2x3x5xf32>
      ```

      Inputs:
      - `a` (Tensor): The first input tensor.
      - `b` (Tensor): The second input tensor.

      Attributes:
      - `transpose_a` (Boolean, default=false): Whether to transpose tensor `a` before multiplication.
      - `transpose_b` (Boolean, default=false): Whether to transpose tensor `b` before multiplication.

      Output:
      - `result` (Tensor): The result of the matrix multiplication.

      Note: The inner dimensions of the input tensors must be compatible for matrix multiplication.
      If `a` has shape [..., m, k] and `b` has shape [..., k, n], then the result will have shape [..., m, n].
      If `transpose_a` is true, then `a` is treated as having shape [..., k, m].
      If `transpose_b` is true, then `b` is treated as having shape [..., n, k].
    }];

    let arguments = (ins AnyRankedTensor:$a,
                         AnyRankedTensor:$b,
                         DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
                         DefaultValuedAttr<BoolAttr, "false">:$transpose_b);

    let results = (outs AnyRankedTensor:$result);

    let hasVerifier = 1;

    let hasCanonicalizer = 1;
}

There are many things to break down here, starting from the top:

  • def in tablegen is used to define a concrete type, this will have a 1-1 mapping to a C++ generated class, and for this particular case the build will end up generating file build/include/ttmlir/Dialect/TTIR/IR/TTIROps.h.inc.
  • It inherits from class TTIR_DPSOp, classes in tablegen don't define a concrete type, but rather an interface that augment or constrain inherited defs. TTIR_DPSOp is a class that defines the common attributes for all TTIR Ops that implement Destination Passing Style (DPS) semantics. DPS just means that the result tensor is passed as an argument to the operation which will be critical for modeling buffer allocation / lifetimes. Note the 3rd argument AnyRankedTensor:$output.
  • Next we have a list of arguments. These arguments consist of a mixture of Types (i.e. AnyRankedTensor) and Attributes. Read more about Types & Attributes here.
    • AnyRankedTensor is part of a tablegen standard library which type aliases to MLIR's builtin Tensor type, with the added constraint that the tensor has a static rank. As much as possible we want to use the builtin types and infrastructure provided by MLIR.
  • Next we have a list of results in this case just 1, which aliases the output tensor. One drawback of DPS is that the result tensor and the output tensor will appear to have different SSA names in the IR, but they really alias the same object. This can make writing some passes more cumbersome.
  • Next we have extraClassDeclaration, which enables us to inject member functions, written directly in C++, into the generated class. We are doing this for this particular case in order to satisfy the DPS interface which requires an implementation for getting the mutated output tensor.
  • Finally, we have hasVerifier = 1, this tells MLIR that we have a verifier function that will be called to validate the operation. This is a good practice to ensure that the IR is well formed.

We can now try building and opening the TTIROps.h.inc file to see the generated C++ code. We will actually get a linker error because we have hasVerifier = 1 which automatically declared a verifier function, but we need to go implement.

Let's head over to lib/Dialect/TTIR/IR/TTIROps.cpp and implement the verifier.

// MatmulOp verification
::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() {
  ::mlir::RankedTensorType inputAType = getA().getType();
  ::mlir::RankedTensorType inputBType = getB().getType();
  ::mlir::RankedTensorType outputType = getType();

  llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
  llvm::SmallVector<int64_t> inputAShape(inputAType.getShape());
  llvm::SmallVector<int64_t> inputBShape(inputBType.getShape());

  // Verify that the input A is at least 1D tensor.
  if (inputAType.getRank() < 1) {
    return emitOpError("Input A must be at least a 1D tensor");
  }

  // Verify that the input B is at least 1D tensor.
  if (inputBType.getRank() < 1) {
    return emitOpError("Input B must be at least a 1D tensor");
  }

  // If input A is a vector (1D tensor), 1 is prepended to its dimensions for
  // the purpose of the matrix multiplication. After the matrix
  // multiplication, the prepended dimension is removed. Otherwise, check if
  // the LHS needs to be transposed.
  if (inputAType.getRank() == 1) {
    inputAShape.insert(inputAShape.begin(), 1);
  } else if (getTransposeA()) {
    std::swap(inputAShape[inputAShape.size() - 1],
              inputAShape[inputAShape.size() - 2]);
  }

  // If input B is a vector (1D tensor), a 1 is appended to its dimensions for
  // the purpose of the matrix-vector product and removed afterwards.
  // Otherwise, check if the RHS needs to be transposed.
  if (inputBType.getRank() == 1) {
    inputBShape.push_back(1);
  } else if (getTransposeB()) {
    std::swap(inputBShape[inputBShape.size() - 1],
              inputBShape[inputBShape.size() - 2]);
  }

  // Verify that the input A and input B has matching inner dimensions.
  if (inputAShape[inputAShape.size() - 1] !=
      inputBShape[inputBShape.size() - 2]) {
    return emitOpError("Input A[-1](")
           << inputAShape[inputAShape.size() - 1] << ") and B[-2]("
           << inputBShape[inputBShape.size() - 2]
           << ") must have matching inner dimensions";
  }

  llvm::SmallVector<int64_t> expectedOutputShape;
  // Verify that the batch dimensions are broadcast compatible and construct
  // the expected output shape. If either of input A or input B is at most 2D
  // tensors, the batch dimensions are trivially broadcast compatible.
  if (inputAShape.size() > 2 || inputBShape.size() > 2) {
    llvm::SmallVector<int64_t> inputABatchDims(inputAShape.begin(),
                                               inputAShape.end() - 2);
    llvm::SmallVector<int64_t> inputBBatchDims(inputBShape.begin(),
                                               inputBShape.end() - 2);

    // Verify that the batch dimensions of input A and B are broadcast
    // compatible.
    llvm::SmallVector<int64_t, 4> broadcastedShape;
    if (!mlir::OpTrait::util::getBroadcastedShape(
            inputABatchDims, inputBBatchDims, broadcastedShape)) {

      return emitOpError("Batch dimensions of input A(" +
                         ttmlir::utils::join(inputABatchDims, ",") +
                         ") and B(" +
                         ttmlir::utils::join(inputBBatchDims, ",") +
                         ") are not broadcast compatible");
    }

    // Insert the broadcasted batch dimensions in the expected output shape.
    expectedOutputShape = std::move(broadcastedShape);
  }

  // Insert the input A and B inner dimensions in expected output shape
  // Consider the case where input A and B are vectors. In that case,
  // the dimension 1 is omitted from the output shape.
  if (inputAType.getRank() > 1) {
    expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]);
  }

  if (inputBType.getRank() > 1) {
    expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]);
  }

  // Check the case of a vector-vector product. At this moment we don't
  // support scalars in IR, hence check that the output is at least 1D tensor
  // of size 1.
  if (expectedOutputShape.size() == 0) {
    if (outputType.getRank() < 1) {
      return emitOpError("Scalar output is not supported, output must be at "
                         "least a 1D tensor");
    }

    if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) {
      return emitOpError("Scalar output must be a 1D tensor of size 1");
    }

    return success();
  }

  // Verify that the output shape is correct.
  if (outputShape.size() != expectedOutputShape.size()) {
    return emitOpError("Output shape rank(")
           << outputShape.size()
           << ") must match the expected output shape rank("
           << expectedOutputShape.size() << ")";
  }

  // Verify each dim of the output shape.
  for (auto [index, outputDim, expectedDim] : llvm::zip(
           llvm::seq(outputShape.size()), outputShape, expectedOutputShape)) {
    if (outputDim != expectedDim) {
      return emitOpError("Output shape dimension[")
             << index << "](" << outputDim
             << ") doesn't match the expected output shape dimension[" << index
             << "](" << expectedDim << ")";
    }
  }

  return success();
}

2. Define the Op in the TTNN backend dialect

Next we will define the Op in the TTNN dialect. TTNN Ops are defined in the same way, but in their respective set of dialect files. Refer to the previous section for details, the process is the same.

TTNNOps.td

def TTNN_MatmulOp : TTNN_Op<"matmul", [TTNN_ComputeKernelConfigOpInterface]> {
    let arguments = (ins AnyRankedTensor:$a,
                         AnyRankedTensor:$b,
                         DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
                         DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
                         OptionalAttr<AnyAttrOf<[
                            TTNN_MatmulMultiCoreReuseProgramConfigAttr,
                            TTNN_MatmulMultiCoreReuseMultiCastProgramConfigAttr,
                            TTNN_MatmulMultiCoreReuseMultiCast1DProgramConfigAttr,
                            TTNN_MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfigAttr
                         ]>>:$matmul_program_config,
                         OptionalAttr<StrAttr>:$activation,
                         OptionalAttr<TTNN_DeviceComputeKernelConfig>:$compute_config);

    let results = (outs AnyRankedTensor:$result);

    let builders = [
      OpBuilder<(ins "Type":$result, "Value":$a, "Value":$b,
                     "bool":$transpose_a, "bool":$transpose_b,
                     "Attribute":$matmul_program_config,
                     "StringAttr":$activation),
      [{
        build($_builder, $_state, result, a, b, transpose_a, transpose_b,
              matmul_program_config, activation, /*compute_config=*/nullptr);
      }]>
    ];

    let hasVerifier = 1;
}

TTNNOps.cpp

// MatmulOp verification
::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() {
  ::mlir::RankedTensorType inputAType = getA().getType();
  ::mlir::RankedTensorType inputBType = getB().getType();
  ::mlir::RankedTensorType outputType = getResult().getType();

  llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
  llvm::SmallVector<int64_t> inputAShape(inputAType.getShape());
  llvm::SmallVector<int64_t> inputBShape(inputBType.getShape());

  // Verify that the input A is at least 1D tensor.
  if (inputAType.getRank() < 1) {
    return emitOpError("Input A must be at least a 1D tensor");
  }

  // Verify that the input B is at least 1D tensor.
  if (inputBType.getRank() < 1) {
    return emitOpError("Input B must be at least a 1D tensor");
  }

  // If input A is a vector (1D tensor), 1 is prepended to its dimensions for
  // the purpose of the matrix multiplication. After the matrix multiplication,
  // the prepended dimension is removed. Otherwise, check if the LHS needs to be
  // transposed.
  if (inputAType.getRank() == 1) {
    inputAShape.insert(inputAShape.begin(), 1);
  } else if (getTransposeA()) {
    std::swap(inputAShape[inputAShape.size() - 1],
              inputAShape[inputAShape.size() - 2]);
  }

  // If input B is a vector (1D tensor), a 1 is appended to its dimensions for
  // the purpose of the matrix-vector product and removed afterwards. Otherwise,
  // check if the RHS needs to be transposed.
  if (inputBType.getRank() == 1) {
    inputBShape.push_back(1);
  } else if (getTransposeB()) {
    std::swap(inputBShape[inputBShape.size() - 1],
              inputBShape[inputBShape.size() - 2]);
  }

  // Verify that the input A and input B has matching inner dimensions.
  if (inputAShape[inputAShape.size() - 1] !=
      inputBShape[inputBShape.size() - 2]) {
    return emitOpError("Input A[-1](")
           << inputAShape[inputAShape.size() - 1] << ") and B[-2]("
           << inputBShape[inputBShape.size() - 2]
           << ") must have matching inner dimensions";
  }

  llvm::SmallVector<int64_t> expectedOutputShape;
  // Verify that the batch dimensions are broadcast compatible and construct the
  // expected output shape. If either of input A or input B is at most 2D
  // tensors, the batch dimensions are trivially broadcast compatible.
  if (inputAShape.size() > 2 || inputBShape.size() > 2) {
    llvm::SmallVector<int64_t> inputABatchDims(inputAShape.begin(),
                                               inputAShape.end() - 2);
    llvm::SmallVector<int64_t> inputBBatchDims(inputBShape.begin(),
                                               inputBShape.end() - 2);

    // Verify that the batch dimensions of input A and B are broadcast
    // compatible.
    llvm::SmallVector<int64_t, 4> broadcastedShape;
    if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims,
                                            broadcastedShape)) {

      return emitOpError("Batch dimensions of input A(" +
                         ttmlir::utils::join(inputABatchDims, ",") +
                         ") and B(" +
                         ttmlir::utils::join(inputBBatchDims, ",") +
                         ") are not broadcast compatible");
    }

    // Insert the broadcasted batch dimensions in the expected output shape.
    expectedOutputShape = std::move(broadcastedShape);
  }

  // Insert the input A and B inner dimensions in expected output shape
  // Consider the case where input A and B are vectors. In that case,
  // the dimension 1 is omitted from the output shape.
  if (inputAType.getRank() > 1) {
    expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]);
  }

  if (inputBType.getRank() > 1) {
    expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]);
  }

  // Check the case of a vector-vector product. At this moment we don't support
  // scalars in IR, hence check that the output is at least 1D tensor of size 1.
  if (expectedOutputShape.size() == 0) {
    if (outputType.getRank() < 1) {
      return emitOpError("Scalar output is not supported, output must be at "
                         "least a 1D tensor");
    }

    if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) {
      return emitOpError("Scalar output must be a 1D tensor of size 1");
    }

    return success();
  }

  // Verify that the output shape is correct.
  if (outputShape.size() != expectedOutputShape.size()) {
    return emitOpError("Output shape rank(")
           << outputShape.size()
           << ") must match the expected output shape rank("
           << expectedOutputShape.size() << ")";
  }

  // Verify each dim of the output shape.
  for (auto [index, outputDim, expectedDim] : llvm::zip(
           llvm::seq(outputShape.size()), outputShape, expectedOutputShape)) {
    if (outputDim != expectedDim) {
      return emitOpError("Output shape dimension[")
             << index << "](" << outputDim
             << ") doesn't match the expected output shape dimension[" << index
             << "](" << expectedDim << ")";
    }
  }

  return success();
}

For more details on adding ops to the TTNN dialect, refer to TTNN Dialect Contribution Guidelines.

Adding constraint/runtime APIs

We need to implement two APIs when adding a TTNN Op, namely getOpConstraints and getOpRuntime. More details about this can be found here.

3. Convert / Implement the Op in the TTNN passes

TTIR to TTNN

Next we will implement the conversion from the TTIR matmul Op to the TTNN matmul Op. This is a trivial conversion, as the Ops are identical in their semantics, so the changeset isn't going to be very instructive, but will at least point to the files involved. The conversion is implemented in the ConvertTTIRToTTNNPass pass in file lib/Conversion/TTIRToTTNN/TTIRToTTNNPass.cpp.

Zooming into class ConvertTTIRToTTNNPass we can see we implement the pass interface via member function void runOnOperation() final. This function will be called for every operation matching the type specified in the pass tablegen file. A quick look at include/ttmlir/Conversion/Passes.td we can see:

def ConvertTTIRToTTNN: Pass<"convert-ttir-to-ttnn", "::mlir::ModuleOp"> {

This means that runOnOperation will be called for every ModuleOp in the graph, usually there is only one ModuleOp which serves as the root of the graph.

Inside runOnOperation is usually where we define a rewrite pattern set that can match much more complicated patterns (nested inside of the ModuleOp's regions) than just a single operation. In runOperation method you will see the call to method populateTTIRToTTNNPatterns(...) that actually generates rewrite patterns. Method populateTTIRToTTNNPatterns(...) is defined in lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp.

  patterns
      .add<TensorEmptyConversionPattern,
           NamedFullConversionPattern<ttir::ZerosOp, ttnn::ZerosOp>,
           NamedFullConversionPattern<ttir::OnesOp, ttnn::OnesOp>,
           FullOpConversionPattern,
           ToLayoutOpConversionPattern,
           QuantizationOpConversionPattern<ttir::QuantizeUnrolledOp, ttnn::QuantizeOp>,
           QuantizationOpConversionPattern<ttir::DequantizeUnrolledOp, ttnn::DequantizeOp>,
           RequantizeOpConversionPattern,
           ElementwiseBinaryOpConversionPattern<ttir::AddOp, ttnn::AddOp>,
           ElementwiseBinaryOpConversionPattern<ttir::LogicalRightShiftOp, ttnn::LogicalRightShiftOp>,
           ElementwiseBinaryOpConversionPattern<ttir::SubtractOp, ttnn::SubtractOp>,
           ElementwiseBinaryOpConversionPattern<ttir::MultiplyOp, ttnn::MultiplyOp>,
           ElementwiseBinaryOpConversionPattern<ttir::DivOp, ttnn::DivideOp>,
           ElementwiseBinaryOpConversionPattern<ttir::EqualOp, ttnn::EqualOp>,
           ElementwiseBinaryOpConversionPattern<ttir::NotEqualOp, ttnn::NotEqualOp>,
           ElementwiseBinaryOpConversionPattern<ttir::GreaterEqualOp, ttnn::GreaterEqualOp>,
           ElementwiseBinaryOpConversionPattern<ttir::GreaterThanOp, ttnn::GreaterThanOp>,
           ElementwiseBinaryOpConversionPattern<ttir::LessEqualOp, ttnn::LessEqualOp>,
           ElementwiseBinaryOpConversionPattern<ttir::LessThanOp, ttnn::LessThanOp>,
           ElementwiseBinaryOpConversionPattern<ttir::LogicalAndOp, ttnn::LogicalAndOp>,
           ElementwiseBinaryOpConversionPattern<ttir::LogicalOrOp, ttnn::LogicalOrOp>,
           ElementwiseBinaryOpConversionPattern<ttir::LogicalXorOp, ttnn::LogicalXorOp>,
           ElementwiseOpConversionPattern<ttir::BitwiseAndOp, ttnn::BitwiseAndOp>,
           ElementwiseOpConversionPattern<ttir::LogicalLeftShiftOp, ttnn::LogicalLeftShiftOp>,
           ElementwiseOpConversionPattern<ttir::BitwiseOrOp, ttnn::BitwiseOrOp>,
           ElementwiseOpConversionPattern<ttir::BitwiseXorOp, ttnn::BitwiseXorOp>,
           ElementwiseOpConversionPattern<ttir::MaximumOp, ttnn::MaximumOp>,
           ElementwiseOpConversionPattern<ttir::MinimumOp, ttnn::MinimumOp>,
           ElementwiseOpConversionPattern<ttir::RemainderOp, ttnn::RemainderOp>,
           ElementwiseOpConversionPattern<ttir::Atan2Op, ttnn::Atan2Op>,
           ElementwiseOpConversionPattern<ttir::AbsOp, ttnn::AbsOp>,
           ElementwiseOpConversionPattern<ttir::CbrtOp, ttnn::CbrtOp>,
           ElementwiseOpConversionPattern<ttir::FloorOp, ttnn::FloorOp>,
           ElementwiseOpConversionPattern<ttir::IsFiniteOp, ttnn::IsFiniteOp>,
           ElementwiseOpConversionPattern<ttir::LogicalNotOp, ttnn::LogicalNotOp>,
           ElementwiseOpConversionPattern<ttir::BitwiseNotOp, ttnn::BitwiseNotOp>,
           ElementwiseOpConversionPattern<ttir::MishOp, ttnn::MishOp>,
           ElementwiseOpConversionPattern<ttir::NegOp, ttnn::NegOp>,
           ElementwiseOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
           ElementwiseOpConversionPattern<ttir::Relu6Op, ttnn::Relu6Op>,
           ElementwiseOpConversionPattern<ttir::GeluOp, ttnn::GeluOp>,
           ElementwiseOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
           ElementwiseOpConversionPattern<ttir::RsqrtOp, ttnn::RsqrtOp>,
           ElementwiseOpConversionPattern<ttir::SignOp, ttnn::SignOp>,
           ElementwiseOpConversionPattern<ttir::SigmoidOp, ttnn::SigmoidOp>,
           ElementwiseOpConversionPattern<ttir::HardsigmoidOp, ttnn::HardsigmoidOp>,
           ElementwiseOpConversionPattern<ttir::SiluOp, ttnn::SiluOp>,
           ElementwiseOpConversionPattern<ttir::Log1pOp, ttnn::Log1pOp>,
           ElementwiseOpConversionPattern<ttir::ReciprocalOp, ttnn::ReciprocalOp>,
           ElementwiseOpConversionPattern<ttir::ExpOp, ttnn::ExpOp>,
           ElementwiseOpConversionPattern<ttir::ErfOp, ttnn::ErfOp>,
           ElementwiseOpConversionPattern<ttir::ErfcOp, ttnn::ErfcOp>,
           ElementwiseOpConversionPattern<ttir::LogOp, ttnn::LogOp>,
           ElementwiseOpConversionPattern<ttir::CeilOp, ttnn::CeilOp>,
           ElementwiseOpConversionPattern<ttir::SinOp, ttnn::SinOp>,
           ElementwiseOpConversionPattern<ttir::AsinOp, ttnn::AsinOp>,
           ElementwiseOpConversionPattern<ttir::AsinhOp, ttnn::AsinhOp>,
           ElementwiseOpConversionPattern<ttir::CosOp, ttnn::CosOp>,
           ElementwiseOpConversionPattern<ttir::AcosOp, ttnn::AcosOp>,
           ElementwiseOpConversionPattern<ttir::Expm1Op, ttnn::Expm1Op>,
           ElementwiseOpConversionPattern<ttir::WhereOp, ttnn::WhereOp>,
           ElementwiseOpConversionPattern<ttir::TanOp, ttnn::TanOp>,
           ElementwiseOpConversionPattern<ttir::TanhOp, ttnn::TanhOp>,
           ElementwiseOpConversionPattern<ttir::AtanOp, ttnn::AtanOp>,
           Pooling2dOpConversionPattern<ttir::MaxPool2dOp, ttnn::MaxPool2dOp>,
           Pooling2dOpConversionPattern<ttir::MaxPool2dWithIndicesOp, ttnn::MaxPool2dWithIndicesOp>,
           Pooling2dOpConversionPattern<ttir::AvgPool2dOp, ttnn::AvgPool2dOp>,
           GlobalAvgPool2dOpConversionPattern,
           ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
           ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
           ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
           ReductionOpConversionPattern<ttir::MinOp, ttnn::MinOp>,
           ReductionProdOpConversionPattern,
           ReductionArgMaxOpConversionPattern,
           ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
           BroadcastOpConversionPattern,
           PadOpConversionPattern,
           PowOpConversionPattern,
           EmbeddingOpConversionPattern,
           EmbeddingBackwardOpConversionPattern,
           RepeatOpConversionPattern,
           CumSumOpConversionPattern,
           RepeatInterleaveOpConversionPattern,
           SoftmaxOpConversionPattern,
           SortOpConversionPattern,
           BitcastConvertOPConversionPattern,
           TypecastOpConversionPattern,
           ClampOpConversionPattern<ttir::ClampScalarOp, ttnn::ClampScalarOp>,
           ClampOpConversionPattern<ttir::ClampTensorOp, ttnn::ClampTensorOp>,
           ConcatOpConversionPattern,
           ReshapeOpConversionPattern,
           SliceOpConversionPattern<ttir::SliceStaticOp, ttnn::SliceStaticOp>,
           SliceOpConversionPattern<ttir::SliceDynamicOp, ttnn::SliceDynamicOp>,
           SqueezeOpConversionPattern,
           UnsqueezeOpConversionPattern,
           ConstantOpConversionPattern,
           LinearOpConversionPattern,
           BatchNormInferenceOpConversionPattern,
           BatchNormTrainingOpConversionPattern,
           RMSNormOpConversionPattern,
           DistributedRMSNormOpConversionPattern,
           DistributedLayerNormOpConversionPattern,
           LayerNormOpConversionPattern,
           GroupNormOpConversionPattern,
           MatmulOpConversionPattern,
           SparseMatmulOpConversionPattern,
           AllToAllDispatchOpConversionPattern,
           AllToAllDispatchMetadataOpConversionPattern,
           AllToAllCombineOpConversionPattern,
           SelectiveReduceCombineOpConversionPattern,
           MoeExpertTokenRemapOpConversionPattern,
           Conv2dOpConversionPattern,
           Conv3dOpConversionPattern,
           ConvTranspose2dOpConversionPattern,
           MeshShardOpConversionPattern,
           AllReduceOpConversionPattern,
           AllReduceAsyncOpConversionPattern,
           AllGatherOpConversionPattern,
           MeshPartitionOpConversionPattern,
           ReduceScatterOpConversionPattern,
           CollectivePermuteOpConversionPattern,
           ArangeOpConversionPattern,
           RandOpConversionPattern,
           UpdateCacheOpConversionPattern,
           PagedFillCacheOpConversionPattern,
           PagedUpdateCacheOpConversionPattern,
           SamplingOpConversionPattern,
           FillCacheOpConversionPattern,
           ScatterOpConversionPattern,
           GatherOpConversionPattern,
           PermuteOpConversionPattern,
           UpsampleOpConversionPattern,
           AllToAllOpConversionPattern,
           CollectiveBroadcastOpConversionPattern,
           ConcatenateHeadsOpConversionPattern,
           ScaledDotProductAttentionOpConversionPattern,
           ScaledDotProductAttentionDecodeOpConversionPattern,
           PagedScaledDotProductAttentionDecodeOpConversionPattern,
           PagedFlashMultiLatentAttentionDecodeOpConversionPattern,
           SplitQueryKeyValueAndSplitHeadsOpConversionPattern,
           GeluBackwardOpConversionPattern,
           DropoutOpConversionPattern,
           DebugOpConversionPattern<debug::DumpOp, ttnn::DumpTensorOp>,
           TopKOpConversionPattern,
           TopKRouterGptOpConversionPattern
           >(typeConverter, ctx);

More information on rewrite patterns and their capabilities can be found in the MLIR documentation here and here.

For matmul, we defined a new conversion pattern that's generic to all binary ops with arguments named a and b:

namespace {
class MatmulOpConversionPattern : public OpConversionPattern<ttir::MatmulOp> {
public:
  using OpConversionPattern<ttir::MatmulOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ttir::MatmulOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto newOp = rewriter.replaceOpWithNewOp<ttnn::MatmulOp>(
        op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(),
        adaptor.getB(), adaptor.getTransposeA(), adaptor.getTransposeB(),
        /*matmul_program_config=*/nullptr, /*activation=*/nullptr);
    if (auto attr = op->getAttr("ttcore.weight_dtype")) {
      newOp->setAttr("ttcore.weight_dtype", attr);
    }
    return success();
  }
};
} // namespace

Invoked as part of the rewrite set:

MatmulOpConversionPattern

TTNN to EmitC

Similarly, we also need to add a pattern to convert from TTNN dialect to EmitC dialect.

Method to populate rewrite patterns can be found in lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp:

void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
                                 mlir::RewritePatternSet &patterns,
                                 TypeConverter &typeConverter) {
  // Device ops
  //
  patterns.add<TTDeviceOpConversionPattern>(typeConverter, ctx);
  patterns.add<GetDeviceOpConversionPattern>(typeConverter, ctx);

  // Memory ops
  //
  // clang-format off
  patterns.add<ToLayoutOpConversionPattern,
               ToMemoryConfigOpConversionPattern,
               TypecastOpConversionPattern,
               ToDeviceOpConversionPattern,
               FromDeviceOpConversionPattern,
               DeallocateOpConversionPattern>(typeConverter, ctx);
  // clang-format on

  // Tensor ops
  //
  // clang-format off
  patterns.add<EmptyOpConversionPattern,
               NamedFullOpConversionPattern<mlir::tt::ttnn::ZerosOp>,
               NamedFullOpConversionPattern<mlir::tt::ttnn::OnesOp>,
               FullOpConversionPattern,
               DefaultOpConversionPattern<mlir::tt::ttnn::ArangeOp>,
               DefaultOpConversionPattern<mlir::tt::ttnn::ConstantOp>,
               RandOpConversionPattern,
               AssignOpConversionPattern>(typeConverter, ctx);
  // clang-format on

  // Eltwise unary ops
  //
  patterns
      .add<EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::AbsOp>,
           EltwiseUnaryCompositeOpConversionPattern<mlir::tt::ttnn::CbrtOp>,
           ClampOpConversionPattern<::mlir::tt::ttnn::ClampScalarOp>,
           ClampOpConversionPattern<mlir::tt::ttnn::ClampTensorOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::FloorOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::IsFiniteOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::LogicalNotOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::BitwiseNotOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::NegOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::ReluOp>,
           EltwiseUnaryWithFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::RsqrtOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::Relu6Op>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::HardsigmoidOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::SiluOp>,
           EltwiseUnaryWithFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::MishOp>,
           ElementwiseUnaryWithFloatParameterOpConversionPattern<
               mlir::tt::ttnn::LeakyReluOp>,
           EltwiseUnaryWithFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::GeluOp>,
           EltwiseUnaryWithFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::SqrtOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::SignOp>,
           EltwiseUnaryWithVectorAndFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::SigmoidOp>,
           EltwiseUnaryCompositeWithFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::Log1pOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::ReciprocalOp>,
           EltwiseUnaryWithFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::ExpOp>,
           EltwiseUnaryWithFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::ErfOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::ErfcOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::CeilOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::SinOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::AsinOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::AsinhOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::CosOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::AcosOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::Expm1Op>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::TanOp>,
           EltwiseUnaryWithOutputAndApproxModeOpConversionPattern<
               mlir::tt::ttnn::TanhOp>,
           EltwiseUnaryOpConversionPattern<mlir::tt::ttnn::AtanOp>,
           EltwiseUnaryWithFastAndApproximateModeOpConversionPattern<
               mlir::tt::ttnn::LogOp>>(typeConverter, ctx);

  // Eltwise binary ops
  //
  patterns.add<
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::AddOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::LogicalRightShiftOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::SubtractOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::MultiplyOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::LogicalAndOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::LogicalOrOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::LogicalXorOp>,
      EltwiseBinaryCompositeOpConversionPattern<mlir::tt::ttnn::BitwiseAndOp>,
      EltwiseBinaryCompositeOpConversionPattern<mlir::tt::ttnn::BitwiseOrOp>,
      EltwiseBinaryCompositeOpConversionPattern<mlir::tt::ttnn::BitwiseXorOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::EqualOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::NotEqualOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::GreaterEqualOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::GreaterThanOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::LessEqualOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::LessThanOp>,
      EltwiseBinaryNGCompositeOpConversionPattern<mlir::tt::ttnn::MaximumOp>,
      EltwiseBinaryNGCompositeOpConversionPattern<mlir::tt::ttnn::MinimumOp>,
      EltwiseBinaryOpConversionPattern<mlir::tt::ttnn::DivideOp>,
      EltwiseBinaryCompositeOpConversionPattern<
          mlir::tt::ttnn::LogicalLeftShiftOp>,
      EltwiseBinaryNGCompositeOpConversionPattern<mlir::tt::ttnn::RemainderOp>,
      EltwiseBinaryNGCompositeOpConversionPattern<mlir::tt::ttnn::PowTensorOp>,
      EltwiseBinaryCompositeOpConversionPattern<mlir::tt::ttnn::Atan2Op>,
      PowScalarOpConversionPattern>(typeConverter, ctx);

  // Experimental binary backward ops
  //
  patterns.add<ExperimentalGeluBackwardOpConversionPattern>(typeConverter, ctx);

  // Experimental dropout op
  //
  patterns.add<DropoutOpConversionPattern>(typeConverter, ctx);

  // Eltwise ternary ops
  //
  patterns.add<EltwiseTernaryOpConversionPattern<mlir::tt::ttnn::WhereOp>>(
      typeConverter, ctx);

  // Tensor manipulation ops
  //
  patterns
      .add<TransposeOpConversionPattern, ConcatOpConversionPattern,
           ReshapeOpConversionPattern, RepeatOpConversionPattern,
           RepeatInterleaveOpConversionPattern, SliceStaticOpConversionPattern,
           SliceDynamicOpConversionPattern, SortOpConversionPattern,
           PermuteOpConversionPattern, PadOpConversionPattern,
           TTNNToEmitCTopKOpConversionPattern,
           TTNNToEmitCSamplingOpConversionPattern, GatherOpConversionPattern>(
          typeConverter, ctx);

  // Quantization ops.
  //
  patterns.add<QuantizationOpConversionPattern<mlir::tt::ttnn::QuantizeOp>,
               QuantizationOpConversionPattern<mlir::tt::ttnn::DequantizeOp>,
               RequantizeOpConversionPattern>(typeConverter, ctx);

  // Matmul ops
  //
  patterns.add<LinearOpConversionPattern, MatmulOpConversionPattern,
               SparseMatmulOpConversionPattern>(typeConverter, ctx);

  // Reduction ops
  //
  patterns.add<ReductionOpConversionPattern<mlir::tt::ttnn::SumOp>,
               ReductionOpConversionPattern<mlir::tt::ttnn::MeanOp>,
               ReductionOpConversionPattern<mlir::tt::ttnn::MaxOp>,
               ReductionOpConversionPattern<mlir::tt::ttnn::MinOp>,
               ProdOpConversionPattern, ArgMaxOpConversionPattern,
               TTNNToEmitCTopKRouterGptOpConversionPattern>(typeConverter, ctx);

  // Pooling ops
  //
  patterns.add<AvgPool2dOpConversionPattern>(typeConverter, ctx);
  patterns.add<MaxPool2dOpConversionPattern>(typeConverter, ctx);
  patterns.add<MaxPool2dWithIndicesOpConversionPattern>(typeConverter, ctx);
  patterns.add<GlobalAvgPool2dOpConversionPattern>(typeConverter, ctx);
  patterns.add<UpsampleOpConversionPattern>(typeConverter, ctx);

  // Convolution ops
  //
  patterns.add<PrepareConv2dWeightsOpConversionPattern>(typeConverter, ctx);
  patterns.add<PrepareConv2dBiasOpConversionPattern>(typeConverter, ctx);
  patterns.add<PrepareConvTranspose2dWeightsOpConversionPattern>(typeConverter,
                                                                 ctx);
  patterns.add<PrepareConvTranspose2dBiasOpConversionPattern>(typeConverter,
                                                              ctx);
  patterns.add<Conv2dOpConversionPattern>(typeConverter, ctx);
  patterns.add<Conv3dOpConversionPattern>(typeConverter, ctx);
  patterns.add<ConvTranspose2dOpConversionPattern>(typeConverter, ctx);

  // Other ops
  //
  patterns
      .add<SoftmaxOpConversionPattern, EmbeddingOpConversionPattern,
           DefaultOpConversionPattern<mlir::tt::ttnn::EmbeddingBackwardOp>,
           CumSumOpConversionPattern, BatchNormInferenceOpConversionPattern,
           BatchNormTrainingOpConversionPattern, RMSNormOpConversionPattern,
           RMSNormPreAllGatherOpConversionPattern,
           DistributedRMSNormOpConversionPattern, LayerNormOpConversionPattern,
           LayerNormPreAllGatherOpConversionPattern,
           LayerNormPostAllGatherOpConversionPattern,
           GroupNormOpConversionPattern>(typeConverter, ctx);

  // CCL ops
  //
  patterns.add<AllGatherOpConversionPattern>(typeConverter, ctx);
  patterns.add<AllReduceOpConversionPattern>(typeConverter, ctx);
  patterns.add<AllReduceAsyncOpConversionPattern>(typeConverter, ctx);
  patterns.add<ReduceScatterOpConversionPattern>(typeConverter, ctx);
  patterns.add<ScatterOpConversionPattern>(typeConverter, ctx);
  patterns.add<MeshPartitionOpConversionPattern>(typeConverter, ctx);
  patterns.add<MeshShardOpConversionPattern>(typeConverter, ctx);
  patterns.add<DistributeTensorOpConversionPattern>(typeConverter, ctx);
  patterns.add<AggregateTensorOpConversionPattern>(typeConverter, ctx);
  patterns.add<PointToPointOpConversionPattern>(typeConverter, ctx);
  patterns.add<AllToAllDispatchOpConversionPattern>(typeConverter, ctx);
  patterns.add<AllToAllDispatchMetadataOpConversionPattern>(typeConverter, ctx);
  patterns.add<AllToAllCombineOpConversionPattern>(typeConverter, ctx);
  patterns.add<MoeExpertTokenRemapOpConversionPattern>(typeConverter, ctx);

  // KV Cache ops
  //
  patterns.add<UpdateCacheOpConversionPattern>(typeConverter, ctx);
  patterns.add<PagedUpdateCacheOpConversionPattern>(typeConverter, ctx);
  patterns.add<PagedFillCacheOpConversionPattern>(typeConverter, ctx);
  patterns.add<DefaultOpConversionPattern<mlir::tt::ttnn::FillCacheOp>>(
      typeConverter, ctx);

  // Tensor serialization ops
  //
  patterns.add<DumpTensorOpConversionPattern>(typeConverter, ctx);
  patterns.add<LoadTensorOpConversionPattern>(typeConverter, ctx);

  // Trace ops
  //
  patterns.add<WriteTensorOpConversionPattern>(typeConverter, ctx);
  patterns.add<BeginTraceCaptureOpConversionPattern>(typeConverter, ctx);
  patterns.add<EndTraceCaptureOpConversionPattern>(typeConverter, ctx);
  patterns.add<CaptureOrExecuteTraceOpConversionPattern>(typeConverter, ctx);
  patterns.add<ExecuteTraceOpConversionPattern>(typeConverter, ctx);

  // Arith ops
  //
  patterns.add<ArithConstantOpConversionPattern>(typeConverter, ctx);

  // Tuple ops
  //
  patterns.add<GetTupleElementOpConversionPattern>(typeConverter, ctx);
  patterns.add<TupleOpConversionPattern>(typeConverter, ctx);

  // LoadCached op
  //
  patterns.add<LoadCachedOpConversionPattern>(typeConverter, ctx);

  // Module op
  //
  patterns.add<ModuleOpConversionPattern>(typeConverter, ctx);

  // FuncOp
  //
  patterns.add<FuncOpConversionPattern>(typeConverter, ctx);

  // Transformers ops
  //
  patterns.add<ConcatenateHeadsOpConversionPattern>(typeConverter, ctx);
  patterns.add<SplitQueryKeyValueAndSplitHeadsOpConversionPattern>(
      typeConverter, ctx);
  patterns.add<RotaryEmbeddingLlamaOpConversionPattern>(typeConverter, ctx);
  patterns.add<RotaryEmbeddingOpConversionPattern>(typeConverter, ctx);
  patterns.add<NLPConcatHeadsDecodeOpConversionPattern>(typeConverter, ctx);
  patterns.add<PagedScaledDotProductAttentionDecodeOpConversionPattern>(
      typeConverter, ctx);
  patterns.add<PagedFlashMultiLatentAttentionDecodeOpConversionPattern>(
      typeConverter, ctx);
  patterns.add<ScaledDotProductAttentionDecodeOpConversionPattern>(
      typeConverter, ctx);
  patterns.add<ScaledDotProductAttentionOpConversionPattern>(typeConverter,
                                                             ctx);
  patterns.add<NLPCreateQKVHeadsDecodeOpConversionPattern>(typeConverter, ctx);
}

Writing conversion patterns to EmitC is a little tricky at first. In general case, we will be converting an op that has operands (SSAs) and attributes (e.g. data type) as arguments. We want to flatten these arguments at call site.

We'll use EmitC's CallOpaqueOp as the target op. Let's take a look at our matmul IR within TTNN dialect:

"ttnn.matmul"(%2, %4, %5) : (tensor<64x128xbf16, #ttnn_layout4>, tensor<128x96xbf16, #ttnn_layout6>, tensor<64x96xbf16, #ttnn_layout7>) -> tensor<64x96xbf16, #ttnn_layout7>

Now let's look at matmul's call signature in TTNN lib:

    static Tensor invoke(
        const Tensor& input_tensor_a,
        const Tensor& input_tensor_b,
        const bool transpose_a = false,
        const bool transpose_b = false,
        const std::optional<const MemoryConfig>& memory_config = std::nullopt,
        const std::optional<const DataType> dtype = std::nullopt,
        const std::optional<const MatmulProgramConfig>& program_config = std::nullopt,
        const std::optional<const std::string>& activation = std::nullopt,
        const std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt,
        const std::optional<const CoreGrid> core_grid = std::nullopt,
        const std::optional<const tt::tt_metal::Tile>& output_tile = std::nullopt,
        std::optional<Tensor> optional_output_tensor = std::nullopt,
        const std::optional<const DeviceGlobalCircularBuffer>& global_cb = std::nullopt);

If we look closely, we'll notice that the IR has way less arguments than can be seen in the actual signature of the op - as we're lowering to EmitC, which gets translated into actual C++ code, we need to correct for this (ideally the op would be perfectly modelled with all the arguments, but that is not the case today).

We do this by filling in the gaps. EmitC's CallOpaqueOp takes in an array of attributes, and an array of operands, which need to be combined. The combining is done by extending the array of attributes with "pointers" into operands, like so:

    llvm::SmallVector<mlir::Attribute> args{
        emitter.emit(srcOp.getA()),
        emitter.emit(srcOp.getB()),
        emitter.emit(srcOp.getTransposeA()),
        emitter.emit(srcOp.getTransposeB()),
        emitter.emit(std::nullopt) | emitter.getMemoryConfig(srcOp.getResult()),
        emitter.emit(emitter.getOutputDtype(srcOp.getResult())),
        /*program_config=*/emitter.emit(std::nullopt),
        emitter.emit(srcOp.getActivation()),
        emitter.emit(srcOp.getComputeConfig()),
    };

Pointers are denoted with IndexTypes, wrapped into IntegerAttrs. Attributes are converted into EmitC's OpaqueAttr which can, for practical purposes, be treated as strings: a BoolAttr carrying "false" as value needs to be converted into an OpaqueAttr whose value is a string "false", which is what the convertBoolAttr function does.

This is our final converted EmitC CallOpaqueOp:

emitc.call_opaque "ttnn::matmul"(%3, %6, %9) {args = [0 : index, 1 : index, #emitc.opaque<"false">, #emitc.opaque<"false">, #emitc.opaque<"std::nullopt">, #emitc.opaque<"std::nullopt">, #emitc.opaque<"std::nullopt">, #emitc.opaque<"std::nullopt">, #emitc.opaque<"std::nullopt">, #emitc.opaque<"std::nullopt">, #emitc.opaque<"std::nullopt">, 2 : index]} : (!emitc.opaque<"ttnn::Tensor">, !emitc.opaque<"ttnn::Tensor">, !emitc.opaque<"ttnn::Tensor">) -> !emitc.opaque<"ttnn::Tensor">

which, when translated to C++ code, looks like:

ttnn::matmul(v6, v9, false, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, v12);

Full conversion pattern for matmul op:

namespace {
class MatmulOpConversionPattern
    : public TTNNToEmitCBaseOpConversionPattern<mlir::tt::ttnn::MatmulOp> {

public:
  using TTNNToEmitCBaseOpConversionPattern<
      mlir::tt::ttnn::MatmulOp>::TTNNToEmitCBaseOpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::tt::ttnn::MatmulOp srcOp,
                  mlir::tt::ttnn::MatmulOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    ttnn_to_emitc::EmitCTTNNEmitter<mlir::tt::ttnn::MatmulOp> emitter(
        srcOp, adaptor, rewriter);

    llvm::SmallVector<mlir::Attribute> args{
        emitter.emit(srcOp.getA()),
        emitter.emit(srcOp.getB()),
        emitter.emit(srcOp.getTransposeA()),
        emitter.emit(srcOp.getTransposeB()),
        emitter.emit(std::nullopt) | emitter.getMemoryConfig(srcOp.getResult()),
        emitter.emit(emitter.getOutputDtype(srcOp.getResult())),
        /*program_config=*/emitter.emit(std::nullopt),
        emitter.emit(srcOp.getActivation()),
        emitter.emit(srcOp.getComputeConfig()),
    };

    emitter.replaceOp(*this, args);

    return success();
  }
};
} // namespace

4. Add a compiler unit test for the Op

So far we have defined the Op in the TTIR and TTNN dialects, implemented verifiers, and have conversion passes. Now we need to add a unit test to ensure that the pass is working correctly. The compiler unit tests are located in test/ttmlir/Dialect area. In this case we'll add a test under the TTNN subdirectory since we are testing the ConvertTTIRToTTNNPass.

test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir

// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline -o %t %s
// RUN: FileCheck %s --input-file=%t
module {
  func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
    // CHECK: "ttnn.matmul"
    %1 = "ttir.matmul"(%arg0, %arg1) : (tensor<64x128xbf16>, tensor<128x96xbf16>) -> tensor<64x96xbf16>
    return %1 : tensor<64x96xbf16>
  }
}

Unit tests in MLIR are typically written using a tool called FileCheck, please refer to the llvm FileCheck documentation for a tutorial and more information about the RUN and CHECK directives.

A few things to point out specifically regarding tt-mlir dialects:

  • ttcore.system_desc: This is a 1-1 mapping to the SystemDesc flatbuffer schema that is used to describe the system configuration. This is a required attribute tagged on the top level module for all tt-mlir dialects.
  • Pass --ttnn-layout is a prerequisite before running convert-ttir-to-ttnn. This pass is responsible for converting the input tensors to device memory space and tile layout before lowering to TTNN.
  • This test is asserting that ttir.matmul converts to ttnn.matmul.

To run the test, you can use the following command:

cmake --build build -- check-ttmlir

You can also manually run ttmlir-opt on the test file to see the resulting output:

./build/bin/ttmlir-opt --ttcore-register-device="system-desc-path=<PATH_TO_SYSTEM_DESC>" --ttir-to-ttnn-runtime-pipeline test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir

5. Define flatbuffer schema for the Op

Next we will define the flatbuffer schema for the Op. The schema must capture all tensor inputs, outputs, and attributes of the Op, i.e. everything the runtime needs to execute the Op.

The schema can be placed in an existing .fbs file located in the include/ttmlir/Target/TTNN/operations directory.

If no suitable .fbs file exists for the operation category, feel free to create new .fbs files as needed. After creating a new .fbs file, remember to add a corresponding cmake target in the include/ttmlir/Target/TTNN/CMakeLists.txt file.

include/ttmlir/Target/TTNN/CMakeLists.txt

  operations/matmul.fbs

In our case, we can add our schema to include/ttmlir/Target/TTNN/operations/matmul.fbs directly, without needing to create a new file.

include/ttmlir/Target/TTNN/operations/matmul.fbs

table MatmulOp {
  a: tt.target.ttnn.TensorRef;
  b: tt.target.ttnn.TensorRef;
  out: tt.target.ttnn.TensorRef;
  transpose_a: bool;
  transpose_b: bool;
  matmul_program_config: tt.target.ttnn.MatmulProgramConfig;
  activation: string;
  compute_config: tt.target.ttnn.DeviceComputeKernelConfig;
}

Type TensorRef, flatbuffer tables with suffix Ref are used to represent live values during the runtime, decoupled from the underlying Desc suffixes which carry the type and attribute information for the object.

After creating the schema for our new operation type, we need to register it in the OpType union within program.fbs. This file serves as the main entry point for all program information, where the OpType union collects and defines all supported operation types and their corresponding schemas.

include/ttmlir/Target/TTNN/program.fbs

  MatmulOp,

If a new .fbs file was created, don't forget to include the new file in include/ttmlir/Target/TTNN/program.fbs.

include "ttmlir/Target/TTNN/operations/matmul.fbs";

More information about writing flatbuffer schemas can be found in the flatbuffers documentation

6. Serialize the Op in the flatbuffer format

In the previous section we defined the flatbuffer schema for the matmul Op, now let's put our new schema definition to use. The schema is used as input to a program called flatc which generates C++ code (or any language for that matter) for serializing and deserializing the schema. This generated code can be found in build/include/ttmlir/Target/TTNN/program_generated.h.

Let's head over to lib/Target/TTNN/TTNNToFlatbuffer.cpp to define a createOp overloaded function that does the conversion from MLIR to flatbuffer:

::flatbuffers::Offset<::tt::target::ttnn::MatmulOp>
createOp(FlatbufferObjectCache &cache, MatmulOp op) {
  auto a = cache.at<::tt::target::ttnn::TensorRef>(
      getOperandThroughDPSOps(op.getA()));
  auto b = cache.at<::tt::target::ttnn::TensorRef>(
      getOperandThroughDPSOps(op.getB()));
  auto output =
      cache.getOrCreateNoSharding(op.getResult(), tensorValueToFlatbuffer,

                                  /*local_shape*/ std::nullopt);

  using MatmulConfigType = ::tt::target::ttnn::MatmulProgramConfig;
  MatmulConfigType matmulProgramConfigType = MatmulConfigType::NONE;
  ::flatbuffers::Offset<void> matmulProgramConfigDesc;
  if (auto matmulProgramConfig = op.getMatmulProgramConfigAttr()) {
    if (auto config =
            mlir::dyn_cast<ttnn::MatmulMultiCoreReuseProgramConfigAttr>(
                matmulProgramConfig)) {
      matmulProgramConfigType =
          MatmulConfigType::MatmulMultiCoreReuseProgramConfig;
      matmulProgramConfigDesc = toFlatbuffer(cache, config).Union();
    } else if (auto config = mlir::dyn_cast<
                   ttnn::MatmulMultiCoreReuseMultiCastProgramConfigAttr>(
                   matmulProgramConfig)) {
      matmulProgramConfigType =
          MatmulConfigType::MatmulMultiCoreReuseMultiCastProgramConfig;
      matmulProgramConfigDesc = toFlatbuffer(cache, config).Union();
    } else if (auto config = mlir::dyn_cast<
                   ttnn::MatmulMultiCoreReuseMultiCast1DProgramConfigAttr>(
                   matmulProgramConfig)) {
      matmulProgramConfigType =
          MatmulConfigType::MatmulMultiCoreReuseMultiCast1DProgramConfig;
      matmulProgramConfigDesc = toFlatbuffer(cache, config).Union();
    } else if (
        auto config = mlir::dyn_cast<
            ttnn::MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfigAttr>(
            matmulProgramConfig)) {
      matmulProgramConfigType = MatmulConfigType::
          MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig;
      matmulProgramConfigDesc = toFlatbuffer(cache, config).Union();
    }
  }

  auto activation = toFlatbuffer(cache, op.getActivation()).value_or(0);

  std::optional<
      ::flatbuffers::Offset<::tt::target::ttnn::DeviceComputeKernelConfig>>
      computeConfig = toFlatbuffer(cache, op.getComputeConfig());

  return ::tt::target::ttnn::CreateMatmulOp(
      *cache.fbb, a, b, output, op.getTransposeA(), op.getTransposeB(),
      matmulProgramConfigType, matmulProgramConfigDesc, activation,
      computeConfig.value_or(0));
}

Lots of things are happening here, let's break it down:

  • FlatbufferObjectCache: This is a helper class that is used to cache objects in the flatbuffer that are created during the serialization process. This is necessary for managing value lifetimes and identifiers, at the same time it is an optimization to avoid having multiple copies of the same object. For example, a TensorRef with multiple uses could naively be recreated, one for each use, but with the cache we can ensure that the object is only created once and all uses point to the same flatbuffer offset. The cache is passed around to all serialization functions and should be used whenever creating a new object.
  • getOperandThroughDPSOps: In section 1. we discussed DPS semantics and the drawback of having the result alias the output tensor. This is one of those cases where we need to use a helper function to trace through the output operands to find the original SSA name in order to associate it with the original TensorRef.
  • CreateMatmulOp: The autogenerated function from the flatbuffer schema that actually serializes the data into the flatbuffer format.

We can finally generate a binary with our new Op! We can use the following command:

./build/bin/ttmlir-opt --ttcore-register-device="system-desc-path=<PATH_TO_SYSTEM_DESC>" --ttir-to-ttnn-runtime-pipeline test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir | ./build/bin/ttmlir-translate --ttnn-to-flatbuffer -o out.ttnn

And we can inspect the with ttrt:

ttrt read out.ttnn

Note: If the above ttrt command yields a segfault, a clean build of your workspace may be required: Build Instructions

7. Add runtime support for the Op

Next, we want to add runtime support for the Op by parsing the flatbuffer and invoking the TTNN API.

runtime/lib/ttnn/operations/matmul/matmul.cpp

void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) {
  ProgramTensorPool &tensorPool = context.getTensorPool();
  const ::ttnn::Tensor &lhs = tensorPool.getTTNNTensorAndValidate(op->a());
  const ::ttnn::Tensor &rhs = tensorPool.getTTNNTensorAndValidate(op->b());

  auto outputMemoryConfig =
      ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded(
          ::tt::runtime::ttnn::utils::getTensorRefMemoryConfig(op->out()));
  LOG_ASSERT(::tt::runtime::ttnn::utils::inSystemMemory(op->out()) ||
                 outputMemoryConfig,
             "Memory config must exist for device tensors");

  ::ttnn::DataType outputDataType = utils::getDataType(op->out());

  std::optional<::ttnn::operations::matmul::MatmulProgramConfig>
      matmulProgramConfig = utils::createMatmulProgramConfigIfNeeded(op);

  std::optional<std::string> activation =
      op->activation() ? std::make_optional(op->activation()->str())
                       : std::nullopt;

  std::optional<::ttnn::DeviceComputeKernelConfig> computeConfig;
  if (op->compute_config()) {
    computeConfig =
        utils::createDeviceComputeKernelConfig(op->compute_config());
  }

  ::ttnn::Tensor output = ::ttnn::matmul(
      lhs, rhs, op->transpose_a(), op->transpose_b(), outputMemoryConfig,
      outputDataType, matmulProgramConfig,
      /*activation=*/activation, /*compute_kernel_config=*/computeConfig,
      /*core_grid=*/std::nullopt, /*output_tile=*/std::nullopt,
      /* optional_output_tensor=*/std::nullopt);

  tensorPool.insertTTNNTensorAndValidate(op->out(), output);
}

A couple things to note from above:

  • Most runtime op functions will follow a similar pattern, they will take in some additional datastructures for managing the program context.
    • Program context tracks the state of the current program. It stores intermediate tensors and devices.
  • tensorPool.at(op->in0()->global_id()): global_id is a unique identifier for the tensor that was generated and managed by the FlatbufferObjectCache. This is how it's intended to be used by the runtime.
  • Some operations may belong to a larger set of operations. For example, any eltwise unary operations can be added in runtime/lib/ttnn/operations/eltwise/unary.cpp directly without needing to create a new file.

If a new file is created for the op, we need to add a new source to runtime/lib/ttnn/operations/CMakeLists.txt and a new case to runtime/lib/ttnn/program_executor.cpp.

To update runtime/lib/ttnn/operations/CMakeLists.txt, include the path to the source file in TTNN_OPS_SRCS:

runtime/lib/ttnn/operations/CMakeLists.txt

  ${CMAKE_CURRENT_SOURCE_DIR}/matmul/matmul.cpp

To update runtime/lib/ttnn/program_executor.cpp, add a new case to the runOperation method of ProgramExecutor:

runtime/lib/ttnn/program_executor.cpp

  case ::tt::target::ttnn::OpType::MatmulOp: {
    return operations::matmul::run(op->type_as_MatmulOp(), getContext());
  }

We can test our changes with ttrt (don't forget to rebuild ttrt):

ttrt run out.ttnn

8. Add a silicon unit test for the Op

After adding runtime support, we're ready to test our Op on silicon. All silicon tests are located under test/ttmlir/Silicon. The process is similar to adding a compiler unit test.

In our specific case, we create a unit test here:

test/ttmlir/Silicon/TTNN/matmul/simple_matmul.mlir

// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" -o %t.mlir %s
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer -o %t.ttnn %t.mlir
module {
  func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
    // CHECK: "ttnn.matmul"
    %1 = "ttir.matmul"(%arg0, %arg1) : (tensor<64x128xbf16>, tensor<128x96xbf16>) -> tensor<64x96xbf16>
    return %1 : tensor<64x96xbf16>
  }

  func.func @matmul_transpose_lhs(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<128x128xbf16> {
    // CHECK: "ttnn.matmul"
    %1 = "ttir.matmul"(%arg0, %arg1) <{transpose_a = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<128x128xbf16>
    return %1 : tensor<128x128xbf16>
  }

  func.func @matmul_transpose_rhs(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> {
    // CHECK: "ttnn.matmul"
    %1 = "ttir.matmul"(%arg0, %arg1) <{transpose_b = true}>: (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x64xbf16>
    return %1 : tensor<64x64xbf16>
  }
}

Couple things to point out about this process:

  • Tests placed under test/ttmlir/Dialect will only test the compiler's capability of compiling the module. If you want the module to run on silicon in CI, the test must be placed under test/ttmlir/Silicon.
  • Notice the differences between the compilation headers of test/ttmlir/Silicon/TTNN/simple_matmul.mlir and test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir
    • --ttir-to-ttnn-runtime-pipeline="system-desc-path=%system_desc_path%": The system-desc-path option specifies the location of the system descriptor required for compiling the module. This is crucial for silicon tests, as modules compiled with different system descriptors may vary in silicon compatibility. Ensuring the system descriptor accurately reflects the target hardware is essential for running the module correctly.
    • // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn: This runs ttmlir-translate that serializes the output mlir module to a flatbuffer binary. We added the logic for this serialization in the Serialize the Op in the flatbuffer format section.

9. Add an EmitC test for the Op

Op should be tested in the EmitC (C++ codegen) path as well.

TTNN EmitC tests live in the test/ttmlir/EmitC/TTNN path. In our case, the test is in test/ttmlir/EmitC/TTNN/matmul/matmul.mlir.

test/ttmlir/EmitC/TTNN/matmul/matmul.mlir

// TODO(dmilinkovic): re-enable CPU-hoisted const-eval once EmitC support for CPU-hoisted ops lands - issue #6100.
// RUN: ttmlir-opt --ttir-to-ttnn-common-pipeline="enable-cpu-hoisted-const-eval=false system-desc-path=%system_desc_path%" -o %t.mlir %s
//
// RUN: ttmlir-opt --ttnn-common-to-runtime-pipeline -o %t_rt.mlir %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer -o %basename_t.ttnn %t_rt.mlir
//
// RUN: ttmlir-opt --ttnn-common-to-emitc-pipeline -o %t2.mlir %t.mlir
// RUN: ttmlir-translate --mlir-to-cpp -o %basename_t.cpp %t2.mlir

func.func @matmul(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
  %1 = "ttir.matmul"(%arg0, %arg1) : (tensor<64x128xbf16>, tensor<128x96xbf16>) -> tensor<64x96xbf16>
  return %1 : tensor<64x96xbf16>
}

The first two RUN lines create a flatbuffer. The third and forth convert to EmitC dialect, translate to C++, then output the result to matmul.mlir.cpp file.

Additionally, the op's header file operations/matmul/matmul.hpp should be added to the list of includes in tools/ttnn-standalone/ttnn-precompiled.hpp:

#include "operations/ccl/all_gather/all_gather.hpp"
#include "operations/ccl/all_to_all_combine/all_to_all_combine.hpp"
#include "operations/ccl/all_to_all_dispatch/all_to_all_dispatch.hpp"
#include "operations/ccl/ccl_host_types.hpp"
#include "operations/conv/conv2d/conv2d.hpp"
#include "operations/conv/conv2d/prepare_conv2d_weights.hpp"
#include "operations/conv/conv_transpose2d/conv_transpose2d.hpp"
#include "operations/core/core.hpp"
#include "operations/creation/creation.hpp"
#include "operations/data_movement/concat/concat.hpp"
#include "operations/data_movement/gather/gather.hpp"
#include "operations/data_movement/pad/pad.hpp"
#include "operations/data_movement/permute/permute.hpp"
#include "operations/data_movement/repeat/repeat.hpp"
#include "operations/data_movement/repeat_interleave/repeat_interleave.hpp"
#include "operations/data_movement/scatter/scatter.hpp"
#include "operations/data_movement/slice/slice.hpp"
#include "operations/data_movement/sort/sort.hpp"
#include "operations/data_movement/transpose/transpose.hpp"
#include "operations/eltwise/binary/binary.hpp"
#include "operations/eltwise/binary/binary_composite.hpp"
#include "operations/eltwise/quantization/quantization.hpp"
#include "operations/eltwise/unary/unary_composite.hpp"
#include "operations/embedding/embedding.hpp"
#include "operations/embedding_backward/embedding_backward.hpp"
#include "operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp"
#include "operations/experimental/ccl/all_to_all_dispatch_metadata/all_to_all_dispatch_metadata.hpp"
#include "operations/experimental/ccl/rms_allgather/rms_allgather.hpp"
#include "operations/experimental/conv3d/conv3d.hpp"
#include "operations/experimental/dropout/dropout.hpp"
#include "operations/experimental/transformer/nlp_concat_heads/nlp_concat_heads.hpp"
#include "operations/experimental/unary_backward/gelu_backward/gelu_backward.hpp"
#include "operations/kv_cache/kv_cache.hpp"
#include "operations/matmul/matmul.hpp"
#include "operations/normalization/batch_norm/batch_norm.hpp"
#include "operations/normalization/groupnorm/groupnorm.hpp"
#include "operations/normalization/layernorm/layernorm.hpp"
#include "operations/normalization/layernorm_distributed/layernorm_post_all_gather.hpp"
#include "operations/normalization/layernorm_distributed/layernorm_pre_all_gather.hpp"
#include "operations/normalization/rmsnorm/rmsnorm.hpp"
#include "operations/normalization/softmax/softmax.hpp"
#include "operations/pool/generic/generic_pools.hpp"
#include "operations/pool/global_avg_pool/global_avg_pool.hpp"
#include "operations/pool/upsample/upsample.hpp"
#include "operations/rand/rand.hpp"
#include "operations/reduction/accumulation/cumsum/cumsum.hpp"
#include "operations/reduction/argmax/argmax.hpp"
#include "operations/reduction/generic/generic_reductions.hpp"
#include "operations/reduction/prod/prod.hpp"
#include "operations/trace.hpp"
#include "operations/transformer/concatenate_heads/concatenate_heads.hpp"
#include "operations/transformer/sdpa/sdpa.hpp"
#include "operations/transformer/sdpa_decode/sdpa_decode.hpp"
#include "operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.hpp"
#include "tt-metalium/bfloat16.hpp"
#include "ttnn/common/queue_id.hpp"
#include "ttnn/core.hpp"
#include "ttnn/device.hpp"
#include "ttnn/global_semaphore.hpp"
#include "ttnn/operations/copy/typecast/typecast.hpp"
#include "ttnn/operations/experimental/paged_cache/paged_cache.hpp"
#include "ttnn/operations/experimental/topk_router_gpt/topk_router_gpt.hpp"
#include "ttnn/operations/experimental/transformer/nlp_concat_heads_decode/nlp_concat_heads_decode.hpp"
#include "ttnn/operations/experimental/transformer/nlp_create_qkv_heads_decode/nlp_create_qkv_heads_decode.hpp"
#include "ttnn/operations/experimental/transformer/rotary_embedding/rotary_embedding.hpp"
#include "ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.hpp"
#include "ttnn/operations/normalization/layernorm/layernorm.hpp"
#include "ttnn/operations/reduction/topk/topk.hpp"
#include "ttnn/tensor/serialization.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"
#include "ttnn/types.hpp"
#include "workarounds.hpp"

10. Add a builder test for the Op

Builder tests verify end-to-end numerical correctness — they compile through the full TTIR → TTNN pipeline and execute on silicon, comparing results against PyTorch golden values. They complement the structural lit tests from steps 4 and 8.

10a. Add the op to ttir_builder.py

If the op does not yet have a builder method in tools/builder/ttir/ttir_builder.py, add one now. This involves:

  1. Writing the builder method that calls _op_proxy with the appropriate TTIR op class, operands, and keyword attributes.
  2. Registering a golden function in tools/golden/mapping.py so _op_proxy can compare compiler output against the PyTorch reference.

Refer to Adding a new op to ttir-builder for the full workflow, including how to write golden functions and register them in GOLDEN_MAPPINGS.

10b. Write the builder test

Builder tests live under test/python/golden/ttir_ops/<category>/. Add a new file (or extend an existing one) for the op. The test parametrizes over shapes, dtypes, and target backends, then delegates to compile_and_execute_ttir which builds the TTIR module, compiles it, runs it on device, and checks numerical correctness against the golden.

test/python/golden/ttir_ops/matmul/test_matmul.py

# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from typing import List, Optional
from conftest import x86_only, get_request_kwargs
from builder.base.builder_utils import Operand, Shape
from builder.ttir.ttir_builder import TTIRBuilder
from builder.base.builder_apis import compile_and_execute_ttir
from test_utils import (
    shapes_list_str,
    shape_str,
)

pytestmark = pytest.mark.frontend("ttir")


@pytest.mark.parametrize("shape", [(128, 128), (4, 128, 128)], ids=shape_str)
@pytest.mark.parametrize("dtype", [torch.float32], ids=["f32"])
@pytest.mark.parametrize("transpose_a", [False, True])
@pytest.mark.parametrize("transpose_b", [False, True])
@pytest.mark.parametrize("target", ["ttnn", "emitpy"])
def test_matmul(
    shape: Shape,
    dtype: torch.dtype,
    transpose_a: bool,
    transpose_b: bool,
    target: str,
    request,
    device,
):
    def module(builder: TTIRBuilder):
        @builder.func([shape, shape], [dtype, dtype])
        def matmul(
            in0: Operand,
            in1: Operand,
            builder: TTIRBuilder,
            unit_attrs: Optional[List[str]] = None,
        ):
            return builder.matmul(
                in0,
                in1,
                transpose_a=transpose_a,
                transpose_b=transpose_b,
                unit_attrs=unit_attrs,
            )

    pipeline_options = []
    compile_and_execute_ttir(
        module,
        **get_request_kwargs(request),
        target=target,
        device=device,
        pipeline_options=pipeline_options,
    )


@pytest.mark.parametrize(
    "shapes",
    [
        [
            (10, 64, 64),
            (64, 64),
        ]
    ],
    ids=shapes_list_str,
)
@pytest.mark.parametrize("dtype", [torch.float32], ids=["f32"])
@pytest.mark.parametrize("has_bias", [True, False], ids=["with_bias", "without_bias"])
@pytest.mark.parametrize("transpose_a", [False, True])
@pytest.mark.parametrize("transpose_b", [False, True])
@pytest.mark.parametrize("target", ["ttnn", "emitpy"])
def test_linear(
    shapes: List[Shape],
    dtype: torch.dtype,
    has_bias: bool,
    transpose_a: bool,
    transpose_b: bool,
    target: str,
    request,
    device,
):
    bias_shape = None
    if has_bias:
        bias_shape = (shapes[1][-1],)

    def module(builder: TTIRBuilder):
        # Set up input shapes and types based on whether bias is used
        if has_bias:
            input_shapes = shapes + [bias_shape]
            input_types = [dtype, dtype, dtype]
        else:
            input_shapes = shapes
            input_types = [dtype, dtype]

        @builder.func(input_shapes, input_types)
        def linear(*args, unit_attrs: Optional[List[str]] = None):
            # The builder is always passed as the last positional argument
            builder = args[-1]
            inputs = args[:-1]

            in0 = inputs[0]
            in1 = inputs[1]
            bias = inputs[2] if len(inputs) > 2 else None

            return builder.linear(
                in0,
                in1,
                bias,
                transpose_a=transpose_a,
                transpose_b=transpose_b,
                unit_attrs=unit_attrs,
            )

    compile_and_execute_ttir(
        module,
        **get_request_kwargs(request),
        device=device,
        target=target,
    )


@x86_only
@pytest.mark.parametrize(
    "shapes",
    [
        [(10, 64, 32), (32, 128), (128,)],
        [(10, 20), (20, 30)],
    ],
    ids=["3D_with_bias", "2D_no_bias"],
)
@pytest.mark.parametrize("dtype", [torch.float32], ids=["f32"])
@pytest.mark.parametrize("target", ["ttnn", "emitpy"])
def test_hoisted_linear(
    shapes: List[Shape], dtype: torch.dtype, target: str, request, device
):
    def module(builder: TTIRBuilder):
        @builder.func(shapes, [dtype] * len(shapes))
        def hoisted_linear(
            *inputs,
            unit_attrs: Optional[List[str]] = None,
        ):
            builder = inputs[-1]
            in0 = inputs[0]
            in1 = inputs[1]
            bias = inputs[2] if len(inputs) > 3 else None
            return builder.linear(in0, in1, bias, unit_attrs=["ttir.should_hoist"])

    compile_and_execute_ttir(
        module,
        **get_request_kwargs(request),
        target=target,
        device=device,
    )


@x86_only
@pytest.mark.parametrize(
    "shapes",
    [
        [(10, 20), (20, 30)],
        [(5, 10, 20), (5, 20, 30)],
    ],
    ids=["standard_2D_matmul", "3D_batched_matmul"],
)
@pytest.mark.parametrize("dtype", [torch.float32], ids=["f32"])
@pytest.mark.parametrize("target", ["ttnn", "emitpy"])
def test_hoisted_matmul(
    shapes: List[Shape], dtype: torch.dtype, target: str, request, device
):
    def module(builder: TTIRBuilder):
        @builder.func(shapes, [dtype] * len(shapes))
        def hoisted_matmul(
            in0: Operand,
            in1: Operand,
            builder: TTIRBuilder,
            unit_attrs: Optional[List[str]] = None,
        ):
            return builder.matmul(in0, in1, unit_attrs=["ttir.should_hoist"])

    compile_and_execute_ttir(
        module,
        **get_request_kwargs(request),
        target=target,
        device=device,
    )

Key conventions:

  • pytestmark = pytest.mark.frontend("ttir") — required file-wide mark.
  • Shape parameters must be named shape, shapes, input_shape, or input_shapes; dtype parameters must be named dtype, dtypes, or input_dtypes. See Builder Testing for the full rules.
  • target must be a parametrized dimension even for single-target tests (the device fixture reads it to initialise the right backend).
  • The inner @builder.func([shapes], [dtypes]) decorator wires up MLIR function arguments; the computation is expressed with builder.<op>(...) calls inside.

Run with:

pytest test/python/golden/ttir_ops/matmul/test_matmul.py

For full parametrization rules, skip marks (skip_config, x86_only), and test-report requirements refer to Builder Testing.

11. Add CPU-hoisting support (if applicable)

CPU-hoisting moves selected TTIR ops off the device and executes them on the host CPU, improving numerical precision (full f32/i32) and reducing peak DRAM/L1 usage. It is appropriate for standard elementwise, reduction, normalization, and similar ops. Skip this step for complex, model-specific ops (e.g., ScaledDotProductAttentionDecodeOp) that have no meaningful host execution path and should always run on device.

Hoisted ops are lowered through two independent paths:

  • Runtime target (flatbuffer path): TTIR → Linalg/TOSA → LLVM IR → .so dylib, embedded in the flatbuffer and loaded at runtime via dlopen().
  • EmitPy target: TTIR → CallOpaqueOp("ttir_cpu.<op>") → pure-torch implementation in ttir_cpu.py.

11a. TTIR to Linalg/TOSA conversion pattern

Add a conversion pattern in lib/Conversion/TTIRToLinalg/ and register it in the appropriate populate function. Elementwise ops typically lower to linalg.generic or a TOSA equivalent; see existing patterns in the same directory for reference.

11b. TTIR to Linalg/TOSA lit test

Add test/ttmlir/Conversion/TTIRToLinalg/<op_name>.mlir. See existing tests in that directory for the RUN/FileCheck boilerplate and CHECK conventions for both Linalg and TOSA targets.

11c. Decomposition alternative to 11a–11b (runtime path only)

If the op has no natural Linalg/TOSA equivalent but decomposes cleanly into ops that already have Linalg support (e.g., DotGeneralOpMatmulOp), steps 11a–11b can be skipped. Add the op as illegal under DecompMode::CPUFallback in lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp and add the decomposition pattern under lib/Conversion/TTIRToTTIRDecomposition/. The runtime CPU pipeline runs this decomposition before the Linalg lowering, and the hoisting validation uses the same check.

Note: decomposition does not help the EmitPy path — the EmitPy CPU pipeline has no prior decomposition step, so steps 11d–11e are still required.

11d. TTIR to EmitPy CPU conversion pattern

Add a pattern in lib/Conversion/TTIRToEmitPy/TTIRCPUToEmitPyPass.cpp. For elementwise ops, use the existing generic templates; for ops with non-trivial attributes write a custom pattern using EmitPyCallBuilder. See existing patterns in the same file for reference.

11e. Torch implementation in ttir_cpu

Add a pure-torch function to tools/tt-alchemist/templates/python/local/ttir_cpu.py mirroring the TTIR op semantics (EmitPy target only — the runtime target goes through Linalg → LLVM). Use **_ to absorb unused kwargs, builtins.* when a local name shadows a Python builtin, and operate in float32 for reductions/normalization. Match TTIR semantics exactly — this is the reference implementation for const-eval.

11f. Tests for CPU-hoisted ops

Linalg conversion lit test: run the test added in 11b:

llvm-lit test/ttmlir/Conversion/TTIRToLinalg/<op_name>.mlir

EmitPy lit test: add a function to test/ttmlir/EmitPy/cpu_hoisted_ops.mlir. The convention is to run the op twice in one function — once tagged {ttir.should_hoist} and once on device — then subtract the results, exercising both paths. Check for the expected ttir_cpu.<op> call in the generated Python. See existing functions in that file for the pattern.

Builder test: add a test_cpu_hoistable_* test to test/python/golden/ttir_ops/<category>/test_<category>.py, passing unit_attrs=["ttir.should_hoist"] to the builder call and parametrizing over all three targets (ttnn, ttmetal, emitpy). See existing test_cpu_hoistable_* tests for exact parametrization and decorator patterns (e.g., @x86_only).