Improving Model Performance
This guide covers best practices and techniques for optimizing the performance of PyTorch models running on single-chip Tenstorrent hardware using the TT-XLA frontend of the Forge compiler.
Overview
Optimization Levels - Compiler optimization levels (0, 1, 2) to balance compile and runtime performance
Device Warmup - Eliminate first-run overhead by performing warmup iterations
Data Formats - Use bfloat16 and bfloat8_b for faster computation and reduced memory usage, including manual mixed precision via per-tensor weight dtype overrides
Runtime Trace - Reduce host-device communication overhead by recording and replaying command sequences
Batch Size Tuning - Find the optimal batch size to maximize throughput for your model
For a complete working example, see the code below from examples/pytorch/mnist_performant.py, which demonstrates all these optimizations together.
MNIST performant example
{{#include ../../examples/pytorch/mnist_performant.py}}
Let’s break down each performance optimization in detail.
1. Optimization Levels
The optimization_level compiler option controls multiple optimization passes from TT-MLIR in a coordinated way. TT-XLA offers three levels (0, 1, 2).
To set the optimization level, use:
torch_xla.set_custom_compile_options({
"optimization_level": 1,
})
Optimization Levels Breakdown
Level 0 (Default)
-
All MLIR optimizer passes disabled
All tensors in DRAM
Use for: Iterating fast, safest option
Compilation time: Fastest
Runtime performance: Slowest
Level 1 (Recommended)
-
Basic optimizations enabled
Const-eval of Conv2D weights preprocessing and fusion patterns
All tensors in DRAM
Use for: General model compilation, good balance
Compilation time: Moderate
Runtime performance: Good
Level 2
-
Advanced optimizations enabled, all level 1 plus:
Maximize number of tensors to put in SRAM instead of DRAM
Use for: Maximum performance
Compilation time: Slower (one-time cost)
Runtime performance: Best
2. Device Warmup
Run at least 3 dummy iterations before measuring performance:
# Warmup iterations.
with torch.no_grad():
for _ in range(3):
output = model(input)
Why Warmup is Necessary
The first iteration is extremely slow because it runs:
Model compilation and optimization
Op kernel compilation
Transferring model weights to the device
Const-eval of model weights and constants
Caching of op kernels on the device
The second iteration is needed for:
Capturing runtime trace to reduce op dispatch overhead (Section 4)
All of the above is a one-time fixed cost; all subsequent iterations of the model are orders of magnitude faster.
3. Data Formats
TT hardware supports multiple lower-precision data formats (see the TT-NN data type documentation). For use through TT-XLA, try the following:
bfloat16
bfloat8_b
bfloat16
To use bfloat16, convert your model in pytorch before compiling:
# Convert model weights and operations to bfloat16.
model = model.to(dtype=torch.bfloat16)
Ensure your input tensors match the model’s data type:
inputs = inputs.to(torch.bfloat16)
bfloat16 (Brain Floating Point 16-bit) provides:
Faster computation compared to fp32
Reduced memory usage (50% of fp32)
Better utilization on TT hardware
Minimal to no accuracy loss for most workloads
bfloat8_b
Enable bfp_bf8 weight conversion using compile options. The model must be cast to bfloat16 before compilation.
torch_xla.set_custom_compile_options({
"experimental_weight_dtype": "bfp_bf8", # Cast matmul weights to bfloat8_b
})
bfloat8_b (Block Float 8-bit) weight conversion casts matmul weights to bfp_bf8 format, providing faster computation and reduced memory usage.
Notes
Possibility of accuracy loss for some workloads
Verify output: Check that accuracy is acceptable for your use case
Automatic conversion: Weights are automatically converted during compilation
Not always beneficial: Profile your specific model to verify improvement
Per-Tensor Weight Dtype Overrides (Manual Mixed Precision)
When uniform weight conversion causes accuracy degradation in specific layers, you can override dtypes on a per-tensor basis. This lets you keep sensitive layers at higher precision (e.g. bf16) while converting the rest to a lower format (e.g. bfp_bf8 or bfp_bf4).
Pass a dict mapping parameter names to target dtypes to apply_weight_dtype_overrides():
from tt_torch import apply_weight_dtype_overrides
# Override specific weights by name (glob patterns supported).
apply_weight_dtype_overrides(model, {
"fc2.weight": "bfp_bf8",
})
Call this after creating the model and before torch.compile. See examples/pytorch/mnist_performant.py for a complete working example.
Note: Currently only matmul/linear layer weight overrides are supported. Convolution weights on lower data types are not yet supported through the compiler.
For more advanced usage including JSON configs, the tt-gen-weight-template CLI, and implementation details, see Mixed Precision.
4. Runtime Trace
What is Runtime Trace?
Runtime tracing is a performance optimization that eliminates some of the host to device communication by recording the commands for dispatching operations and replaying these as a single command when executing a trace.
How to Enable
Step 1: Set environment variable before importing torch_xla:
import os
os.environ["TT_RUNTIME_TRACE_REGION_SIZE"] = "10000000" # ~10MB
Step 2: Enable trace in compiler options:
torch_xla.set_custom_compile_options({
"enable_trace": "true",
})
Requirements
-
TT_RUNTIME_TRACE_REGION_SIZEshould be set (recommended:"10000000"or 10MB)The trace region size determines how much memory is allocated in DRAM for storing the trace. Adjust based on your model.
If you see trace-related errors, try increasing this value.
5. Batch Size Tuning
Batch size impacts:
Throughput (samples/second) - larger batches typically (not always) increase throughput
Latency (time per sample) - larger batches increase per-sample latency
Memory usage - larger batches require more device memory
Tuning Process
Typical values to start with (e.g., 1, 2, 4, 8, 16, 32)
Measure throughput for each batch size
-
Increase batch size until:
-
Throughput plateaus or starts decreasing
Sometimes smaller batches can use SRAM much more effectively, leading to an overall greater throughput than using bigger batches
Memory is exhausted (OOM error)
-