Getting Started

This document walks you through how to set up TT-XLA. TT-XLA is a front end for TT-Forge that is primarily used to ingest JAX models via jit compile, providing a StableHLO (SHLO) graph to the TT-MLIR compiler. TT-XLA leverages PJRT to integrate JAX, TT-MLIR and Tenstorrent hardware. Please see this blog post for more information about PJRT project. This project started as a fork of iree-pjrt, but has since been refactored and diverged.

NOTE: Currently, only Tenstorrent Nebula boards are supported.

This is the main Getting Started page. There are two additional Getting Started pages depending on what you want to do. They are all described here, with links provided to each.

The following topics are covered:

NOTE: If you encounter issues, please request assistance on the TT-XLA Issues page.

Setup Options

TT-XLA can be used to run JAX models on Tenstorrent's AI hardware. Because TT-XLA is open source, you can also develop and add features to it. Setup instructions differ based on the task. You have the following options, listed in order of difficulty:

Configuring Hardware

Before setup can happen, you must configure your hardware. You can skip this section if you already completed the configuration steps. Otherwise, this section of the walkthrough shows you how to do a quick setup using TT-Installer.

  1. Configure your hardware with TT-Installer using the Software Installation section here.

  2. Reboot your machine.

  3. Make sure hugepages is enabled:

sudo systemctl enable --now 'dev-hugepages\x2d1G.mount'
sudo systemctl enable --now tenstorrent-hugepages.service
  1. Please ensure that after you run the TT-Installer script, after you complete reboot and set up hugepages, you activate the virtual environment it sets up - source ~/.tenstorrent-venv/bin/activate.

  2. After your environment is running, to check that everything is configured, type the following:

tt-smi

You should see the Tenstorrent System Management Interface. It allows you to view real-time stats, diagnostics, and health info about your Tenstorrent device.

TT-SMI

Installing a Wheel and Running an Example

To install a wheel and run an example model, do the following:

  1. Make sure you are in an active virtual environment. This walkthrough uses the same environment you activated to look at TT-SMI in the Configuring Hardware section. If you are using multiple TT-Forge front ends to run models, you may want to set up a separate virtual environment instead. For example:
python3 -m venv .xla-venv
source .xla-venv/bin/activate
  1. Install the wheel in your active virtual environment:
pip install pjrt-plugin-tt --extra-index-url https://pypi.eng.aws.tenstorrent.com/

NOTE: You can pull pre-releases (these may not be stable, so proceed with caution) by adding the --pre flag directly after pip install. You can also choose a wheel from the nightly release page.

  1. You are now ready to try running a model. Navigate to the section of the TT-Forge repo that contains TT-XLA demos.

  2. For this walkthrough, the demo in the gpt2 folder is used. In the gpt2 folder, in the requirements.txt file, you can see that flax and transformers are necessary to run the demo. Install them:

pip install flax transformers
  1. Download the demo.py file from the gpt2 folder inside your activated virtual environment in a place where you can run it. The demo you are about to run takes a piece of text and tries to predict the next word that logically follows.

  2. Run the model:

python demo.py
  1. If all goes well you should see the prompt "The capital of France is", the predicted next token, the probability it will occur, and a list of other ranked options that could follow instead.

Other Setup Options

If you want to keep your environment completely separate in a Docker container, or you want to develop TT-XLA further, this section links you to the pages with those options:

Where to Go Next

Now that you have set up the TT-XLA wheel, you can compile and run other demos. See the TT-XLA folder in the TT-Forge repo for other demos you can try.

Getting Started with Docker

This document walks you through how to set up TT-XLA using a Docker image. There are two other available options for getting started:

  • Installing a Wheel - if you do not want to use Docker, and prefer to use a virtual environment by itself instead, use this method.
  • Building from Source - if you plan to develop TT-XLA further, you must build from source, and should use this method.

The following topics are covered:

Configuring Hardware

Before setup can happen, you must configure your hardware. You can skip this section if you already completed the configuration steps. Otherwise, follow the instructions on the Getting Started page.

Setting up the Docker Container

This section walks through the installation steps for using a Docker container for your project.

To install, do the following:

  1. Install Docker if you do not already have it:
sudo apt update
sudo apt install docker.io -y
sudo systemctl start docker
sudo systemctl enable docker
  1. Test that Docker is installed:
docker --version
  1. Add your user to the Docker group:
sudo usermod -aG docker $USER
newgrp docker
  1. Run the Docker container:
docker run -it --rm \
  --device /dev/tenstorrent \
  -v /dev/hugepages-1G:/dev/hugepages-1G \
  ghcr.io/tenstorrent/tt-xla-slim:latest

NOTE: You cannot isolate devices in containers. You must pass through all devices even if you are only using one. You can do this by passing --device /dev/tenstorrent. Do not try to pass --device /dev/tenstorrent/1 or similar, as this type of device-in-container isolation will result in fatal errors later on during execution.

  1. If you want to check that it is running, open a new tab with the Same Command option and run the following:
docker ps

Running Models in Docker

This section shows you how to run a model using Docker. The provided example is from the TT-Forge repo. Do the following:

  1. Inside your running Docker container, clone the TT-Forge repo:
git clone https://github.com/tenstorrent/tt-forge.git
  1. Set the path for Python:
export PYTHONPATH=/tt-forge:$PYTHONPATH
  1. Navigate into TT-Forge and run the following command:
git submodule update --init --recursive
  1. Navigate back out of the TT-Forge directory.

  2. Run a model. For this example, the demo.py for opt_125m is used. Similar to gpt2, this model predicts what the next word in a sentence is likely to be. The requirements.txt file shows that you need to install flax and transformers:

pip install flax transformers
  1. After completing installation, run the following:
python tt-forge/demos/tt-xla/opt_125m/demo.py

If all goes well, you should get an example prompt saying 'The capital of France is.' The prediction for the next term is listed, along with the probability it will occur. This is followed by a table of other likely choices.

Where to Go Next

Now that you have set up TT-XLA, you can compile and run your own models, or try some of the other demos. You can find TT-XLA demos in the TT-Forge directory.

Getting Started with Building from Source

This document describes how to build the TT-XLA project on your local machine. You must build from source if you want to develop for TT-XLA. If you only want to run models, please choose one of the following sets of instructions instead:

The following topics are covered:

NOTE: If you encounter issues, please request assistance on the TT-XLA Issues page.

Configuring Hardware

Before setup can happen, you must configure your hardware. You can skip this section if you already completed the configuration steps. Otherwise, follow the instructions on the Getting Started page.

System Dependencies

TT-XLA has the following system dependencies:

  • Ubuntu 22.04
  • Python 3.11
  • python3.11-venv
  • Clang 17
  • GCC 11
  • Ninja
  • CMake 4.0.3

Installing Python

If your system already has Python installed, make sure it is Python 3.11:

python3 --version

If not, install Python:

sudo apt install python3.11

Installing CMake 4.0.3

To install CMake 4 or higher, do the following:

  1. Install CMake 4.0.3:
pip install cmake==4.0.3
  1. Check that the correct version of CMake is installed:
cmake --version

If you see cmake version 4.0.3 you are ready for the next section.

Installing Clang 17

To install Clang 17, do the following:

  1. Install Clang 17:
wget https://apt.llvm.org/llvm.sh
chmod u+x llvm.sh
sudo ./llvm.sh 17
sudo apt install -y libc++-17-dev libc++abi-17-dev
sudo ln -s /usr/bin/clang-17 /usr/bin/clang
sudo ln -s /usr/bin/clang++-17 /usr/bin/clang++
  1. Check that the selected GCC candidate using Clang 17 is using 11:
clang -v
  1. Look for the line that starts with: Selected GCC installation:. If it is something other than GCC 11, and you do not see GCC 11 listed as an option, please install GCC 11 using:
sudo apt-get install gcc-11 lib32stdc++-11-dev lib32gcc-11-dev
  1. If you see GCC 12 listed as installed and listed as the default choice, uninstall it with:
sudo rm -rf /usr/bin/../lib/gcc/x86_64-linux-gnu/12

Installing Ninja

To install Ninja, do the following:

sudo apt install ninja-build

Installing OpenMPI

To install OpenMPI, do the following:

sudo wget -q https://github.com/dmakoviichuk-tt/mpi-ulfm/releases/download/v5.0.7-ulfm/openmpi-ulfm_5.0.7-1_amd64.deb -O /tmp/openmpi-ulfm.deb && sudo apt install /tmp/openmpi-ulfm.deb

Installing Additional Dependencies

TT-XLA additionally requires the following libraries:

sudo apt install protobuf-compiler libprotobuf-dev
sudo apt install ccache
sudo apt install libnuma-dev
sudo apt install libhwloc-dev
sudo apt install libboost-all-dev

Build Process

TT-XLA integration with the TT-MLIR compiler is still in progress. Currently TT-XLA depends on the TT-MLIR toolchain to build from source. This build flow provides an easy way to experiment with TT-XLA, StableHLO, and the TT-MLIR infrastructure. The build process will be updated in the future to enhance the user experience.

Building the TT-MLIR Toolchain

Before compiling TT-XLA, the TT-MLIR toolchain needs to be built:

Building TT-XLA

Before running these commands to build TT-XLA, please ensure that the environment variable TTMLIR_TOOLCHAIN_DIR is set to point to the TT-MLIR toolchain directory created above as part of the TT-MLIR environment setup (for example export TTMLIR_TOOLCHAIN_DIR=/opt/ttmlir-toolchain/). You can also set export LOGGER_LEVEL=DEBUG in order to enable debug logs, or export LOGGER_LEVEL=VERBOSE to enable even more verbose logs like printing intermediate IR in compiler passes. To build TT-XLA do the following:

  1. Make sure you are not in the TT-MLIR build directory, and you are in the location where you want TT-XLA to install.

  2. Clone TT-XLA:

git clone https://github.com/tenstorrent/tt-xla.git
  1. Navigate into the TT-XLA folder:
cd tt-xla
  1. Initialize third-party submodules:
git submodule update --init --recursive
  1. Run the following set of commands to build TT-XLA (this will build the PJRT plugin and install it into venv):
source venv/activate
cmake -G Ninja -B build # -DCMAKE_BUILD_TYPE=Debug in case you want debug build
cmake --build build
  1. To verify that everything is working correctly, run the following command:
python -c "import jax; print(jax.devices('tt'))"

The command should output all available TT devices, e.g. [TTDevice(id=0, arch=Wormhole_b0)]

  1. (optional) If you want to build the TT-XLA wheel, run the following command:
cd python_package
python setup.py bdist_wheel

The above command outputs a python_package/dist/pjrt_plugin_tt*.whl file which is self-contained. To install the created wheel, run:

pip install dist/pjrt_plugin_tt*.whl

The wheel has the following structure:

pjrt_plugin_tt/                     # PJRT plugin package
    |-- __init__.py
    |-- pjrt_plugin_tt.so               # PJRT plugin binary
    |-- tt-metal/                       # tt-metal runtime dependencies (kernels, riscv compiler/linker, etc.)
    `-- lib/                            # shared library dependencies (tt-mlir, tt-metal)
jax_plugin_tt/                      # Thin JAX wrapper
    `-- __init__.py                     # imports and sets up pjrt_plugin_tt for XLA
torch_plugin_tt                     # Thin PyTorch/XLA wrapper
    `-- __init__.py                     # imports and sets up pjrt_plugin_tt for PyTorch/XLA

It contains a custom Tenstorrent PJRT plugin (pjrt_plugin_tt.so) and its dependencies (tt-mlir and tt-metal). Additionally, there are thin wrappers for JAX (jax_plugin_tt) and PyTorch/XLA (torch_plugin_tt) that import the PJRT plugin and set it up for use with the respective frameworks.

Testing

The TT-XLA repo contains various tests in the tests directory. To run an individual test, pytest -svv is recommended in order to capture all potential error messages down the line. Multi-chip tests can be run only on specific Tenstorrent hardware, therefore these tests are structured in folders named by the Tenstorrent cards/systems they can be run on. For example, you can run pytest -v tests/jax/multi_chip/n300 only on a system with an n300 Tenstorrent card. Single-chip tests can be run on any system with the command pytest -v tests/jax/single_chip.

Common Build Errors

  • Building TT-XLA requires clang-17. Please make sure that clang-17 is installed on the system and clang/clang++ links to the correct version of the respective tools.
  • Please also see the TT-MLIR docs for common build errors.

Pre-commit

Pre-commit applies a git hook to the local repository such that linting is checked and applied on every git commit action. Install it from the root of the repository using:

source venv/activate
pre-commit install

If you have already committed something locally before installing the pre-commit hooks, you can run this command to check all files:

pre-commit run --all-files

For more information please visit pre-commit.

Test Infra

Test infra consists of main "tester" classes and a few helper ones. Its main goal is making test writing easy.

Here is a brief class diagram of the infra:

infra

Op and Graph Tests

Op tester exposes easy to use functions:

run_op_test(...)
run_op_test_with_random_inputs(...)

They wrap the instantiation of the OpTester and all the underlying complexity. User just need to pass the op (python function) they want to test to one of these functions like this:

def test_add(x_shape: tuple, y_shape: tuple):
    def add(x: jax.Array, y: jax.Array) -> jax.Array:
        return jnp.add(x, y)

    run_op_test_with_random_inputs(add, [x_shape, y_shape])

and that's it.

GraphTester is at the moment identical to OpTester, and it too exposes

run_graph_test(...)
run_graph_test_with_random_inputs(...)

which are meant to be used in the same way as for op tests.

Model Tests

Models are tested by inheriting one of *ModelTester classes and overriding required methods. Please read docstring of appropriate class you want to inherit for more information.

Jax Model Example

First, you define a model:

class MNISTMLPModel(nn.Module):
    hidden_sizes: tuple[int]

    @nn.compact
    def __call__(self, x: jax.Array):
        x = x.reshape((x.shape[0], -1))

        for h in self.hidden_sizes:
            x = nn.Dense(features=h)(x)
            x = nn.relu(x)

        x = nn.Dense(features=10)(x)
        x = nn.softmax(x)

        return x

Then you define a tester by inheriting JaxModelTester:

class MNISTMLPTester(JaxModelTester):
    def __init__(
        self,
        hidden_sizes: Sequence[int],
        comparison_config: ComparisonConfig = ComparisonConfig(),
        run_mode: RunMode = RunMode.INFERENCE,
    ) -> None:
        self._hidden_sizes = hidden_sizes
        super().__init__(comparison_config, run_mode)

    # @override
    def _get_model(self) -> nn.Module:
        return MNISTMLPModel(self._hidden_sizes)

    # @override
    def _get_forward_method_name(self) -> str:
        return "apply"

    # @override
    def _get_input_activations(self) -> Sequence[jax.Array]:
        key = jax.random.PRNGKey(37)
        img = jax.random.normal(key, (4, 28, 28, 1))  # B, H, W, C
        # Channels is 1 as MNIST is in grayscale.
        return img

    # @override
    def _get_forward_method_args(self):
        inp = self._get_input_activations()
        parameters = self._model.init(jax.random.PRNGKey(42), inp)
        return [parameters, inp]

Finally, you run the test:

@pytest.fixture
def inference_tester(request) -> MNISTMLPTester:
    return MNISTMLPTester(request.param)

@pytest.mark.parametrize(
    "inference_tester", [(256, 128, 64)], indirect=True, ids=lambda val: f"{val}"
)
def test_mnist_mlp_inference(inference_tester: MNISTMLPTester):
    inference_tester.test()

Code Generation Guide

Convert JAX or PyTorch models into standalone Python or C++ source code targeting Tenstorrent hardware.


Quick Reference

FrameworkBackend OptionOutputStandalone?
PyTorch/JAXcodegen_pyPython (.py)No (requires TT-XLA build)
PyTorch/JAXcodegen_cppC++ (.cpp, .h)Yes

New to code generation? Start with the Python Code Generation Tutorial for a hands-on walkthrough.


Overview

Code generation (powered by TT-Alchemist) transforms your model into human-readable source code that directly calls the TT-NN library, enabling:

  • Customization - Modify generated code to add optimizations or integrate with existing infrastructure
  • Model Portability - Extract models into standalone code deployable without the full framework stack
  • Inspection & Debugging - Examine generated source to understand exact operations performed
  • Education - Study how high-level framework operations translate to TT-NN library calls

Technical Note: Internally referred to as TT-Alchemist, EmitPy (Python generation), or EmitC (C++ generation).


Basic Usage

PyTorch

Configure code generation options before compiling your model:

import torch
import torch_xla.core.xla_model as xm

# Configure code generation
options = {
    "backend": "codegen_py",              # Or "codegen_cpp" for C++
    "export_path": "torch_codegen_output" # Output directory
}
torch_xla.set_custom_compile_options(options)

# Standard PyTorch workflow
device = xm.xla_device()
model = YourModel()
model.compile(backend="tt")
model = model.to(device)
x = torch.randn(32, 32).to(device)

# Trigger code generation
output = model(x)

Output location: torch_codegen_output/ directory containing:

  • ttir.mlir - TTIR intermediate representation
  • *.py or *.cpp/*.h - Generated source files

JAX

Pass compile options directly to jax.jit():

import jax
from flax import nnx

def forward(graphdef, state, x):
    model = nnx.merge(graphdef, state)
    return model(x)

# JIT compile with code generation
jitted_forward = jax.jit(
    forward,
    compiler_options={
        "backend": "codegen_py",          # Or "codegen_cpp" for C++
        "export_path": "jax_codegen_output"
    }
)

# Trigger code generation
result = jitted_forward(graphdef, state, x)

Output location: jax_codegen_output/ directory containing:

  • ttir.mlir - TTIR intermediate representation
  • *.py or *.cpp/*.h - Generated source files

Configuration Options

Required Options

OptionTypeDescription
backendstringCode generation target:
"codegen_py" - Generate Python code
"codegen_cpp" - Generate C++ code
export_pathstringDirectory for generated code (created if doesn't exist)

Example Configurations

Python Generation:

options = {
    "backend": "codegen_py",
    "export_path": "./generated_python"
}

C++ Generation:

options = {
    "backend": "codegen_cpp",
    "export_path": "./generated_cpp"
}

Generated Output

Directory Structure

After code generation completes, your export_path directory contains:

<export_path>/
├── ttir.mlir          # TTIR intermediate representation (debugging)
├── main.py/cpp        # Generated Python/C++ code
└── run                # Execution script

File Descriptions

ttir.mlir - Tenstorrent Intermediate Representation

  • High-level representation after initial compilation
  • Useful for debugging compilation issues
  • Human-readable MLIR dialect

Generated Python (*.py) - Python Implementation

  • Direct TT-NN API calls
  • Human-readable and modifiable
  • Not standalone - requires TT-XLA build to execute
  • Includes run script for execution

Generated C++ (*.cpp, *.h) - C++ Implementation

  • Direct TT-NN API calls
  • Human-readable and modifiable
  • Fully standalone - only requires TT-NN library
  • Can be integrated into existing C++ projects

Code Generation Behavior

Expected Process Flow

  1. ✅ Model compiles through TT-XLA pipeline
  2. ✅ Code generation writes files to export_path
  3. ⚠️ Process may terminate with error (expected behavior)

Process Termination

Important: The process terminating after code generation is expected behavior when running through the frontend.

You'll see this error message:

Standalone solution was successfully generated. Executing codegen through the frontend is not supported yet.
Unfortunately your program will now crash :(
ERROR:root:Caught an exception when exiting the process.
RuntimeError: Bad StatusOr access: UNIMPLEMENTED: Error code: 12

This is normal! Your code was generated successfully. The error simply indicates that continuing execution through the frontend isn't currently supported.

Verifying Success

Check that code generation succeeded:

ls -la <export_path>/

You should see:

  • ttir.mlir file exists
  • ✅ Generated source files (.py or .cpp/.h)
  • ✅ File sizes are non-zero
  • ✅ (Python only) Executable run script exists

Use Cases

1. Custom Optimization

Scenario: Hand-tune generated code for specific workloads

Benefits:

  • Modify operation ordering
  • Adjust memory layouts
  • Add custom kernel calls

Best for: Performance-critical applications, specialized hardware configurations

2. Model Deployment & Portability

Scenario: Deploy models without the full JAX/PyTorch stack

Benefits:

  • Smaller deployment footprint
  • Fewer dependencies
  • Direct control over execution

Best for: Production environments, edge devices, embedded systems

3. Model Inspection & Debugging

Scenario: Understand what operations your model performs

Benefits:

  • Examine exact TT-NN API calls
  • Identify performance bottlenecks
  • Understand memory access patterns

Best for: Performance optimization, debugging accuracy issues

4. Educational & Research

Scenario: Learn how high-level operations translate to hardware

Benefits:

  • Study framework→hardware mapping
  • Experiment with low-level optimizations
  • Understand compiler transformations

Best for: Learning, research, optimization experiments


Advanced Topics

Alternative: Code Generation via Serialization

Note: Most users should use compile options (documented above). This method is provided for advanced use cases.

You can also invoke code generation by hooking into the serialization infrastructure and running TT-Alchemist directly on the results.

When to use this:

  • Custom compilation workflows
  • Integration with existing build systems
  • Automated pipeline generation

Examples:



Next Steps

  1. Try the tutorial: Follow the Python Code Generation Tutorial for hands-on experience
  2. Experiment: Try both codegen_py and codegen_cpp backends
  3. Inspect code: Examine generated code to understand TT-NN API usage
  4. Customize: Modify generated code to optimize for your use case
  5. Deploy: Integrate generated C++ code into your applications

Questions or issues? Visit TT-XLA GitHub Issues for support.

Tutorial: Generate Python Code from Your Model

Learn how to convert PyTorch and JAX models into standalone Python code using TT-XLA's code generation feature.

What You'll Learn

By the end of this tutorial, you'll be able to:

  • Generate Python code from PyTorch/JAX models using TT-XLA
  • Execute the generated code on Tenstorrent hardware
  • Inspect and understand the TT-NN API calls in the generated code

Time to complete: ~15 minutes

What is Code Generation?

Code generation (also called "EmitPy" or powered by "TT-Alchemist") transforms your high-level model into human-readable Python source code that directly calls the TT-NN library. This lets you inspect, modify, and deploy models without the full TT-XLA runtime.

For complete conceptual overview and all options, see the Code Generation Guide.


Prerequisites

Before starting, ensure you have:

  • Access to Tenstorrent hardware (via IRD or physical device) - jump to Step 1
  • TT-XLA Docker image: ghcr.io/tenstorrent/tt-xla/tt-xla-ird-ubuntu-22-04:latest - jump to Step 2

Step-by-Step Guide

Step 1: Reserve Hardware and Start Docker Container

Reserve Tenstorrent hardware with the TT-XLA Docker image:

ird reserve --docker-image ghcr.io/tenstorrent/tt-xla/tt-xla-ird-ubuntu-22-04:latest [additional ird options]

Tip: The [additional ird options] should include your typical IRD configuration like architecture, number of chips, etc.

Step 2: Clone and Setup TT-XLA

Inside your Docker container, clone the repository:

git clone https://github.com/tenstorrent/tt-xla.git
cd tt-xla
git checkout sdjordjevic/tt_xla_codegen

Initialize submodules (required for dependencies):

git submodule update --init --recursive

Expected output: Git will download all third-party dependencies. This may take a few minutes.

Step 3: Build TT-XLA

Set up the build environment and compile the project:

# Set the toolchain directory (pre-installed in Docker image)
export TTMLIR_TOOLCHAIN_DIR=/opt/ttmlir-toolchain/

# Activate the Python virtual environment
source venv/activate

# Configure the build
cmake -G Ninja -B build

# Build the project (this may take 10-15 minutes)
cmake --build build

Debug Build: Add -DCMAKE_BUILD_TYPE=Debug to the cmake configure command if you need debug symbols.

Step 4: Run Code Generation Example

Choose your framework and run the example:

PyTorch:

python examples/pytorch/codegen_via_options_example.py

JAX:

python examples/jax/codegen_via_options_example.py

What Happens During Code Generation

Both examples configure TT-XLA with these options:

options = {
    "backend": "codegen_py",  # Generate Python code
    "export_path": "model",   # Output directory name
}

The process will:

  1. ✅ Compile your model through the TT-XLA pipeline
  2. ✅ Generate Python source code in the model/ directory
  3. ⚠️ Terminate with an error message (this is expected behavior)

Expected terminal output:

Standalone solution was successfully generated. Executing codegen through the frontend is not supported yet. Unfortunately your program will now crash :(
ERROR:root:Caught an exception when exiting the process. Exception:
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch_xla/__init__.py", line 216, in _prepare_to_exit
    _XLAC._prepare_to_exit()
RuntimeError: Bad StatusOr access: UNIMPLEMENTED: Error code: 12

Don't worry! Despite the error message, your code was generated successfully. This termination is a known limitation when running code generation through the frontend.

Generated Files

Check the model/ directory for your generated code:

ls -la model/

You should see:

  • main.py - Generated Python code with TT-NN API calls
  • run - Executable shell script to run the generated code
  • ttir.mlir - TTIR intermediate representation (for debugging)

Step 5: Execute the Generated Code

Navigate to the model directory and run the execution script:

cd model
./run

What the run Script Does

The script automatically:

  • Sets up the Python environment with TT-NN dependencies
  • Configures Tenstorrent hardware settings
  • Executes the generated Python code
  • Displays inference results

Expected output: You should see inference results printed to the console, showing your model running successfully on Tenstorrent hardware.


Next Steps

Now that you've successfully generated and executed code:

Inspect the Generated Code

Open model/main.py to see how your PyTorch/JAX operations map to TT-NN API calls:

cat model/main.py

Look for patterns like:

  • Tensor allocation and initialization
  • Operation implementations (matrix multiply, activation functions, etc.)
  • Memory management and device synchronization

Customize the Generated Code

Try modifying operations in main.py:

  • Change tensor shapes or data types
  • Add print statements to debug intermediate values
  • Optimize memory layouts or operation ordering

Generate C++ Code

Want C++ instead of Python? Change the backend:

options = {
    "backend": "codegen_cpp",  # Generate C++ code
    "export_path": "model_cpp",
}

The generated C++ code is fully standalone and can be integrated into existing C++ projects.

Generate resnet TTNN code using following example:

Learn More


Summary

What you accomplished:

  • ✅ Built TT-XLA from source in Docker
  • ✅ Generated Python code from a PyTorch/JAX model
  • ✅ Executed the generated code on Tenstorrent hardware
  • ✅ Learned where to find and inspect the generated code

Key takeaways:

  • Code generation creates inspectable implementations
  • The process intentionally terminates after generation (current limitation)
  • Generated code can be modified and optimized for your use case

Need help? Visit the TT-XLA Issues page or check the Code Generation Guide for more details.

Troubleshooting

Code Generation Fails

Symptom:

ERROR: tt-alchemist generatePython failed

Cause: Code generation process encountered an error

Solutions:

  1. Check export path is writable:

    mkdir -p <export_path>
    touch <export_path>/test && rm <export_path>/test
    
  2. Verify TTIR was generated:

    ls -lh <export_path>/ttir.mlir
    

    If ttir.mlir is missing or empty (0 bytes), compilation failed before code generation.

  3. Check for compilation errors: Review the full output for errors before the "generatePython failed" message.

  4. Try with minimal model: Test with a simple model to isolate the issue:

    class MinimalModel(torch.nn.Module):
        def forward(self, x):
            return x + 1
    

Export Path Not Set

Symptom:

Compile option 'export_path' must be provided when backend is not 'TTNNFlatbuffer'

Cause: The export_path option is missing

Solution: Add export_path to your compiler options:

options = {
    "backend": "codegen_py",
    "export_path": "./output"  # ← Add this
}

Generated Code Execution Fails

Symptom: Errors when running generated Python code via ./run

Possible Causes & Solutions:

  1. TT-XLA not built:

    cd /path/to/tt-xla
    cmake --build build
    
  2. Hardware not accessible:

    tt-smi  # Should show your Tenstorrent devices
    
  3. Wrong hardware configuration:

    • Verify generated code matches your hardware setup
    • Check device IDs and chip counts
    • Rebuild TT-XLA if hardware configuration changed
  4. Missing dependencies:

    source venv/activate  # Ensure virtual environment is active
    

Generated C++ Code Won't Compile

Symptom: C++ compilation errors in generated code

Solutions:

  1. Check TT-NN headers are available:

    find /opt/ttmlir-toolchain -name "ttnn*.h"
    
  2. Verify C++ compiler version: Generated code requires C++17 or later

  3. Link against TT-NN library: Ensure your build system links the TT-NN library correctly