Getting Started
This document walks you through how to set up TT-XLA. TT-XLA is a front end for TT-Forge that ingests JAX models via jit compile and PyTorch models through torch-xla, providing StableHLO (SHLO) graphs to the TT-MLIR compiler. TT-XLA leverages PJRT to integrate JAX, TT-MLIR and Tenstorrent hardware. Please see this blog post for more information about the PJRT project.
NOTE: If you encounter issues, please request assistance on the TT-XLA Issues page.
Prerequisites
1. Set Up the Hardware
- Follow the instructions for the Tenstorrent device you are using at: Hardware Setup
2. Install Software (choose one)
-
Option 1: Quick path: Use TT-Installer using: Software Installation
-
Option 2: Manual path: For more control, follow the manual software dependencies installation guide.
TT-XLA Installation Options
-
Option 1: Installing a Wheel and Running an Example
You should choose this option if you want to run models.
-
Option 2: Using a Docker Container to Run an Example
Choose this option if you want to keep the environment for running models separate from your existing environment.
-
Option 3: Building from Source
This option is best if you want to develop TT-XLA further. It's a more complex process you are unlikely to need if you want to stick with running a model.
Installing a Wheel and Running an Example
To install a wheel and run an example model, do the following:
Step 1. Install the Latest Wheel:
Download the latest wheel with pip.
pip install pjrt-plugin-tt --extra-index-url https://pypi.eng.aws.tenstorrent.com/
Run the tt-forge-install script to install missing system dependencies.
tt-forge-install
Step 2. Run some models:
Use wget to fetch each demo script into your current directory, and install packages with pip when noted.
MNIST (small CNN)
The mnist.py example runs a simple CNN on a Tenstorrent device and compares the output against a CPU reference.
wget https://raw.githubusercontent.com/tenstorrent/tt-xla/main/examples/pytorch/mnist.py
python mnist.py
You should see the model output and a PCC (Pearson Correlation Coefficient) check confirming the TT device output matches the CPU reference.
Tiny Llama (Hugging Face transformers)
The tiny_llama_demo.py example in the TT-Forge repo loads a small LLM from Hugging Face, compiles it with torch.compile(..., backend="tt"), and prints top-token predictions. You must download the script and install transformers (and its dependencies); the wheel install in Step 1 does not include them. The first run also downloads model weights from Hugging Face over the network.
wget https://raw.githubusercontent.com/tenstorrent/tt-forge/main/demos/tt-xla/nlp/pytorch/tiny_llama_demo.py
pip install transformers
python tiny_llama_demo.py
You should see the prompt "The capital of France is", the predicted next token, the probability it will occur, and a list of other ranked options that could follow instead, for example:
Prompt: `The capital of France is`
Top prediction: `Paris`
Rank Token Probability
-----------------------------------
1 'Paris' 36.9141%
2 'located' 10.0098%
3 'the' 8.8867%
4 'a' 4.2480%
5 'in' 2.5391%
Using a Docker Container to Run an Example
This section walks through the installation steps for using a Docker container for your project.
- Prerequisite: Docker must be installed. See the official Docker installation guide if needed.
Step 1. Run the Docker container:
docker run -it --rm \
--device /dev/tenstorrent \
-v /dev/hugepages-1G:/dev/hugepages-1G \
ghcr.io/tenstorrent/tt-xla-slim:latest
NOTE: You cannot isolate devices in containers. You must pass through all devices even if you are only using one. You can do this by passing
--device /dev/tenstorrent. Do not try to pass--device /dev/tenstorrent/1or similar, as this type of device-in-container isolation will result in fatal errors later on during execution.
-
If you want to check that it is running, open a new tab with the Same Command option and run the following:
docker ps
Step 2: Running Models in Docker
Use wget to fetch each demo script into your current directory, and install packages with pip when noted.
MNIST (small CNN)
The mnist.py example runs a simple CNN on a Tenstorrent device and compares the output against a CPU reference.
wget https://raw.githubusercontent.com/tenstorrent/tt-xla/main/examples/pytorch/mnist.py
python mnist.py
You should see the model output and a PCC (Pearson Correlation Coefficient) check confirming the TT device output matches the CPU reference.
Tiny Llama (Hugging Face transformers)
The tiny_llama_demo.py example in the TT-Forge repo loads a small LLM from Hugging Face, compiles it with torch.compile(..., backend="tt"), and prints top-token predictions. You must download the script and install transformers (and its dependencies); the slim image does not include them. The first run also downloads model weights from Hugging Face over the network.
wget https://raw.githubusercontent.com/tenstorrent/tt-forge/main/demos/tt-xla/nlp/pytorch/tiny_llama_demo.py
pip install transformers
python tiny_llama_demo.py
You should see the prompt "The capital of France is", the predicted next token, the probability it will occur, and a list of other ranked options that could follow instead, for example:
Prompt: `The capital of France is`
Top prediction: `Paris`
Rank Token Probability
-----------------------------------
1 'Paris' 36.9141%
2 'located' 10.0098%
3 'the' 8.8867%
4 'a' 4.2480%
5 'in' 2.5391%
Building from Source
Install from source if you are a developer who wants to develop for TT-XLA.
Step 1: Prerequisites
-
TT-XLA has the following system dependencies:
- Ubuntu 24.04
- Python 3.12
- python3.12-venv
- Clang 20
- GCC 13
- Ninja
- CMake 4.0.3
-
TT-XLA additionally requires the following libraries:
sudo apt install protobuf-compiler libprotobuf-dev sudo apt install ccache sudo apt install libnuma-dev sudo apt install libhwloc-dev sudo apt install libboost-all-dev sudo apt install libnsl-dev
Step 2: Building the TT-MLIR Toolchain
-
Before compiling TT-XLA, the TT-MLIR toolchain needs to be built:
- Clone the tt-mlir repo.
- Follow the TT-MLIR build instructions to set up the environment and build the toolchain.
-
After building the toolchain, set the following environment variables:
| Variable | Required | Description |
|---|---|---|
TTMLIR_TOOLCHAIN_DIR | Yes | Path to TT-MLIR toolchain (e.g., /opt/ttmlir-toolchain/) |
TTXLA_LOGGER_LEVEL | No | Set to DEBUG or VERBOSE for detailed logs |
Step 3: Building TT-XLA
Make sure you are not in the TT-MLIR build directory, and you are in the location where you want TT-XLA to install.
-
Clone TT-XLA:
git clone https://github.com/tenstorrent/tt-xla.git -
Navigate into the TT-XLA folder:
cd tt-xla -
Initialize third-party submodules:
git submodule update --init --recursive -
Run the following set of commands to build TT-XLA (this will build the PJRT plugin and install it into
venv):source venv/activate cmake -G Ninja -B build # -DCMAKE_BUILD_TYPE=Debug in case you want debug build cmake --build build -
To verify that everything is working correctly, run the following command:
python -c "import jax; print(jax.devices('tt'))"The command should output all available TT devices, e.g.
[TTDevice(id=0, arch=Wormhole_b0)] -
(optional) If you want to build the TT-XLA wheel, run the following command:
cd python_package python setup.py bdist_wheelThe above command outputs a
python_package/dist/pjrt_plugin_tt*.whlfile which is self-contained. To install the created wheel, run:pip install dist/pjrt_plugin_tt*.whlThe wheel has the following structure:
pjrt_plugin_tt/ # PJRT plugin package |-- __init__.py |-- pjrt_plugin_tt.so # PJRT plugin binary |-- tt-metal/ # tt-metal runtime dependencies (kernels, riscv compiler/linker, etc.) `-- lib/ # shared library dependencies (tt-mlir, tt-metal) jax_plugin_tt/ # Thin JAX wrapper `-- __init__.py # imports and sets up pjrt_plugin_tt for XLA torch_plugin_tt # Thin PyTorch/XLA wrapper `-- __init__.py # imports and sets up pjrt_plugin_tt for PyTorch/XLAIt contains a custom Tenstorrent PJRT plugin (
pjrt_plugin_tt.so) and its dependencies (tt-mlirandtt-metal). Additionally, there are thin wrappers for JAX (jax_plugin_tt) and PyTorch/XLA (torch_plugin_tt) that import the PJRT plugin and set it up for use with the respective frameworks.
Testing
The TT-XLA repo contains various tests in the tests directory. To run an individual test, pytest -svv is recommended in order to capture all potential error messages down the line. Multi-chip tests can be run only on specific Tenstorrent hardware, therefore these tests are structured in folders named by the Tenstorrent cards/systems they can be run on. For example, you can run pytest -v tests/jax/multi_chip/n300 only on a system with an n300 Tenstorrent card. Single-chip tests can be run on any system with the command pytest -v tests/jax/single_chip.
Common Build Errors
- Building TT-XLA requires
clang-20. Please make sure thatclang-20is installed on the system andclang/clang++links to the correct version of the respective tools. - Please also see the TT-MLIR docs for common build errors.
Pre-commit
Pre-commit applies a git hook to the local repository such that linting is checked and applied on every git commit action. Install it from the root of the repository using:
source venv/activate
pre-commit install
If you have already committed something locally before installing the pre-commit hooks, you can run this command to check all files:
pre-commit run --all-files
For more information please visit pre-commit.
Where to Go Next
- Try more examples in the TT-XLA examples directory
- Learn about Improving Model Performance
- Explore Code Generation to convert models into standalone code