ttir.abs (tt::ttir::AbsOp)

Eltwise absolute op.

Eltwise absolute operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.add (tt::ttir::AddOp)

Eltwise add.

Eltwise add operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface, TTIR_GenericRegionOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.alloc (tt::ttir::AllocOp)

Alloc op.

Tensor Alloc operation

Attributes:

AttributeMLIR TypeDescription
address::mlir::IntegerAttr64-bit signless integer attribute
size::mlir::IntegerAttr64-bit signless integer attribute
memory_space::mlir::tt::MemorySpaceAttr
TT MemorySpace{{% markdown %}}Enum cases: * system (`System`) * mmio (`SystemMMIO`) * dram (`DeviceDRAM`) * l1 (`DeviceL1`){{% /markdown %}}

Results:

ResultDescription
resultranked tensor of any type values

ttir.broadcast (tt::ttir::BroadcastOp)

Broadcast operation.

Broadcast op.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
dimension::mlir::ArrayAttr64-bit integer array attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.concat (tt::ttir::ConcatOp)

Concat op.

Concat tensors along a given dimension.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
dim::mlir::IntegerAttr32-bit signed integer attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.constant (tt::ttir::ConstantOp)

Constant op.

Produces tensor filled with given constant value.

Examples: %0 = "ttir.constant"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32> // %0: [[0, 0, 0], [0, 0, 0]] %1 = "ttir.constant"() {value = dense<[0.2, 1.3]> : tensor<2xf32>} : () -> tensor<2xf32> // %1: [0.2, 1.3]

Traits: ConstantLike

Attributes:

AttributeMLIR TypeDescription
value::mlir::ElementsAttrconstant vector/tensor attribute

Results:

ResultDescription
resultranked tensor of any type values

ttir.conv2d (tt::ttir::Conv2dOp)

Conv2d operation.

Applies a 2D convolution over an input image composed of several input planes.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
stride_height::mlir::IntegerAttr32-bit signed integer attribute
stride_width::mlir::IntegerAttr32-bit signed integer attribute
dilation_height::mlir::IntegerAttr32-bit signed integer attribute
dilation_width::mlir::IntegerAttr32-bit signed integer attribute
groups::mlir::IntegerAttr32-bit signed integer attribute
padding_left::mlir::IntegerAttr32-bit signed integer attribute
padding_right::mlir::IntegerAttr32-bit signed integer attribute
padding_top::mlir::IntegerAttr32-bit signed integer attribute
padding_bottom::mlir::IntegerAttr32-bit signed integer attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
weightranked tensor of any type values
biasranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.dealloc (tt::ttir::DeallocOp)

Dealloc op.

Tensor Dealloc operation

Operands:

OperandDescription
resultranked tensor of any type values

ttir.div (tt::ttir::DivOp)

Eltwise divide.

Eltwise divide operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.embedding (tt::ttir::EmbeddingOp)

Embedding op.

Embedding operation.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
weightranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.exp (tt::ttir::ExpOp)

Eltwise exponential op.

Eltwise exponential operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.generic (tt::ttir::GenericOp)

Generically dispatch work to a grid of cores.

This generic op carries a region that represents the work each core does. The region is expected to have the same signature as the op itself. The op is expected to be lowered to a backend specific form by a consuming backend. This op is heavily inspired by the linalg.generic op so it can be useful to refer to linalg.generic documentation for more details.

%5 = "ttir.generic"(%1, %3, %4) <{
  grid = #tt.grid<1x1>,                     // The grid range of cores to dispatch work to.
  indexing_maps = [#map, #map, #map],       // Affine maps for indexing into the input/output tensors. See linalg.generic
  iterator_types = [#parallel, #parallel],  // Iterator types for the input/output tensors. See linalg.generic
  operandSegmentSizes = array<i32: 2, 1>,   // Sizes of the operand segments, i.e. 2 inputs and 1 output.
({
^bb0(%arg2: memref<64x128xf32, #l1_>, %arg3: memref<64x128xf32, #l1_>, %arg4: memref<64x128xf32, #l1_>):
    // Region body, would contain some computation that represents the work each core does.
}) : (tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
grid::mlir::tt::GridAttr
TT grid attribute{{% markdown %}} TT grid attribute {{% /markdown %}}
indexing_maps::mlir::ArrayAttrAffineMap array attribute
iterator_types::mlir::ArrayAttr
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.ge (tt::ttir::GreaterEqualOp)

Eltwise greater than or equal to.

Eltwise greater than or equal to operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.kernel (tt::ttir::KernelOp)

Kernel call.

A generic kernel call operation. This operation is used to pattern match by some consuming backend.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
op::mlir::FlatSymbolRefAttrflat symbol reference attribute
kind::mlir::FlatSymbolRefAttrflat symbol reference attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values or non-0-ranked.memref of any type values
outputsvariadic of ranked tensor of any type values or non-0-ranked.memref of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values or non-0-ranked.memref of any type values

ttir.matmul (tt::ttir::MatmulOp)

Matrix multiply operation.

Matrix multiply operation.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
aranked tensor of any type values
branked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.max (tt::ttir::MaxOp)

Max reduction op.

Max reduction op.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
keep_dim::mlir::BoolAttrbool attribute
dim_arg::mlir::ArrayAttr32-bit integer array attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.max_pool2d (tt::ttir::MaxPool2dOp)

Applies a 2D max pooling over an input signal composed of several input planes.

Applies a 2D max pooling over an input signal composed of several input planes.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
kernel_height::mlir::IntegerAttr32-bit signed integer attribute
kernel_width::mlir::IntegerAttr32-bit signed integer attribute
stride_height::mlir::IntegerAttr32-bit signed integer attribute
stride_width::mlir::IntegerAttr32-bit signed integer attribute
dilation_height::mlir::IntegerAttr32-bit signed integer attribute
dilation_width::mlir::IntegerAttr32-bit signed integer attribute
ceil_mode::mlir::BoolAttrbool attribute
padding_left::mlir::IntegerAttr32-bit signed integer attribute
padding_right::mlir::IntegerAttr32-bit signed integer attribute
padding_top::mlir::IntegerAttr32-bit signed integer attribute
padding_bottom::mlir::IntegerAttr32-bit signed integer attribute
operand_constraints::mlir::ArrayAttr
original_height::mlir::IntegerAttr32-bit signed integer attribute
original_width::mlir::IntegerAttr32-bit signed integer attribute

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.maximum (tt::ttir::MaximumOp)

Eltwise maximum OP.

Calculates maximum of input tensors' values element-wise and stores result in output tensor.

Example: %lhs: [[3, 2, 7], [1, 4, 4]] %rhs: [[1, 4, 2], [1, 2, 3]] "ttir.maximum"(%lhs, %rhs, %out) -> %out: [[3, 4, 7], [1, 4, 4]]

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.mean (tt::ttir::MeanOp)

Mean reduction op.

Mean reduction op.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
keep_dim::mlir::BoolAttrbool attribute
dim_arg::mlir::ArrayAttr32-bit integer array attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.multiply (tt::ttir::MultiplyOp)

Eltwise multiply.

Eltwise multiply operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface, TTIR_GenericRegionOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.neg (tt::ttir::NegOp)

Eltwise negate op.

Eltwise negate operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.reciprocal (tt::ttir::ReciprocalOp)

Eltwise reciprocal.

Eltwise reciprocal operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.relu (tt::ttir::ReluOp)

Eltwise ReLU.

Eltwise ReLU operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.reshape (tt::ttir::ReshapeOp)

Reshape op.

Reshape tensor.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
shape::mlir::ArrayAttr32-bit integer array attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.sigmoid (tt::ttir::SigmoidOp)

Eltwise sigmoid.

Eltwise sigmoid operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.softmax (tt::ttir::SoftmaxOp)

Softmax operation.

Softmax operation.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
dimension::mlir::IntegerAttr32-bit signed integer attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.sqrt (tt::ttir::SqrtOp)

Eltwise square root.

Eltwise square root operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.squeeze (tt::ttir::SqueezeOp)

Squeeze op.

Squeeze tensor.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
dim::mlir::IntegerAttr32-bit signed integer attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.subtract (tt::ttir::SubtractOp)

Eltwise subtract.

Eltwise subtract operation.

Traits: AttrSizedOperandSegments

Interfaces: DestinationStyleOpInterface, TTIROpInterface, TTIR_ElementwiseOpInterface

Attributes:

AttributeMLIR TypeDescription
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputsvariadic of ranked tensor of any type values
outputsvariadic of ranked tensor of any type values

Results:

ResultDescription
resultsvariadic of ranked tensor of any type values

ttir.sum (tt::ttir::SumOp)

Sum reduction op.

Sum reduction op.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
keep_dim::mlir::BoolAttrbool attribute
dim_arg::mlir::ArrayAttr32-bit integer array attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.to_layout (tt::ttir::ToLayoutOp)

Layout op.

ToLayout operation, transition tensors from one layout to another. Some examples include:

  • Transitioning between different memory spaces, e.g. DRAM to L1.
  • Transitioning between different data types, e.g. f32 to f16.
  • Transitioning between different tile sizes, e.g. 1x16 to 32x32
  • Transitioning between different tensor sharding
  • Some combination of the above
#layout = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>>
#layout1 = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>>
%1 = "ttir.to_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.transpose (tt::ttir::TransposeOp)

Transpose op.

Transpose tensor along two given dimensions.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
dim0::mlir::IntegerAttr32-bit signed integer attribute
dim1::mlir::IntegerAttr32-bit signed integer attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.unsqueeze (tt::ttir::UnsqueezeOp)

Unsqueeze op.

Unsqueeze tensor.

Interfaces: DestinationStyleOpInterface, TTIROpInterface

Attributes:

AttributeMLIR TypeDescription
dim::mlir::IntegerAttr32-bit signed integer attribute
operand_constraints::mlir::ArrayAttr

Operands:

OperandDescription
inputranked tensor of any type values
outputranked tensor of any type values

Results:

ResultDescription
resultranked tensor of any type values

ttir.yield (tt::ttir::YieldOp)

Yield op.

Yield operation, this is required by MLIR to mark the end of a dispatch region.

Traits: AlwaysSpeculatableImplTrait, ReturnLike, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface

Effects: MemoryEffects::Effect{}

Operands:

OperandDescription
valuesvariadic of ranked tensor of any type values or non-0-ranked.memref of any type values