Decomposing an Op in TTIR

This guide explains how to add and decompose a new operation in the TTIR dialect. We’ll focus on adding an Index operation, which will be decomposed into the Slice operation. The decomposition is implemented as a conversion pass in MLIR since it allows us to mark operations or dialects as legal or illegal, type conversion...

This guide will cover the following steps:

1. Define the Op in the TTIR frontend dialect

The more information regarding this step can be found here: Define the Op in the TTIR frontend dialect

I updated the TTIROps.td as following:

def TTIR_IndexOp: TTIR_NamedOp<"index"> {
    let summary = "Tensor indexing operation.";
    let description = [{
      The `index` operation extracts a sub-tensor (slice) from the input tensor along a specified dimension.

      This operation selects elements from the input tensor along a single dimension based on the specified
      begin, end, and step indices. It's similar to Python's slicing notation `tensor[:, begin:end:step, :]`
      where the slicing is applied only to the specified dimension.

      Example:
      ```mlir
      // Extract elements with indices 1, 3, 5 from dimension 0 of a 1D tensor
      %input = ... : tensor<6xf32>  // Input tensor with values: [1, 2, 3, 4, 5, 6]
      %output = ttir.empty() : tensor<3xf32>  // Output tensor shape
      %result = ttir.index(%input, %output) {
          dim = 0 : i32,    // Dimension to index
          begin = 1 : i32,  // Start index
          end = 6 : i32,    // End index (exclusive)
          step = 2 : i32    // Step size
      } : tensor<6xf32>, tensor<3xf32> -> tensor<3xf32>
      // Result: [2, 4, 6]

      // Extract columns 0 and 2 from a 2D tensor
      %input = ... : tensor<3x4xf32>  // Input tensor with values:
                                      // [[1, 2, 3, 4],
                                      //  [5, 6, 7, 8],
                                      //  [9, 10, 11, 12]]
      %output = ttir.empty() : tensor<3x2xf32>  // Output tensor shape
      %result = ttir.index(%input, %output) {
          dim = 1 : i32,    // Index along columns (dimension 1)
          begin = 0 : i32,  // Start from first column
          end = 3 : i32,    // End at third column (exclusive)
          step = 2 : i32    // Take every other column
      } : tensor<3x4xf32>, tensor<3x2xf32> -> tensor<3x2xf32>
      // Result:
      // [[1, 3],
      //  [5, 7],
      //  [9, 11]]
      ```

      Inputs:
      - `input` (Tensor): The input tensor to index.

      Attributes:
      - `dim` (Integer): The dimension along which to index.
      - `begin` (Integer): The starting index.
      - `end` (Integer): The ending index (exclusive).
      - `step` (Integer): The step size between indices.

      Outputs:
      - `result` (Tensor): The indexed tensor.

      Note: The shape of the output tensor is the same as the input tensor except for the indexed dimension,
      which will have size `ceil((end - begin) / step)`. The indices selected will be `begin`, `begin + step`,
      `begin + 2*step`, etc., up to but not including `end`.
    }];

    let arguments = (ins AnyRankedTensor:$input,
                         AnyRankedTensor:$output,
                         I32Attr:$dim,
                         I32Attr:$begin,
                         I32Attr:$end,
                         I32Attr:$step);

    let results = (outs AnyRankedTensor:$result);

    let hasVerifier = 1;
}

The verification function has been added as well:

// IndexOp verification
::mlir::LogicalResult mlir::tt::ttir::IndexOp::verify() {
  ::mlir::RankedTensorType inputType = getInput().getType();
  ::llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
  ::mlir::RankedTensorType outputType = getOutput().getType();
  int32_t dim = getDim();
  int32_t begin = getBegin();
  int32_t end = getEnd();
  int32_t step = getStep();

  // Verify that the input is at least 1D tensor
  if (inputType.getRank() < 1) {
    return emitOpError("Input must be at least a 1D tensor");
  }

  // Validate that the output tensor has the same element type as the input
  // tensor
  if (inputType.getElementType() != outputType.getElementType()) {
    return emitOpError(
        "Output tensor must have the same element type as the input tensor");
  }

  // Verify the output tensor rank
  if (inputType.getRank() != outputType.getRank()) {
    return emitOpError(
        "Output tensor must have the same rank as the input tensor");
  }

  // Verify that the dim attribute is within the bounds of the input tensor
  if (dim < 0 || dim >= inputType.getRank()) {
    return emitOpError() << "Invalid dimension index " << dim
                         << ". Input tensor rank is " << inputType.getRank();
  }

  // Verify begin, end, step and the output tensor dimensions
  int64_t dimSize = inputShape[dim];

  // Adjust negative begin and end
  int32_t adjustedBegin = (begin < 0) ? (begin + dimSize) : begin;
  int32_t adjustedEnd = (end < 0) ? (end + dimSize) : end;

  std::ostringstream inputShapeStream;
  inputShapeStream << "(";
  for (size_t i = 0; i < inputShape.size(); ++i) {
    inputShapeStream << inputShape[i];
    if (i != inputShape.size() - 1) {
      inputShapeStream << ", ";
    }
  }
  inputShapeStream << ")";
  std::string inputShapeStr = inputShapeStream.str();

  if (adjustedBegin < 0 || adjustedBegin >= dimSize) {
    return emitOpError() << "Invalid begin index for dimension "
                         << std::to_string(dim) << ". Expected value in range ["
                         << std::to_string(-dimSize) << ", " << dimSize
                         << "), got " << begin
                         << ". Input shape: " << inputShapeStr;
  }
  if (adjustedEnd < 0 || adjustedEnd > dimSize) {
    return emitOpError() << "Invalid end index for dimension "
                         << std::to_string(dim) << ". Expected value in range ["
                         << std::to_string(-dimSize) << ", " << dimSize
                         << "], got " << end
                         << ". Input shape: " << inputShapeStr;
  }

  auto formatValueMessage = [](int value, int adjustedValue) {
    return value < 0 ? std::to_string(adjustedValue) + " (" +
                           std::to_string(value) + ")"
                     : std::to_string(value);
  };
  std::string beginValueMessage = formatValueMessage(begin, adjustedBegin);
  std::string endValueMessage = formatValueMessage(end, adjustedEnd);

  if (step == 0) {
    return emitOpError("Step value for dimension " + std::to_string(dim) +
                       " cannot be zero");
  }

  if (step > 0 && adjustedBegin > adjustedEnd) {
    return emitOpError() << "For positive step, begin index must be less "
                            "than or equal to end index for dimension "
                         << dim << ". Got begin: " << beginValueMessage
                         << ", end: " << endValueMessage << ", step: " << step
                         << ", input shape: " << inputShapeStr;
  }

  if (step < 0 && adjustedBegin < adjustedEnd) {
    return emitOpError() << "For negative step, begin index must be greater "
                            "than or equal to end index for dimension "
                         << dim << ". Got begin: " << beginValueMessage
                         << ", end: " << endValueMessage << ", step: " << step
                         << ", input shape: " << inputShapeStr;
  }

  // Calculate the expected size of the output dimension
  int32_t expectedDimSize =
      (std::abs(adjustedEnd - adjustedBegin) + std::abs(step) - 1) /
      std::abs(step);
  if (outputType.getDimSize(dim) != expectedDimSize) {
    return emitOpError() << "Mismatch in dimension " << std::to_string(dim)
                         << " of the output tensor: expected size "
                         << expectedDimSize << ", but got "
                         << outputType.getDimSize(dim);
  }

  return success();
}

2. Create a conversion pattern

A conversion pattern defines how MLIR should rewrite the Op. It can be implemented in either C++ or TableGen. Currently, we only have the C++ implementation; TableGen format will be added in the future.

C++ conversion pattern

For the Index operation, we use the C++ conversion pattern because it involves changing the Op’s input types from integers to arrays, which TableGen lacks flexibility for.

// This transformation adjusts IndexOp attributes so that `begin`, `end`, and
// `step` become arrays, where each array element corresponds to a dimension of
// the input tensor. For dimensions other than the sliced dimension, default
// values are used.
//
namespace {
struct IndexToSliceConversionPattern
    : public OpConversionPattern<ttir::IndexOp> {
  using OpConversionPattern<ttir::IndexOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ttir::IndexOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto inputType =
        ::mlir::dyn_cast<mlir::RankedTensorType>(adaptor.getInput().getType());
    if (!inputType || !inputType.hasRank()) {
      return failure();
    }

    int64_t rank = inputType.getRank();
    llvm::SmallVector<mlir::Attribute, 4> begins, ends, steps;

    for (int64_t i = 0; i < rank; ++i) {
      if (i == op.getDim()) {
        begins.push_back(rewriter.getI32IntegerAttr(adaptor.getBegin()));
        ends.push_back(rewriter.getI32IntegerAttr(adaptor.getEnd()));
        steps.push_back(rewriter.getI32IntegerAttr(adaptor.getStep()));
      } else {
        begins.push_back(rewriter.getI32IntegerAttr(0));
        ends.push_back(rewriter.getI32IntegerAttr(inputType.getDimSize(i)));
        steps.push_back(rewriter.getI32IntegerAttr(1));
      }
    }

    auto newOp = rewriter.create<ttir::SliceOp>(
        op.getLoc(), op.getType(), adaptor.getInput(), adaptor.getOutput(),
        rewriter.getArrayAttr(begins), rewriter.getArrayAttr(ends),
        rewriter.getArrayAttr(steps));

    rewriter.replaceOp(op, newOp.getResult());
    return success();
  }
};
} // namespace

The matchAndRewrite method from OpConversionPattern is implemented to replace the matched Op with the newly created Op. Since decomposition is implemented as a conversion pass, OpAdaptor is used to access the attributes of the original Op in their converted types. Finally, we instantiate the new Op and call the replaceOp method on ConversionPatternRewriter to replace the original Op.

Tablegen conversion pattern

TODO

3. Register the created conversion pattern

To register the new pattern, go to the populateTTIRToTTIRDecompositionPatterns function in TTIRToTTIRDecomposition.cpp and add it to RewritePatternSet using the add method. After that is done you should mark the decomposed op as illegal in runOnOperation method of TTIRToTTIRDecompositionPass in TTIRToTTIRDecompositionPass.cpp.

You should also add a silicon test like described here: Add a silicon unit test for the Op. This is how the silicon test for the Index operation looks like:

// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
  func.func @forward(%arg0: tensor<4x32x32xbf16>) -> tensor<4x32x16xbf16> {
    %0 = ttir.empty() : tensor<4x32x16xbf16>
    // CHECK: = "ttnn.slice"
    %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 0: i32, end = 32: i32, step = 2: i32}> : (tensor<4x32x32xbf16>, tensor<4x32x16xbf16>) -> tensor<4x32x16xbf16>
    return %1 : tensor<4x32x16xbf16>
  }
}