Getting Started

This document walks you through how to set up TT-XLA. TT-XLA is a front end for TT-Forge that ingests JAX models via jit compile and PyTorch models through torch-xla, providing StableHLO (SHLO) graphs to the TT-MLIR compiler. TT-XLA leverages PJRT to integrate JAX, TT-MLIR and Tenstorrent hardware. Please see this blog post for more information about the PJRT project.

NOTE: If you encounter issues, please request assistance on the TT-XLA Issues page.

Prerequisites

1. Set Up the Hardware

  • Follow the instructions for the Tenstorrent device you are using at: Hardware Setup

2. Install Software (choose one)

TT-XLA Installation Options


Installing a Wheel and Running an Example

To install a wheel and run an example model, do the following:

Step 1. Install the Latest Wheel:

Download the latest wheel with pip.

pip install pjrt-plugin-tt --extra-index-url https://pypi.eng.aws.tenstorrent.com/

Run the tt-forge-install script to install missing system dependencies.

tt-forge-install

Step 2. Run some models:

Use wget to fetch each demo script into your current directory, and install packages with pip when noted.

MNIST (small CNN)

The mnist.py example runs a simple CNN on a Tenstorrent device and compares the output against a CPU reference.

wget https://raw.githubusercontent.com/tenstorrent/tt-xla/main/examples/pytorch/mnist.py
python mnist.py

You should see the model output and a PCC (Pearson Correlation Coefficient) check confirming the TT device output matches the CPU reference.

Tiny Llama (Hugging Face transformers)

The tiny_llama_demo.py example in the TT-Forge repo loads a small LLM from Hugging Face, compiles it with torch.compile(..., backend="tt"), and prints top-token predictions. You must download the script and install transformers (and its dependencies); the wheel install in Step 1 does not include them. The first run also downloads model weights from Hugging Face over the network.

wget https://raw.githubusercontent.com/tenstorrent/tt-forge/main/demos/tt-xla/nlp/pytorch/tiny_llama_demo.py
pip install transformers
python tiny_llama_demo.py

You should see the prompt "The capital of France is", the predicted next token, the probability it will occur, and a list of other ranked options that could follow instead, for example:

Prompt: `The capital of France is`
Top prediction: `Paris`

Rank   Token           Probability
-----------------------------------
1      'Paris'         36.9141%
2      'located'       10.0098%
3      'the'           8.8867%
4      'a'             4.2480%
5      'in'            2.5391%

Using a Docker Container to Run an Example

This section walks through the installation steps for using a Docker container for your project.

Step 1. Run the Docker container:

docker run -it --rm \
--device /dev/tenstorrent \
-v /dev/hugepages-1G:/dev/hugepages-1G \
ghcr.io/tenstorrent/tt-xla-slim:latest

NOTE: You cannot isolate devices in containers. You must pass through all devices even if you are only using one. You can do this by passing --device /dev/tenstorrent. Do not try to pass --device /dev/tenstorrent/1 or similar, as this type of device-in-container isolation will result in fatal errors later on during execution.

  • If you want to check that it is running, open a new tab with the Same Command option and run the following:

    docker ps
    

Step 2: Running Models in Docker

Use wget to fetch each demo script into your current directory, and install packages with pip when noted.

MNIST (small CNN)

The mnist.py example runs a simple CNN on a Tenstorrent device and compares the output against a CPU reference.

wget https://raw.githubusercontent.com/tenstorrent/tt-xla/main/examples/pytorch/mnist.py
python mnist.py

You should see the model output and a PCC (Pearson Correlation Coefficient) check confirming the TT device output matches the CPU reference.

Tiny Llama (Hugging Face transformers)

The tiny_llama_demo.py example in the TT-Forge repo loads a small LLM from Hugging Face, compiles it with torch.compile(..., backend="tt"), and prints top-token predictions. You must download the script and install transformers (and its dependencies); the slim image does not include them. The first run also downloads model weights from Hugging Face over the network.

wget https://raw.githubusercontent.com/tenstorrent/tt-forge/main/demos/tt-xla/nlp/pytorch/tiny_llama_demo.py
pip install transformers
python tiny_llama_demo.py

You should see the prompt "The capital of France is", the predicted next token, the probability it will occur, and a list of other ranked options that could follow instead, for example:

Prompt: `The capital of France is`
Top prediction: `Paris`

Rank   Token           Probability
-----------------------------------
1      'Paris'         36.9141%
2      'located'       10.0098%
3      'the'           8.8867%
4      'a'             4.2480%
5      'in'            2.5391%

Building from Source

Install from source if you are a developer who wants to develop for TT-XLA.

Step 1: Prerequisites

  • TT-XLA has the following system dependencies:

    • Ubuntu 24.04
    • Python 3.12
    • python3.12-venv
    • Clang 20
    • GCC 13
    • Ninja
    • CMake 4.0.3
  • TT-XLA additionally requires the following libraries:

    sudo apt install protobuf-compiler libprotobuf-dev
    sudo apt install ccache
    sudo apt install libnuma-dev
    sudo apt install libhwloc-dev
    sudo apt install libboost-all-dev
    sudo apt install libnsl-dev
    

Step 2: Building the TT-MLIR Toolchain

  • Before compiling TT-XLA, the TT-MLIR toolchain needs to be built:

  • After building the toolchain, set the following environment variables:

VariableRequiredDescription
TTMLIR_TOOLCHAIN_DIRYesPath to TT-MLIR toolchain (e.g., /opt/ttmlir-toolchain/)
TTXLA_LOGGER_LEVELNoSet to DEBUG or VERBOSE for detailed logs

Step 3: Building TT-XLA

Make sure you are not in the TT-MLIR build directory, and you are in the location where you want TT-XLA to install.

  1. Clone TT-XLA:

    git clone https://github.com/tenstorrent/tt-xla.git
    
  2. Navigate into the TT-XLA folder:

    cd tt-xla
    
  3. Initialize third-party submodules:

    git submodule update --init --recursive
    
  4. Run the following set of commands to build TT-XLA (this will build the PJRT plugin and install it into venv):

    source venv/activate
    cmake -G Ninja -B build # -DCMAKE_BUILD_TYPE=Debug in case you want debug build
    cmake --build build
    
  5. To verify that everything is working correctly, run the following command:

    python -c "import jax; print(jax.devices('tt'))"
    

    The command should output all available TT devices, e.g. [TTDevice(id=0, arch=Wormhole_b0)]

  6. (optional) If you want to build the TT-XLA wheel, run the following command:

    cd python_package
    python setup.py bdist_wheel
    

    The above command outputs a python_package/dist/pjrt_plugin_tt*.whl file which is self-contained. To install the created wheel, run:

    pip install dist/pjrt_plugin_tt*.whl
    

    The wheel has the following structure:

    pjrt_plugin_tt/                     # PJRT plugin package
       |-- __init__.py
       |-- pjrt_plugin_tt.so               # PJRT plugin binary
       |-- tt-metal/                       # tt-metal runtime dependencies (kernels, riscv compiler/linker, etc.)
       `-- lib/                            # shared library dependencies (tt-mlir, tt-metal)
    jax_plugin_tt/                      # Thin JAX wrapper
       `-- __init__.py                     # imports and sets up pjrt_plugin_tt for XLA
    torch_plugin_tt                     # Thin PyTorch/XLA wrapper
       `-- __init__.py                     # imports and sets up pjrt_plugin_tt for PyTorch/XLA
    

    It contains a custom Tenstorrent PJRT plugin (pjrt_plugin_tt.so) and its dependencies (tt-mlir and tt-metal). Additionally, there are thin wrappers for JAX (jax_plugin_tt) and PyTorch/XLA (torch_plugin_tt) that import the PJRT plugin and set it up for use with the respective frameworks.

Testing

The TT-XLA repo contains various tests in the tests directory. To run an individual test, pytest -svv is recommended in order to capture all potential error messages down the line. Multi-chip tests can be run only on specific Tenstorrent hardware, therefore these tests are structured in folders named by the Tenstorrent cards/systems they can be run on. For example, you can run pytest -v tests/jax/multi_chip/n300 only on a system with an n300 Tenstorrent card. Single-chip tests can be run on any system with the command pytest -v tests/jax/single_chip.

Common Build Errors

  • Building TT-XLA requires clang-20. Please make sure that clang-20 is installed on the system and clang/clang++ links to the correct version of the respective tools.
  • Please also see the TT-MLIR docs for common build errors.

Pre-commit

Pre-commit applies a git hook to the local repository such that linting is checked and applied on every git commit action. Install it from the root of the repository using:

source venv/activate
pre-commit install

If you have already committed something locally before installing the pre-commit hooks, you can run this command to check all files:

pre-commit run --all-files

For more information please visit pre-commit.

Where to Go Next