Getting Started
This document walks you through how to set up TT-XLA. TT-XLA is a front end for TT-Forge that is primarily used to ingest JAX models via jit compile, providing a StableHLO (SHLO) graph to the TT-MLIR compiler. TT-XLA leverages PJRT to integrate JAX, TT-MLIR and Tenstorrent hardware. Please see this blog post for more information about PJRT project. This project started as a fork of iree-pjrt, but has since been refactored and diverged.
NOTE: Currently, only Tenstorrent Nebula boards are supported.
This is the main Getting Started page. There are two additional Getting Started pages depending on what you want to do. They are all described here, with links provided to each.
The following topics are covered:
- Setup Options
- Configuring Hardware
- Installing a Wheel and Running an Example
- Other Setup Options
- Where to Go Next
NOTE: If you encounter issues, please request assistance on the TT-XLA Issues page.
Setup Options
TT-XLA can be used to run JAX models on Tenstorrent's AI hardware. Because TT-XLA is open source, you can also develop and add features to it. Setup instructions differ based on the task. You have the following options, listed in order of difficulty:
- Installing a Wheel and Running an Example - You should choose this option if you want to run models.
- Using a Docker Container to Run an Example - Choose this option if you want to keep the environment for running models separate from your existing environment.
- Building from Source - This option is best if you want to develop TT-XLA further. It's a more complex process you are unlikely to need if you want to stick with running a model.
Configuring Hardware
Before setup can happen, you must configure your hardware. You can skip this section if you already completed the configuration steps. Otherwise, this section of the walkthrough shows you how to do a quick setup using TT-Installer.
-
Configure your hardware with TT-Installer using the Software Installation section here.
-
Reboot your machine.
-
Make sure hugepages is enabled:
sudo systemctl enable --now 'dev-hugepages\x2d1G.mount'
sudo systemctl enable --now tenstorrent-hugepages.service
-
Please ensure that after you run the TT-Installer script, after you complete reboot and set up hugepages, you activate the virtual environment it sets up -
source ~/.tenstorrent-venv/bin/activate. -
After your environment is running, to check that everything is configured, type the following:
tt-smi
You should see the Tenstorrent System Management Interface. It allows you to view real-time stats, diagnostics, and health info about your Tenstorrent device.

Installing a Wheel and Running an Example
To install a wheel and run an example model, do the following:
- Make sure you are in an active virtual environment. This walkthrough uses the same environment you activated to look at TT-SMI in the Configuring Hardware section. If you are using multiple TT-Forge front ends to run models, you may want to set up a separate virtual environment instead. For example:
python3 -m venv .xla-venv
source .xla-venv/bin/activate
- Install the wheel in your active virtual environment:
pip install pjrt-plugin-tt --extra-index-url https://pypi.eng.aws.tenstorrent.com/
NOTE: You can pull pre-releases (these may not be stable, so proceed with caution) by adding the
--preflag directly afterpip install. You can also choose a wheel from the nightly release page.
-
You are now ready to try running a model. Navigate to the section of the TT-Forge repo that contains TT-XLA demos.
-
For this walkthrough, the demo in the gpt2 folder is used. In the gpt2 folder, in the requirements.txt file, you can see that flax and transformers are necessary to run the demo. Install them:
pip install flax transformers
-
Download the demo.py file from the gpt2 folder inside your activated virtual environment in a place where you can run it. The demo you are about to run takes a piece of text and tries to predict the next word that logically follows.
-
Run the model:
python demo.py
- If all goes well you should see the prompt "The capital of France is", the predicted next token, the probability it will occur, and a list of other ranked options that could follow instead.
Other Setup Options
If you want to keep your environment completely separate in a Docker container, or you want to develop TT-XLA further, this section links you to the pages with those options:
Where to Go Next
Now that you have set up the TT-XLA wheel, you can compile and run other demos. See the TT-XLA folder in the TT-Forge repo for other demos you can try.
Getting Started with Docker
This document walks you through how to set up TT-XLA using a Docker image. There are two other available options for getting started:
- Installing a Wheel - if you do not want to use Docker, and prefer to use a virtual environment by itself instead, use this method.
- Building from Source - if you plan to develop TT-XLA further, you must build from source, and should use this method.
The following topics are covered:
Configuring Hardware
Before setup can happen, you must configure your hardware. You can skip this section if you already completed the configuration steps. Otherwise, follow the instructions on the Getting Started page.
Setting up the Docker Container
This section walks through the installation steps for using a Docker container for your project.
To install, do the following:
- Install Docker if you do not already have it:
sudo apt update
sudo apt install docker.io -y
sudo systemctl start docker
sudo systemctl enable docker
- Test that Docker is installed:
docker --version
- Add your user to the Docker group:
sudo usermod -aG docker $USER
newgrp docker
- Run the Docker container:
docker run -it --rm \
--device /dev/tenstorrent \
-v /dev/hugepages-1G:/dev/hugepages-1G \
ghcr.io/tenstorrent/tt-xla-slim:latest
NOTE: You cannot isolate devices in containers. You must pass through all devices even if you are only using one. You can do this by passing
--device /dev/tenstorrent. Do not try to pass--device /dev/tenstorrent/1or similar, as this type of device-in-container isolation will result in fatal errors later on during execution.
- If you want to check that it is running, open a new tab with the Same Command option and run the following:
docker ps
Running Models in Docker
This section shows you how to run a model using Docker. The provided example is from the TT-Forge repo. Do the following:
- Inside your running Docker container, clone the TT-Forge repo:
git clone https://github.com/tenstorrent/tt-forge.git
- Set the path for Python:
export PYTHONPATH=/tt-forge:$PYTHONPATH
- Navigate into TT-Forge and run the following command:
git submodule update --init --recursive
-
Navigate back out of the TT-Forge directory.
-
Run a model. For this example, the demo.py for opt_125m is used. Similar to gpt2, this model predicts what the next word in a sentence is likely to be. The requirements.txt file shows that you need to install flax and transformers:
pip install flax transformers
- After completing installation, run the following:
python tt-forge/demos/tt-xla/opt_125m/demo.py
If all goes well, you should get an example prompt saying 'The capital of France is.' The prediction for the next term is listed, along with the probability it will occur. This is followed by a table of other likely choices.
Where to Go Next
Now that you have set up TT-XLA, you can compile and run your own models, or try some of the other demos. You can find TT-XLA demos in the TT-Forge directory.
Getting Started with Building from Source
This document describes how to build the TT-XLA project on your local machine. You must build from source if you want to develop for TT-XLA. If you only want to run models, please choose one of the following sets of instructions instead:
- Installing a Wheel and Running an Example - You should choose this option if you want to run models.
- Using a Docker Container to Run an Example - Choose this option if you want to keep the environment for running models separate from your existing environment.
The following topics are covered:
NOTE: If you encounter issues, please request assistance on the TT-XLA Issues page.
Configuring Hardware
Before setup can happen, you must configure your hardware. You can skip this section if you already completed the configuration steps. Otherwise, follow the instructions on the Getting Started page.
System Dependencies
TT-XLA has the following system dependencies:
- Ubuntu 22.04
- Python 3.11
- python3.11-venv
- Clang 17
- GCC 11
- Ninja
- CMake 4.0.3
Installing Python
If your system already has Python installed, make sure it is Python 3.11:
python3 --version
If not, install Python:
sudo apt install python3.11
Installing CMake 4.0.3
To install CMake 4 or higher, do the following:
- Install CMake 4.0.3:
pip install cmake==4.0.3
- Check that the correct version of CMake is installed:
cmake --version
If you see cmake version 4.0.3 you are ready for the next section.
Installing Clang 17
To install Clang 17, do the following:
- Install Clang 17:
wget https://apt.llvm.org/llvm.sh
chmod u+x llvm.sh
sudo ./llvm.sh 17
sudo apt install -y libc++-17-dev libc++abi-17-dev
sudo ln -s /usr/bin/clang-17 /usr/bin/clang
sudo ln -s /usr/bin/clang++-17 /usr/bin/clang++
- Check that the selected GCC candidate using Clang 17 is using 11:
clang -v
- Look for the line that starts with:
Selected GCC installation:. If it is something other than GCC 11, and you do not see GCC 11 listed as an option, please install GCC 11 using:
sudo apt-get install gcc-11 lib32stdc++-11-dev lib32gcc-11-dev
- If you see GCC 12 listed as installed and listed as the default choice, uninstall it with:
sudo rm -rf /usr/bin/../lib/gcc/x86_64-linux-gnu/12
Installing Ninja
To install Ninja, do the following:
sudo apt install ninja-build
Installing OpenMPI
To install OpenMPI, do the following:
sudo wget -q https://github.com/dmakoviichuk-tt/mpi-ulfm/releases/download/v5.0.7-ulfm/openmpi-ulfm_5.0.7-1_amd64.deb -O /tmp/openmpi-ulfm.deb && sudo apt install /tmp/openmpi-ulfm.deb
Installing Additional Dependencies
TT-XLA additionally requires the following libraries:
sudo apt install protobuf-compiler libprotobuf-dev
sudo apt install ccache
sudo apt install libnuma-dev
sudo apt install libhwloc-dev
sudo apt install libboost-all-dev
Build Process
TT-XLA integration with the TT-MLIR compiler is still in progress. Currently TT-XLA depends on the TT-MLIR toolchain to build from source. This build flow provides an easy way to experiment with TT-XLA, StableHLO, and the TT-MLIR infrastructure. The build process will be updated in the future to enhance the user experience.
Building the TT-MLIR Toolchain
Before compiling TT-XLA, the TT-MLIR toolchain needs to be built:
- Clone the tt-mlir repo.
- Follow the TT-MLIR build instructions to set up the environment and build the toolchain.
Building TT-XLA
Before running these commands to build TT-XLA, please ensure that the environment variable TTMLIR_TOOLCHAIN_DIR is set to point to the TT-MLIR toolchain directory created above as part of the TT-MLIR environment setup (for example export TTMLIR_TOOLCHAIN_DIR=/opt/ttmlir-toolchain/). You can also set export LOGGER_LEVEL=DEBUG in order to enable debug logs, or export LOGGER_LEVEL=VERBOSE to enable even more verbose logs like printing intermediate IR in compiler passes. To build TT-XLA do the following:
-
Make sure you are not in the TT-MLIR build directory, and you are in the location where you want TT-XLA to install.
-
Clone TT-XLA:
git clone https://github.com/tenstorrent/tt-xla.git
- Navigate into the TT-XLA folder:
cd tt-xla
- Initialize third-party submodules:
git submodule update --init --recursive
- Run the following set of commands to build TT-XLA (this will build the PJRT plugin and install it into
venv):
source venv/activate
cmake -G Ninja -B build # -DCMAKE_BUILD_TYPE=Debug in case you want debug build
cmake --build build
- To verify that everything is working correctly, run the following command:
python -c "import jax; print(jax.devices('tt'))"
The command should output all available TT devices, e.g. [TTDevice(id=0, arch=Wormhole_b0)]
- (optional) If you want to build the TT-XLA wheel, run the following command:
cd python_package
python setup.py bdist_wheel
The above command outputs a python_package/dist/pjrt_plugin_tt*.whl file which is self-contained. To install the created wheel, run:
pip install dist/pjrt_plugin_tt*.whl
The wheel has the following structure:
pjrt_plugin_tt/ # PJRT plugin package
|-- __init__.py
|-- pjrt_plugin_tt.so # PJRT plugin binary
|-- tt-metal/ # tt-metal runtime dependencies (kernels, riscv compiler/linker, etc.)
`-- lib/ # shared library dependencies (tt-mlir, tt-metal)
jax_plugin_tt/ # Thin JAX wrapper
`-- __init__.py # imports and sets up pjrt_plugin_tt for XLA
torch_plugin_tt # Thin PyTorch/XLA wrapper
`-- __init__.py # imports and sets up pjrt_plugin_tt for PyTorch/XLA
It contains a custom Tenstorrent PJRT plugin (pjrt_plugin_tt.so) and its dependencies (tt-mlir and tt-metal). Additionally, there are thin wrappers for JAX (jax_plugin_tt) and PyTorch/XLA (torch_plugin_tt) that import the PJRT plugin and set it up for use with the respective frameworks.
Testing
The TT-XLA repo contains various tests in the tests directory. To run an individual test, pytest -svv is recommended in order to capture all potential error messages down the line. Multi-chip tests can be run only on specific Tenstorrent hardware, therefore these tests are structured in folders named by the Tenstorrent cards/systems they can be run on. For example, you can run pytest -v tests/jax/multi_chip/n300 only on a system with an n300 Tenstorrent card. Single-chip tests can be run on any system with the command pytest -v tests/jax/single_chip.
Common Build Errors
- Building TT-XLA requires
clang-17. Please make sure thatclang-17is installed on the system andclang/clang++links to the correct version of the respective tools. - Please also see the TT-MLIR docs for common build errors.
Pre-commit
Pre-commit applies a git hook to the local repository such that linting is checked and applied on every git commit action. Install it from the root of the repository using:
source venv/activate
pre-commit install
If you have already committed something locally before installing the pre-commit hooks, you can run this command to check all files:
pre-commit run --all-files
For more information please visit pre-commit.
Test Infra
Test infra consists of main "tester" classes and a few helper ones. Its main goal is making test writing easy.
Here is a brief class diagram of the infra:

Op and Graph Tests
Op tester exposes easy to use functions:
run_op_test(...)
run_op_test_with_random_inputs(...)
They wrap the instantiation of the OpTester and all the underlying complexity. User just need to
pass the op (python function) they want to test to one of these functions like this:
def test_add(x_shape: tuple, y_shape: tuple):
def add(x: jax.Array, y: jax.Array) -> jax.Array:
return jnp.add(x, y)
run_op_test_with_random_inputs(add, [x_shape, y_shape])
and that's it.
GraphTester is at the moment identical to OpTester, and it too exposes
run_graph_test(...)
run_graph_test_with_random_inputs(...)
which are meant to be used in the same way as for op tests.
Model Tests
Models are tested by inheriting one of *ModelTester classes and overriding required methods.
Please read docstring of appropriate class you want to inherit for more information.
Jax Model Example
First, you define a model:
class MNISTMLPModel(nn.Module):
hidden_sizes: tuple[int]
@nn.compact
def __call__(self, x: jax.Array):
x = x.reshape((x.shape[0], -1))
for h in self.hidden_sizes:
x = nn.Dense(features=h)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.softmax(x)
return x
Then you define a tester by inheriting JaxModelTester:
class MNISTMLPTester(JaxModelTester):
def __init__(
self,
hidden_sizes: Sequence[int],
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._hidden_sizes = hidden_sizes
super().__init__(comparison_config, run_mode)
# @override
def _get_model(self) -> nn.Module:
return MNISTMLPModel(self._hidden_sizes)
# @override
def _get_forward_method_name(self) -> str:
return "apply"
# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
key = jax.random.PRNGKey(37)
img = jax.random.normal(key, (4, 28, 28, 1)) # B, H, W, C
# Channels is 1 as MNIST is in grayscale.
return img
# @override
def _get_forward_method_args(self):
inp = self._get_input_activations()
parameters = self._model.init(jax.random.PRNGKey(42), inp)
return [parameters, inp]
Finally, you run the test:
@pytest.fixture
def inference_tester(request) -> MNISTMLPTester:
return MNISTMLPTester(request.param)
@pytest.mark.parametrize(
"inference_tester", [(256, 128, 64)], indirect=True, ids=lambda val: f"{val}"
)
def test_mnist_mlp_inference(inference_tester: MNISTMLPTester):
inference_tester.test()
Model Auto-Discovery Tests
Overview
- What: A pytest-based runner that auto-discovers Torch models from
tt-forge-modelsand generates tests for inference and training across parallelism modes. - Why: Standardize model testing, reduce bespoke tests in repos, and scale coverage as models are added or updated.
- Scope: Discovers
loader.pyunder<model>/pytorch/inthird_party/tt_forge_models, queries variants, and runs each combination of:- Run mode:
inference,training - Parallelism:
single_device,data_parallel,tensor_parallel
- Run mode:
Note: Discovery currently targets PyTorch models only. JAX model auto-discovery is planned.
Prerequisites
- A working TT-XLA development environment, built and ready to run tests, with
pytestinstalled. third_party/tt_forge_modelsgit submodule initialized and up to date:
git submodule update --init --recursive third_party/tt_forge_models
- Device availability matching your chosen parallelism mode (e.g., multiple devices for data/tensor parallel).
- Optional internet access for per-model pip installs during test execution.
- env-var
IRD_LF_CACHEset to point to large file cache / webserver for s3 bucket mirror. Reach out to team for details.
Quick start / commonly used commands
Warning: Since the number of models and variants supported here is high (1000+), it is a good idea to run with --collect-only first to see what will be discovered/collected before running non-targeted pytest commands locally.
Also, running the full matrix can collect thousands of tests and may install per-model Python packages during execution. Prefer targeted runs locally using -m, -k, or an exact node id.
Tip: Use -q --collect-only to list tests with full path shown, remove --collect-only and use -vv when running.
- List all tests without running:
pytest --collect-only -q tests/runner/test_models.py |& tee collect.log
- List only tensor-parallel expected-passing on
n300-llmbox(remove--collect-onlyto run):
pytest --collect-only -q tests/runner/test_models.py -m "tensor_parallel and expected_passing and n300_llmbox" --arch n300-llmbox |& tee tests.log
- Run a specific collected test node id exactly:
pytest -vv tests/runner/test_models.py::test_all_models[llama/sequence_classification/pytorch-llama_3_2_1b-single_device-full-inference] |& tee test.log
- Validate test_config files for typos, model name changes, useful when making updates:
pytest -svv --validate-test-config tests/runner/test_models.py |& tee validate.log
- List all expected passing llama inference tests for n150 (using substring
-kand markers with-m):
pytest -q --collect-only -k "llama" tests/runner/test_models.py -m "n150 and expected_passing and inference" |& tee tests.log
tests/runner/test_models.py::test_all_models[deepcogito/pytorch-v1_preview_llama_3b-single_device-full-inference]
tests/runner/test_models.py::test_all_models[huggyllama/pytorch-llama_7b-single_device-full-inference]
tests/runner/test_models.py::test_all_models[llama/sequence_classification/pytorch-llama_3_8b_instruct-single_device-full-inference]
tests/runner/test_models.py::test_all_models[llama/sequence_classification/pytorch-llama_3_1_8b-single_device-full-inference]
tests/runner/test_models.py::test_all_models[llama/causal_lm/pytorch-llama_3_8b-single_device-full-inference]
tests/runner/test_models.py::test_all_models[llama/causal_lm/pytorch-llama_3_8b_instruct-single_device-full-inference]
tests/runner/test_models.py::test_all_models[llama/causal_lm/pytorch-llama_3_1_8b-single_device-full-inference]
<snip>
21/3048 tests collected (3027 deselected) in 3.53s
How discovery and parametrization work
- The runner scans
third_party/tt_forge_models/**/pytorch/loader.py(the git submodule) and importsModelLoaderto callquery_available_variants(). - For every discovered variant, the runner generates tests across run modes and parallelism.
- Implementation highlights:
- Discovery and IDs:
tests/runner/test_utils.py(setup_test_discovery,discover_loader_paths,create_test_entries,create_test_id_generator) - Main test:
tests/runner/test_models.py - Config loading/validation:
tests/runner/test_config/config_loader.py(merges YAML into Python with validation)
- Discovery and IDs:
Test IDs and filtering
- Test ID format:
<relative_model_path>-<variant_name>-<parallelism>-full-<run_mode> - Examples:
squeezebert/pytorch-squeezebert-mnli-single_device-full-inference...-data_parallel-full-training
- Filter by substring with
-kor by markers with-m:
pytest -q -k "qwen_2_5_vl/pytorch-3b_instruct" tests/runner/test_models.py
pytest -q -m "training and tensor_parallel" tests/runner/test_models.py
Take a look at model-test-passing.json and related .json files inside .github/workflows/test-matrix-presets for seeing how filtering works for CI jobs.
Parallelism modes
- single_device: Standard execution on one device.
- data_parallel: Inputs are automatically batched to
xr.global_runtime_device_count(); shard spec inferred on batch dim 0. - tensor_parallel: Mesh derived from
loader.get_mesh_config(num_devices); execution sharded by model dimension.
Per-model requirements
-
If a model provides
requirements.txtnext to itsloader.py, the runner will:- Freeze the current environment
- Install those requirements (and optional
requirements.nodeps.txtwith--no-deps) - Run tests
- Uninstall newly added packages and restore version changes
-
Environment toggles:
TT_XLA_DISABLE_MODEL_REQS=1to disable install/uninstall managementTT_XLA_REQS_DEBUG=1to print pip operations for debugging
Test configuration and statuses
- Central configuration is authored as YAML in
tests/runner/test_config/*and loaded/validated bytests/runner/test_config/config_loader.py(merged into Python at runtime). - Example:
tests/runner/test_config/test_config_inference_single_device.yamlfor all single device inference test tagging, andtests/runner/test_config/test_config_inference_data_parallel.yamlfor data parallel inference test tagging. - Each entry is keyed by the collected test ID and can specify:
- Status:
EXPECTED_PASSING,KNOWN_FAILURE_XFAIL,NOT_SUPPORTED_SKIP,UNSPECIFIED,EXCLUDE_MODEL - Comparators:
required_pcc,assert_pcc,assert_allclose,allclose_rtol,allclose_atol - Metadata:
bringup_status,reason, custommarkers(e.g.,push,nightly) - Architecture scoping:
supported_archsused for filtering by CI job and optionalarch_overridesused if test_config entries need to be modified based on arch.
- Status:
YAML to Python loading and validation
- The YAML files in
tests/runner/test_config/*are the single source of truth. At runtime,tests/runner/test_config/config_loader.py:- Loads and merges all YAML fragments into a single Python dictionary keyed by collected test IDs
- Normalizes enum-like values (accepts both names like
EXPECTED_PASSINGand values likeexpected_passing) - Applies
--arch <archname>-specificarch_overrideswhen provided - Validates field names/types and raises helpful errors on typos or invalid values
- Uses
ruamel.yamlfor parsing, which will flag duplicate mapping keys and detect duplicate test entries both within a single YAML file and across multiple YAML files. Duplicates cause validation errors with clear messages.
Model status and bringup_status guidance
Use tests/runner/test_config/* to declare intent for each collected test ID. Typical fields:
-
status(fromModelTestStatus) controls filtering of tests in CI:EXPECTED_PASSING: Test is green and should run in Nightly CI. Optionally set thresholds.KNOWN_FAILURE_XFAIL: Known failure that should xfail; includereasonandbringup_statusto set them statically otherwise will attempt to be set dynamically at runtime.NOT_SUPPORTED_SKIP: Skip on this architecture or generally unsupported; providereasonand (optionally)bringup_status.UNSPECIFIED: Default for new tests; runs in Experimental Nightly until triaged.EXCLUDE_MODEL: Deselect from auto-run entirely (rare; use for temporary exclusions).
-
bringup_status(fromBringupStatus) summarizes current health for Superset dashboard reporting:PASSED(set automatically on pass),INCORRECT_RESULT(e.g., PCC mismatch),FAILED_FE_COMPILATION(frontend compile error),FAILED_TTMLIR_COMPILATION(tt-mlir compile error),FAILED_RUNTIME(runtime crash),NOT_STARTED,UNKNOWN.
-
reason: Short human-readable context, ideally with a link to a tracking issue. -
Comparator controls: prefer
required_pcc; useassert_pcc=Falsesparingly as a temporary measure.
Examples
- Passing with a tuned PCC threshold if reasonable / understood decrease:
"resnet/pytorch-resnet_50_hf-single_device-full-inference": {
"status": ModelTestStatus.EXPECTED_PASSING,
"required_pcc": 0.98,
}
- Known compile failure (xfail) with issue link:
"clip/pytorch-openai/clip-vit-base-patch32-single_device-full-inference": {
"status": ModelTestStatus.KNOWN_FAILURE_XFAIL,
"bringup_status": BringupStatus.FAILED_TTMLIR_COMPILATION,
"reason": "Error Message - Github issue link",
}
- If minor unexpected PCC mismatch, open ticket, decrease threshold and set bringup_status/reason as:
"wide_resnet/pytorch-wide_resnet101_2-single_device-full-inference": {
"status": ModelTestStatus.EXPECTED_PASSING,
"required_pcc": 0.96,
"bringup_status": BringupStatus.INCORRECT_RESULT,
"reason": "PCC regression after consteval changes - Github Issue Link",
}
- If severe unexpected PCC mismatch, open ticket, disable pcc check and set bringup_status/reason as:
"gpt_neo/causal_lm/pytorch-gpt_neo_2_7B-single_device-full-inference": {
"status": ModelTestStatus.EXPECTED_PASSING,
"assert_pcc": False,
"bringup_status": BringupStatus.INCORRECT_RESULT,
"reason": "AssertionError: PCC comparison failed. Calculated: pcc=-1.0000001192092896. Required: pcc=0.99 - Github Issue Link",
}
- Architecture-specific overrides (e.g., pcc thresholds, status, etc):
"qwen_3/embedding/pytorch-embedding_8b-single_device-full-inference": {
"status": ModelTestStatus.EXPECTED_PASSING,
"arch_overrides": {
"n150": {
"status": ModelTestStatus.NOT_SUPPORTED_SKIP,
"reason": "Too large for single chip",
"bringup_status": BringupStatus.FAILED_RUNTIME,
},
},
},
Targeting architectures
- Use
--arch {n150,p150,n300,n300-llmbox}on pytest command line to enablearch_overridesresolution in config in case there are specific overrides (like PCC requirements, checking enablement, tagging) per arch. - Tests are also marked with supported arch markers (or defaults), so you can select subsets using
-m, example:
pytest -q -m n300 --arch n300 tests/runner/test_models.py
pytest -q -m n300_llmbox --arch n300-llmbox tests/runner/test_models.py
Placeholder models (report-only)
- Placeholder models are declared in YAML at
tests/runner/test_config/test_config_placeholders.yamland list important customerModelGroup.REDmodels not yet merged, typically marked withBringupStatus.NOT_STARTED. These entries are loaded using the same config loader as other YAML files. tests/runner/test_models.py::test_placeholder_modelsemits report entries with theplaceholdermarker; used for reporting on Superset dashboard and run in tt-xla Nightly CI (typically viamodel-test-xfail.json).- Be sure to remove the placeholder at the same time the real model is added to avoid duplicate reports.
CI setup
- Push/PR: A small, fast subset runs on each pull request (e.g., tests marked
push). This provides quick signal without large queues. - Nightly: The broad model matrix (inference/training across supported parallelism) runs nightly and reports to the Superset dashboard. Tests are selected via markers and
tests/runner/test_config/*statuses/arch tags likeModelTestStatus.EXPECTED_PASSING - Experimental nightly: New or experimental models not yet promoted/tagged in
tests/runner/test_config/*(typicallyunspecified) run separately. These do not report to Superset until promoted with proper status/markers.
Adding a new model to run in Nightly CI
It is not difficult, but involves potentially 2 projects (tt-xla and tt-forge-models). If model is already added to tt-forge-models and uplifted to tt-xla then skip steps 1-4.
- In
tt-forge-models/<model>/pytorch/loader.py, implement aModelLoaderif doesn't already exist, exposing:query_available_variants()andget_model_info(variant=...)load_model(...)andload_inputs(...)load_shard_spec(...)(if needed) andget_mesh_config(num_devices)(for tensor parallel)
- Optionally add
requirements.txt(andrequirements.nodeps.txt) next toloader.pyfor per-model dependencies. - Contribute the model upstream: open a PR in the
tt-forge-modelsrepository and land it (seett-forge-modelsrepo: https://github.com/tenstorrent/tt-forge-models). - Uplift
third_party/tt_forge_modelssubmodule intt-xlato the merged commit so the loader is discoverable:- Update the submodule and commit the pointer:
git submodule update --remote third_party/tt_forge_models
git add third_party/tt_forge_models
git commit -m "Uplift tt-forge-models submodule to <version> to include <model>"
- Verify the test appears via
--collect-onlyand run desired flavor locally if needed. - Add or update the corresponding entry in
tests/runner/test_config/*to set status/thresholds/markers/arch support so that the model test is run in tt-xla Nightly CI. Look at existing tests for reference. - Remove any corresponding placeholder entry from
PLACEHOLDER_MODELSintest_config_placeholders.yamlif it exists. - Locally run
pytest -q --validate-test-config tests/runner/test_models.pyto validatetests/runner/test_config/*updates (on-PR jobs run it too). - Open a PR in
tt-xlafor changes, consider running full set of expected passing models on CI to qualifytt_forge_modelsuplift (if it is risky), and land the PR intt-xlamain when confident in changes.
Troubleshooting
- Discovery/import errors show as:
Cannot import path: <loader.py>: <error>; add per-model requirements or setTT_XLA_DISABLE_MODEL_REQS=1to isolate issues. - Runtime/compilation failures are recorded with a bring-up status and reason in test properties; check the test report’s
tagsanderror_message. - Some models may be temporarily excluded from discovery; see logs printed during collection.
- Use
-vvand--collect-onlyfor detailed collection/ID debugging.
Future enhancements
- Expand auto-discovery beyond PyTorch to include JAX models
- Automate updates of
tests/runner/test_config/*potentially based on results of Nightly CI, automatic promotion of tests from Experimental Nightly to stable Nightly. - Broader usability improvements and workflow polish tracked in issue #1307
Reference
tests/runner/test_models.py: main parametrized pytest runnertests/runner/test_utils.py: discovery, IDs,DynamicTorchModelTestertests/runner/requirements.py: per-model requirements context managertests/runner/conftest.py: config attachment, markers,--arch, config validationtests/runner/test_config/*.yaml: YAML test config files (source of truth)tests/runner/test_config/config_loader.py: loads/merges/validates YAML into Python at runtimethird_party/tt_forge_models/config.py:Parallelismand model metadata
Code Generation Guide
Convert JAX or PyTorch models into standalone Python or C++ source code targeting Tenstorrent hardware.
Quick Reference
| Framework | Backend Option | Output | Standalone? |
|---|---|---|---|
| PyTorch/JAX | codegen_py | Python (.py) | No (requires TT-XLA build) |
| PyTorch/JAX | codegen_cpp | C++ (.cpp, .h) | Yes |
New to code generation? Start with the Python Code Generation Tutorial for a hands-on walkthrough.
Overview
Code generation (powered by TT-Alchemist) transforms your model into human-readable source code that directly calls the TT-NN library, enabling:
- Customization - Modify generated code to add optimizations or integrate with existing infrastructure
- Model Portability - Extract models into standalone code deployable without the full framework stack
- Inspection & Debugging - Examine generated source to understand exact operations performed
- Education - Study how high-level framework operations translate to TT-NN library calls
Technical Note: Internally referred to as TT-Alchemist, EmitPy (Python generation), or EmitC (C++ generation).
Basic Usage
PyTorch
Configure code generation options before compiling your model:
import torch
import torch_xla.core.xla_model as xm
# Configure code generation
options = {
"backend": "codegen_py", # Or "codegen_cpp" for C++
"export_path": "torch_codegen_output" # Output directory
}
torch_xla.set_custom_compile_options(options)
# Standard PyTorch workflow
device = xm.xla_device()
model = YourModel()
model.compile(backend="tt")
model = model.to(device)
x = torch.randn(32, 32).to(device)
# Trigger code generation
output = model(x)
Output location: torch_codegen_output/ directory containing:
ttir.mlir- TTIR intermediate representation*.pyor*.cpp/*.h- Generated source files
JAX
Pass compile options directly to jax.jit():
import jax
from flax import nnx
def forward(graphdef, state, x):
model = nnx.merge(graphdef, state)
return model(x)
# JIT compile with code generation
jitted_forward = jax.jit(
forward,
compiler_options={
"backend": "codegen_py", # Or "codegen_cpp" for C++
"export_path": "jax_codegen_output"
}
)
# Trigger code generation
result = jitted_forward(graphdef, state, x)
Output location: jax_codegen_output/ directory containing:
ttir.mlir- TTIR intermediate representation*.pyor*.cpp/*.h- Generated source files
Configuration Options
Codegen Options
| Option | Type | Description |
|---|---|---|
backend | string | Code generation target: • "codegen_py" - Generate Python code• "codegen_cpp" - Generate C++ code |
export_path | string | Directory for generated code (created if doesn't exist) |
dump_inputs | bool | Whether to dump model input and parameter tensors to disk (False by default) |
Example Configurations
Python Generation:
options = {
"backend": "codegen_py",
"export_path": "./generated_python"
"dump_inputs": True
}
C++ Generation:
options = {
"backend": "codegen_cpp",
"export_path": "./generated_cpp"
#dump_inputs -> default False
}
Generated Output
Directory Structure
After code generation completes, your export_path directory contains:
<export_path>/
├── ttir.mlir # TTIR intermediate representation (debugging)
├── main.py/cpp # Generated Python/C++ code
├── run # Execution script
└── input_tensors/ # Directory with dumped tensors if specified by dump_inputs option
File Descriptions
ttir.mlir - Tenstorrent Intermediate Representation
- High-level representation after initial compilation
- Useful for debugging compilation issues
- Human-readable MLIR dialect
Generated Python (*.py) - Python Implementation
- Direct TT-NN API calls
- Human-readable and modifiable
- Not standalone - requires TT-XLA build to execute
- Includes
runscript for execution
Generated C++ (*.cpp, *.h) - C++ Implementation
- Direct TT-NN API calls
- Human-readable and modifiable
- Fully standalone - only requires TT-NN library
- Can be integrated into existing C++ projects
input_tensors/ - Serialized model inputs and parameters (created when dump_inputs: True)
- Used by the generated code to load real model inputs and weights instead of random values
Code Generation Behavior
Expected Process Flow
- ✅ Model compiles through TT-XLA pipeline
- ✅ Code generation writes files to
export_path - ⚠️ Process may terminate with error (expected behavior)
Process Termination
Important: The process terminating after code generation is expected behavior when running through the frontend.
You'll see this error message:
Standalone solution was successfully generated. Executing codegen through the frontend is not supported yet.
Unfortunately your program will now crash :(
ERROR:root:Caught an exception when exiting the process.
RuntimeError: Bad StatusOr access: UNIMPLEMENTED: Error code: 12
This is normal! Your code was generated successfully. The error simply indicates that continuing execution through the frontend isn't currently supported.
Verifying Success
Check that code generation succeeded:
ls -la <export_path>/
You should see:
- ✅
ttir.mlirfile exists - ✅ Generated source files (
.pyor.cpp/.h) - ✅ File sizes are non-zero
- ✅ (Python only) Executable
runscript exists
Use Cases
1. Custom Optimization
Scenario: Hand-tune generated code for specific workloads
Benefits:
- Modify operation ordering
- Adjust memory layouts
- Add custom kernel calls
Best for: Performance-critical applications, specialized hardware configurations
2. Model Deployment & Portability
Scenario: Deploy models without the full JAX/PyTorch stack
Benefits:
- Smaller deployment footprint
- Fewer dependencies
- Direct control over execution
Best for: Production environments, edge devices, embedded systems
3. Model Inspection & Debugging
Scenario: Understand what operations your model performs
Benefits:
- Examine exact TT-NN API calls
- Identify performance bottlenecks
- Understand memory access patterns
Best for: Performance optimization, debugging accuracy issues
4. Educational & Research
Scenario: Learn how high-level operations translate to hardware
Benefits:
- Study framework→hardware mapping
- Experiment with low-level optimizations
- Understand compiler transformations
Best for: Learning, research, optimization experiments
Advanced Topics
Alternative: Code Generation via Serialization
Note: Most users should use compile options (documented above). This method is provided for advanced use cases.
You can also invoke code generation by hooking into the serialization infrastructure and running TT-Alchemist directly on the results.
When to use this:
- Custom compilation workflows
- Integration with existing build systems
- Automated pipeline generation
Examples:
- PyTorch: [
examples/pytorch/custom_module.py](../../examples/pytorch/codegen/custom_module.py - JAX:
examples/jax/custom_module.py
Related Documentation
- Python Code Generation Tutorial - Step-by-step hands-on tutorial
- Getting Started Guide - Main TT-XLA setup
- Building from Source - Development setup
Next Steps
- Try the tutorial: Follow the Python Code Generation Tutorial for hands-on experience
- Experiment: Try both
codegen_pyandcodegen_cppbackends - Inspect code: Examine generated code to understand TT-NN API usage
- Customize: Modify generated code to optimize for your use case
- Deploy: Integrate generated C++ code into your applications
Questions or issues? Visit TT-XLA GitHub Issues for support.
Tutorial: Generate Python Code from Your Model
Learn how to convert PyTorch and JAX models into standalone Python code using TT-XLA's code generation feature.
What You'll Learn
By the end of this tutorial, you'll be able to:
- Generate Python code from PyTorch/JAX models using TT-XLA
- Execute the generated code on Tenstorrent hardware
- Inspect and understand the TT-NN API calls in the generated code
Time to complete: ~15 minutes
What is Code Generation?
Code generation (also called "EmitPy" or powered by "TT-Alchemist") transforms your high-level model into human-readable Python source code that directly calls the TT-NN library. This lets you inspect, modify, and deploy models without the full TT-XLA runtime.
For complete conceptual overview and all options, see the Code Generation Guide.
Prerequisites
Before starting, ensure you have:
- Access to Tenstorrent hardware (via IRD or physical device) - jump to Step 1
-
TT-XLA Docker image:
ghcr.io/tenstorrent/tt-xla/tt-xla-ird-ubuntu-22-04:latest- jump to Step 2
Step-by-Step Guide
Step 1: Reserve Hardware and Start Docker Container
Reserve Tenstorrent hardware with the TT-XLA Docker image:
ird reserve --docker-image ghcr.io/tenstorrent/tt-xla/tt-xla-ird-ubuntu-22-04:latest [additional ird options]
Tip: The
[additional ird options]should include your typical IRD configuration like architecture, number of chips, etc.
Step 2: Clone and Setup TT-XLA
Inside your Docker container, clone the repository:
git clone https://github.com/tenstorrent/tt-xla.git
cd tt-xla
Initialize submodules (required for dependencies):
git submodule update --init --recursive
Expected output: Git will download all third-party dependencies.
Step 3: Build TT-XLA
Set up the build environment and compile the project:
# Activate the Python virtual environment
source venv/activate
# Configure the build
cmake -G Ninja -B build
# Build the project (this may take 10-15 minutes)
cmake --build build
Debug Build: Add
-DCMAKE_BUILD_TYPE=Debugto the cmake configure command if you need debug symbols.
Step 4: Run Code Generation Example
Choose your framework and run the example:
PyTorch:
python examples/pytorch/codegen/codegen_via_options_example.py
JAX:
python examples/jax/codegen/codegen_via_options_example.py
What Happens During Code Generation
Both examples configure TT-XLA with these options:
options = {
# Code generation options
"backend": "codegen_py",
# Export path
"export_path": "model",
}
The process will:
- ✅ Compile your model through the TT-XLA pipeline
- ✅ Generate Python source code in the
model/directory - ⚠️ Terminate with an error message (this is expected behavior)
Expected terminal output:
Standalone solution was successfully generated. Executing codegen through the frontend is not supported yet. Unfortunately your program will now crash :(
ERROR:root:Caught an exception when exiting the process. Exception:
Traceback (most recent call last):
File "/usr/local/lib/python3.11/dist-packages/torch_xla/__init__.py", line 216, in _prepare_to_exit
_XLAC._prepare_to_exit()
RuntimeError: Bad StatusOr access: UNIMPLEMENTED: Error code: 12
Don't worry! Despite the error message, your code was generated successfully. This termination is a known limitation when running code generation through the frontend.
Generated Files
Check the model/ directory for your generated code:
ls -la model/
You should see:
main.py- Generated Python code with TT-NN API callsrun- Executable shell script to run the generated codettir.mlir- TTIR intermediate representation (for debugging)
Step 5: Generate the optimized code
We can specify different optimization options in order to produce the more performant code. For example, we can supply following set of options to produce the optimized code.
options = {
# Code generation options
"backend": "codegen_py",
# Optimizer options
"enable_optimizer": True,
"enable_memory_layout_analysis": True,
"enable_l1_interleaved": False,
# Export path
"export_path": "model",
}
Link to other optimizer options to be added here. [#1849] TT-XLA Optimizer Docs
Step 6: Dump model inputs
By default, the generated code loads random inputs and parameters. To dump actual model inputs and parameters to disk for use during code execution, you need to run a separate command.
Note: Due to a current limitation (#1851), you must dump tensors and generate code in two separate runs. Once resolved, both can be done together.
First Run: Dump Inputs
options = {
# Tensor dumping options
"dump_inputs": True,
# Export path
"export_path": "model",
}
Second Run: Generate Code (with or without optimizer)
options = {
# Code generation options
"backend": "codegen_py",
# Optimizer options (optional)
#"enable_optimizer": True,
#"enable_memory_layout_analysis": True,
#"enable_l1_interleaved": False,
# Export path
"export_path": "model",
}
You can run these in either order. The generated code will automatically load the dumped tensors if they exist in the model/ directory.
Step 7: Execute the Generated Code
Navigate to the model directory and run the execution script:
cd model
./run
What the run Script Does
The script automatically:
- Sets up the Python environment with TT-NN dependencies
- Configures Tenstorrent hardware settings
- Executes the generated Python code
- Displays inference results
Expected output: You should see inference results printed to the console, showing your model running successfully on Tenstorrent hardware.
Next Steps
Now that you've successfully generated and executed code:
Inspect the Generated Code
Open model/main.py to see how your PyTorch/JAX operations map to TT-NN API calls:
cat model/main.py
Look for patterns like:
- Tensor allocation and initialization
- Operation implementations (matrix multiply, activation functions, etc.)
- Memory management and device synchronization
Customize the Generated Code
Try modifying operations in main.py:
- Change tensor shapes or data types
- Add print statements to debug intermediate values
- Optimize memory layouts or operation ordering
Generate C++ Code
Want C++ instead of Python? Change the backend:
options = {
"backend": "codegen_cpp", # Generate C++ code
"export_path": "model_cpp",
}
The generated C++ code is fully standalone and can be integrated into existing C++ projects.
Generate resnet TTNN code using following example:
Learn More
- Code Generation Guide - Complete reference for all options and use cases
- PyTorch Example Source - Full example code
- JAX Example Source - Full example code
Summary
What you accomplished:
- ✅ Built TT-XLA from source in Docker
- ✅ Generated Python code from a PyTorch/JAX model
- ✅ Executed the generated code on Tenstorrent hardware
- ✅ Learned where to find and inspect the generated code
Key takeaways:
- Code generation creates inspectable implementations
- The process intentionally terminates after generation (current limitation)
- Generated code can be modified and optimized for your use case
Need help? Visit the TT-XLA Issues page or check the Code Generation Guide for more details.
Troubleshooting
Code Generation Fails
Symptom:
ERROR: tt-alchemist generatePython failed
Cause: Code generation process encountered an error
Solutions:
-
Check export path is writable:
mkdir -p <export_path> touch <export_path>/test && rm <export_path>/test -
Verify TTIR was generated:
ls -lh <export_path>/ttir.mlirIf
ttir.mliris missing or empty (0 bytes), compilation failed before code generation. -
Check for compilation errors: Review the full output for errors before the "generatePython failed" message.
-
Try with minimal model: Test with a simple model to isolate the issue:
class MinimalModel(torch.nn.Module): def forward(self, x): return x + 1
Export Path Not Set
Symptom:
Compile option 'export_path' must be provided when backend is not 'TTNNFlatbuffer'
Cause: The export_path option is missing
Solution: Add export_path to your compiler options:
options = {
"backend": "codegen_py",
"export_path": "./output" # ← Add this
}
Generated Code Execution Fails
Symptom: Errors when running generated Python code via ./run
Possible Causes & Solutions:
-
TT-XLA not built:
cd /path/to/tt-xla cmake --build build -
Hardware not accessible:
tt-smi # Should show your Tenstorrent devices -
Wrong hardware configuration:
- Verify generated code matches your hardware setup
- Check device IDs and chip counts
- Rebuild TT-XLA if hardware configuration changed
-
Missing dependencies:
source venv/activate # Ensure virtual environment is active
Generated C++ Code Won't Compile
Symptom: C++ compilation errors in generated code
Solutions:
-
Check TT-NN headers are available:
find /opt/ttmlir-toolchain -name "ttnn*.h" -
Verify C++ compiler version: Generated code requires C++17 or later
-
Link against TT-NN library: Ensure your build system links the TT-NN library correctly
Tools
This section covers tools available in the TT-XLA project.
Available Tools
- Explorer - Tool for exploring and analyzing models
Explorer
Explorer is an interactive GUI tool from TT-MLIR for visualizing and experimenting with model graphs (including Tenstorrent's MLIR dialects), compiling and executing your model on Tenstorrent hardware.
What is Explorer?
Explorer is a visual debugging and performance analysis tool that allows you to:
- Visualize MLIR graphs: Inspect your model graph with hierarchical visualization
- Compile and execute your model: Compile your model to Tenstorrent hardware and execute it
- Debug performance: Identify bottlenecks and see affects of optimizations on runtime performance
Building with Explorer
Explorer is only available when building TT-XLA from source. It is not included in pre-built wheels. It is disabled by default in TT-XLA. You can enable it by building with the TTXLA_ENABLE_EXPLORER CMake option:
cmake -G Ninja -B build -DCMAKE_BUILD_TYPE=Release -DTTXLA_ENABLE_EXPLORER=ON
cmake --build build
Note: Enabling Explorer also enables Tracy performance tracing (
TTMLIR_ENABLE_PERF_TRACE), which may slow down execution. For production deployments or performance benchmarking, consider building with-DTTXLA_ENABLE_EXPLORER=OFF.
Using Explorer
After building with Explorer enabled, launch the tool by running:
tt-explorer
This will start the interactive GUI for analyzing your model's compilation and execution.
Example graph to try out
module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %arg2: tensor<64x128xbf16>, %arg3: tensor<64x128xbf16>) -> tensor<64x128xbf16> {
%0 = ttir.empty() : tensor<64x128xbf16>
%1 = "ttir.add"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
%2 = ttir.empty() : tensor<64x128xbf16>
%3 = "ttir.add"(%arg2, %arg3, %2) : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
%4 = ttir.empty() : tensor<64x128xbf16>
%5 = "ttir.add"(%1, %3, %4) : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
%6 = ttir.empty() : tensor<64x128xbf16>
%7 = "ttir.relu"(%5, %6) : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
return %7 : tensor<64x128xbf16>
}
}
Learn More
For detailed documentation on how to use Explorer, including tutorials and advanced features, see the TT-MLIR Explorer Documentation.
Explorer is based on Google's Model Explorer with added support for Tenstorrent hardware compilation and execution.