Adding an Op
This guide will walk you through the process of adding a new Op end to end in
tt-mlir, in this case we will be adding a matmul
operation. Note that the matmul
op was added as part of the same changeset as this guide, it could be useful to
reference the diff alongside this guide to see the changes in full.
This guide will cover the following steps:
- Adding an Op
- 1. Define the Op in the TTIR frontend dialect
- 2. Define the Op in the TTNN backend dialect
- 3. Convert / Implement the Op in the TTNN passes
- 4. Add a compiler unit test for the Op
- 5. Define flatbuffer schema for the Op
- 6. Serialize the Op in the flatbuffer format
- 7. Add runtime support for the Op
- 8. Add a silicon unit test for the Op
1. Define the Op in the TTIR frontend dialect
We will start by defining the Op in the TTIR dialect. The TTIR Ops are defined
in a tablegen file located at include/ttmlir/Dialect/TTIR/IR/TTIROps.td
.
Tablegen is a domain-specific language for defining ops/types/attributes in MLIR and LLVM, these definitions constitute the dialect's Operation Definition Specification (ODS).
Here is an example of defining matmul
in the TTIR dialect:
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
let description = [{
Matrix multiply operation.
}];
let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
AnyRankedTensor:$output);
let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}
There are many things to break down here, starting from the top:
def
in tablegen is used to define a concrete type, this will have a 1-1 mapping to a C++ generated class, and for this particular case the build will end up generating filebuild/include/ttmlir/Dialect/TTIR/IR/TTIROps.h.inc
.- It inherits from
class TTIR_DPSOp
, classes in tablegen don't define a concrete type, but rather an interface that augment or constrain inheriteddef
s.TTIR_DPSOp
is a class that defines the common attributes for all TTIR Ops that implement Destination Passing Style (DPS) semantics. DPS just means that the result tensor is passed as an argument to the operation which will be critical for modeling buffer allocation / lifetimes. Note the 3rd argumentAnyRankedTensor:$output
. - Next we have a list of
arguments
. These arguments consist of a mixture ofType
s (i.e.AnyRankedTensor
) andAttribute
s. Read more about Types & Attributes here.AnyRankedTensor
is part of a tablegen standard library which type aliases to MLIR's builtin Tensor type, with the added constraint that the tensor has a static rank. As much as possible we want to use the builtin types and infrastructure provided by MLIR.
- Next we have a list of
results
in this case just 1, which aliases theoutput
tensor. One drawback of DPS is that the result tensor and the output tensor will appear to have different SSA names in the IR, but they really alias the same object. This can make writing some passes more cumbersome. - Next we have
extraClassDeclaration
, which enables us to inject member functions, written directly in C++, into the generated class. We are doing this for this particular case in order to satisfy the DPS interface which requires an implementation for getting the mutated output tensor. - Finally, we have
hasVerifier = 1
, this tells MLIR that we have a verifier function that will be called to validate the operation. This is a good practice to ensure that the IR is well formed.
We can now try building and opening the TTIROps.h.inc
file to see the generated C++ code.
We will actually get a linker error because we have hasVerifier = 1
which
automatically declared a verifier function, but we need to go implement.
Let's head over to lib/Dialect/TTIR/IR/TTIROps.cpp
and implement the verifier.
// MatmulOp verification
::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
::mlir::RankedTensorType inputBType = getB().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
llvm::SmallVector<int64_t> inputAShape(inputAType.getShape());
llvm::SmallVector<int64_t> inputBShape(inputBType.getShape());
// Verify that the input A is at least 1D tensor
if (inputAType.getRank() < 1) {
return emitOpError("Input A must be at least a 1D tensor");
}
// Verify that the input B is at least 1D tensor
if (inputBType.getRank() < 1) {
return emitOpError("Input B must be at least a 1D tensor");
}
// If input A is a vector (1D tensor), 1 is prepended to its dimension for the
// purpose of the matrix multiply. After the matrix multiply, the prepended
// dimension is removed.
if (inputAType.getRank() == 1) {
inputAShape.insert(inputAShape.begin(), 1);
}
// If input B is a vector (1D tensor), a 1 is appended to its dimension for
// the purpose of the matrix-vector product and removed after.
if (inputBType.getRank() == 1) {
inputBShape.push_back(1);
}
// Verify that the input A and input B has matching inner dimensions
if (inputAShape[inputAShape.size() - 1] !=
inputBShape[inputBShape.size() - 2]) {
return emitOpError(
"Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) +
") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) +
") must have matching inner dimensions");
}
llvm::SmallVector<int64_t> expectedOutputShape;
// Verify that the batch dimensions are broadcast compatible and construct the
// expected output shape
if (inputAShape.size() > 2 || inputBShape.size() > 2) {
llvm::SmallVector<int64_t> inputABatchDims, inputBBatchDims;
if (inputAShape.size() > 2) {
inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(),
inputAShape.end() - 2);
}
if (inputBShape.size() > 2) {
inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(),
inputBShape.end() - 2);
}
// Verify that the batch dimensions of input A and B are broadcast
// compatible
llvm::SmallVector<int64_t, 4> broadcastedShape;
if (!mlir::OpTrait::util::getBroadcastedShape(
inputABatchDims, inputBBatchDims, broadcastedShape)) {
return emitOpError("Batch dimensions of input A(" +
ttmlir::utils::join(inputABatchDims, ",") +
") and B(" +
ttmlir::utils::join(inputBBatchDims, ",") +
") are not broadcast compatible");
}
// Insert the broadcasted batch dimensions in the expected output shape
expectedOutputShape.insert(expectedOutputShape.begin(),
broadcastedShape.begin(),
broadcastedShape.end());
}
// Insert the input A and B inner dimensions in expected output shape
// Consider the case where input A and B are vectors. In that case,
// the dimension 1 is ommited from the output shape.
if (inputAType.getRank() > 1) {
expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]);
}
if (inputBType.getRank() > 1) {
expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]);
}
// Check the case of a vector-vector product. At this moment we don't support
// scalars in IR, hence check that the output is at least 1D tensor of size 1.
if (expectedOutputShape.size() == 0) {
if (outputType.getRank() < 1) {
return emitOpError("Scalar output is not supported, output must be at "
"least a 1D tensor");
}
if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) {
return emitOpError("Scalar output must be a 1D tensor of size 1");
}
return llvm::success();
}
// Verify that the output shape is correct
if (outputShape.size() != expectedOutputShape.size()) {
return emitOpError("Output shape rank(" +
std::to_string(outputShape.size()) +
") must match the expected output shape rank(" +
std::to_string(expectedOutputShape.size()) + ")");
}
// Verify each dim of the output shape
for (size_t i = 0; i < outputShape.size(); i++) {
if (outputShape[i] != expectedOutputShape[i]) {
return emitOpError(
"Output shape dimension[" + std::to_string(i) + "](" +
std::to_string(outputShape[i]) +
") doesn't match the expected output shape dimension[" +
std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) +
")");
}
}
return success();
}
2. Define the Op in the TTNN backend dialect
Next we will define the Op in the TTNN dialect. TTNN Ops are defined in the same way, but in their respective set of dialect files. Refer to the previous section for details, the process is the same.
TTNNOps.td
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul",
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints"]>]
> {
let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
AnyRankedTensor:$output);
let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}
TTNNOps.cpp
// MatmulOp verification
::mlir::LogicalResult mlir::tt::ttnn::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
::mlir::RankedTensorType inputBType = getB().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
llvm::SmallVector<int64_t> inputAShape(inputAType.getShape());
llvm::SmallVector<int64_t> inputBShape(inputBType.getShape());
// Verify that the input A is at least 1D tensor
if (inputAType.getRank() < 1) {
return emitOpError("Input A must be at least a 1D tensor");
}
// Verify that the input B is at least 1D tensor
if (inputBType.getRank() < 1) {
return emitOpError("Input B must be at least a 1D tensor");
}
// If input A is a vector (1D tensor), 1 is prepended to its dimension for the
// purpose of the matrix multiply. After the matrix multiply, the prepended
// dimension is removed.
if (inputAType.getRank() == 1) {
inputAShape.insert(inputAShape.begin(), 1);
}
// If input B is a vector (1D tensor), a 1 is appended to its dimension for
// the purpose of the matrix-vector product and removed after.
if (inputBType.getRank() == 1) {
inputBShape.push_back(1);
}
// Verify that the input A and input B has matching inner dimensions
if (inputAShape[inputAShape.size() - 1] !=
inputBShape[inputBShape.size() - 2]) {
return emitOpError(
"Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) +
") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) +
") must have matching inner dimensions");
}
llvm::SmallVector<int64_t> expectedOutputShape;
// Verify that the batch dimensions are broadcast compatible and construct the
// expected output shape
if (inputAShape.size() > 2 || inputBShape.size() > 2) {
llvm::SmallVector<int64_t> inputABatchDims, inputBBatchDims;
if (inputAShape.size() > 2) {
inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(),
inputAShape.end() - 2);
}
if (inputBShape.size() > 2) {
inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(),
inputBShape.end() - 2);
}
// Verify that the batch dimensions of input A and B are broadcast
// compatible
llvm::SmallVector<int64_t, 4> broadcastedShape;
if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims,
broadcastedShape)) {
return emitOpError("Batch dimensions of input A(" +
ttmlir::utils::join(inputABatchDims, ",") +
") and B(" +
ttmlir::utils::join(inputBBatchDims, ",") +
") are not broadcast compatible");
}
// Insert the broadcasted batch dimensions in the expected output shape
expectedOutputShape.insert(expectedOutputShape.begin(),
broadcastedShape.begin(),
broadcastedShape.end());
}
// Insert the input A and B inner dimensions in expected output shape
// Consider the case where input A and B are vectors. In that case,
// the dimension 1 is ommited from the output shape.
if (inputAType.getRank() > 1) {
expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]);
}
if (inputBType.getRank() > 1) {
expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]);
}
// Check the case of a vector-vector product. At this moment we don't support
// scalars in IR, hence check that the output is at least 1D tensor of size 1.
if (expectedOutputShape.size() == 0) {
if (outputType.getRank() < 1) {
return emitOpError("Scalar output is not supported, output must be at "
"least a 1D tensor");
}
if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) {
return emitOpError("Scalar output must be a 1D tensor of size 1");
}
return llvm::success();
}
// Verify that the output shape is correct
if (outputShape.size() != expectedOutputShape.size()) {
return emitOpError("Output shape rank(" +
std::to_string(outputShape.size()) +
") must match the expected output shape rank(" +
std::to_string(expectedOutputShape.size()) + ")");
}
// Verify each dim of the output shape
for (size_t i = 0; i < outputShape.size(); i++) {
if (outputShape[i] != expectedOutputShape[i]) {
return emitOpError(
"Output shape dimension[" + std::to_string(i) + "](" +
std::to_string(outputShape[i]) +
") doesn't match the expected output shape dimension[" +
std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) +
")");
}
}
return success();
}
3. Convert / Implement the Op in the TTNN passes
Next we will implement the conversion from the TTIR matmul
Op to the TTNN matmul
Op.
This is a trivial conversion, as the Ops are identical in their semantics, so
the changeset isn't going to be very instructive, but will at least point to the
files involved. The conversion is implemented in the ConvertTTIRToTTNNPass
pass in
file lib/Conversion/TTIRToTTNN/TTIRToTTNNPass.cpp
.
Zooming into class ConvertTTIRToTTNNPass
we can see we implement the pass interface
via member function void runOnOperation() final
. This function will be called
for every operation matching the type specified in the pass tablegen file. A
quick look at include/ttmlir/Conversion/Passes.td
we can see:
def ConvertTTIRToTTNN: Pass<"convert-ttir-to-ttnn", "::mlir::ModuleOp"> {
This means that runOnOperation
will be called for every ModuleOp
in the
graph, usually there is only one ModuleOp
which serves as the root of the
graph.
Inside runOnOperation
is usually where we define a rewrite pattern set that
can match much more complicated patterns (nested inside of the ModuleOp
's
regions)
than just a single operation. In runOperation
method you will see the call to
method populateTTIRToTTNNPatterns(...)
that actually generates rewrite patterns.
Method populateTTIRToTTNNPatterns(...)
is defined
in lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
.
patterns
.add<TensorEmptyConversionPattern,
OnesOpConversionPattern,
ToLayoutOpConversionPattern,
ElementwiseOpConversionPattern<ttir::AbsOp, ttnn::AbsOp>,
ElementwiseOpConversionPattern<ttir::AddOp, ttnn::AddOp>,
ElementwiseOpConversionPattern<ttir::CbrtOp, ttnn::CbrtOp>,
ElementwiseOpConversionPattern<ttir::FloorOp, ttnn::FloorOp>,
ElementwiseOpConversionPattern<ttir::IsFiniteOp, ttnn::IsFiniteOp>,
ElementwiseOpConversionPattern<ttir::LogicalAndOp, ttnn::LogicalAndOp>,
ElementwiseOpConversionPattern<ttir::LogicalOrOp, ttnn::LogicalOrOp>,
ElementwiseOpConversionPattern<ttir::LogicalNotOp, ttnn::LogicalNotOp>,
ElementwiseOpConversionPattern<ttir::LogicalXorOp, ttnn::LogicalXorOp>,
ElementwiseOpConversionPattern<ttir::BitwiseAndOp, ttnn::BitwiseAndOp>,
ElementwiseOpConversionPattern<ttir::BitwiseOrOp, ttnn::BitwiseOrOp>,
ElementwiseOpConversionPattern<ttir::BitwiseXorOp, ttnn::BitwiseXorOp>,
ElementwiseOpConversionPattern<ttir::BitwiseNotOp, ttnn::BitwiseNotOp>,
ElementwiseOpConversionPattern<ttir::MultiplyOp, ttnn::MultiplyOp>,
ElementwiseOpConversionPattern<ttir::EqualOp, ttnn::EqualOp>,
ElementwiseOpConversionPattern<ttir::NotEqualOp, ttnn::NotEqualOp>,
ElementwiseOpConversionPattern<ttir::GreaterEqualOp, ttnn::GreaterEqualOp>,
ElementwiseOpConversionPattern<ttir::GreaterThanOp, ttnn::GreaterThanOp>,
ElementwiseOpConversionPattern<ttir::LessEqualOp, ttnn::LessEqualOp>,
ElementwiseOpConversionPattern<ttir::LessThanOp, ttnn::LessThanOp>,
ElementwiseOpConversionPattern<ttir::MaximumOp, ttnn::MaximumOp>,
ElementwiseOpConversionPattern<ttir::MinimumOp, ttnn::MinimumOp>,
ElementwiseOpConversionPattern<ttir::NegOp, ttnn::NegOp>,
ElementwiseOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ElementwiseOpConversionPattern<ttir::GeluOp, ttnn::GeluOp>,
ElementwiseOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
ElementwiseOpConversionPattern<ttir::RsqrtOp, ttnn::RsqrtOp>,
ElementwiseOpConversionPattern<ttir::SignOp, ttnn::SignOp>,
ElementwiseOpConversionPattern<ttir::SigmoidOp, ttnn::SigmoidOp>,
ElementwiseOpConversionPattern<ttir::Log1pOp, ttnn::Log1pOp>,
ElementwiseOpConversionPattern<ttir::ReciprocalOp, ttnn::ReciprocalOp>,
ElementwiseOpConversionPattern<ttir::ExpOp, ttnn::ExpOp>,
ElementwiseOpConversionPattern<ttir::LogOp, ttnn::LogOp>,
ElementwiseOpConversionPattern<ttir::DivOp, ttnn::DivOp>,
ElementwiseOpConversionPattern<ttir::CeilOp, ttnn::CeilOp>,
ElementwiseOpConversionPattern<ttir::SinOp, ttnn::SinOp>,
ElementwiseOpConversionPattern<ttir::CosOp, ttnn::CosOp>,
ElementwiseOpConversionPattern<ttir::Expm1Op, ttnn::Expm1Op>,
ElementwiseOpConversionPattern<ttir::RemainderOp, ttnn::RemainderOp>,
ElementwiseOpConversionPattern<ttir::WhereOp, ttnn::WhereOp>,
ElementwiseOpConversionPattern<ttir::TanOp, ttnn::TanOp>,
ElementwiseOpConversionPattern<ttir::TanhOp, ttnn::TanhOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
EmbeddingBackwardOpConversionPattern,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
TypecastOpConversionPattern,
ClampOpConversionPattern,
ConcatOpConversionPattern,
ReshapeOpConversionPattern,
SliceOpConversionPattern,
SqueezeOpConversionPattern,
UnsqueezeOpConversionPattern,
ConstantOpConversionPattern,
LinearOpConversionPattern,
MatmulOpConversionPattern,
Conv2dOpConversionPattern,
ConvTranspose2dOpConversionPattern,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
MeshShardOpConversionPattern,
AllReduceOpConversionPattern,
AllGatherOpConversionPattern,
ArangeOpConversionPattern,
UpdateCacheOpConversionPattern,
FillCacheOpConversionPattern,
ScatterOpConversionPattern,
PermuteOpConversionPattern
>(typeConverter, ctx);
More information on rewrite patterns and their capabilities can be found in the MLIR documentation here and here.
For matmul, we defined a new conversion pattern that's generic to all binary ops
with arguments named a
and b
:
class MatmulOpConversionPattern : public OpConversionPattern<ttir::MatmulOp> {
public:
using OpConversionPattern<ttir::MatmulOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ttir::MatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::MatmulOp>(
op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(),
adaptor.getB(), adaptor.getOutput());
return success();
}
};
Invoked as part of the rewrite set:
MatmulOpConversionPattern
Note:
We also need to add this op to the C++ emitter,
lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
see
populateTTNNToEmitCPatterns(...)
.
4. Add a compiler unit test for the Op
So far we have defined the Op in the TTIR and TTNN dialects,
implemented verifiers, and have conversion passes. Now we need to add a unit
test to ensure that the pass is working correctly. The compiler unit tests are located
in test/ttmlir/Dialect
area. In this case we'll add a test under the TTNN
subdirectory since we are testing the ConvertTTIRToTTNNPass
.
test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
%0 = tensor.empty() : tensor<64x96xbf16>
// CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]]
%1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16>
return %1 : tensor<64x96xbf16>
}
}
Unit tests in MLIR are typically written using a tool called
FileCheck
, please refer to the llvm FileCheck documentation for a tutorial and more information about theRUN
andCHECK
directives.
A few things to point out specifically regarding tt-mlir dialects:
tt.system_desc
: This is a 1-1 mapping to theSystemDesc
flatbuffer schema that is used to describe the system configuration. This is a required attribute tagged on the top level module for all tt-mlir dialects.- Pass
--ttir-layout
is a prerequisite before runningconvert-ttir-to-ttnn
. This pass is responsible for converting the input tensors to device memory space and tile layout before lowering to TTNN. - This test is asserting that
ttir.matmul
converts tottnn.matmul
.
To run the test, you can use the following command:
cmake --build build -- check-ttmlir
You can also manually run ttmlir-opt
on the test file to see the
resulting output:
./build/bin/ttmlir-opt --ttir-load-system-desc="path=<PATH_TO_SYSTEM_DESC>" --ttir-to-ttnn-backend-pipeline test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir
5. Define flatbuffer schema for the Op
Next we will define the flatbuffer schema for the Op. The schema must capture all tensor inputs, outputs, and attributes of the Op, i.e. everything the runtime needs to execute the Op.
include/ttmlir/Target/TTNN/program.fbs
table MatmulOp {
in0: tt.target.TensorRef;
in1: tt.target.TensorRef;
out: tt.target.TensorRef;
}
Type TensorRef
, flatbuffer tables with suffix Ref
are used to represent live values
during the runtime, decoupled from the underlying Desc
suffixes which carry the
type and attribute information for the object.
We also add this new op to the union OpType
, which is the variant type for all
ops.
More information about writing flatbuffer schemas can be found in the flatbuffers documentation
6. Serialize the Op in the flatbuffer format
In the previous section we defined the flatbuffer schema for the matmul
Op, now let's put our new schema definition to use. The schema is used as input
to a program called flatc
which generates C++ code (or any language for that
matter) for serializing and deserializing the schema. This generated code can be
found in build/include/ttmlir/Target/TTNN/program_generated.h
.
Let's head over to lib/Target/TTNN/TTNNToFlatbuffer.cpp
to define
a createOp
overloaded function that does the conversion from MLIR to flatbuffer:
::flatbuffers::Offset<::tt::target::ttnn::MatmulOp>
createOp(FlatbufferObjectCache &cache, MatmulOp op) {
auto in0 =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getA()));
auto in1 =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getB()));
auto output = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
return ::tt::target::ttnn::CreateMatmulOp(*cache.fbb, in0, in1, output);
}
Lots of things are happening here, let's break it down:
FlatbufferObjectCache
: This is a helper class that is used to cache objects in the flatbuffer that are created during the serialization process. This is necessary for managing value lifetimes and identifiers, at the same time it is an optimization to avoid having multiple copies of the same object. For example, aTensorRef
with multiple uses could naively be recreated, one for each use, but with the cache we can ensure that the object is only created once and all uses point to the same flatbuffer offset. The cache is passed around to all serialization functions and should be used whenever creating a new object.getOperandThroughDPSOps
: In section 1. we discussed DPS semantics and the drawback of having the result alias the output tensor. This is one of those cases where we need to use a helper function to trace through the output operands to find the original SSA name in order to associate it with the originalTensorRef
.CreateMatmulOp
: The autogenerated function from the flatbuffer schema that actually serializes the data into the flatbuffer format.
We can finally generate a binary with our new Op! We can use the following command:
./build/bin/ttmlir-opt --ttir-load-system-desc="path=<PATH_TO_SYSTEM_DESC>" --ttir-to-ttnn-backend-pipeline test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir | ./build/bin/ttmlir-translate --ttnn-to-flatbuffer -o out.ttnn
And we can inspect the with ttrt
:
ttrt read out.ttnn
7. Add runtime support for the Op
Next, we want to add runtime support for the Op by parsing the flatbuffer and invoking the TTNN API.
runtime/lib/ttnn/operations/matmul/matmul.cpp
void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id());
const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id());
DEBUG_ASSERT(lhs.is_allocated());
DEBUG_ASSERT(rhs.is_allocated());
::ttnn::DataType outputDataType = utils::getDataType(op->out());
::tt::tt_metal::MemoryConfig outputMemoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
const std::optional<const ::tt::tt_metal::MemoryConfig> memoryConfig =
std::make_optional(outputMemoryConfig);
const std::optional<const ::ttnn::DataType> dtype =
std::make_optional(outputDataType);
::ttnn::Tensor out = ::ttnn::matmul(
lhs, rhs, /*transposeA*/ false, /*transposeB*/ false, memoryConfig, dtype,
/*programConfig*/ std::nullopt, /*activation*/ std::nullopt,
/*computeKernelConfig*/ std::nullopt, /*coreGrid*/ std::nullopt);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
A couple things to note from above:
- Most runtime op functions will follow a similar pattern, they will take in
some additional datastructures for managing the program context.
- Program context tracks the state of the current program. It stores intermediate tensors and devices.
tensorPool.at(op->in0()->global_id())
:global_id
is a unique identifier for the tensor that was generated and managed by theFlatbufferObjectCache
. This is how it's intended to be used by the runtime.- Some operations may belong to a larger set of operations. For example, any eltwise unary operations can
be added in
runtime/lib/ttnn/operations/eltwise/unary.cpp
directly without needing to create a new file.
If a new file is created for the op, we need to add a new source to runtime/lib/ttnn/operations/CMakeLists.txt
and a new case to runtime/lib/ttnn/program.cpp
.
To update runtime/lib/ttnn/operations/CMakeLists.txt
, include the path to the source file in TTNN_OPS_SRCS
:
runtime/lib/ttnn/operations/CMakeLists.txt
${CMAKE_CURRENT_SOURCE_DIR}/matmul/matmul.cpp
To update runtime/lib/ttnn/program.cpp
, add a new case to the runOperation
method of ProgramExecutor
:
runtime/lib/ttnn/program.cpp
case ::tt::target::ttnn::OpType::MatmulOp: {
return operations::matmul::run(op->type_as_MatmulOp(), context);
}
We can test our changes with ttrt
(don't forget to rebuild ttrt
):
ttrt run out.ttnn
8. Add a silicon unit test for the Op
After adding runtime support, we're ready to test our Op on silicon. All silicon tests are located
under test/ttmlir/Silicon
. The process is similar to adding a compiler unit test.
In our specific case, we create a unit test here: test/ttmlir/Silicon/TTNN/simple_matmul.mlir
:
test/ttmlir/Silicon/TTNN/simple_matmul.mlir
{{#include ../../../test/ttmlir/Silicon/TTNN/simple_matmul.mlir}}
Couple things to point out about this process:
- Tests placed under
test/ttmlir/Dialect
will only test the compiler's capability of compiling the module. If you want the module to run on silicon in CI, the test must be placed undertest/ttmlir/Silicon
. - Notice the differences between the compilation headers of
test/ttmlir/Silicon/TTNN/simple_matmul.mlir
andtest/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir
--ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%"
: Thesystem-desc-path
option specifies the location of the system descriptor required for compiling the module. This is crucial for silicon tests, as modules compiled with different system descriptors may vary in silicon compatibility. Ensuring the system descriptor accurately reflects the target hardware is essential for running the module correctly.// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
: This runsttmlir-translate
that serializes the output mlir module to a flatbuffer binary. We added the logic for this serialization in the Serialize the Op in the flatbuffer format section.