Decomposing an Op in TTIR
This guide explains how to add and decompose a new operation in the TTIR dialect. We’ll focus on adding an Index
operation, which will be decomposed into the Slice
operation. The decomposition is implemented as a conversion pass in MLIR since it allows us to mark operations or dialects as legal or illegal, type conversion...
This guide will cover the following steps:
1. Define the Op in the TTIR frontend dialect
The more information regarding this step can be found here: Define the Op in the TTIR frontend dialect
I updated the TTIROps.td
as following:
def TTIR_IndexOp: TTIR_NamedOp<"index"> {
let summary = "Tensor indexing operation.";
let description = [{
The `index` operation extracts a sub-tensor (slice) from the input tensor along a specified dimension.
This operation selects elements from the input tensor along a single dimension based on the specified
begin, end, and step indices. It's similar to Python's slicing notation `tensor[:, begin:end:step, :]`
where the slicing is applied only to the specified dimension.
Example:
```mlir
// Extract elements with indices 1, 3, 5 from dimension 0 of a 1D tensor
%input = ... : tensor<6xf32> // Input tensor with values: [1, 2, 3, 4, 5, 6]
%output = ttir.empty() : tensor<3xf32> // Output tensor shape
%result = ttir.index(%input, %output) {
dim = 0 : i32, // Dimension to index
begin = 1 : i32, // Start index
end = 6 : i32, // End index (exclusive)
step = 2 : i32 // Step size
} : tensor<6xf32>, tensor<3xf32> -> tensor<3xf32>
// Result: [2, 4, 6]
// Extract columns 0 and 2 from a 2D tensor
%input = ... : tensor<3x4xf32> // Input tensor with values:
// [[1, 2, 3, 4],
// [5, 6, 7, 8],
// [9, 10, 11, 12]]
%output = ttir.empty() : tensor<3x2xf32> // Output tensor shape
%result = ttir.index(%input, %output) {
dim = 1 : i32, // Index along columns (dimension 1)
begin = 0 : i32, // Start from first column
end = 3 : i32, // End at third column (exclusive)
step = 2 : i32 // Take every other column
} : tensor<3x4xf32>, tensor<3x2xf32> -> tensor<3x2xf32>
// Result:
// [[1, 3],
// [5, 7],
// [9, 11]]
```
Inputs:
- `input` (Tensor): The input tensor to index.
Attributes:
- `dim` (Integer): The dimension along which to index.
- `begin` (Integer): The starting index.
- `end` (Integer): The ending index (exclusive).
- `step` (Integer): The step size between indices.
Outputs:
- `result` (Tensor): The indexed tensor.
Note: The shape of the output tensor is the same as the input tensor except for the indexed dimension,
which will have size `ceil((end - begin) / step)`. The indices selected will be `begin`, `begin + step`,
`begin + 2*step`, etc., up to but not including `end`.
}];
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I32Attr:$dim,
I32Attr:$begin,
I32Attr:$end,
I32Attr:$step);
let results = (outs AnyRankedTensor:$result);
let hasVerifier = 1;
}
The verification function has been added as well:
// IndexOp verification
::mlir::LogicalResult mlir::tt::ttir::IndexOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
::mlir::RankedTensorType outputType = getOutput().getType();
int32_t dim = getDim();
int32_t begin = getBegin();
int32_t end = getEnd();
int32_t step = getStep();
// Verify that the input is at least 1D tensor
if (inputType.getRank() < 1) {
return emitOpError("Input must be at least a 1D tensor");
}
// Validate that the output tensor has the same element type as the input
// tensor
if (inputType.getElementType() != outputType.getElementType()) {
return emitOpError(
"Output tensor must have the same element type as the input tensor");
}
// Verify the output tensor rank
if (inputType.getRank() != outputType.getRank()) {
return emitOpError(
"Output tensor must have the same rank as the input tensor");
}
// Verify that the dim attribute is within the bounds of the input tensor
if (dim < 0 || dim >= inputType.getRank()) {
return emitOpError() << "Invalid dimension index " << dim
<< ". Input tensor rank is " << inputType.getRank();
}
// Verify begin, end, step and the output tensor dimensions
int64_t dimSize = inputShape[dim];
// Adjust negative begin and end
int32_t adjustedBegin = (begin < 0) ? (begin + dimSize) : begin;
int32_t adjustedEnd = (end < 0) ? (end + dimSize) : end;
std::ostringstream inputShapeStream;
inputShapeStream << "(";
for (size_t i = 0; i < inputShape.size(); ++i) {
inputShapeStream << inputShape[i];
if (i != inputShape.size() - 1) {
inputShapeStream << ", ";
}
}
inputShapeStream << ")";
std::string inputShapeStr = inputShapeStream.str();
if (adjustedBegin < 0 || adjustedBegin >= dimSize) {
return emitOpError() << "Invalid begin index for dimension "
<< std::to_string(dim) << ". Expected value in range ["
<< std::to_string(-dimSize) << ", " << dimSize
<< "), got " << begin
<< ". Input shape: " << inputShapeStr;
}
if (adjustedEnd < 0 || adjustedEnd > dimSize) {
return emitOpError() << "Invalid end index for dimension "
<< std::to_string(dim) << ". Expected value in range ["
<< std::to_string(-dimSize) << ", " << dimSize
<< "], got " << end
<< ". Input shape: " << inputShapeStr;
}
auto formatValueMessage = [](int value, int adjustedValue) {
return value < 0 ? std::to_string(adjustedValue) + " (" +
std::to_string(value) + ")"
: std::to_string(value);
};
std::string beginValueMessage = formatValueMessage(begin, adjustedBegin);
std::string endValueMessage = formatValueMessage(end, adjustedEnd);
if (step == 0) {
return emitOpError("Step value for dimension " + std::to_string(dim) +
" cannot be zero");
}
if (step > 0 && adjustedBegin > adjustedEnd) {
return emitOpError() << "For positive step, begin index must be less "
"than or equal to end index for dimension "
<< dim << ". Got begin: " << beginValueMessage
<< ", end: " << endValueMessage << ", step: " << step
<< ", input shape: " << inputShapeStr;
}
if (step < 0 && adjustedBegin < adjustedEnd) {
return emitOpError() << "For negative step, begin index must be greater "
"than or equal to end index for dimension "
<< dim << ". Got begin: " << beginValueMessage
<< ", end: " << endValueMessage << ", step: " << step
<< ", input shape: " << inputShapeStr;
}
// Calculate the expected size of the output dimension
int32_t expectedDimSize =
(std::abs(adjustedEnd - adjustedBegin) + std::abs(step) - 1) /
std::abs(step);
if (outputType.getDimSize(dim) != expectedDimSize) {
return emitOpError() << "Mismatch in dimension " << std::to_string(dim)
<< " of the output tensor: expected size "
<< expectedDimSize << ", but got "
<< outputType.getDimSize(dim);
}
return success();
}
2. Create a conversion pattern
A conversion pattern defines how MLIR should rewrite the Op. It can be implemented in either C++ or TableGen. Currently, we only have the C++ implementation; TableGen format will be added in the future.
C++ conversion pattern
For the Index
operation, we use the C++ conversion pattern because it involves changing the Op’s input types from integers to arrays, which TableGen lacks flexibility for.
// This transformation adjusts IndexOp attributes so that `begin`, `end`, and
// `step` become arrays, where each array element corresponds to a dimension of
// the input tensor. For dimensions other than the sliced dimension, default
// values are used.
//
namespace {
struct IndexToSliceConversionPattern
: public OpConversionPattern<ttir::IndexOp> {
using OpConversionPattern<ttir::IndexOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ttir::IndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType =
::mlir::dyn_cast<mlir::RankedTensorType>(adaptor.getInput().getType());
if (!inputType || !inputType.hasRank()) {
return failure();
}
int64_t rank = inputType.getRank();
llvm::SmallVector<mlir::Attribute, 4> begins, ends, steps;
for (int64_t i = 0; i < rank; ++i) {
if (i == op.getDim()) {
begins.push_back(rewriter.getI32IntegerAttr(adaptor.getBegin()));
ends.push_back(rewriter.getI32IntegerAttr(adaptor.getEnd()));
steps.push_back(rewriter.getI32IntegerAttr(adaptor.getStep()));
} else {
begins.push_back(rewriter.getI32IntegerAttr(0));
ends.push_back(rewriter.getI32IntegerAttr(inputType.getDimSize(i)));
steps.push_back(rewriter.getI32IntegerAttr(1));
}
}
auto newOp = rewriter.create<ttir::SliceOp>(
op.getLoc(), op.getType(), adaptor.getInput(), adaptor.getOutput(),
rewriter.getArrayAttr(begins), rewriter.getArrayAttr(ends),
rewriter.getArrayAttr(steps));
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
} // namespace
The matchAndRewrite
method from OpConversionPattern
is implemented to replace the matched Op with the newly created Op. Since decomposition is implemented as a conversion pass, OpAdaptor
is used to access the attributes of the original Op in their converted types. Finally, we instantiate the new Op and call the replaceOp
method on ConversionPatternRewriter
to replace the original Op.
Tablegen conversion pattern
TODO
3. Register the created conversion pattern
To register the new pattern, go to the populateTTIRToTTIRDecompositionPatterns
function in TTIRToTTIRDecomposition.cpp
and add it to RewritePatternSet
using the add method. After that is done you should mark the decomposed op as illegal in runOnOperation
method of TTIRToTTIRDecompositionPass
in TTIRToTTIRDecompositionPass.cpp
.
You should also add a silicon test like described here: Add a silicon unit test for the Op. This is how the silicon test for the Index
operation looks like:
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
func.func @forward(%arg0: tensor<4x32x32xbf16>) -> tensor<4x32x16xbf16> {
%0 = ttir.empty() : tensor<4x32x16xbf16>
// CHECK: = "ttnn.slice"
%1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 0: i32, end = 32: i32, step = 2: i32}> : (tensor<4x32x32xbf16>, tensor<4x32x16xbf16>) -> tensor<4x32x16xbf16>
return %1 : tensor<4x32x16xbf16>
}
}