Getting Started

This document walks you through how to set up TT-XLA. TT-XLA is a front end for TT-Forge that is primarily used to ingest JAX models via jit compile, providing a StableHLO (SHLO) graph to the TT-MLIR compiler. TT-XLA leverages PJRT to integrate JAX, tt-mlir and Tenstorrent hardware. Please see this blog post for more information about PJRT project. This project is a fork of iree-pjrt.

NOTE: Currently, only Tenstorrent Nebula boards are supported.

This is the main Getting Started page. There are two additional Getting Started pages depending on what you want to do. They are all described here, with links provided to each.

The following topics are covered:

NOTE: If you encounter issues, please request assistance on the TT-XLA Issues page.

Setup Options

TT-XLA can be used to run JAX models on Tenstorrent's AI hardware. Because TT-XLA is open source, you can also develop and add features to it. Setup instructions differ based on the task. You have the following options, listed in order of difficulty:

Configuring Hardware

Before setup can happen, you must configure your hardware. You can skip this section if you already completed the configuration steps. Otherwise, this section of the walkthrough shows you how to do a quick setup using TT-Installer.

  1. Configure your hardware with TT-Installer using the Quick Installation section here.

  2. Reboot your machine.

  3. Please ensure that after you run this script, after you complete reboot, you activate the virtual environment it sets up - source ~/.tenstorrent-venv/bin/activate.

  4. After your environment is running, to check that everything is configured, type the following:

tt-smi

You should see the Tenstorrent System Management Interface. It allows you to view real-time stats, diagnostics, and health info about your Tenstorrent device.

TT-SMI

Installing a Wheel and Running an Example

This section walks you through downloading and installing a wheel. You can install the wheel wherever you would like if it is for running a model.

  1. Make sure you are in an active virtual environment. This walkthrough uses the same environment you activated to look at TT-SMI in the Configuring Hardware section. If you are using multiple TT-Forge front ends to run models, you may want to set up a separate virtual environment instead. For example:
python3 -m venv .xla-venv
source .xla-venv/bin/activate
  1. Download the wheel in your active virtual environment:
pip install pjrt-plugin-tt --extra-index-url https://pypi.eng.aws.tenstorrent.com/
  1. You are now ready to try running a model. Navigate to the section of the TT-Forge repo that contains TT-XLA demos.

  2. For this walkthrough, the demo in the gpt2 folder is used. In the gpt2 folder, in the requirements.txt file, you can see that flax and transformers are necessary to run the demo. Install them:

pip install flax transformers
  1. Download the demo.py file from the gpt2 folder inside your activated virtual environment in a place where you can run it. The demo you are about to run takes a piece of text and tries to predict the next word that logically follows.

  2. Run the model:

python demo.py
  1. If all goes well you should see the prompt "The capital of France is", the predicted next token, the probability it will occur, and a list of other ranked options that could follow instead.

Other Setup Options

If you want to keep your environment completely separate in a Docker container, or you want to develop TT-XLA further, this section links you to the pages with those options:

Where to Go Next

Now that you have set up the TT-XLA wheel, you can compile and run other demos. See the TT-XLA folder in the TT-Forge repo for other demos you can try.

Getting Started with Docker

This document walks you through how to set up TT-XLA using a Docker image. There are two other available options for getting started:

  • Installing a Wheel - if you do not want to use Docker, and prefer to use a virtual environment by itself instead, use this method.
  • Building from Source - if you plan to develop TT-XLA further, you must build from source, and should use this method.

The following topics are covered:

Configuring Hardware

Before setup can happen, you must configure your hardware. You can skip this section if you already completed the configuration steps. Otherwise, this section of the walkthrough shows you how to do a quick setup using TT-Installer.

  1. Configure your hardware with TT-Installer using the Quick Installation section here.

  2. Reboot your machine.

  3. Please ensure that after you run this script, after you complete reboot, you activate the virtual environment it sets up - source ~/.tenstorrent-venv/bin/activate.

  4. When your environment is running, to check that everything is configured, type the following:

tt-smi

You should see the Tenstorrent System Management Interface. It allows you to view real-time stats, diagnostics, and health info about your Tenstorrent device.

TT-SMI

Setting up the Docker Container

This section walks through the installation steps for using a Docker container for your project.

To install, do the following:

  1. Install Docker if you do not already have it:
sudo apt update
sudo apt install docker.io -y
sudo systemctl start docker
sudo systemctl enable docker
  1. Test that Docker is installed:
docker --version
  1. Add your user to the Docker group:
sudo usermod -aG docker $USER
newgrp docker
  1. Run the Docker container:
docker run -it --rm \
  --device /dev/tenstorrent \
  -v /dev/hugepages-1G:/dev/hugepages-1G \
  ghcr.io/tenstorrent/tt-forge/tt-xla-slim:latest
  1. If you want to check that it is running, open a new tab with the Same Command option and run the following:
docker ps

Running Models in Docker

This section shows you how to run a model using Docker. The provided example is from the TT-Forge repo. Do the following:

  1. Inside your running Docker container, clone the TT-Forge repo:
git clone https://github.com/tenstorrent/tt-forge.git
  1. Set the path for Python:
export PYTHONPATH=/tt-forge:$PYTHONPATH
  1. Navigate into TT-Forge and run the following command:
git submodule update --init --recursive
  1. Navigate back out of the TT-Forge directory.

  2. Run a model. Similar to gpt2, this model predicts what the next word in a sentence is likely to be. For this model, the demo.py for opt_125m is used. The requirements.txt file shows that you need to install flax and transformers:

pip install flax transformers
  1. After completing installation, run the following:
python tt-forge/demos/tt-xla/opt_125m/demo.py

If all goes well, you should get an example prompt saying 'The capital of France is.' The prediction for the next term is listed, along with the probability it will occur. This is followed by a table of other likely choices.

Where to Go Next

Now that you have set up TT-XLA, you can compile and run your own models, or try some of the other demos. You can find TT-XLA demos in the TT-Forge directory.

Getting Started with Building from Source

This document describes how to build the TT-XLA project on your local machine. You must build from source if you want to develop for TT-XLA. If you only want to run models, please choose one of the following sets of instructions instead:

Build Process

TT-XLA integration with the TT-MLIR compiler is still in progress. Currently TT-XLA depends on the TT-MLIR toolchain to build from source. This build flow provides an easy way to experiment with TT-XLA, StableHLO, and the TT-MLIR infrastructure. The build process will be updated in the future to enhance the user experience.

TT-MLIR Toolchain

Before compiling TT-XLA, the TT-MLIR toolchain needs to be built:

TT-XLA

Before running these commands to build TT-XLA, please ensure that the environtment variable TTMLIR_TOOLCHAIN_DIR is set to point to the TT-MLIR toolchain directory created above as part of the TT-MLIR environment setup (for example export TTMLIR_TOOLCHAIN_DIR=/opt/ttmlir-toolchain/). You can also set export LOGGER_LEVEL=DEBUG in order to enable debug logs.

git clone git@github.com:tenstorrent/tt-xla.git
cd tt-xla
source venv/activate
cmake -G Ninja -B build # -DCMAKE_BUILD_TYPE=Debug in case you want debug build
cmake --build build

Wheel Build

To build a wheel run

>> cd python_package
>> python setup.py bdist_wheel

this will output a python_package/dist/pjrt_plugin_tt*.whl file which is self-contained and can be installed using:

pip install pjrt_plugin_tt*.whl

The wheel has the following structure:

jax_plugins
`-- pjrt_plugin_tt
    |-- __init__.py
    |-- pjrt_plugin_tt.so   # Plugin itself.
    `-- tt-mlir             # Entire tt-mlir installation folder
        `-- install
            |-- include
            |   `-- ...
            |-- lib
            |   |-- libTTMLIRCompiler.so
            |   |-- libTTMLIRRuntime.so
            |   `-- ...
            `-- tt-metal    # We need to set TT_METAL_HOME to this dir when loading plugin
                |-- runtime
                |   `-- ...
                |-- tt_metal
                |   `-- ...
                `-- ttnn
                    `-- ...

It contains a custom Tenstorrent PJRT plugin (an .so file), __init__.py file which holds a python function for registering the PJRT plugin with JAX and the tt-mlir installation dir. This is needed in order to dynamically link TT-MLIR libs in runtime and to resolve various tt-metal dependencies without which the plugin does not work.

Structuring wheel/folders this way allows JAX to automatically register the plugin upon usage. Do the following:

>> pip install pjrt_plugin_tt*.whl
>> python
# Python console
>>>> import jax
>>>> tt_device = jax.devices("tt") # this will trigger plugin registration.

Testing

The TT-XLA repo contains various tests in the tests directory. To run them all, please run pytest -v tests from the project root directory. To run an individual test, pytest -svv is recommended in order to capture all potential error messages down the line.

Common Build Errors

  • Building TT-XLA requires clang-17. Please make sure that clang-17 is installed on the system and clang/clang++ links to the correct version of the respective tools.
  • Please also see the TT-MLIR docs for common build errors.

Pre-commit

Pre-commit applies a git hook to the local repository such that linting is checked and applied on every git commit action. Install it from the root of the repository using:

source venv/activate
pre-commit install

If you have already committed something locally before installing the pre-commit hooks, you can run this command to check all files:

pre-commit run --all-files

For more information please visit pre-commit.

Test infra

Test infra consists of main "tester" classes and a few helper ones. Its main goal is making test writing easy.

Here is a brief class diagram of the infra:

infra

Op and graph tests

Op tester exposes easy to use functions:

run_op_test(...)
run_op_test_with_random_inputs(...)

They wrap the instantiation of the OpTester and all the underlying complexity. User just need to pass the op (python function) they want to test to one of these functions like this:

def test_add(x_shape: tuple, y_shape: tuple):
    def add(x: jax.Array, y: jax.Array) -> jax.Array:
        return jnp.add(x, y)

    run_op_test_with_random_inputs(add, [x_shape, y_shape])

and that's it.

GraphTester is at the moment identical to OpTester, and it too exposes

run_graph_test(...)
run_graph_test_with_random_inputs(...)

which are meant to be used in the same way as for op tests.

Model tests

Models are tested by inheriting one of *ModelTester classes and overriding required methods. Please read docstring of appropriate class you want to inherit for more information.

Jax model example

First, we define a model

class MNISTMLPModel(nn.Module):
    hidden_sizes: tuple[int]

    @nn.compact
    def __call__(self, x: jax.Array):
        x = x.reshape((x.shape[0], -1))

        for h in self.hidden_sizes:
            x = nn.Dense(features=h)(x)
            x = nn.relu(x)

        x = nn.Dense(features=10)(x)
        x = nn.softmax(x)

        return x

Then we define a tester by inheriting JaxModelTester

class MNISTMLPTester(JaxModelTester):
    def __init__(
        self,
        hidden_sizes: Sequence[int],
        comparison_config: ComparisonConfig = ComparisonConfig(),
        run_mode: RunMode = RunMode.INFERENCE,
    ) -> None:
        self._hidden_sizes = hidden_sizes
        super().__init__(comparison_config, run_mode)

    # @override
    def _get_model(self) -> nn.Module:
        return MNISTMLPModel(self._hidden_sizes)

    # @override
    def _get_forward_method_name(self) -> str:
        return "apply"

    # @override
    def _get_input_activations(self) -> Sequence[jax.Array]:
        key = jax.random.PRNGKey(37)
        img = jax.random.normal(key, (4, 28, 28, 1))  # B, H, W, C
        # Channels is 1 as MNIST is in grayscale.
        return img

    # @override
    def _get_forward_method_args(self):
        inp = self._get_input_activations()
        parameters = self._model.init(jax.random.PRNGKey(42), inp)
        return [parameters, inp]

and finally, we run the test

@pytest.fixture
def inference_tester(request) -> MNISTMLPTester:
    return MNISTMLPTester(request.param)

@pytest.mark.parametrize(
    "inference_tester", [(256, 128, 64)], indirect=True, ids=lambda val: f"{val}"
)
def test_mnist_mlp_inference(inference_tester: MNISTMLPTester):
    inference_tester.test()