Improving Model Performance

This guide covers best practices and techniques for optimizing the performance of PyTorch models running on single chip Tenstorrent hardware using the tt-xla frontend of the forge compiler.

Overview

  1. Optimization Levels - Compiler optimization levels (0, 1, 2) to balance compile and runtime performance
  2. Device Warmup - Eliminate first-run overhead by performing warmup iterations
  3. Data Formats - Use bfloat16 and bfloat8_b for faster computation and reduced memory usage
  4. Runtime Trace - Reduce host-device communication overhead by recording and replaying command sequences
  5. Batch Size Tuning - Find the optimal batch size to maximize throughput for your model

For a complete working example, see the code below from examples/pytorch/mnist_performant.py, which demonstrates all these optimizations together.

Mnist peformant example:

# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import os
import time

# Required to enable runtime tracing.
os.environ["TT_RUNTIME_TRACE_REGION_SIZE"] = "10000000"

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr


class MNISTCNNDropoutModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = self.dropout1(x)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)

        x = self.dropout2(x)
        x = self.fc2(x)

        output = F.log_softmax(x, dim=1)
        return output


def mnist_performant():
    """Minimal example of running MNIST CNN model with all performance options enabled."""
    # Initialize model.
    model = MNISTCNNDropoutModel()

    # Put it in inference mode.
    model = model.eval()

    # Convert weights and ops to bfloat16.
    model = model.to(dtype=torch.bfloat16)

    # Set relevant compiler options.
    torch_xla.set_custom_compile_options(
        {
            # Set to highest optimization level.
            "optimization_level": 2,
            # Enable runtime trace.
            "enable_trace": "true",
            # Cast weights and ops to bfloat8_b.
            "enable_bfp8_conversion": "true",
        }
    )

    # Compile the model for TT backend.
    model.compile(backend="tt")

    # Connect the device.
    device = xm.xla_device()

    # Move model to device.
    model = model.to(device)

    # Set batch size to optimal value.
    batch_size = 64

    # Warmup the device with 3 runs. This is needed as first 2 iterations are slow.
    warmup_input = generate_input(batch_size, torch.bfloat16)
    run_inference(model, device, warmup_input, loop_count=3, verbose=False)

    # Run fast inference loop and measure performance.
    inference_input = generate_input(batch_size, torch.bfloat16)
    run_inference(model, device, inference_input, loop_count=128, verbose=True)


def run_inference(model, device, input, loop_count, verbose=True):
    """Run inference and measure performance."""
    iteration_times = []
    # Run fast inference loop.
    with torch.no_grad():
        for i in range(loop_count):
            start = time.perf_counter_ns()

            # Move input to device.
            device_input = input.to(device)
            # Run the model.
            output = model(device_input)
            # Move output back to CPU.
            output.to("cpu")

            end = time.perf_counter_ns()

            iteration_times.append(end - start)
            if verbose:
                print(f"Iteration {i} took:\t{iteration_times[-1] / 1_000_000} ms")

    # Calculate and print average throughput.
    batch_size = input.shape[0]
    total_time = sum(iteration_times)
    samples_per_second = batch_size * loop_count / (total_time / 1_000_000_000)
    if verbose:
        print(f"Average throughput: {round(samples_per_second)} samples/second")


def generate_input(batch_size, dtype):
    """Helper to generate random inputs for inference."""
    return torch.randn((batch_size, 1, 28, 28), dtype=dtype)


if __name__ == "__main__":
    # By default torch_xla uses the CPU device so we have to set it to TT device.
    xr.set_device_type("TT")

    mnist_performant()

Let's break down each performance optimization in detail.


1. Optimization Levels

The optimization_level compiler option controls multiple optimization passes from tt-mlir in a coordinated way. tt-xla offers three levels (0, 1, 2).

To set the optimization level, use:

torch_xla.set_custom_compile_options({
    "optimization_level": "1",
})

Optimization Levels Breakdown

Level 0 (Default)

  • All MLIR optimizer passes disabled
    • All tensors in DRAM
  • Use for: Iterating fast, safest option
  • Compilation time: Fastest
  • Runtime performance: Slowest
  • Basic optimizations enabled
    • Const-eval of Conv2D weights preprocessing and fusion patterns
    • All tensors in DRAM
  • Use for: General model compilation, good balance
  • Compilation time: Moderate
  • Runtime performance: Good

Level 2

  • Advanced optimizations enabled, all level 1 plus:
    • Maximize number of tensors to put in SRAM instead of DRAM
  • Use for: Maximum performance
  • Compilation time: Slower (one-time cost)
  • Runtime performance: Best

2. Device Warmup

Run at least 3 dummy iterations before measuring performance:

# Warmup iterations.
with torch.no_grad():
    for _ in range(3):
        output = model(input)

Why Warmup is Necessary

The first iteration is extremely slow due to it running:

  • Model compilation and optimization
  • Op kernel compilation
  • Transferring of model weights to device
  • Const-eval of model weight and constants
  • Caching of op kernels on device

The second iteration is needed for:

  • Capturing runtime trace to reduce op dispatch overhead (Section 4)

All of the above is a one time fixed cost and all subsequent iterations of the model will be orders of magnitude faster.


3. Data Formats

TT Hardware supports multiple lower precision data formats (docs). For use trough tt-xla try the following:

  • bfloat16
  • bfloat8_b

bfloat16

To use bfloat16, convert your model in pytorch before compiling:

# Convert model weights and operations to bfloat16.
model = model.to(dtype=torch.bfloat16)

Ensure your input tensors match the model's data type:

inputs = inputs.to(torch.bfloat16)

bfloat16 (Brain Floating Point 16-bit) provides:

  • Faster computation compared to fp32
  • Reduced memory usage (50% of fp32)
  • Better utilization on TT hardware
  • Minimal to no accuracy loss for most workloads

bfloat8_b

Enable bfp8 conversion using compile options. The model MUST be cast to bfloat16 before compilation.

torch_xla.set_custom_compile_options({
    "enable_bfp8_conversion": "true",  # Enable bfloat8_b
})

bfloat8_b (Block Float 8-bit) provides even faster computation and more memory reduction.

Notes

  • Possibility of accuracy loss for some workloads
  • Verify output: Check that accuracy is acceptable for your use case
  • Automatic conversion: Model is automatically converted during compilation (for bfp8)
  • Not always beneficial: Profile your specific model to verify improvement

4. Runtime Trace

What is Runtime Trace?

Runtime tracing is a performance optimization that eliminates some of the host to device communication by recording the commands for dispatching operations and replaying these as a single command when executing a trace.

How to Enable

Step 1: Set environment variable before importing torch_xla:

import os
os.environ["TT_RUNTIME_TRACE_REGION_SIZE"] = "10000000"  # ~10MB

Step 2: Enable trace in compiler options:

torch_xla.set_custom_compile_options({
    "enable_trace": "true",
})

Requirements

  • TT_RUNTIME_TRACE_REGION_SIZE should be set (recommended: "10000000" or 10MB)
    • The trace region size determines how much memory is allocated in DRAM for storing the trace. Adjust based on your model.
    • If you see trace-related errors, try increasing this value.
  • Program cache must be enabled with TT_RUNTIME_ENABLE_PROGRAM_CACHE must be set to "1" (This is set by default)

5. Batch Size Tuning

Batch size impacts:

  • Throughput (samples/second) - larger batches typically (not always) increase throughput
  • Latency (time per sample) - larger batches increase per-sample latency
  • Memory usage - larger batches require more device memory

Tuning Process

  1. Typical values to start with (e.g., 1, 2, 4, 8, 16, 32)
  2. Measure throughput for each batch size
  3. Increase batch size until:
    • Throughput plateaus or starts decreasing
      • Sometimes smaller batches can use SRAM much more effectively, leading to an overall greater throughput than using bigger batches
    • Memory is exhausted (OOM error)