ttnn.leaky_relu
(tt::ttnn::LeakyReluOp)
Eltwise leaky relu operation.
The Leaky ReLU (Rectified Linear Unit) operation computes an element-wise activation function over its input tensor. It is defined as:
y = x if x > 0 y = parameter * x if x <= 0
where parameter
is a small, user-defined constant that determines the slope for
negative inputs.
Attributes:
parameter
(float): The slope for negative values.
Inputs:
input
(Tensor): The input tensor to be activated.
Outputs:
output
(Tensor): The tensor after applying the Leaky ReLU activation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
parameter | ::mlir::FloatAttr | 32-bit float attribute |
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.abs
(tt::ttnn::AbsOp)
Eltwise absolute.
Eltwise absolute operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.add
(tt::ttnn::AddOp)
Eltwise add.
Eltwise add operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.all_gather
(tt::ttnn::AllGatherOp)
All gather op.
Tensor All Gather operation
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
dim | ::mlir::IntegerAttr | 32-bit signed integer attribute |
num_links | ::mlir::IntegerAttr | 32-bit signed integer attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.alloc
(tt::ttnn::AllocOp)
Alloc op.
Tensor Alloc operation
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
address | ::mlir::IntegerAttr | 64-bit signless integer attribute |
size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
buffer_type | ::mlir::tt::ttnn::BufferTypeAttr | TTNN Buffer Type{{% markdown %}}Enum cases: * dram (`DRAM`) * l1 (`L1`) * system_memory (`SystemMemory`) * l1_small (`L1Small`) * trace (`Trace`){{% /markdown %}} |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.cbrt
(tt::ttnn::CbrtOp)
Eltwise cubic root.
Eltwise cubic root operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.ceil
(tt::ttnn::CeilOp)
Eltwise ceil.
Eltwise ceil operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.clamp
(tt::ttnn::ClampOp)
Clamp op.
Clamp tensor values to a specified range.
Example: min: 2.000000+00 input: [[0, 1, 2, 3, 4, 5, 6, 7]] max: 5.000000+00
"ttnn.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}> -> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
min | ::mlir::FloatAttr | 32-bit float attribute |
max | ::mlir::FloatAttr | 32-bit float attribute |
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | variadic of ranked tensor of any type values |
ttnn.concat
(tt::ttnn::ConcatOp)
Concat op.
Concat tensors along a given dimension.
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
dim | ::mlir::IntegerAttr | 32-bit signed integer attribute |
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
output | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.conv2d
(tt::ttnn::Conv2dOp)
Conv2d operation.
Applies a 2D convolution over an input image composed of several input planes.
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
in_channels | ::mlir::IntegerAttr | 32-bit signless integer attribute |
out_channels | ::mlir::IntegerAttr | 32-bit signless integer attribute |
batch_size | ::mlir::IntegerAttr | 32-bit signless integer attribute |
input_height | ::mlir::IntegerAttr | 32-bit signless integer attribute |
input_width | ::mlir::IntegerAttr | 32-bit signless integer attribute |
kernel_height | ::mlir::IntegerAttr | 32-bit signless integer attribute |
kernel_width | ::mlir::IntegerAttr | 32-bit signless integer attribute |
stride_height | ::mlir::IntegerAttr | 32-bit signless integer attribute |
stride_width | ::mlir::IntegerAttr | 32-bit signless integer attribute |
padding_height | ::mlir::IntegerAttr | 32-bit signless integer attribute |
padding_width | ::mlir::IntegerAttr | 32-bit signless integer attribute |
dilation_height | ::mlir::IntegerAttr | 32-bit signless integer attribute |
dilation_width | ::mlir::IntegerAttr | 32-bit signless integer attribute |
groups | ::mlir::IntegerAttr | 32-bit signless integer attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
weight | ranked tensor of any type values |
bias | ranked tensor of any type values |
output | ranked tensor of any type values |
device | TT device |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.cos
(tt::ttnn::CosOp)
Eltwise cosine.
Eltwise cosine operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.deallocate
(tt::ttnn::DeallocateOp)
Deallocate op.
Tensor Deallocate operation
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
force | ::mlir::BoolAttr | bool attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
ttnn.div
(tt::ttnn::DivOp)
Eltwise divide.
Eltwise divide operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.embedding
(tt::ttnn::EmbeddingOp)
Embedding op.
Embedding operation.
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
output | ranked tensor of any type values |
weight | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.empty
(tt::ttnn::EmptyOp)
Empty op.
Tensor empty operation
Interfaces: NoMemoryEffect (MemoryEffectOpInterface)
, TTNN_OpModelInterface
Effects: MemoryEffects::Effect{}
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
shape | ::mlir::tt::ttnn::ShapeAttr | TTNN Shape attribute{{% markdown %}} TTNN shape attribute {{% /markdown %}} |
dtype | ::mlir::tt::DataTypeAttr | TT DataTypes{{% markdown %}}Enum cases: * f32 (`Float32`) * f16 (`Float16`) * bf16 (`BFloat16`) * bfp_f8 (`BFP_Float8`) * bfp_bf8 (`BFP_BFloat8`) * bfp_f4 (`BFP_Float4`) * bfp_bf4 (`BFP_BFloat4`) * bfp_f2 (`BFP_Float2`) * bfp_bf2 (`BFP_BFloat2`) * u32 (`UInt32`) * u16 (`UInt16`) * u8 (`UInt8`){{% /markdown %}} |
layout | ::mlir::tt::ttnn::LayoutAttr | TTNN Layout{{% markdown %}}Enum cases: * row_major (`RowMajor`) * tile (`Tile`) * invalid (`Invalid`){{% /markdown %}} |
memory_config | ::mlir::tt::ttnn::MemoryConfigAttr | TTNN MemoryConfig attribute{{% markdown %}} TTNN memory config attribute {{% /markdown %}} |
Operands:
Operand | Description |
---|---|
device | TT device |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.eq
(tt::ttnn::EqualOp)
Eltwise equal to.
Eltwise equal to operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.exp
(tt::ttnn::ExpOp)
Eltwise exponential.
Eltwise exponential operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.expm1
(tt::ttnn::Expm1Op)
Eltwise unary op.
Performs element-wise exponential minus one operation on operand
tensor
and stores the result in the output tensor.
Example: %a: [[0, 1], [0, 0]] "ttnn.exmp1"(%a, %out) -> %out: [[0, 1.71828], [0, 0]]
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.floor
(tt::ttnn::FloorOp)
Eltwise floor op.
Eltwise floor operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.from_device
(tt::ttnn::FromDeviceOp)
FromDevice op.
This op retrieves the input tensor from the given device.
Interfaces: TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.full
(tt::ttnn::FullOp)
Full op.
Tensor full operation
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
fillValue | ::mlir::FloatAttr | 32-bit float attribute |
Operands:
Operand | Description |
---|---|
device | TT device |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.gelu
(tt::ttnn::GeluOp)
Eltwise GELU.
Eltwise GELU operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.get_device
(tt::ttnn::GetDeviceOp)
Get Device op.
This op returns the current runtime device.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
mesh_shape | ::mlir::tt::ttnn::MeshShapeAttr | TTNN Mesh Shape{{% markdown %}} TTNN mesh shape {{% /markdown %}} |
Results:
Result | Description |
---|---|
device | TT device |
ttnn.ge
(tt::ttnn::GreaterEqualOp)
Eltwise greater than or equal to.
Eltwise greater than or equal to operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.gt
(tt::ttnn::GreaterThanOp)
Eltwise greater than.
Eltwise greater than operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.isfinite
(tt::ttnn::IsFiniteOp)
Eltwise isfinite op.
Eltwise isfinite operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.le
(tt::ttnn::LessEqualOp)
Eltwise less than or equal to.
Eltwise less than or equal to operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.lt
(tt::ttnn::LessThanOp)
Eltwise less than.
Eltwise less than operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.log1p
(tt::ttnn::Log1pOp)
Eltwise log1p operation.
Performs element-wise logarithm plus one operation on operand
tensor and
puts the result in the output tensor.
Example: %a: [0.0, -0.999, 7.0, 6.38905621, 15.0] "ttnn.logp1"(%a, %out) -> %out: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.log
(tt::ttnn::LogOp)
Eltwise logarithm.
Eltwise logarithm operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.logical_and
(tt::ttnn::LogicalAndOp)
Eltwise logical and.
Eltwise logical and operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.logical_not
(tt::ttnn::LogicalNotOp)
Eltwise logical not op.
Eltwise logical not operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.logical_or
(tt::ttnn::LogicalOrOp)
Eltwise logical or.
Eltwise logical or operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.logical_xor
(tt::ttnn::LogicalXorOp)
Eltwise logical xor.
Eltwise logical xor operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.matmul
(tt::ttnn::MatmulOp)
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
a | ranked tensor of any type values |
b | ranked tensor of any type values |
output | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.max
(tt::ttnn::MaxOp)
Max reduction op.
Max reduction op.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
keep_dim | ::mlir::BoolAttr | bool attribute |
dim_arg | ::mlir::ArrayAttr | 32-bit integer array attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.max_pool2d
(tt::ttnn::MaxPool2dOp)
Applies a 2D max pooling over an input signal composed of several input planes.
Applies a 2D max pooling over an input signal composed of several input planes.
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
batch_size | ::mlir::IntegerAttr | 32-bit signed integer attribute |
input_height | ::mlir::IntegerAttr | 32-bit signed integer attribute |
input_width | ::mlir::IntegerAttr | 32-bit signed integer attribute |
channels | ::mlir::IntegerAttr | 32-bit signed integer attribute |
kernel_height | ::mlir::IntegerAttr | 32-bit signed integer attribute |
kernel_width | ::mlir::IntegerAttr | 32-bit signed integer attribute |
stride_height | ::mlir::IntegerAttr | 32-bit signed integer attribute |
stride_width | ::mlir::IntegerAttr | 32-bit signed integer attribute |
dilation_height | ::mlir::IntegerAttr | 32-bit signed integer attribute |
dilation_width | ::mlir::IntegerAttr | 32-bit signed integer attribute |
ceil_mode | ::mlir::BoolAttr | bool attribute |
padding_height | ::mlir::IntegerAttr | 32-bit signed integer attribute |
padding_width | ::mlir::IntegerAttr | 32-bit signed integer attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
output | ranked tensor of any type values |
device | TT device |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.maximum
(tt::ttnn::MaximumOp)
Eltwise maximum OP.
Calculates maximum of input tensors' values element-wise and stores result in output tensor.
Example: %lhs: [[3, 2, 7], [1, 4, 4]] %rhs: [[1, 4, 2], [1, 2, 3]] "ttnn.maximum"(%lhs, %rhs, %out) -> %out: [[3, 4, 7], [1, 4, 4]]
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.mean
(tt::ttnn::MeanOp)
Mean reduction op.
Mean reduction op.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
keep_dim | ::mlir::BoolAttr | bool attribute |
dim_arg | ::mlir::ArrayAttr | 32-bit integer array attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.minimum
(tt::ttnn::MinimumOp)
Eltwise minimum OP.
Calculates minimum of input tensors' values element-wise and stores result in output tensor.
Example: %lhs: [[3, 2, 7], [1, 4, 4]] %rhs: [[1, 4, 2], [1, 2, 3]] "ttnn.minimum"(%lhs, %rhs, %out) -> %out: [[1, 2, 2], [1, 2, 3]]
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.multiply
(tt::ttnn::MultiplyOp)
Eltwise multiply.
Eltwise multiply operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.neg
(tt::ttnn::NegOp)
Eltwise negate.
Eltwise negate operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.ne
(tt::ttnn::NotEqualOp)
Eltwise not equal to.
Eltwise not equal to operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.reciprocal
(tt::ttnn::ReciprocalOp)
Eltwise reciprocal.
Eltwise reciprocal operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.reduce_scatter
(tt::ttnn::ReduceScatterOp)
Reduce scatter op.
Tensor Reduce Scatter operation
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
scatter_split_dim | ::mlir::IntegerAttr | 32-bit signed integer attribute |
math_op | ::mlir::IntegerAttr | TTNN Reduce Operation Type{{% markdown %}}Enum cases: * sum (`Sum`) * mean (`Mean`) * max (`Max`) * min (`Min`) * std (`Std`) * var (`Var`){{% /markdown %}} |
num_links | ::mlir::IntegerAttr | 32-bit signed integer attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.relu
(tt::ttnn::ReluOp)
Eltwise ReLU.
Eltwise ReLU operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, OpModel
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.remainder
(tt::ttnn::RemainderOp)
Eltwise remainder.
Performs element-wise remainder of dividend lhs and divisor rhs tensors and produces a result tensor.
Example:
// %lhs: [17, -17, 17, -17] // %rhs: [3, 3, -3, -3] %result = "ttnn.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64> // %result: [2, -2, 2, -2]
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.reshape
(tt::ttnn::ReshapeOp)
Reshape op.
Reshape tensor.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
shape | ::mlir::ArrayAttr | 32-bit integer array attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.rsqrt
(tt::ttnn::RsqrtOp)
Eltwise rsqrt.
Eltwise rsqrt operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.sigmoid
(tt::ttnn::SigmoidOp)
Eltwise sigmoid.
Eltwise sigmoid operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.sign
(tt::ttnn::SignOp)
Eltwise sign operation.
Returns the sign of the operand
element-wise and produces a result
tensor.
Example: %a: [[3, -2, 0], [1, -4, 4]] "ttnn.sign"(%a, %out) -> %out: [[1, -1, 0], [1, -1, 1]]
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.sin
(tt::ttnn::SinOp)
Eltwise sine.
Eltwise sine operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.slice
(tt::ttnn::SliceOp)
Slice op.
Extract a portion of a tensor based on the specified start (begins
), stop (ends
), and step
indices for each dimension.
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
begins | ::mlir::ArrayAttr | 32-bit integer array attribute |
ends | ::mlir::ArrayAttr | 32-bit integer array attribute |
step | ::mlir::ArrayAttr | 32-bit integer array attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
output | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.softmax
(tt::ttnn::SoftmaxOp)
Softmax op.
Softmax operation.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
dimension | ::mlir::IntegerAttr | 32-bit signed integer attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.sqrt
(tt::ttnn::SqrtOp)
Eltwise sqrt.
Eltwise sqrt operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.subtract
(tt::ttnn::SubtractOp)
Eltwise subtract.
Eltwise subtract operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |
ttnn.sum
(tt::ttnn::SumOp)
Sum reduction op.
Sum reduction op.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
keep_dim | ::mlir::BoolAttr | bool attribute |
dim_arg | ::mlir::ArrayAttr | 32-bit integer array attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.to_device
(tt::ttnn::ToDeviceOp)
ToDevice op.
This op sends the input tensor to the given device with the given memory config.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
memory_config | ::mlir::tt::ttnn::MemoryConfigAttr | TTNN MemoryConfig attribute{{% markdown %}} TTNN memory config attribute {{% /markdown %}} |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
device | TT device |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.to_layout
(tt::ttnn::ToLayoutOp)
ToLayout op.
This op wraps all layout information gathered from ttir.toLayout. It is used/updated by the optimizer to perform optimizations, and later broken down into specific memory/layout operations (toDevice, toMemoryConfig etc.). Currently in the TTNN backend, we use this op solely for tilize/untilize, therefore marking all other attrs as optional. Once ttnn::to_layout supports other attrs, we can remove the optional tag.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
layout | ::mlir::tt::ttnn::LayoutAttr | TTNN Layout{{% markdown %}}Enum cases: * row_major (`RowMajor`) * tile (`Tile`) * invalid (`Invalid`){{% /markdown %}} |
dtype | ::mlir::tt::DataTypeAttr | TT DataTypes{{% markdown %}}Enum cases: * f32 (`Float32`) * f16 (`Float16`) * bf16 (`BFloat16`) * bfp_f8 (`BFP_Float8`) * bfp_bf8 (`BFP_BFloat8`) * bfp_f4 (`BFP_Float4`) * bfp_bf4 (`BFP_BFloat4`) * bfp_f2 (`BFP_Float2`) * bfp_bf2 (`BFP_BFloat2`) * u32 (`UInt32`) * u16 (`UInt16`) * u8 (`UInt8`){{% /markdown %}} |
memory_config | ::mlir::tt::ttnn::MemoryConfigAttr | TTNN MemoryConfig attribute{{% markdown %}} TTNN memory config attribute {{% /markdown %}} |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
device | TT device |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.to_memory_config
(tt::ttnn::ToMemoryConfigOp)
ToMemoryConfig op.
This op converts the memory config of the input tensor based on the given memory config. It handles:
- Dram to L1
- L1 to Dram
- Interleaved to sharded
- Sharded to interleaved
- Sharded to sharded (reshard)
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
memory_config | ::mlir::tt::ttnn::MemoryConfigAttr | TTNN MemoryConfig attribute{{% markdown %}} TTNN memory config attribute {{% /markdown %}} |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.transpose
(tt::ttnn::TransposeOp)
Transpose op.
Transpose tensor along two given dimensions.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
dim0 | ::mlir::IntegerAttr | 32-bit signed integer attribute |
dim1 | ::mlir::IntegerAttr | 32-bit signed integer attribute |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.typecast
(tt::ttnn::TypecastOp)
Typecast op.
This op converts the data type of the input tensor based on the given data type. It handles:
- conversions of data types.
Interfaces: TTNN_OpModelInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
dtype | ::mlir::tt::DataTypeAttr | TT DataTypes{{% markdown %}}Enum cases: * f32 (`Float32`) * f16 (`Float16`) * bf16 (`BFloat16`) * bfp_f8 (`BFP_Float8`) * bfp_bf8 (`BFP_BFloat8`) * bfp_f4 (`BFP_Float4`) * bfp_bf4 (`BFP_BFloat4`) * bfp_f2 (`BFP_Float2`) * bfp_bf2 (`BFP_BFloat2`) * u32 (`UInt32`) * u16 (`UInt16`) * u8 (`UInt8`){{% /markdown %}} |
Operands:
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results:
Result | Description |
---|---|
result | ranked tensor of any type values |
ttnn.where
(tt::ttnn::WhereOp)
Eltwise where.
Eltwise where operation.
Traits: AttrSizedOperandSegments
Interfaces: DestinationStyleOpInterface
, TTNN_OpModelInterface
Operands:
Operand | Description |
---|---|
inputs | variadic of ranked tensor of any type values |
outputs | variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results | variadic of ranked tensor of any type values |