Decomposing an Op in TTIR
This guide explains how to add and decompose a new operation in the TTIR dialect. We’ll focus on adding an Index
operation, which will be decomposed into the Slice
operation. The decomposition is implemented as a conversion pass in MLIR since it allows us to mark operations or dialects as legal or illegal, type conversion...
This guide will cover the following steps:
1. Define the Op in the TTIR frontend dialect
The more information regarding this step can be found here: Define the Op in the TTIR frontend dialect
I updated the TTIROps.td
as following:
def TTIR_IndexOp: TTIR_DPSOp<"index"> {
let summary = "Index op.";
let description = [{
Extract a sub-tensor (slice) from the input tensor along a specified dimension.
The `begin`, `end`, and `step` attributes define the start, stop, and step indices for the
selected dimension (`dim`) of the tensor.
}];
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I32Attr:$dim,
I32Attr:$begin,
I32Attr:$end,
I32Attr:$step);
let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}
The verification function has been added as well:
// IndexOp verification
::mlir::LogicalResult mlir::tt::ttir::IndexOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
::mlir::RankedTensorType outputType = getOutput().getType();
int32_t dim = getDim();
int32_t begin = getBegin();
int32_t end = getEnd();
int32_t step = getStep();
// Verify that the input is at least 1D tensor
if (inputType.getRank() < 1) {
return emitOpError("Input must be at least a 1D tensor");
}
// Validate that the output tensor has the same element type as the input
// tensor
if (inputType.getElementType() != outputType.getElementType()) {
return emitOpError(
"Output tensor must have the same element type as the input tensor");
}
// Verify the output tensor rank
if (inputType.getRank() != outputType.getRank()) {
return emitOpError(
"Output tensor must have the same rank as the input tensor");
}
// Verify that the dim attribute is within the bounds of the input tensor
if (dim < 0 || dim >= inputType.getRank()) {
return emitOpError() << "Invalid dimension index " << dim
<< ". Input tensor rank is " << inputType.getRank();
}
// Verify begin, end, step and the output tensor dimensions
int64_t dimSize = inputShape[dim];
// Adjust negative begin and end
int32_t adjustedBegin = (begin < 0) ? (begin + dimSize) : begin;
int32_t adjustedEnd = (end < 0) ? (end + dimSize) : end;
std::ostringstream inputShapeStream;
inputShapeStream << "(";
for (size_t i = 0; i < inputShape.size(); ++i) {
inputShapeStream << inputShape[i];
if (i != inputShape.size() - 1) {
inputShapeStream << ", ";
}
}
inputShapeStream << ")";
std::string inputShapeStr = inputShapeStream.str();
if (adjustedBegin < 0 || adjustedBegin >= dimSize) {
return emitOpError() << "Invalid begin index for dimension "
<< std::to_string(dim) << ". Expected value in range ["
<< std::to_string(-dimSize) << ", " << dimSize
<< "), got " << begin
<< ". Input shape: " << inputShapeStr;
}
if (adjustedEnd < 0 || adjustedEnd > dimSize) {
return emitOpError() << "Invalid end index for dimension "
<< std::to_string(dim) << ". Expected value in range ["
<< std::to_string(-dimSize) << ", " << dimSize
<< "], got " << end
<< ". Input shape: " << inputShapeStr;
}
auto formatValueMessage = [](int value, int adjustedValue) {
return value < 0 ? std::to_string(adjustedValue) + " (" +
std::to_string(value) + ")"
: std::to_string(value);
};
std::string beginValueMessage = formatValueMessage(begin, adjustedBegin);
std::string endValueMessage = formatValueMessage(end, adjustedEnd);
if (step == 0) {
return emitOpError("Step value for dimension " + std::to_string(dim) +
" cannot be zero");
}
if (step > 0 && adjustedBegin > adjustedEnd) {
return emitOpError() << "For positive step, begin index must be less "
"than or equal to end index for dimension "
<< dim << ". Got begin: " << beginValueMessage
<< ", end: " << endValueMessage << ", step: " << step
<< ", input shape: " << inputShapeStr;
}
if (step < 0 && adjustedBegin < adjustedEnd) {
return emitOpError() << "For negative step, begin index must be greater "
"than or equal to end index for dimension "
<< dim << ". Got begin: " << beginValueMessage
<< ", end: " << endValueMessage << ", step: " << step
<< ", input shape: " << inputShapeStr;
}
// Calculate the expected size of the output dimension
int32_t expectedDimSize =
(std::abs(adjustedEnd - adjustedBegin) + std::abs(step) - 1) /
std::abs(step);
if (outputType.getDimSize(dim) != expectedDimSize) {
return emitOpError() << "Mismatch in dimension " << std::to_string(dim)
<< " of the output tensor: expected size "
<< expectedDimSize << ", but got "
<< outputType.getDimSize(dim);
}
return success();
}
2. Create a conversion pattern
A conversion pattern defines how MLIR should rewrite the Op. It can be implemented in either C++ or TableGen. Currently, we only have the C++ implementation; TableGen format will be added in the future.
C++ conversion pattern
For the Index
operation, we use the C++ conversion pattern because it involves changing the Op’s input types from integers to arrays, which TableGen lacks flexibility for.
// This transformation adjusts IndexOp attributes so that `begin`, `end`, and
// `step` become arrays, where each array element corresponds to a dimension of
// the input tensor. For dimensions other than the sliced dimension, default
// values are used.
//
struct IndexToSliceConversionPattern
: public OpConversionPattern<ttir::IndexOp> {
using OpConversionPattern<ttir::IndexOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ttir::IndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType =
::mlir::dyn_cast<mlir::RankedTensorType>(adaptor.getInput().getType());
if (!inputType || !inputType.hasRank()) {
return failure();
}
int64_t rank = inputType.getRank();
llvm::SmallVector<mlir::Attribute, 4> begins, ends, steps;
for (int64_t i = 0; i < rank; ++i) {
if (i == op.getDim()) {
begins.push_back(rewriter.getI32IntegerAttr(adaptor.getBegin()));
ends.push_back(rewriter.getI32IntegerAttr(adaptor.getEnd()));
steps.push_back(rewriter.getI32IntegerAttr(adaptor.getStep()));
} else {
begins.push_back(rewriter.getI32IntegerAttr(0));
ends.push_back(rewriter.getI32IntegerAttr(inputType.getDimSize(i)));
steps.push_back(rewriter.getI32IntegerAttr(1));
}
}
auto newOp = rewriter.create<ttir::SliceOp>(
op.getLoc(), op.getType(), adaptor.getInput(), adaptor.getOutput(),
rewriter.getArrayAttr(begins), rewriter.getArrayAttr(ends),
rewriter.getArrayAttr(steps));
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
The matchAndRewrite
method from OpConversionPattern
is implemented to replace the matched Op with the newly created Op. Since decomposition is implemented as a conversion pass, OpAdaptor
is used to access the attributes of the original Op in their converted types. Finally, we instantiate the new Op and call the replaceOp
method on ConversionPatternRewriter
to replace the original Op.
Tablegen conversion pattern
TODO
3. Register the created conversion pattern
To register the new pattern, go to the populateTTIRToTTIRDecompositionPatterns
function in TTIRToTTIRDecomposition.cpp
and add it to RewritePatternSet
using the add method. After that is done you should mark the decomposed op as illegal in runOnOperation
method of TTIRToTTIRDecompositionPass
in TTIRToTTIRDecompositionPass.cpp
.
You should also add a silicon test like described here: Add a silicon unit test for the Op. This is how the silicon test for the Index
operation looks like:
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
func.func @forward(%arg0: tensor<4x32x32xbf16>) -> tensor<4x32x16xbf16> {
%0 = tensor.empty() : tensor<4x32x16xbf16>
// CHECK: %[[C:.*]] = "ttnn.slice"[[C:.*]]
%1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 0: i32, end = 32: i32, step = 2: i32}> : (tensor<4x32x32xbf16>, tensor<4x32x16xbf16>) -> tensor<4x32x16xbf16>
return %1 : tensor<4x32x16xbf16>
}
}