ttir_builder.ops

class ttir_builder.ops.TTIRBuilderOps

Bases: object

abs(in0, unit_attrs=None)

Creates ttir.abs.

Elementwise absolute value operation.

Computes the absolute value of each element in the input tensor.

// Compute absolute values of all elements in %input
%result = ttir.abs(%input, %output) : tensor<4x4xf32>, tensor<4x4xf32> -> tensor<4x4xf32>
// Input tensor:
// [[-2.5,  3.7,  0.0,  1.2], ... ]
// Output tensor:
// [[2.5, 3.7, 0.0, 1.2], ... ]

// Example with integer tensor
%result = ttir.abs(%int_input, %int_output) : tensor<10xi32>, tensor<10xi32> -> tensor<10xi32>
// Input tensor:
// [-5, 0, 3, -2, ...]
// Output tensor:
// [5, 0, 3, 2, ...]
  • Parameters:
    • in0 (Operand) – Input tensor to compute absolute value of
    • unit_attrs (Optional[List[str]]) – Optional list of unit attributes
  • Return type: (OpView)

add(in0, in1, unit_attrs=None)

Creates ttir.add.

Elementwise addition operation.

Performs elementwise addition between two tensors. For each pair of corresponding elements, adds the element in the second tensor to the element in the first tensor.

Mathematical definition: add(x, y) = x + y

// Add corresponding elements
%result = ttir.add(%lhs, %rhs, %output) : tensor<3xf32>, tensor<3xf32>, tensor<3xf32> -> tensor<3xf32>
// Input tensors:
// lhs: [3.5, 0.0, -1.2]
// rhs: [1.5, 2.0, -3.2]
// Output tensor:
// [5.0, 2.0, -4.4]
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the elementwise sum of the inputs
  • Return type: *OpView*

all_gather(input, all_gather_dim=None, cluster_axis=None)

Creates ttir.all_gather.

Gather tensor data from all devices.

Collects tensor data from all devices in the system and concatenates them along the specified dimension. The gather operation can be performed along different axes of the device mesh.

For a mesh shape of [2,4] with device IDs: [[0, 1, 2, 3], [4, 5, 6, 7]]

  • If cluster_axis=0: Gathers along columns (0,4), (1,5), (2,6), (3,7)
  • If cluster_axis=1: Gathers along rows (0,1,2,3), (4,5,6,7)
// Gather tensor data from all devices along dimension 0
%result = ttir.all_gather(%input) {all_gather_dim = 0, cluster_axis = 1} : tensor<32x64xf32> -> tensor<128x64xf32>
// Input tensor on device 0:
// [[1.0, 2.0],
//  [3.0, 4.0]]
// Output tensor after gathering:
// [[1.0, 2.0],  // from device 0
//  [5.0, 6.0],  // from device 1
//  [9.0, 10.0], // from device 2
//  [13.0, 14.0]] // from device 3
  • Parameters:
    • input (Operand) – Input tensor to be gathered
    • all_gather_dim (int , optional) – Dimension along which to concatenate gathered tensors
    • cluster_axis (int , optional) – Axis of device mesh for gathering (0 or 1)
  • Return type: *OpView*

all_reduce(input, reduce_type, cluster_axis)

Creates ttir.all_reduce.

AllReduce operation.

AllReduce op.

  • Parameters:
    • input (Operand) – Input tensor to be reduced
    • reduce_type (str) – Type of reduction operation (e.g., “sum”, “max”)
    • cluster_axis (int) – Axis of device mesh for reduction (0 or 1)
  • Return type: *OpView*

arange(result, start, end, step, arange_dimension, unit_attrs=None)

Creates ttir.arange.

Creates a 1-D tensor of sequential values.

Returns a 1-D tensor of size (end - start) / step with values from start to end taken with common difference step.

  • Parameters:
    • start (int) – Starting value
    • end (int) – Ending value (exclusive)
    • step (int , optional) – Step size between values (default: 1)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • result (Value | OpView | Operation)
    • arange_dimension (int)
  • Returns: 1-D tensor with sequential values
  • Return type: *OpView*

argmax(in0, dim_arg, keep_dim=False, unit_attrs=None)

Creates ttir.argmax.

Argmax reduction operation.

Returns the indices of the maximum values along the specified dimensions.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim_arg (List *[*int ]) – Dimensions to reduce over
    • keep_dim (bool , optional) – If True, retains reduced dimensions with length 1 (default: False)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor containing the indices of maximum values
  • Return type: *OpView*

atan(in0, unit_attrs=None)

Creates ttir.atan.

Elementwise arctangent operation.

Computes the inverse tangent (arctangent) of each element in the input tensor.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with arctangent values
  • Return type: *OpView*

bitwise_and(in0, in1, unit_attrs=None)

Creates ttir.bitwise_and.

Elementwise bitwise AND operation.

Performs elementwise bitwise AND operation between two tensors. For each pair of corresponding elements, performs a bitwise AND on their binary representations.

This operation is typically used with integer data types and has the following properties:

  • Commutative: bitwise_and(x, y) = bitwise_and(y, x)
  • Associative: bitwise_and(x, bitwise_and(y, z)) = bitwise_and(bitwise_and(x, y), z)
  • Identity: bitwise_and(x, -1) = x
  • Zero: bitwise_and(x, 0) = 0
// Bitwise AND with integer tensors
%result = ttir.bitwise_and(%lhs, %rhs, %output) : tensor<3xi8>, tensor<3xi8>, tensor<3xi8> -> tensor<3xi8>
// Input tensors:
// lhs: [5, 3, 255]  (binary: [00000101, 00000011, 11111111])
// rhs: [3, 6, 129]   (binary: [00000011, 00000110, 10000001])
// Output tensor:
// [1, 2, 129]    (binary: [00000001, 00000010, 10000001])
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

bitwise_or(in0, in1, unit_attrs=None)

Creates ttir.bitwise_or.

Elementwise bitwise OR operation.

Performs elementwise bitwise OR operation between two tensors. For each pair of corresponding elements, performs a bitwise OR on their binary representations.

This operation is typically used with integer data types and has the following properties:

  • Commutative: bitwise_or(x, y) = bitwise_or(y, x)
  • Associative: bitwise_or(x, bitwise_or(y, z)) = bitwise_or(bitwise_or(x, y), z)
  • Identity: bitwise_or(x, 0) = x
  • One: bitwise_or(x, -1) = -1
// Bitwise OR with integer tensors
%result = ttir.bitwise_or(%lhs, %rhs, %output) : tensor<3xi8>, tensor<3xi8>, tensor<3xi8> -> tensor<3xi8>
// Input tensors:
// lhs: [5, 3, 255]  (binary: [00000101, 00000011, 11111111])
// rhs: [3, 6, 129]   (binary: [00000011, 00000110, 10000001])
// Output tensor:
// [7, 7, 255]    (binary: [00000111, 00000111, 11111111])
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

bitwise_xor(in0, in1, unit_attrs=None)

Creates ttir.bitwise_xor.

Elementwise bitwise XOR operation.

Performs elementwise bitwise XOR (exclusive OR) operation between two tensors. For each pair of corresponding elements, performs a bitwise XOR on their binary representations.

// Bitwise XOR with integer tensors
%result = ttir.bitwise_xor(%input1, %input2, %output) : tensor<2x2xi32>, tensor<2x2xi32> -> tensor<2x2xi32>
// Input1 tensor:
// [[1, 3],  // binary: [[0001, 0011],
//  [5, 7]]  //         [0101, 0111]]
// Input2 tensor:
// [[2, 3],  // binary: [[0010, 0011],
//  [6, 7]]  //         [0110, 0111]]
// Output tensor:
// [[3, 0],  // binary: [[0011, 0000],
//  [3, 0]]  //         [0011, 0000]]
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the bitwise XOR of corresponding elements
  • Return type: *OpView*

broadcast(in0, in1, broadcast_dimensions, unit_attrs=None)

Creates ttir.broadcast.

Tensor broadcast operation.

Broadcasts a tensor to a new shape by replicating its values along specified dimensions. The broadcast_dimensions parameter specifies how dimensions of the input map to dimensions of the output.

// Broadcast a 1D tensor to 2D
%result = ttir.broadcast(%input, %output, broadcast_dimensions = [1]) : tensor<3xf32>, tensor<2x3xf32> -> tensor<2x3xf32>
// Input tensor:
// [1.0, 2.0, 3.0]
// Output tensor:
// [[1.0, 2.0, 3.0],
//  [1.0, 2.0, 3.0]]
  • Parameters:
    • in0 (Operand) – Input tensor to broadcast
    • in1 (Operand) – Output tensor with target shape
    • broadcast_dimensions (List *[*int ]) – List of dimension mappings from input to output
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The broadcasted tensor
  • Return type: *OpView*

cbrt(in0, unit_attrs=None)

Creates ttir.cbrt.

Elementwise cubic root operation.

Computes the cubic root (∛) of each element in the input tensor. For each element, returns the real-valued number that, when cubed, equals the input value. Unlike square root, cubic root is defined for negative numbers as well as positive numbers.

// Compute cubic root of all elements
%result = ttir.cbrt(%input, %output) : tensor<4xf32>, tensor<4xf32> -> tensor<4xf32>
// Input tensor:
// [8.0, 27.0, -8.0, 1.0]
// Output tensor:
// [2.0, 3.0, -2.0, 1.0]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the cubic root of each element in the input tensor
  • Return type: *OpView*

ceil(in0, unit_attrs=None)

Creates ttir.ceil.

Elementwise ceiling operation.

Computes the ceiling of each element in the input tensor, rounding up to the nearest integer. This operation is idempotent.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with ceiling values
  • Return type: *OpView*

clamp_scalar(in0, min_arg=None, max_arg=None, unit_attrs=None)

  • Parameters:
    • in0 (Value | OpView | Operation)
    • min_arg (float | None)
    • max_arg (float | None)
    • unit_attrs (List *[*str ] | None)
  • Return type: OpView

clamp_tensor(in0, in1, in2, in3, unit_attrs=None)

  • Parameters:
    • in0 (Value | OpView | Operation)
    • in1 (Value | OpView | Operation)
    • in2 (Value | OpView | Operation)
    • in3 (Value | OpView | Operation)
    • unit_attrs (List *[*str ] | None)
  • Return type: OpView

collective_permute(input, source_target_pairs)

Creates ttir.collective_permute.

Collective permute operation.

Collective permute op. This operation ingests a multi-device tensor spread across multi-devices and will shuffle the data according to source_target_pairs [[‘src’, ‘dest’]].

Example

For a 1x2 mesh, the following will take the device shard living in device 0 and move it to device 1. The device shard living in device 1 will move to device 0.

%source_target_pairs: [[0, 1], [1, 0]]

In the case of missing ‘dest’, the device shard living on that device will contain values of 0. For example, device shard living in device 0 will contain 0 values. %source_target_pairs: [[0, 1]]

  • Parameters:
    • input (Operand) – The input tensor to be permuted
    • source_target_pairs (List *[*Tuple *[*int , int ] ]) – List of pairs of source and target device ids
  • Return type: *OpView*

concat(ins, dim=0, unit_attrs=None)

Creates ttir.concat.

Tensor concatenation operation.

Concatenates the given sequence of tensors in the given dimension. All tensors must have the same shape, except in the concatenating dimension.

  • Parameters:
    • ins (List *[*Operand ]) – List of input tensors to concatenate
    • dim (int , optional) – Dimension along which to concatenate (default: 0)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Concatenated tensor
  • Return type: *OpView*

conv2d(in0, weight, bias, in1, stride, padding, dilation, groups, unit_attrs=None)

Creates ttir.conv2d.

Conv2d operation.

Applies a 2D convolution over an input image composed of several input planes. This operation performs a 2D convolution on the input tensor using the provided weight tensor and optional bias. It supports configurable stride, padding, dilation, and grouping parameters.

// Basic 2D convolution
%input = ... : tensor<1x28x28x3xf32>    // Batch size 1, 28x28 image, 3 channels
%weight = ... : tensor<16x3x3x3xf32>    // 16 output channels, 3 input channels, 3x3 kernel
%bias = ... : tensor<1x1x1x16xf32>      // Bias for 16 output channels
%output = ttir.empty() : tensor<1x26x26x16xf32>  // Output shape with no padding
%result = ttir.conv2d(%input, %weight, %bias, %output) {
    stride = [1, 1],
    padding = [0, 0, 0, 0],
    dilation = [1, 1],
    groups = 1
}
  • Parameters:
    • in0 (Operand) – Input tensor in (N, H_in, W_in, C) format
    • weight (Operand) – Weight tensor in (O, C/G, K_H, K_W) format
    • bias (Optional *[*Operand ]) – Optional bias tensor in (1, 1, 1, O) format
    • output (Operand) – Output tensor specification
    • stride (Union *[*int , List *[*int ] ] , optional) – Stride for height and width dimensions (default: 1)
    • padding (Union *[*int , List *[*int ] ] , optional) – Padding for all sides or [top, left, bottom, right] (default: 0)
    • dilation (Union *[*int , List *[*int ] ] , optional) – Spacing between kernel elements (default: 1)
    • groups (int , optional) – Number of blocked connections from input to output channels (default: 1)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • in1 (Value | OpView | Operation)
  • Returns: Output tensor after convolution
  • Return type: *OpView*

conv_transpose2d(in0, weight, bias, in1, stride, padding, output_padding, dilation, groups, unit_attrs=None)

Creates ttir.conv_transpose2d.

2D transposed convolution operation.

Applies a 2D transposed convolution over an input image. This operation can be seen as the gradient of Conv2d with respect to its input. Also known as a deconvolution or fractionally strided convolution.

// Apply 2D transposed convolution
%result = ttir.conv_transpose2d(%input, %weight, %bias, %output,
                              stride = [2, 2], padding = [1, 1],
                              output_padding = [1, 1], dilation = [1, 1],
                              groups = 1) :
    tensor<1x1x4x4xf32>, tensor<1x1x3x3xf32>, tensor<1xf32>, tensor<1x1x9x9xf32> -> tensor<1x1x9x9xf32>
// Input tensor: 4x4 feature map
// Weight tensor: 3x3 kernel
// Output tensor: 9x9 upsampled feature map
  • Parameters:
    • in0 (Operand) – Input tensor of shape (batch, in_channels, height, width)
    • weight (Operand) – Weight tensor of shape (in_channels, out_channels/groups, kernel_height, kernel_width)
    • bias (Optional *[*Operand ]) – Optional bias tensor of shape (out_channels)
    • in1 (Operand) – Output tensor shape reference
    • stride (Union *[*int , List *[*int ] ]) – Stride of the convolution
    • padding (Union *[*int , List *[*int ] ]) – Padding added to input
    • output_padding (Union *[*int , List *[*int ] ]) – Additional size added to output shape
    • dilation (Union *[*int , List *[*int ] ]) – Dilation of the kernel
    • groups (int) – Number of blocked connections from input to output channels
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The output tensor after transposed convolution
  • Return type: *OpView*

cos(in0, unit_attrs=None)

Creates ttir.cos.

Elementwise cosine operation.

Computes the cosine of each element in the input tensor. Input values are expected to be in radians.

// Compute cosine of all elements
%result = ttir.cos(%input, %output) : tensor<4xf32>, tensor<4xf32> -> tensor<4xf32>
// Input tensor (in radians):
// [0.0, 3.14159, 1.5708, -1.5708]
// Output tensor:
// [1.0, -1.0, 0.0, 0.0]
  • Parameters:
    • in0 (Operand) – Input tensor (values in radians)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the cosine of each element in the input tensor
  • Return type: *OpView*

cumsum(in0, in1, dim, unit_attrs=None)

Creates ttir.cumsum.

Cumulative sum operation.

Computes the cumulative sum of elements along a specified dimension. For each element at index i in the dimension, computes the sum of all elements with indices ≤ i in that dimension.

// Compute cumulative sum along dimension 1
%result = ttir.cumsum(%input, %output, dim = 1) : tensor<2x3xf32>, tensor<2x3xf32> -> tensor<2x3xf32>
// Input tensor:
// [[1.0, 2.0, 3.0],
//  [4.0, 5.0, 6.0]]
// Output tensor:
// [[1.0, 3.0, 6.0],
//  [4.0, 9.0, 15.0]]
  • Parameters:
    • in0 (Operand) – Input tensor
    • in1 (Operand) – Output tensor
    • dim (int) – Dimension along which to compute cumulative sum
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the cumulative sums along the specified dimension
  • Return type: *OpView*

dequantize(in0, scale, zero_point, dtype, unit_attrs=None)

Creates ttir.dequantize.

Dequantize integer tensor to floating-point tensor.

Converts a quantized integer tensor back into a floating-point tensor using the specified scale and zero_point parameters. For each element in the input tensor, computes: output[i] = (input[i] - zero_point) * scale

// Dequantize int8 tensor to float32
%result = ttir.dequantize(%input, %output) {scale = 0.1 : f32, zero_point = 128 : i32} : tensor<2x2xi8>, tensor<2x2xf32> -> tensor<2x2xf32>
// Input tensor:
// [[143, 126],
//  [128, 165]]
// Output tensor:
// [[1.5, -0.2],
//  [0.0, 3.7]]
  • Parameters:
    • in0 (Operand) – Input quantized integer tensor to be dequantized
    • scale (float) – Scale factor used in the original quantization
    • zero_point (int) – Integer value that represents 0.0 in the quantized space
    • dtype (torch.dtype) – Target floating-point data type (e.g., torch.float32)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The dequantized floating-point tensor
  • Return type: *OpView*

div(in0, in1, unit_attrs=None)

Creates ttir.div.

Elementwise division operation.

Performs elementwise division between two tensors. For each pair of corresponding elements, divides the element in the first tensor by the element in the second tensor.

Note: Division by zero behavior depends on the implementation and data type.

Mathematical definition: div(x, y) = x / y

// Divide corresponding elements
%result = ttir.div(%lhs, %rhs, %output) : tensor<3xf32>, tensor<3xf32>, tensor<3xf32> -> tensor<3xf32>
// Input tensors:
// lhs: [3.5, 0.0, -1.2]
// rhs: [1.5, 2.0, -3.2]
// Output tensor:
// [2.333, 0.0, 0.375]
  • Parameters:
    • in0 (Operand) – First input tensor (dividend)
    • in1 (Operand) – Second input tensor (divisor)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the elementwise quotient of the inputs
  • Return type: *OpView*

dot_general(in0, in1, out0, batch_dims_lhs, contract_dims_lhs, batch_dims_rhs, contract_dims_rhs, unit_attrs=None)

Creates ttir.dot_general.

Generalized dot product operation.

Flexible tensor operation that generalizes matrix multiplication by allowing user to specify which dimensions of two tensors to contract. Matrix multiplication is a special case of this operation, where the contraction happens along the last axis of the first tensor and the second-to-last axis of the second tensor. From StableHLO DotGeneral Op https://openxla.org/stablehlo/spec#dot_general

  • Parameters:
    • in0 (Operand) – Left-hand side input tensor
    • in1 (Operand) – Right-hand side input tensor
    • out0 (Operand) – Output tensor
    • batch_dims_lhs (List *[*int ]) – Batch dimensions for the left-hand side tensor
    • contract_dims_lhs (List *[*int ]) – Contracting dimensions for the left-hand side tensor
    • batch_dims_rhs (List *[*int ]) – Batch dimensions for the right-hand side tensor
    • contract_dims_rhs (List *[*int ]) – Contracting dimensions for the right-hand side tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

embedding(in0, in1, unit_attrs=None)

Creates ttir.embedding.

Embedding lookup operation.

Performs a lookup in an embedding table (in1) using indices (in0). Returns a tensor containing the embeddings for the given indices.

// Lookup embeddings for indices
%result = ttir.embedding(%indices, %weights, %output) : tensor<2xi32>, tensor<4x3xf32> -> tensor<2x3xf32>
// Indices tensor:
// [1, 3]  // Looking up embeddings at indices 1 and 3
// Weights tensor (embedding table):
// [[0.1, 0.2, 0.3],  // embedding 0
//  [0.4, 0.5, 0.6],  // embedding 1
//  [0.7, 0.8, 0.9],  // embedding 2
//  [1.0, 1.1, 1.2]]  // embedding 3
// Output tensor:
// [[0.4, 0.5, 0.6],  // embedding for index 1
//  [1.0, 1.1, 1.2]]  // embedding for index 3
  • Parameters:
    • in0 (Operand) – Input tensor containing indices
    • in1 (Operand) – Weight tensor containing embeddings
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the embeddings for the input indices
  • Return type: *OpView*

eq(in0, in1, unit_attrs=None)

Creates ttir.eq.

Elementwise equality comparison operation.

Performs an elementwise equality comparison between two tensors. For each pair of corresponding elements, returns:

  • 1 (true) if the elements are equal
  • 0 (false) if the elements are not equal

Note that special handling may be required for floating-point NaN values, as NaN is not equal to any value, including itself.

Mathematical definition: equal(x, y) = x == y

// Compare elements for equality
%result = ttir.eq(%lhs, %rhs, %output) : tensor<3xf32>, tensor<3xf32>, tensor<3xi1> -> tensor<3xi1>
// Input tensors:
// lhs: [1.0, 2.0, 3.0]
// rhs: [1.0, 2.0, 4.0]
// Output tensor:
// [1, 1, 0]
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

exp(in0, unit_attrs=None)

Creates ttir.exp.

Elementwise exponential operation.

Computes the exponential of each element in the input tensor. For each element x, returns e^x, where e is Euler’s number (approximately 2.71828).

// Compute exponential of all elements
%result = ttir.exp(%input, %output) : tensor<3xf32>, tensor<3xf32> -> tensor<3xf32>
// Input tensor:
// [0.0, 1.0, 2.0]
// Output tensor:
// [1.0, 2.71828, 7.38906]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the exponential of each element in the input tensor
  • Return type: *OpView*

expm1(in0, unit_attrs=None)

Creates ttir.expm1.

Elementwise exponential minus one operation.

Computes e^x - 1 for each element in the input tensor, where e is Euler’s number. This operation provides better numerical precision than computing exp(x) - 1 directly, especially for small values of x.

// Compute exp(x) - 1 for all elements
%result = ttir.expm1(%input, %output) : tensor<3xf32>, tensor<3xf32> -> tensor<3xf32>
// Input tensor:
// [0.0, 0.1, -0.1]
// Output tensor:
// [0.0, 0.10517, -0.09516]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing exp(x) - 1 for each element x in the input tensor
  • Return type: *OpView*

fill_cache(in0, in1, batch_offset=0, unit_attrs=None)

Creates ttir.fill_cache.

Cache fill operation.

Fills a cache tensor with new values starting at a specified batch offset. This operation is typically used in sequence models to initialize or update cached states.

// Fill cache with new values at batch offset 1
%result = ttir.fill_cache(%new_values, %cache, batch_offset = 1) : tensor<2x3xf32>, tensor<4x3xf32> -> tensor<4x3xf32>
// New values tensor:
// [[1.0, 2.0, 3.0],
//  [4.0, 5.0, 6.0]]
// Cache tensor before:
// [[0.1, 0.2, 0.3],
//  [0.4, 0.5, 0.6],
//  [0.7, 0.8, 0.9],
//  [1.0, 1.1, 1.2]]
// Cache tensor after:
// [[0.1, 0.2, 0.3],
//  [1.0, 2.0, 3.0],
//  [4.0, 5.0, 6.0],
//  [1.0, 1.1, 1.2]]
  • Parameters:
    • in0 (Operand) – New values to fill into cache
    • in1 (Operand) – Cache tensor to be filled
    • batch_offset (int , optional) – Starting position in batch dimension (default: 0)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The updated cache tensor
  • Return type: *OpView*

floor(in0, unit_attrs=None)

Creates ttir.floor.

Elementwise floor operation.

Computes the floor of each element in the input tensor, rounding down to the nearest integer. This operation is idempotent.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with floor values
  • Return type: *OpView*

gather(input, start_indices, offset_dims, collapsed_slice_dims, operand_batching_dims, start_indices_batching_dims, start_index_map, index_vector_dim, slice_sizes, indices_are_sorted=False, unit_attrs=None)

Creates ttir.gather.

Gather operation.

Collects slices from an input tensor at positions specified by start indices. This operation is based on the StableHLO Gather operation and allows for flexible slicing and indexing of tensors. It can be used to implement operations like array indexing, slicing, dynamic indexing, and more complex gathering patterns.

// Basic gather example: gather elements from a 2D tensor using indices
%input = ... : tensor<5x3xf32>         // Input tensor with shape [5,3]
%indices = ... : tensor<2xi64>         // Indices tensor with values [2, 1]
%output = ttir.empty() : tensor<3xf32> // Output tensor
%result = ttir.gather(%input, %indices, %output) {
    offset_dims = [0],                 // Output dimensions that are gathered from input
    collapsed_slice_dims = [0],        // Input dimensions that are collapsed
    operand_batching_dims = [],        // Batch dimensions of the input
    start_indices_batching_dims = [],  // Batch dimensions of the indices
    start_index_map = [0],             // Maps indices to input dimensions
    index_vector_dim = 0,              // Which dimension of indices contains the index vector
    slice_sizes = [1, 3],              // Size of the slice to extract from each position
    indices_are_sorted = false         // Whether indices are sorted
} : tensor<5x3xf32>, tensor<2xi64>, tensor<3xf32> -> tensor<3xf32>
  • Parameters:
    • input (Operand) – The tensor from which to gather values
    • start_indices (Operand) – Tensor containing the starting indices for slices
    • offset_dims (List *[*int ]) – Output dimensions that correspond to dimensions of the gathered slice
    • collapsed_slice_dims (List *[*int ]) – Input dimensions that are collapsed when gathering
    • operand_batching_dims (List *[*int ]) – Batch dimensions of the input tensor
    • start_indices_batching_dims (List *[*int ]) – Batch dimensions of the indices tensor
    • start_index_map (List *[*int ]) – Maps index values to input dimensions
    • index_vector_dim (int) – Which dimension of indices contains the index vector
    • slice_sizes (List *[*int ]) – Size of the slice to extract from each position
    • indices_are_sorted (bool , optional) – Whether indices are sorted (for optimization)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The gathered tensor
  • Return type: *OpView*

ge(in0, in1, unit_attrs=None)

Creates ttir.ge.

Elementwise greater than or equal to comparison operation.

Performs elementwise greater than or equal to comparison between two tensors. For each pair of corresponding elements, returns:

  • 1 (true) if the left element is greater than or equal to the right element
  • 0 (false) if the left element is less than the right element

Mathematical definition: greater_equal(x, y) = x >= y

// Compare elements for greater than or equal to
%result = ttir.ge(%lhs, %rhs, %output) : tensor<4xf32>, tensor<4xf32>, tensor<4xi1> -> tensor<4xi1>
// Input tensors:
// lhs: [1.0, 2.0, 3.0, 2.0]
// rhs: [1.0, 2.0, 4.0, 5.0]
// Output tensor:
// [1, 1, 0, 0]
  • Parameters:
    • in0 (Operand) – First input tensor (left-hand side)
    • in1 (Operand) – Second input tensor (right-hand side)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

gelu(in0, unit_attrs=None)

Creates ttir.gelu.

Elementwise GELU operation.

Computes the GELU (Gaussian Error Linear Unit) of each element in the input tensor. GELU is a smooth, non-monotonic activation function that approximates the cumulative distribution function of a standard normal distribution.

Mathematical definition: gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))

// Compute GELU of all elements
%result = ttir.gelu(%input, %output) : tensor<4xf32>, tensor<4xf32> -> tensor<4xf32>
// Input tensor:
// [1.0, -0.5, 2.0, -2.0]
// Output tensor:
// [0.841, -0.154, 1.954, -0.046]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the GELU values of each element in the input tensor
  • Return type: *OpView*

get_dimension_size(in0, dimension=0, unit_attrs=None)

Creates ttir.get_dimension_size.

Dimension size query operation.

Produces the size of the given dimension of the operand.

%operand: [[3, 2, 7], [1, 4, 4]]
"ttir.get_dimension_size"(%operand, value = dense<0>, %out) -> %out: [[3]]
  • Parameters:
    • in0 (Operand) – Input tensor operand to get dimension size from
    • dimension (int , optional) – The dimension index to get size of (default: 0)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

gt(in0, in1, unit_attrs=None)

Creates ttir.gt.

Elementwise greater than comparison operation.

Performs elementwise greater than comparison between two tensors. For each pair of corresponding elements, returns:

  • 1 (true) if the left element is greater than the right element
  • 0 (false) if the left element is less than or equal to the right element

Mathematical definition: greater(x, y) = x > y

// Compare elements for greater than
%result = ttir.gt(%lhs, %rhs, %output) : tensor<4xf32>, tensor<4xf32>, tensor<4xi1> -> tensor<4xi1>
// Input tensors:
// lhs: [1.0, 2.0, 3.0, 2.0]
// rhs: [1.0, 1.0, 4.0, 5.0]
// Output tensor:
// [0, 1, 0, 0]
  • Parameters:
    • in0 (Operand) – First input tensor (left-hand side)
    • in1 (Operand) – Second input tensor (right-hand side)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

index(in0, dim, begin, end, step, unit_attrs=None)

Creates ttir.index.

Tensor indexing operation.

Indexes into the input tensor along the specified dimension using a range of indices.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim (int) – Dimension to index into
    • begin (int) – Starting index
    • end (int) – Ending index (exclusive)
    • step (int) – Step size between indices
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The indexed tensor
  • Return type: *OpView*

is_finite(in0, unit_attrs=None)

Creates ttir.is_finite.

Elementwise finite check operation.

Checks if each element in the input tensor is finite (neither infinite nor NaN). For each element, returns a boolean value indicating whether the element is finite.

Mathematical definition: isfinite(x) = x ∈ ℝ

// Check if elements are finite
%result = ttir.is_finite(%input, %output) : tensor<4xf32>, tensor<4xi1> -> tensor<4xi1>
// Input tensor:
// [1.0, inf, -inf, nan]
// Output tensor:
// [true, false, false, false]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

le(in0, in1, unit_attrs=None)

Creates ttir.le.

Elementwise less than or equal to comparison operation.

Performs elementwise less than or equal to comparison between two tensors. For each pair of corresponding elements, returns:

  • 1 (true) if the left element is less than or equal to the right element
  • 0 (false) if the left element is greater than the right element

Mathematical definition: less_equal(x, y) = x <= y

// Compare elements for less than or equal to
%result = ttir.le(%lhs, %rhs, %output) : tensor<4xf32>, tensor<4xf32>, tensor<4xi1> -> tensor<4xi1>
// Input tensors:
// lhs: [1.0, 2.0, 3.0, 2.0]
// rhs: [1.0, 2.0, 4.0, 5.0]
// Output tensor:
// [1, 1, 1, 1]
  • Parameters:
    • in0 (Operand) – First input tensor (left-hand side)
    • in1 (Operand) – Second input tensor (right-hand side)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

leaky_relu(in0, parameter=0.01, unit_attrs=None)

Creates ttir.leaky_relu.

Elementwise leaky ReLU activation operation.

Computes a leaky version of the Rectified Linear Unit (ReLU) activation function. For each element x in the input tensor:

  • If x > 0: returns x
  • If x ≤ 0: returns parameter * x

The parameter controls the slope for negative values, allowing a small gradient when the unit is not active.

// Compute leaky ReLU with slope 0.01 for negative values
%result = ttir.leaky_relu(%input, %output) : tensor<4xf32>, tensor<4xf32> -> tensor<4xf32>
// Input tensor:
// [2.0, -1.0, 0.0, -3.0]
// Output tensor:
// [2.0, -0.01, 0.0, -0.03]
  • Parameters:
    • in0 (Operand) – Input tensor to be activated
    • parameter (float , optional) – Slope for negative values (default: 0.01)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the leaky ReLU activation values
  • Return type: *OpView*

linear(in0, in1, bias=None, transpose_a=False, transpose_b=False, unit_attrs=None)

Creates ttir.linear.

Linear transformation operation.

Applies a linear transformation to the incoming data: y = xA^T + b

  • Parameters:
    • in0 (Operand) – Input tensor
    • weight (Operand) – Weight matrix
    • bias (Optional *[*Operand ] , optional) – Bias vector (default: None)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • in1 (Value | OpView | Operation)
    • transpose_a (bool)
    • transpose_b (bool)
  • Returns: Output tensor after linear transformation
  • Return type: *OpView*

log(in0, unit_attrs=None)

Creates ttir.log.

Elementwise natural logarithm operation.

Computes the natural logarithm of each element in the input tensor. For each element x, returns ln(x), where ln is the natural logarithm.

// Compute natural logarithm of all elements
%result = ttir.log(%input, %output) : tensor<3xf32>, tensor<3xf32> -> tensor<3xf32>
// Input tensor:
// [1.0, 2.71828, 7.38906]
// Output tensor:
// [0.0, 1.0, 2.0]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the natural logarithm of each element in the input tensor
  • Return type: *OpView*

log1p(in0, unit_attrs=None)

Elementwise natural logarithm of one plus input operation.

The log1p operation computes the natural logarithm of one plus each element in the input tensor. For each element x, it returns ln(1 + x). This operation is more accurate than computing log(1 + x) directly for x values close to zero, and it is defined for x > -1. For values less than or equal to -1, the behavior depends on the implementation (may return NaN or negative infinity).

// Compute log1p of all elements
%result = ttir.log1p(%input, %output) : tensor<4xf32>, tensor<4xf32> -> tensor<4xf32>
// Input tensor:
// [0.0, -0.999, 7.0, 6.38905621, 15.0]
// Output tensor:
// [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the log1p values of the input tensor
  • Return type: *OpView*

logical_and(in0, in1, unit_attrs=None)

Creates ttir.logical_and.

Elementwise logical AND operation.

Performs elementwise logical AND operation between two tensors. For each pair of corresponding elements, returns:

  • 1 (true) if both elements are 1 (true)
  • 0 (false) if at least one element is 0 (false)

This operation is idempotent, meaning logical_and(x, x) = x.

// Logical AND operation
%result = ttir.logical_and(%lhs, %rhs, %output) : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> -> tensor<4xi1>
// Input tensors:
// lhs: [1, 0, 1, 0]
// rhs: [1, 1, 0, 1]
// Output tensor:
// [1, 0, 0, 0]
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

logical_not(in0, unit_attrs=None)

Creates ttir.logical_not.

Elementwise logical NOT operation.

Computes the logical NOT of each element in the input tensor. For each element x, returns True if x is False, and False if x is True.

// Compute logical NOT of all elements
%result = ttir.logical_not(%input, %output) : tensor<3xi1>, tensor<3xi1> -> tensor<3xi1>
// Input tensor:
// [true, false, true]
// Output tensor:
// [false, true, false]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the logical NOT of each element in the input tensor
  • Return type: *OpView*

logical_or(in0, in1, unit_attrs=None)

Creates ttir.logical_or.

Elementwise logical OR operation.

Performs elementwise logical OR operation between two tensors. For each pair of corresponding elements, returns:

  • 1 (true) if at least one element is 1 (true)
  • 0 (false) if both elements are 0 (false)

This operation is idempotent, meaning logical_or(x, x) = x.

Mathematical definition: logical_or(x, y) = x || y

// Logical OR operation
%result = ttir.logical_or(%lhs, %rhs, %output) : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> -> tensor<4xi1>
// Input tensors:
// lhs: [1, 0, 1, 0]
// rhs: [1, 1, 0, 1]
// Output tensor:
// [1, 1, 1, 1]
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

logical_xor(in0, in1, unit_attrs=None)

Creates ttir.logical_xor.

Elementwise logical XOR operation.

Performs elementwise logical XOR (exclusive OR) operation between two tensors. For each pair of corresponding elements, returns:

  • 1 (true) if exactly one element is 1 (true)
  • 0 (false) if both elements are the same (both 0 or both 1)

Mathematical definition: logical_xor(x, y) = (x || y) && !(x && y)

// Logical XOR operation
%result = ttir.logical_xor(%lhs, %rhs, %output) : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> -> tensor<4xi1>
// Input tensors:
// lhs: [1, 0, 1, 0]
// rhs: [1, 1, 0, 1]
// Output tensor:
// [0, 1, 1, 1]
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

lt(in0, in1, unit_attrs=None)

Creates ttir.lt.

Elementwise less than comparison operation.

The lt operation performs an elementwise less than comparison between two tensors. For each pair of corresponding elements, it returns:

  • 1 (true) if the left element is less than the right element
  • 0 (false) if the left element is greater than or equal to the right element

Mathematical definition: less(x, y) = x < y

// Compare elements for less than
%result = ttir.lt(%lhs, %rhs, %output) : tensor<4xf32>, tensor<4xf32>, tensor<4xf32> -> tensor<4xf32>
// Input tensors:
// lhs: [1.0, 2.0, 3.0, 2.0]
// rhs: [1.0, 2.0, 4.0, 5.0]
// Output tensor: [0, 0, 1, 1]  # 1 where less, 0 where greater or equal
  • Parameters:
    • in0 (Operand) – First input tensor (left-hand side)
    • in1 (Operand) – Second input tensor (right-hand side)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A boolean tensor with 1s where left < right and 0s otherwise
  • Return type: *OpView*

matmul(in0, in1, bias=None, unit_attrs=None)

  • Parameters:
    • in0 (Value | OpView | Operation)
    • in1 (Value | OpView | Operation)
    • bias (Value | OpView | Operation | None)
    • unit_attrs (List *[*str ] | None)
  • Return type: OpView

max(in0, dim_arg=None, keep_dim=True, unit_attrs=None)

Creates ttir.max.

Maximum reduction operation.

Returns the maximum values along the specified dimension.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim_arg (int , optional) – Dimension to reduce over (default: None, reduces over all dimensions)
    • keep_dim (bool , optional) – If True, retains reduced dimensions with length 1 (default: True)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with maximum values
  • Return type: *OpView*

max_pool2d(in0, in1, kernel_height, kernel_width, stride_height, stride_width, dilation_height, dilation_width, ceil_mode, padding_left, padding_right, padding_top, padding_bottom, unit_attrs=None)

Creates ttir.max_pool2d.

Max pooling operation.

Applies a 2D max pooling over an input signal composed of several input planes.

  • Parameters:
    • in0 (Operand) – Input tensor
    • kernel_size (Union *[*int , List *[*int ] ]) – Size of the pooling window
    • stride (Optional *[*Union *[*int , List *[*int ] ] ] , optional) – Stride of the pooling window (default: None, same as kernel_size)
    • padding (Union *[*int , List *[*int ] ] , optional) – Padding added to all sides of input (default: 0)
    • dilation (Union *[*int , List *[*int ] ] , optional) – Controls spacing between kernel elements (default: 1)
    • ceil_mode (bool , optional) – When True, use ceil instead of floor for output shape (default: False)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • in1 (Value | OpView | Operation)
    • kernel_height (int)
    • kernel_width (int)
    • stride_height (int)
    • stride_width (int)
    • dilation_height (int)
    • dilation_width (int)
    • padding_left (int)
    • padding_right (int)
    • padding_top (int)
    • padding_bottom (int)
  • Returns: Output tensor after max pooling
  • Return type: *OpView*

maximum(in0, in1, unit_attrs=None)

Creates ttir.maximum.

Elementwise maximum operation.

Returns the element-wise maximum of two tensors.

  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with maximum values
  • Return type: *OpView*

mean(in0, dim_arg=[0], keep_dim=True, unit_attrs=None)

Creates ttir.mean.

Mean reduction operation.

Computes the mean of elements along specified dimensions of the input tensor. If dim_arg is not provided, the mean is computed over all dimensions.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim_arg (List *[*int ] , optional) – Dimensions to reduce over (default: [0])
    • keep_dim (bool , optional) – If True, retains reduced dimensions with length 1 (default: True)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with mean values
  • Return type: *OpView*

mesh_shard(input, shard_type, shard_direction, shard_shape, shard_dims)

Creates ttir.mesh_shard.

Shard a tensor across a device mesh.

Distributes a tensor across multiple devices in a mesh according to the specified sharding configuration. The sharding can be performed along one or more dimensions of the tensor.

// Shard a tensor across a 2x2 device mesh
%result = ttir.mesh_shard(%input) {
    shard_type = "block",
    shard_direction = "row",
    shard_shape = [2, 2],
    shard_dims = [0, 1]
} : tensor<128x128xf32> -> tensor<64x64xf32>
// Input tensor on single device:
// [[1.0, 2.0, ...],
//  [3.0, 4.0, ...]]
// Output tensor sharded across devices:
// Device 0: [[1.0, 2.0], [3.0, 4.0]]
// Device 1: [[1.1, 2.1], [3.1, 4.1]]
// Device 2: [[1.2, 2.2], [3.2, 4.2]]
// Device 3: [[1.3, 2.3], [3.3, 4.3]]
  • Parameters:
    • input (Operand) – Input tensor to be sharded
    • shard_type (str) – Type of sharding (e.g., “block”, “cyclic”)
    • shard_direction (str) – Direction of sharding (e.g., “row”, “col”)
    • shard_shape (Tuple *[*int , ... ]) – Shape of the device mesh
    • shard_dims (Tuple *[*int , ... ]) – Tensor dimensions to shard along
  • Return type: *OpView*

min(in0, dim_arg=None, keep_dim=True, unit_attrs=None)

Creates ttir.min.

Minimum reduction operation.

Returns the minimum values along the specified dimension.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim_arg (int , optional) – Dimension to reduce over (default: None, reduces over all dimensions)
    • keep_dim (bool , optional) – If True, retains reduced dimensions with length 1 (default: True)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with minimum values
  • Return type: *OpView*

minimum(in0, in1, unit_attrs=None)

Creates ttir.minimum.

Elementwise minimum operation.

Returns the element-wise minimum of two tensors. This operation is idempotent and partially broadcastable.

  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with minimum values
  • Return type: *OpView*

multiply(in0, in1, unit_attrs=None)

Creates ttir.multiply.

Elementwise multiplication operation.

Performs elementwise multiplication between two tensors. For each pair of corresponding elements, multiplies the element in the first tensor by the element in the second tensor.

Mathematical definition: multiply(x, y) = x * y

// Multiply corresponding elements
%result = ttir.multiply(%lhs, %rhs, %output) : tensor<3xf32>, tensor<3xf32>, tensor<3xf32> -> tensor<3xf32>
// Input tensors:
// lhs: [3.5, 0.0, -1.2]
// rhs: [1.5, 2.0, -3.2]
// Output tensor:
// [5.25, 0.0, 3.84]
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the elementwise product of the inputs
  • Return type: *OpView*

ne(in0, in1, unit_attrs=None)

Creates ttir.ne.

Elementwise inequality comparison operation.

Performs elementwise inequality comparison between two tensors. For each pair of corresponding elements, returns:

  • 1 (true) if the elements are not equal
  • 0 (false) if the elements are equal

Note: Special handling may be required for floating-point NaN values, as NaN is not equal to any value, including itself. This means ne(NaN, NaN) should return true.

Mathematical definition: not_equal(x, y) = x != y

// Compare elements for inequality
%result = ttir.ne(%lhs, %rhs, %output) : tensor<4xf32>, tensor<4xf32>, tensor<4xi1> -> tensor<4xi1>
// Input tensors:
// lhs: [1.0, 2.0, 3.0, 2.0]
// rhs: [1.0, 2.0, 4.0, 5.0]
// Output tensor:
// [0, 0, 1, 1]
  • Parameters:
    • in0 (Operand) – First input tensor
    • in1 (Operand) – Second input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

neg(in0, unit_attrs=None)

Creates ttir.neg.

Elementwise negate operation.

Computes the negation of each element in the input tensor. For each element, returns the negation of the value.

Mathematical definition: neg(x) = -x

// Compute negation of all elements
%result = ttir.neg(%input, %output) : tensor<4xf32>, tensor<4xf32> -> tensor<4xf32>
// Input tensor:
// [1.7, 2.0, -0.3, 4.5]
// Output tensor:
// [-1.7, -2.0, 0.3, -4.5]
  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the negation of each input element
  • Return type: *OpView*

ones(shape, unit_attrs=None)

Creates ttir.ones.

Creates a tensor filled with ones.

Returns a tensor of given shape filled with ones.

  • Parameters:
    • shape (Shape) – Shape of the output tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor of ones with specified shape
  • Return type: *OpView*

pad(in0, in1, padding, value, unit_attrs=None)

Creates ttir.pad.

Tensor padding operation.

Pads a tensor with a constant value. The padding amount is specified for each dimension and can be asymmetric (different padding at the start and end of each dimension).

  • Parameters:
    • in0 (Operand) – Input tensor to pad
    • in1 (Operand) – Output tensor
    • padding (List *[*int ]) – Amount of padding for each dimension
    • value (int) – Value to use for padding
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The padded tensor
  • Return type: *OpView*

permute(in0, in1, permutation, unit_attrs=None)

Creates ttir.permute.

Tensor permutation operation.

Permutes the dimensions of the input tensor according to the given permutation.

  • Parameters:
    • in0 (Operand) – Input tensor
    • permutation (List *[*int ]) – The desired ordering of dimensions
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • in1 (Value | OpView | Operation)
  • Returns: Tensor with permuted dimensions
  • Return type: *OpView*

pow(in0, in1, unit_attrs=None)

Creates ttir.pow.

Elementwise power operation.

Takes the first tensor to the power of the second tensor element-wise.

  • Parameters:
    • in0 (Operand) – First input tensor (base)
    • in1 (Operand) – Second input tensor (exponent)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with power values
  • Return type: *OpView*

prod(in0, dim_arg, keep_dim=False, unit_attrs=None)

Creates ttir.prod.

Product reduction operation.

Computes the product of elements along specified dimensions.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim_arg (List *[*int ]) – Dimensions to reduce over
    • keep_dim (bool , optional) – If True, retains reduced dimensions with length 1 (default: False)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with product values
  • Return type: *OpView*

quantize(in0, scale, zero_point, dtype, unit_attrs=None)

Creates ttir.quantize.

Quantize floating-point tensor to integer tensor.

Converts a floating-point tensor into a quantized integer tensor using the specified scale and zero_point parameters. For each element in the input tensor, computes: output[i] = (input[i] / scale) + zero_point

// Quantize float32 tensor to int8
%result = ttir.quantize(%input, %output) {scale = 0.1 : f32, zero_point = 128 : i32} : tensor<2x2xf32>, tensor<2x2xi8> -> tensor<2x2xi8>
// Input tensor:
// [[1.5, -0.2],
//  [0.0, 3.7]]
// Output tensor:
// [[143, 126],
//  [128, 165]]
  • Parameters:
    • in0 (Operand) – Input floating-point tensor to be quantized
    • scale (float) – Scale factor for quantization (each integer step represents this value)
    • zero_point (int) – Integer value that represents 0.0 in the quantized space
    • dtype (torch.dtype) – Target integer data type for quantization (e.g., torch.int8)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The quantized integer tensor
  • Return type: *OpView*

reciprocal(in0, unit_attrs=None)

Creates ttir.reciprocal.

Elementwise reciprocal operation.

Computes the reciprocal (1/x) of each element in the input tensor. This operation is involutive (applying it twice returns to the original value).

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with reciprocal values
  • Return type: *OpView*

reduce_scatter(input, reduce_type, scatter_dim, cluster_axis)

Creates ttir.reduce_scatter.

Reduce scatter operation.

Reduce scatter op.

  • Parameters:
    • input (Operand) – Input tensor to be reduced and scattered
    • reduce_type (str) – Type of reduction operation (e.g., “sum”, “max”)
    • scatter_dim (int) – Dimension along which to scatter the reduced results
    • cluster_axis (int) – Axis of device mesh for reduction (0 or 1)
  • Return type: *OpView*

relu(in0, unit_attrs=None)

Creates ttir.relu.

Elementwise ReLU activation operation.

Computes the Rectified Linear Unit function for each element in the input tensor. This operation is idempotent (applying it multiple times has the same effect as applying it once).

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with ReLU activation values
  • Return type: *OpView*

remainder(in0, in1, unit_attrs=None)

Creates ttir.remainder.

Elementwise remainder operation.

Computes the element-wise remainder of division (modulo operation).

  • Parameters:
    • in0 (Operand) – First input tensor (dividend)
    • in1 (Operand) – Second input tensor (divisor)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with remainder values
  • Return type: *OpView*

repeat(in0, dims, unit_attrs=None)

Creates ttir.repeat.

Tensor repeat operation.

Repeats the tensor along each dimension the number of times given by dims.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dims (List *[*int ]) – Number of repetitions for each dimension
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with repeated elements
  • Return type: *OpView*

repeat_interleave(in0, in1, repeats, dim, unit_attrs=None)

Creates ttir.repeat_interleave.

Tensor repeat interleave operation.

Repeats elements of a tensor along a dimension by interleaving the repeated elements.

  • Parameters:
    • in0 (Operand) – Input tensor
    • repeats (int) – Number of repetitions for each element
    • dim (int) – Dimension along which to repeat
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • in1 (Value | OpView | Operation)
  • Returns: Tensor with interleaved repeated elements
  • Return type: *OpView*

requantize(in0, scale, zero_point, dtype, unit_attrs=None)

Creates ttir.requantize.

Requantize integer tensor to new scale and zero-point.

Converts a quantized integer tensor from one quantization scheme to another using new scale and zero-point parameters. For each element in the input tensor, computes: output[i] = round((input[i] - input_zero_point) * (input_scale / output_scale)) + output_zero_point

// Requantize int8 tensor to new scale and zero-point
%result = ttir.requantize(%input, %output) {scale = 0.2 : f32, zero_point = 100 : i32} : tensor<2x2xi8>, tensor<2x2xi8> -> tensor<2x2xi8>
// Input tensor (scale=0.1, zero_point=128):
// [[143, 126],
//  [128, 165]]
// Output tensor (scale=0.2, zero_point=100):
// [[107, 98],
//  [100, 119]]
  • Parameters:
    • in0 (Operand) – Input quantized integer tensor to be requantized
    • scale (float) – New scale factor for requantization
    • zero_point (int) – New integer value that represents 0.0 in the quantized space
    • dtype (torch.dtype) – Target integer data type (e.g., torch.int8)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The requantized integer tensor with new scale and zero-point
  • Return type: *OpView*

reshape(in0, shape, unit_attrs=None)

Creates ttir.reshape.

Tensor reshape operation.

The reshape operation changes the shape of a tensor without changing the data or number of elements. The total number of elements in the tensor must remain the same after reshaping. This operation is commonly used in neural networks to change the dimensionality of tensors between layers.

// Reshape a 2x3 tensor to a 1x6 tensor
%input = ... : tensor<2x3xf32>  // Input tensor with shape [2,3]
%output = ttir.empty() : tensor<1x6xf32>  // Output tensor with shape [1,6]
%result = ttir.reshape(%input, %output) {shape = [1, 6]} :
    tensor<2x3xf32>, tensor<1x6xf32> -> tensor<1x6xf32>
  • Parameters:
    • in0 (Operand) – Input tensor to reshape
    • shape (Shape) – The new shape for the tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The reshaped tensor
  • Return type: *OpView*

reverse(in0, dims, unit_attrs=None)

Creates ttir.reverse.

Tensor reverse operation.

Reverses the order of elements along specified dimensions. The input and output shapes must match.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dims (List *[*int ]) – Dimensions to reverse
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with reversed elements
  • Return type: *OpView*

rsqrt(in0, unit_attrs=None)

Creates ttir.rsqrt.

Elementwise reciprocal square root operation.

Computes the reciprocal of the square root (1/√x) of each element in the input tensor.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with reciprocal square root values
  • Return type: *OpView*

select(in0, dim=0, begin=0, length=2, stride=2, unit_attrs=None)

Creates ttir.select.

Tensor selection operation.

Selects a slice of the input tensor along the specified dimension with given stride.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim (int , optional) – Dimension to select from (default: 0)
    • begin (int , optional) – Starting index (default: 0)
    • length (int , optional) – Length of the slice (default: 2)
    • stride (int , optional) – Stride between elements (default: 2)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The selected slice of the tensor
  • Return type: *OpView*

sigmoid(in0, unit_attrs=None)

Creates ttir.sigmoid.

Elementwise sigmoid activation operation.

Computes the sigmoid function (1/(1 + e^-x)) for each element in the input tensor.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with sigmoid activation values
  • Return type: *OpView*

sign(in0, unit_attrs=None)

Creates ttir.sign.

Elementwise sign operation.

Returns the sign (-1, 0, or 1) of each element in the input tensor. This operation is idempotent.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with sign values
  • Return type: *OpView*

sin(in0, unit_attrs=None)

Creates ttir.sin.

Elementwise sine operation.

Computes the sine of each element in the input tensor.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with sine values
  • Return type: *OpView*

softmax(in0, dimension=1, unit_attrs=None)

Creates ttir.softmax.

Softmax operation.

Applies the Softmax function to an n-dimensional input tensor rescaling them so that the elements lie in the range [0,1] and sum to 1.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim (int , optional) – Dimension along which Softmax will be computed (default: -1)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • dimension (int)
  • Returns: Output tensor after softmax
  • Return type: *OpView*

sqrt(in0, unit_attrs=None)

Creates ttir.sqrt.

Elementwise square root operation.

Computes the square root of each element in the input tensor.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with square root values
  • Return type: *OpView*

squeeze(in0, dim=0, unit_attrs=None)

Creates ttir.squeeze.

Tensor squeeze operation.

Removes dimensions of size 1 from the shape of a tensor. If dim is specified, only squeezes the dimension if it has size 1.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim (Optional *[*int ] , optional) – Dimension to squeeze (default: 0)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with specified dimensions of size 1 removed
  • Return type: *OpView*

subtract(in0, in1, unit_attrs=None)

Creates ttir.subtract.

Elementwise subtraction operation.

Performs elementwise subtraction between two tensors. For each pair of corresponding elements, subtracts the element in the second tensor from the element in the first tensor.

Mathematical definition: subtract(x, y) = x - y

// Subtract corresponding elements
%result = ttir.subtract(%lhs, %rhs, %output) : tensor<3xf32>, tensor<3xf32>, tensor<3xf32> -> tensor<3xf32>
// Input tensors:
// lhs: [3.5, 0.0, -1.2]
// rhs: [1.5, 2.0, -3.2]
// Output tensor:
// [2.0, -2.0, 2.0]
  • Parameters:
    • in0 (Operand) – First input tensor (minuend)
    • in1 (Operand) – Second input tensor (subtrahend)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the elementwise difference of the inputs
  • Return type: *OpView*

sum(in0, dim_arg=[0], keep_dim=True, unit_attrs=None)

Creates ttir.sum.

Sum reduction operation.

The sum operation computes the sum of elements along specified dimensions of the input tensor. If dim_arg is not provided, the sum is computed over all dimensions. If keep_dim is True, the reduced dimensions are retained with a size of 1.

// Sum along dimension 1
%input = ... : tensor<2x3xf32>
%output = ttir.empty() : tensor<2xf32>
%result = ttir.sum(%input, %output) {keep_dim = false, dim_arg = [1: i32]} : tensor<2x3xf32>, tensor<2xf32> -> tensor<2xf32>
// Input: [[1.0, 2.0, 3.0],
//         [4.0, 5.0, 6.0]]
// Output: [6.0, 15.0]  // Sum of each row
  • Parameters:
    • in0 (Operand) – Input tensor
    • dim_arg (List *[*int ] , optional) – Dimensions to reduce over (default: [0])
    • keep_dim (bool , optional) – If True, retains reduced dimensions with length 1 (default: True)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with summed values
  • Return type: *OpView*

tan(in0, unit_attrs=None)

Creates ttir.tan.

Elementwise tangent operation.

Computes the tangent of each element in the input tensor.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with tangent values
  • Return type: *OpView*

tanh(in0, unit_attrs=None)

Creates ttir.tanh.

Elementwise hyperbolic tangent operation.

Computes the hyperbolic tangent of each element in the input tensor.

  • Parameters:
    • in0 (Operand) – Input tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with hyperbolic tangent values
  • Return type: *OpView*

tilize(in0, output_type, unit_attrs=None)

Creates ttir.tilize.

Convert tensor to tiled layout.

Transforms a tensor into a tiled layout format, where data is organized into regular blocks or tiles. This can improve memory access patterns and cache utilization for certain operations.

// Convert tensor to tiled layout
%result = ttir.tilize(%input) : tensor<128x128xf32> -> tensor<128x128xf32, #tiled<32x32>>
// Input tensor (standard layout):
// [[1.5, 2.0, ...],
//  [3.0, 4.0, ...]]
// Output tensor (tiled 32x32 layout):
// Same values but organized in 32x32 tiles
  • Parameters:
    • in0 (Operand) – Input tensor to be tiled
    • output_type (RankedTensorType) – Target type specifying the desired tiled layout
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The tensor with tiled layout
  • Return type: *OpView*

to_layout(in0, output_type, unit_attrs=None, **kwargs)

Creates ttir.to_layout.

Layout operation.

ToLayout operation, transition tensors from one layout to another. Some examples include:

  • Transitioning between different memory spaces, e.g. DRAM to L1.
  • Transitioning between different data types, e.g. f32 to f16.
  • Transitioning between different tile sizes, e.g. 1x16 to 32x32
  • Transitioning between different tensor sharding
  • Some combination of the above
#layout = #ttcore.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>>
#layout1 = #ttcore.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>>
%1 = "ttir.to_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>
  • Parameters:
    • in0 (Operand) – Input tensor to be transformed
    • output_type (RankedTensorType) – Target type specifying the desired layout
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • **kwargs (dict) – Additional keyword arguments for layout transformation
  • Returns: The tensor with transformed layout
  • Return type: *OpView*

transpose(in0, dim0=0, dim1=1, unit_attrs=None)

Creates ttir.transpose.

Tensor transpose operation.

Swaps two dimensions of a tensor.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim0 (int , optional) – First dimension to swap (default: 0)
    • dim1 (int , optional) – Second dimension to swap (default: 1)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with swapped dimensions
  • Return type: *OpView*

typecast(in0, out, unit_attrs=None)

Creates ttir.typecast.

Elementwise type casting operation.

Casts each element in the input tensor to the type of the output tensor. The output type can be any supported tensor element type.

// Cast float32 to int32
%result = ttir.typecast(%input, %output) : tensor<2x2xf32>, tensor<2x2xi32> -> tensor<2x2xi32>
// Input tensor:
// [[1.7, 2.3],
//  [3.8, 4.1]]
// Output tensor:
// [[1, 2],
//  [3, 4]]
  • Parameters:
    • in0 (Operand) – Input tensor to cast
    • out (Operand) – Output tensor with desired type
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A tensor containing the input values cast to the output type
  • Return type: *OpView*

unsqueeze(in0, dim=0, unit_attrs=None)

Creates ttir.unsqueeze.

Tensor unsqueeze operation.

Adds a dimension of size 1 at the specified position.

  • Parameters:
    • in0 (Operand) – Input tensor
    • dim (Optional *[*int ] , optional) – Position to insert the new dimension (default: 0)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: Tensor with a new dimension of size 1 inserted
  • Return type: *OpView*

untilize(in0, output_type, unit_attrs=None)

Creates ttir.untilize.

Convert tensor from tiled to standard layout.

Transforms a tensor from a tiled layout back to a standard row-major or column-major layout. This is the inverse operation of tilize.

// Convert tensor from tiled to standard layout
%result = ttir.untilize(%input) : tensor<128x128xf32, #tiled<32x32>> -> tensor<128x128xf32>
// Input tensor (tiled 32x32 layout):
// Data organized in 32x32 tiles
// Output tensor (standard layout):
// [[1.5, 2.0, ...],
//  [3.0, 4.0, ...]]
  • Parameters:
    • in0 (Operand) – Input tensor with tiled layout
    • output_type (RankedTensorType) – Target type specifying the desired standard layout
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The tensor with standard layout
  • Return type: *OpView*

update_cache(in0, in1, in2, batch_offset=0, unit_attrs=None)

Creates ttir.update_cache.

Cache update operation.

Updates a cache tensor by combining new values with existing cache values, starting at a specified batch offset. This operation is typically used in sequence models to maintain and update cached states.

// Update cache with new values at batch offset 1
%result = ttir.update_cache(%new_values, %old_cache, %mask, batch_offset = 1)                 : tensor<2x3xf32>, tensor<4x3xf32>, tensor<2xi1> -> tensor<4x3xf32>
// New values tensor:
// [[1.0, 2.0, 3.0],
//  [4.0, 5.0, 6.0]]
// Old cache tensor:
// [[0.1, 0.2, 0.3],
//  [0.4, 0.5, 0.6],
//  [0.7, 0.8, 0.9],
//  [1.0, 1.1, 1.2]]
// Mask tensor:
// [true, false]  // Only update first new value
// Output tensor:
// [[0.1, 0.2, 0.3],
//  [1.0, 2.0, 3.0],  // Updated with first new value
//  [0.7, 0.8, 0.9],  // Kept old value due to mask
//  [1.0, 1.1, 1.2]]
  • Parameters:
    • in0 (Operand) – New values to update cache with
    • in1 (Operand) – Cache tensor to be updated
    • in2 (Operand) – Mask tensor indicating which values to update
    • batch_offset (int , optional) – Starting position in batch dimension (default: 0)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: The updated cache tensor
  • Return type: *OpView*

upsample2d(in0, in1, scale_factor, mode='nearest', unit_attrs=None)

  • Parameters:
    • in0 (Value | OpView | Operation)
    • in1 (Value | OpView | Operation)
    • scale_factor (int | List *[*int ])
    • mode (str)
    • unit_attrs (List *[*str ] | None)
  • Return type: OpView

view_layout(in0, output_type, reinterpret_layout=False, unit_attrs=None)

Creates ttir.view_layout.

Create a new view of tensor with different layout.

Creates a new view of the input tensor with a different layout without copying or moving data. This is useful for reinterpreting the same data with different layout metadata.

  • If reinterpretLayout is true, the layout view change can include a data type cast, but note this does not actually change the format of the data in memory.
  • All ViewLayout ops can trivially be converted to ToLayout ops.
#layout = #ttcore.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>>
#layout1 = #ttcore.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>>
%1 = "ttir.view_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>
  • Parameters:
    • in0 (Operand) – Input tensor to create new view from
    • output_type (RankedTensorType) – Type of output tensor with desired layout
    • reinterpret_layout (bool , optional) – If true, allows data type cast in layout view change (default: False)
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Returns: A new view of the tensor with the specified layout
  • Return type: *OpView*

where(in0, in1, in2, unit_attrs=None)

Creates ttir.where.

Elementwise conditional selection operation.

For each element position, selects between two values based on a boolean condition:

  • If the condition is true (non-zero), selects from the first value tensor
  • If the condition is false (zero), selects from the second value tensor

Supports broadcasting according to standard broadcasting rules.

// Basic selection between two tensors
%result = ttir.where(%cond, %true_vals, %false_vals) :
    tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32> -> tensor<2x2xf32>
// Input tensors:
// %cond: [[1, 0], [0, 1]]
// %true_vals: [[1.0, 2.0], [3.0, 4.0]]
// %false_vals: [[5.0, 6.0], [7.0, 8.0]]
// Output tensor:
// [[1.0, 6.0], [7.0, 4.0]]

// With broadcasting (scalar condition)
%result = ttir.where(%scalar_cond, %true_vals, %false_vals) :
    tensor<i1>, tensor<2x2xf32>, tensor<2x2xf32> -> tensor<2x2xf32>
  • Parameters:
    • in0 (Operand) – Condition tensor (predicate)
    • in1 (Operand) – Tensor containing values to select when condition is true
    • in2 (Operand) – Tensor containing values to select when condition is false
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
  • Return type: *OpView*

zeros(shape, data_type=None, unit_attrs=None)

Creates ttir.zeros.

Creates a tensor filled with zeros.

Returns a tensor of given shape filled with zeros.

  • Parameters:
    • shape (Shape) – Shape of the output tensor
    • unit_attrs (Optional[List[str]], optional) – Optional list of unit attributes
    • data_type (Type | None)
  • Returns: Tensor of zeros with specified shape
  • Return type: *OpView*