TT-LIB
Overview
The tt_lib
Python module is a
unified Python interface to the Tensor library located within tt_eager
. This library currently only supports 4 dimensional tensors with shape [W, Z, Y, X]
, in ROW_MAJOR layout, and with BFLOAT16 data type.
Some OPs in this library might change layout of input tensors and pad them to better match expectations of execution kernels on TT Accelerator device. These OPs will unpad the result tensor before it is returned to caller.
There is a limitation that tensor in ROW_MAJOR layout on TT Accelerator device must have the size of last dimension X
be divisible by 2.
You can’t create these type of tensors on TT Accelerator device or send them to TT Accelerator device with `ttnn.Tensor.to()
.
However, you can supply these type of tensors to OPs from TT-LIB library as they can automatically pad the last dimension before moving the tensor
to TT Accelerator device. To use this functionality, you must call ttnn.SetDefaultDevice(tt_device) to set your TT Accelerator device
as the default device that will be used to execute operations on tensors that are on host machine.
Operation Infrastructure
TT-LIB has operation infrastructure which is used to launch, profile and cache operations generically.
To add a new operation that can plug in to the infrastructure, all that’s needed is a struct that implements methods needed by operation interface. Below, is an example of how to declare a new on-device operation with all of the methods required by the interface.
New Device Operation
struct <NewOperation> {
void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
};
New Device Operation with a member
struct <NewOperation> {
int some_member
void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
};
New Device Operation with Optional Input Tensors
struct <NewOperation> {
void validate(const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
std::vector<Tensor> &output_tensors) const;
};
New Device Operation with Optional Output Tensors
If an operation is expected to leverage optional output tensors, please use instead the validate_with_output_tensors and create_output_tensors with the additional parameter for the output_tensors.
struct <NewOperation> {
void validate_with_output_tensors(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_spec(const std::vector<Tensor> &input_tensors) const;
std::vector<std::optional<Tensor>> create_output_tensors(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
operation::ProgramWithOptionalOutputTensors create_program(const std::vector<Tensor>& input_tensors, std::vector<std::optional<Tensor>> &output_tensors) const;
};
Profiler
Profiler is supported out of the box for any op.
And there are 2 special methods that can be optionally implemented to set the preferred_name and parallelization_strategy.
// Implement `get_parallelization_strategy` to set the parallelization strategy on the profiler
struct <NewOperation> {
<ParallelizationStrategyEnum> get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const;
};
Fast Dispatch
Fast dispatch allows programs/kernels to be enqueued to run, so host code does not have to wait for ops/programs to finish running. The enqueued programs run asynchronously to the host code. To wait for kernels to complete, either read a tensor from device to host with tensor.cpu:
- ttnn.Tensor.cpu(self: ttnn._ttnn.tensor.Tensor, blocking: bool = True, cq_id: ttnn._ttnn.types.QueueId = QueueId(0)) ttnn._ttnn.tensor.Tensor
-
Move TT Tensor from TT accelerator device to host device.
tt_tensor = tt_tensor.cpu()
or to perform only a wait, use:
Program Caching
Program caching provides an ability for an operation to cache the program and simply reload it the next time the same operation is used.
It can be enabled by running:
tt::tt_metal::program_cache::enable()
And it can be disabled by running:
tt::tt_metal::program_cache::disable_and_clear()
Number of entries can be queried using:
tt::tt_metal::program_cache::num_entries()
In order for an op to be cachable, it needs to implement the following:
struct <NewOperation> {
// Mandatory methods
// Return type of `create_program` needs to implement override_runtime_args_callback
// i.e.:
operation::ProgramWithCallbacks create_program(const std::vector<Tensor> &input_tensors) const {
Program program{};
// ...
auto override_runtime_args_callback = [unary_reader_kernel_id, unary_writer_kernel_id](
const Program &program,
const std::vector<Buffer*>& input_buffers,
const std::vector<Buffer*>& output_buffers
) {
auto src_dram_buffer = input_buffers.at(0);
auto dst_dram_buffer = output_buffers.at(0);
CoreCoord core = {0, 0};
{
auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
runtime_args[0] = src_dram_buffer->address();
}
{
auto &runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core);
runtime_args[0] = dst_dram_buffer->address();
}
};
return {std::move(program), override_runtime_args_callback};
}
};
Logs
To see logs related to operation infrastructure, use the following environment variables:
export TT_METAL_LOGGER_TYPES=Op
export TT_METAL_LOGGER_LEVEL=Debug
The logs will print currently running op and information related to program caching. i.e.:
Op | DEBUG | Operation Type: silu (fallback operation)
Op | DEBUG | Operation Attributes: ()
Op | DEBUG | Input Tensors: {tt::tt_metal::Tensor(storage=tt::tt_metal::DeviceStorage(memory_config=tt::tt_metal::MemoryConfig(memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED, buffer_type=tt::tt_metal::BufferType::DRAM)), shape={1, 1, 1, 1280}, dtype=tt::tt_metal::DataType::BFLOAT16, layout=tt::tt_metal::Layout::ROW_MAJOR)}
Op | DEBUG | Operation Type: tt::tt_metal::LayoutConversionOnHost
Op | DEBUG | Operation Attributes: (target_layout=tt::tt_metal::Layout::TILE)
Op | DEBUG | Input Tensors: {tt::tt_metal::Tensor(storage=tt::tt_metal::OwnedStorage(), shape={1, 1, 320, 1280}, dtype=tt::tt_metal::DataType::BFLOAT16, layout=tt::tt_metal::Layout::ROW_MAJOR)}
...
Op | DEBUG | Program Cache: MISS - Compiling new program "tt::tt_metal::EltwiseUnary(op_type=tt::tt_metal::UnaryOpType::Enum::GELU, param=1)_tt::tt_metal::Tensor(storage=tt::tt_metal::DeviceStorage(memory_config=tt::tt_metal::MemoryConfig(memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED, buffer_type=tt::tt_metal::BufferType::DRAM)), shape={1, 1, 32, 32}, dtype=tt::tt_metal::DataType::BFLOAT16, layout=tt::tt_metal::Layout::TILE)"
Op | DEBUG | Operation Name: tt::tt_metal::EltwiseUnary
Op | DEBUG | Operation Attributes: (op_type=tt::tt_metal::UnaryOpType::Enum::GELU, param=0)
Op | DEBUG | Input Tensors: {tt::tt_metal::Tensor(storage=tt::tt_metal::DeviceStorage(memory_config=tt::tt_metal::MemoryConfig(memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED, buffer_type=tt::tt_metal::BufferType::DRAM)), shape={1, 1, 32, 32}, dtype=tt::tt_metal::DataType::BFLOAT16, layout=tt::tt_metal::Layout::TILE)}
TT-LIB API through tt_lib
Primary Operations
- ttnn.operations.moreh.softmax()
-
Moreh Softmax Operation
- ttnn.operations.moreh.softmax_backward()
-
Moreh Softmax Backward Operation
- ttnn.operations.moreh.softmin()
-
Moreh Softmin Operation
- ttnn.operations.moreh.softmin_backward()
-
Moreh Softmin Backward Operation
- ttnn.operations.moreh.logsoftmax()
-
Moreh LogSoftmax Operation
- ttnn.operations.moreh.logsoftmax_backward()
-
Moreh LogSoftmax Backward Operation
- ttnn.operations.moreh.mean()
-
Moreh Mean Operation
- ttnn.operations.moreh.mean_backward()
-
Moreh Mean Backward Operation
- ttnn.operations.moreh.group_norm()
-
Moreh Group Norm Operation
- ttnn.operations.moreh.group_norm_backward()
-
Moreh Group Norm Backward Operation
- ttnn.operations.moreh.norm()
-
Moreh Norm Operation
- ttnn.operations.moreh.norm_backward()
-
Moreh Norm Backward Operation
Enums
- class ttnn.BcastOpMath
-
Members:
ADD
SUB
MUL
- class ttnn.BcastOpDim
-
Members:
H
W
HW
Fallback Operations
These operations are currently not supported on TT accelerator device and will execute on host machine using Pytorch.
- tt_lib.fallback_ops.full(size: List[int], fill_value: float) Tensor
-
Creates a
ttnn.Tensor
of shapesize
filled withfill_value
value.Argument
Description
Data type
Valid range
Required
size
Shape of output tensor
List[int]
list of 4 ints
Yes
fill_value
Value with which to fill output tensor
float
Yes
- tt_lib.fallback_ops.tensor_slice(input: Tensor, slices: List[slice | ellipsis]) Tensor
-
Creates a
ttnn.Tensor
frominput
usingslices
. To use...
, pass in...
orEllipsis
. To use:
, pass inslice(None)
.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
slices
List of slices to slice the input tensor
List[slice, Ellipsis]
Yes
- tt_lib.fallback_ops.reshape(input: ~ttnn._ttnn.tensor.Tensor, N: int, C: int, H: int, W: int, output_layout: ~ttnn._ttnn.tensor.Layout | None = <Layout.TILE: 1>, output_on_device: bool | None = True) Tensor
-
Returns a new
ttnn.Tensor
with the same data and number of elements asinput
, but with the specified shape[N, C, H, W]
.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
N
Size of the first dimension of output tensor
int
Yes
C
Size of the second dimension of output tensor
int
Yes
H
Size of the third dimension of output tensor
int
Yes
W
Size of the fourth dimension of output tensor
int
Yes
output_layout
Output layout
Layout
default is TILE
No
output_on_device
Output on device
bool
default is True
No
- tt_lib.fallback_ops.chunk(input: Tensor, chunks: int, dim: int = 0) List[Tensor]
-
Attempts to split a
ttnn.Tensor
into the specified number of chunks. Each chunk is a new copy of part of the input tensor.If the tensor size along the given dimension
dim
is divisible bychunks
, all returned chunks will be the same size.If the tensor size along the given dimension
dim
is not divisible bychunks
, all returned chunks will be the same size, except the last one. If such division is not possible, this function may return fewer than the specified number of chunks.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
chunks
Number of chunks to return
int
Yes
dim
Dimension along which to split the tensor
int
0, 1, 2, 3 (default is 0)
No
- tt_lib.fallback_ops.conv2d(input: Tensor, weight: Tensor, bias: Tensor | None = None, stride: int | Tuple = 1, padding: int | str | Tuple = 0, dilation: int | Tuple = 1, groups: int = 1) Tensor
-
Applies a 2D convolution over an input image composed of several input planes.
Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
weight
Weights tensor
Tensor
Yes
bias
Bias tensor
Tensor
No
strides
Stride of the convolution
int or tuple[int] (size 2)
positive ints (default value is 1)
No
padding
Padding added to all four sides of the input
int or tuple[int] (size 2)
or string
positive ints (default value is 0)
for string valid or same
No
dilation
Spacing between kernel elements
int or (int, int)
positive ints (default value is 1)
No
groups
Number of blocked connections from input channels to output channels
int
positive ints (default value is 1)
No
- tt_lib.fallback_ops.group_norm(input: Tensor, num_groups: int, weight: Tensor | None = None, bias: Tensor | None = None, eps: float = 1e-05) Tensor
-
Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization.
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
num_groups
Number of groups to separate the input channels into
int
int, such that number of channels in input is divisble by it
Yes
weight
Weight tensor \(\gamma\)
Tensor
No
bias
Bias tensor \(\beta\)
Tensor
No
eps
A value added to the denominator for numerical stability
float
default value is 1e-05
No
- tt_lib.fallback_ops.layer_norm(input: Tensor, normalized_shape: int | List[int], weight: Tensor | None = None, bias: Tensor | None = None, eps: float = 1e-05) Tensor
-
Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer Normalization
\[y = \frac{x - \text{E}[x]}{\sqrt{\text{Var}[x] + \epsilon}} * \gamma + \beta\]Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
normalized_shape
Shape over which to normalize
int or List[int]
Yes
weight
Weight tensor \(\gamma\)
Tensor
No
bias
Bias tensor \(\beta\)
Tensor
No
eps
A value added to the denominator for numerical stability
float
default value is 1e-05
No
- tt_lib.fallback_ops.pad(input: ~ttnn._ttnn.tensor.Tensor, pad: ~typing.Tuple[int], mode: str = 'constant', value: int | None = None, output_layout: ~ttnn._ttnn.tensor.Layout | None = <Layout.TILE: 1>, output_on_device: bool | None = True) Tensor
-
Pads tensor.
pad
determines how much padding to add.Values in
pad
specify padding starting from the last dimension of input tensorinput
and moving forward.pad
is and m-elements tuple, where m/2 is less of equal to input dimensions and m is even.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
pad
The padding size by which to pad some dimensions of input
Tuple[int]
Yes
mode
Padding mode
string
constant, reflect, replicate, or circular (default is constant)
No
value
Fill value for constant padding
int
default is 0
No
output_layout
Output layout
Layout
default is TILE
No
output_on_device
Output on device
bool
default is True
No
- tt_lib.fallback_ops.interpolate(input: Tensor, size: int | Tuple[int] | None = None, scale_factor: float | Tuple[float] | None = None, mode: str = 'nearest', align_corners: bool | None = None, recompute_scale_factor: bool | None = None, antialias: bool = False) Tensor
-
Down/up samples the input to either the given size or the given scale_factor
The algorithm used for interpolation is determined by mode.
Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
size
Output spatial size
Tuple[int]
default is None
No
scale_factor
Multiplier for spatial size
Tuple[float]
default is None
No
mode
algorithm used for upsampling
string
nearest, linear, bilinear, bicubic, trilinear, area, or nearest-exact (default is nearest)
No
align_corners
Whether to align center or corner points of corner pixels
bool
default is None
No
recompute_scale_factor
Recompute the scale_factor for use in interpolation
bool
default is None
No
antialias
Flag to apply anti-aliasing
bool
default is False
No
- tt_lib.fallback_ops.repeat(input: Tensor, sizes: List[int]) Tensor
-
Returns the input tensor
input
repeated along the specified dims.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
sizes
The number of times to repeat the tensor along each dim
int
Yes
- tt_lib.fallback_ops.repeat_interleave(input: Tensor, repeats: Tensor | int, dim: int | None = None, *, output_size: int | None = None) Tensor
-
Returns a tensor with repeated elements of input tensor
input
.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
repeats
The number of repetitions for each element
Tensor or int
Yes
dim
The dimension along which to repeat values
int
No
output_size
Total output size for the given axis ( e.g. sum of repeats)
int
No
- tt_lib.fallback_ops.concat(tensors: List[Tensor], dim: int = 0) Tensor
-
Concatenates input tensors in list
tensors
on provided dimensiondim
.All tensors must either have the same shape (except in the concatenating dimension) or be empty.
Argument
Description
Data type
Valid range
Required
tensors
Input tensors
List[Tensor]
Yes
dim
The dimension along which to concatenate
int
0, 1, 2, or 3 (default is 0)
No
- tt_lib.fallback_ops.silu(input: Tensor) Tensor
-
Applies the Sigmoid Linear Unit (SiLU) function, element-wise. The SiLU function is also known as the swish function.
\[\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}\]Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
- tt_lib.fallback_ops.softmax(input: Tensor, dim: int | None = None) Tensor
-
Applies a softmax function to input tensor
input
.Softmax is defined as:
\[\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]It is applied to all slices along dim, and will re-scale them so that the elements lie in the range [0, 1] and sum to 1.
Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
dim
A dimension along which Softmax will be computed
int
0, 1, 2, or 3
No
- class tt_lib.fallback_ops.Conv2d(weights: Tensor, biases: Tensor | None, in_channels: int, out_channels: int, kernel_size: int | Tuple, stride: int | Tuple = 1, padding: int | str | Tuple = 0, dilation: int | Tuple = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros')
-
Applies a 2D convolution over an input signal composed of several input planes.
In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, H, W)\) and output \((N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})\) can be precisely described as:
\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]where \(\star\) is a valid 2D cross-correlation operator, \(N\) is batch size, \(C\) denotes the number of channels, \(H\) is a height of input planes in pixels, and \(W\) is width in pixels.
Argument
Description
Data type
Valid range
Required
weights
Weights tensor
Tensor
Yes
biases
Bias tensor
Tensor
Yes
in_channels
Number of channels in the input image
int
Yes
out_channels
Number of channels produced by the convolution
int
Yes
kernel_size
Size of the convolving kernel
int or Tuple[int]
Yes
stride
Stride of the convolution
int or Tuple[int]
default is 1
No
padding
Padding added to all four sides of the input
int or Tuple[int]
or string
default is 0
‘valid’ or ‘same’
No
dilation
Spacing between kernel elements
int or Tuple[int]
default is 1
No
groups
Number of blocked connections from input channels to output channels
int
default is 1
No
bias
If True, adds a learnable bias to the output
bool
default is True
No
padding_mode
Padding mode
string
zeros, reflect, replicate, or circular
default is zeros
No
- class tt_lib.fallback_ops.BatchNorm2d(weights: Tensor, biases: Tensor, running_mean: Tensor, running_var: Tensor, num_batches_tracked: Tensor, num_features: int, eps: float = 1e-05, momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True)
-
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]Argument
Description
Data type
Valid range
Required
weights
Weights tensor
Tensor
Yes
biases
Bias tensor
Tensor
Yes
running_mean
Tracked Running Mean tensor
Tensor
Yes
running_var
Tracked Running Variances tensor
Tensor
Yes
num_batches_tracked
Number of Batches Tracked tensor
Tensor
Yes
num_features
C from an expected input of size (N, C, H, W)
int
Yes
eps
A value added to the denominator for numerical stability
float
default is 1e-05
No
momentum
The value used for the running_mean and running_var computation.
float/None
default is 0.1
No
affine
Controls initialization of weights and biases
bool
default is True
No
track_running_stats
Whether to track the running mean and variance
bool
default is True
No
- class tt_lib.fallback_ops.GroupNorm(weights: Tensor, biases: Tensor, num_groups: int, num_channels: int, eps: float = 1e-05, affine: bool = True)
-
Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]affine
is a boolean value that when set to True, this module has lernable per-channel affine parameters initialized to ones (for weights) and zeros (for biases).Argument
Description
Data type
Valid range
Required
weights
Weights tensor
Tensor
Yes
biases
Bias tensor
Tensor
Yes
num_groups
Number of groups to separate the channels into
int
Yes
num_channels
Number of channels expected in input
int
Yes
eps
A value added to the denominator for numerical stability
float
default is 1e-05
No
affine
Controls initialization of weights and biases
bool
default is True
No
- class tt_lib.fallback_ops.LayerNorm(weights: Tensor, biases: Tensor, normalized_shape: int | List[int], eps: float = 1e-05, elementwise_affine: bool = True)
-
Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer Normalization
\[y = \frac{x - \text{E}[x]}{\sqrt{\text{Var}[x] + \epsilon}} * \gamma + \beta\]elementwise_affine
is a boolean value that when set to True, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases).Argument
Description
Data type
Valid range
Required
weights
Weights tensor
Tensor
Yes
biases
Bias tensor
Tensor
Yes
normalized_shape
Shape over which to normalize
int or List[int]
Yes
eps
A value added to the denominator for numerical stability
float
default is 1e-05
No
elementwise_affine
Controls initialization of weights and biases
bool
default is True
No
- class tt_lib.fallback_ops.MaxPool2d(kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] = None, padding: int | Tuple[int, int] = 0, dilation: int | Tuple[int, int] = 1, return_indices: bool = False, ceil_mode: bool = False, channels_last=False, reshape_2d=False)
-
Applies a 2D max pooling over an input signal composed of several input planes.
In the simplest case, the output value of the layer with input size \((N, C, H, W)\), output \((N, C, H_{out}, W_{out})\) and
kernel_size
\((kH, kW)\) can be precisely described as:\[\begin{split}\begin{aligned} out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ & \text{input}(N_i, C_j, \text{stride[0]} \times h + m, \text{stride[1]} \times w + n) \end{aligned}\end{split}\]Argument
Description
Data type
Valid range
Required
kernel_size
The size of the window to take a max over
int or Tuple[int]
Yes
stride
The stride of the window. Default value is kernel_size
int or Tuple[int]
default is kernel_size
No
padding
Implicit negative infinity padding to be added on both sides
int or Tuple[int]
default is 0
No
dilation
A parameter that controls the stride of elements in the window
int or Tuple[int]
default is 1
No
return_indices
If True, will return the max indices along with the outputs.
bool
default is False
No
ceil_mode
If True, will use ceil instead of floor to compute the output shape
bool
default is False
No
- class tt_lib.fallback_ops.AdaptiveAvgPool2d(output_size: int | None | Tuple[int | None, int | None], channels_last=False)
-
Applies a 2D adaptive average pooling over an input signal composed of several input planes.
The output is of size H x W, for any input size. The number of output features is equal to the number of input planes.
Argument
Description
Data type
Valid range
Required
output_size
The target output size of the image
int
int/None or tuple of int/None (size 2)
yes
- class tt_lib.fallback_ops.ceil(input: Tensor)
-
Returns a new tensor with the ceil of the elements of
input
, the smallest integer greater than or equal to each element.Argument
Description
Data type
Valid range
Required
input
Input tensor for ceil
Tensor
Yes
- class tt_lib.fallback_ops.floor(input: Tensor)
-
Returns a new tensor with the floor of the elements of
input
, the largest integer less than or equal to each element.Argument
Description
Data type
Valid range
Required
input
Input tensor for floor
Tensor
Yes
- class tt_lib.fallback_ops.trunc(input: Tensor)
-
Returns a new tensor with the truncated integer values of the elements of
input
.Argument
Description
Data type
Valid range
Required
input
Input tensor for trunc
Tensor
Yes
- class tt_lib.fallback_ops.unary_fmod(input: Tensor, other: float)
-
Applies mod operations and the result has the same sign as the dividend
input
and its absolute value is less than that ofother
.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
Other
Scalar
float
Yes
- class tt_lib.fallback_ops.binary_fmod(input: Tensor, other: Tensor)
-
Applies mod operations and the result has the same sign as the dividend
input
and its absolute value is less than that ofother
.Argument
Description
Data type
Valid range
Required
input
First tensor
Tensor
Yes
Other
Second tensor
Tensor
Yes
- class tt_lib.fallback_ops.bitwise_not(input: Tensor)
-
Computes the bitwise NOT of the given
input
tensor.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
- class tt_lib.fallback_ops.unary_bitwise_or(input: Tensor, other: int)
-
Computes the bitwise OR of
input
andother
.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
other
Immediate value
int
Yes
- class tt_lib.fallback_ops.unary_bitwise_and(input: Tensor, other: int)
-
Computes the bitwise AND of
input
andother
.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
other
Immediate value
int
Yes
- class tt_lib.fallback_ops.unary_bitwise_xor(input: Tensor, other: int)
-
Computes the bitwise XOR of
input
andother
.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
other
Immediate value
int
Yes
- class tt_lib.fallback_ops.binary_bitwise_or(input: Tensor, other: Tensor)
-
Computes the bitwise OR of
input
andother
.Argument
Description
Data type
Valid range
Required
input
First tensor
Tensor
Yes
other
Second tensor
Tensor
Yes
- class tt_lib.fallback_ops.binary_bitwise_and(input: Tensor, other: Tensor)
-
Computes the bitwise AND of
input
andother
.Argument
Description
Data type
Valid range
Required
input
First tensor
Tensor
Yes
other
Second Tensor
Tensor
Yes
- class tt_lib.fallback_ops.binary_bitwise_xor(input: Tensor, other: Tensor)
-
Computes the bitwise XOR of
input
andother
.Argument
Description
Data type
Valid range
Required
input
First tensor
Tensor
Yes
other
Second tensor
Tensor
Yes
- class tt_lib.fallback_ops.unary_bitwise_left_shift(input: Tensor, other: int)
-
Computes the left arithmetic shift of
input
byother
bits. The input tensor must be of integral type.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
other
Immediate value
int
Yes
- class tt_lib.fallback_ops.unary_bitwise_right_shift(input: Tensor, other: int)
-
Computes the right arithmetic shift of
input
byother
bits. The input tensor must be of integral type. In any case, if the value of the right operand is negative or is greater or equal to the number of bits in the promoted left operand, the behavior is undefined.Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
other
Immediate value
int
Yes
- class tt_lib.fallback_ops.binary_bitwise_left_shift(input: Tensor, other: Tensor)
-
Computes the left arithmetic shift of
input
byother
bits. The input tensor must be of integral type.Argument
Description
Data type
Valid range
Required
input
First tensor
Tensor
Yes
other
Second tensor
Tensor
Yes
- class tt_lib.fallback_ops.binary_bitwise_right_shift(input: Tensor, other: Tensor)
-
Computes the right arithmetic shift of
input
byother
bits. The input tensor must be of integral type. In any case, if the value of the right operand is negative or is greater or equal to the number of bits in the promoted left operand, the behavior is undefined.Argument
Description
Data type
Valid range
Required
input
First tensor
Tensor
Yes
other
Second tensor
Tensor
Yes
- class tt_lib.fallback_ops.torch_argmax(input: Tensor, dim: int = None, keepdim: bool = False)
-
Returns the indices of the maximum values of a tensor along a dimension.
Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
dim
Dimension along which to compute the argmax
int
Yes
keepdim
Whether to retain the dimensionality of input
bool
Yes
- class tt_lib.fallback_ops.torch_argmin(input: Tensor, dim: int, keepdim: bool)
-
Returns the indices of the minimum values of a tensor along a dimension.
Argument
Description
Data type
Valid range
Required
input
Input tensor
Tensor
Yes
dim
Dimension along which to compute the argmin
int
Yes
keepdim
Whether to retain the dimensionality of input
bool
Yes
Experimental Operations
Operations in this section are experimental, don’t have full support, and may behave in unexpected ways.
Fused Operations from tt_lib
Mini-Graph Library
We have a variety of common operations that require fusion of multiple base operations together.
- tt_lib.fused_ops.linear.Linear(in_features: int, out_features: int, weight: List[int | float], bias, device)
-
Returns a function that performs a Linear operation with optional bias.
weight
must be the weight as a tilized list of values.
- tt_lib.fused_ops.layernorm.Layernorm(gamma: float, beta: float, epsilon: float, H, W, device, num_dims=2)
-
Returns a function that performs LayerNorm with parameters.
H, W correspond to normalized_shape in pytorch Layernorm spec
Note: Note that the only
num_dims
supported at the moment is2
.
Complex Operations (Type 2)
Type 2 Complex representation allows for more flexible storage than earlier one while providing same set of operations; specifically this storage allows for compute without the cost of split-concat implicit in the Type 1 contiguous representations.