n150 n300 T3000 p100 p150 p300c Galaxy 10 min Draft

JAX and PyTorch/XLA on Tenstorrent

The venv-forge / tt-forge-venv environment ships JAX, torch-xla, and the TT PJRT plugin pre-installed. There is no installation step — just activate and start computing.

TT-QuietBox® 2 users: All four p300c chips appear as TT devices (jax.devices() returns four entries). pmap distributes work across them automatically.


Activate the environment

source ~/tt-forge-venv/bin/activate

Can't find ~/tt-forge-venv? Developer images put the forge env at /opt/venv-forge and symlink it to ~/tt-forge-venv automatically. If you're on a system where only one path exists, create the link yourself:

# /opt/venv-forge exists but ~/tt-forge-venv doesn't:
ln -s /opt/venv-forge ~/tt-forge-venv

# ~/tt-forge-venv exists but /opt/venv-forge doesn't (needs sudo):
sudo ln -s ~/tt-forge-venv /opt/venv-forge

Note: The PJRT plugin requires tt_torch to be imported before jax so the TT shared libraries are loaded first. The verify command handles this automatically. If you call import jax without importing tt_torch first, JAX will fall back to CPU.

▶ Activate Forge Environment
source ~/tt-forge-venv/bin/activate && python3 -c "\n${JAX_DEVICE_CHECK_PY}\n"

Expected output:

TT devices: [TtDevice(id=0)]          # n150 / p300c
# or
TT devices: [TtDevice(id=0), TtDevice(id=1), TtDevice(id=2), TtDevice(id=3)]   # TT-QuietBox 2

▶ Check TT Devices
tt-smi


JAX — 30 seconds to tensor on silicon

JAX dispatches to TT hardware automatically via the PJRT plugin. No device placement code needed:

import jax
import jax.numpy as jnp

# Create arrays — they live on TT hardware
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))

# This runs on your TT chip
c = a @ b
print(c.shape)            # (1024, 1024)
print(c.devices())        # {TtDevice(id=0)}
print(c[0, 0])            # 1024.0

▶ Run JAX Quickstart
source ~/tt-forge-venv/bin/activate && python3 -c \


JIT compilation

@jax.jit compiles the function into an XLA program the first time it runs, then caches it. Subsequent calls hit the compiled path:

import jax
import jax.numpy as jnp

@jax.jit
def scaled_matmul(A, B, scale):
    return scale * (A @ B)

A = jnp.ones((256, 256))
B = jnp.ones((256, 256))

# First call: compiles + runs
result = scaled_matmul(A, B, 2.0)

# Subsequent calls: cached compiled kernel, fast
result = scaled_matmul(A, B, 3.0)
print(result[0, 0])       # 768.0

Transformer attention on TT hardware

A minimal multi-head self-attention block — the core of every modern LLM:

import jax
import jax.numpy as jnp

def attention(Q, K, V):
    """Scaled dot-product attention."""
    d_k = Q.shape[-1]
    scores = Q @ K.T / jnp.sqrt(d_k)
    weights = jax.nn.softmax(scores, axis=-1)
    return weights @ V

attention_jit = jax.jit(attention)

seq_len, d_model = 64, 128
Q = jnp.ones((seq_len, d_model))
K = jnp.ones((seq_len, d_model))
V = jnp.ones((seq_len, d_model))

out = attention_jit(Q, K, V)
print(out.shape)          # (64, 128)
print(out.devices())      # {TtDevice(id=0)}

Multi-device with pmap (TT-QuietBox 2 / n300 / T3000)

jax.pmap maps a function over a leading batch dimension, one slice per device. On TT-QuietBox 2 with four p300c chips this uses all four in parallel:

import jax
import jax.numpy as jnp

devices = jax.devices()
n = len(devices)
print(f"Running across {n} TT device(s)")

# Replicate computation across all chips
@jax.pmap
def matmul_per_device(A):
    return A @ A.T

# Leading axis = number of devices
A = jnp.ones((n, 512, 512))
result = matmul_per_device(A)

print(result.shape)       # (4, 512, 512) on TT-QuietBox 2
print(result.sharding)    # shows per-device placement

▶ Run Multi-Device pmap Demo
source ~/tt-forge-venv/bin/activate && python3 -c \


PyTorch/XLA — PyTorch models on TT silicon

torch-xla is also pre-installed. Use xm.xla_device() to get the TT device and .to(device) to place tensors there — standard PyTorch idiom:

import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()
print(f"TT device: {device}")

# PyTorch tensors on TT hardware
x = torch.randn(256, 256).to(device)
y = torch.randn(256, 256).to(device)

z = x @ y
xm.mark_step()           # flush the XLA graph

print(z.shape)            # torch.Size([256, 256])
print(z.device)           # xla:0

▶ Run PyTorch/XLA Demo
source ~/tt-forge-venv/bin/activate && python3 -c \

PyTorch model inference

import torch
import torch_xla.core.xla_model as xm
import torchvision.models as models

device = xm.xla_device()

# Standard torchvision model — no code changes needed
model = models.mobilenet_v2(weights="DEFAULT").to(device)
model.eval()

x = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    output = model(x)
    xm.mark_step()

print(output.shape)       # torch.Size([1, 1000])

Note: torch-xla (without forge.compile()) runs models via the XLA JIT path. For the full TT-Forge compiler pipeline with MLIR optimization, see the TT-Forge Image Classification lesson.


Hardware configuration

Wormhole and Blackhole® chips are configured identically at the JAX API level. jax.devices() returns one entry per chip, regardless of board type.

Hardware jax.devices() Notes
n150 [TtDevice(id=0)] Single Wormhole chip
n300 [TtDevice(id=0), TtDevice(id=1)] Two Wormhole chips
T3000 [TtDevice(id=0..7)] Eight Wormhole chips
p300c [TtDevice(id=0)] Single Blackhole chip
TT-QuietBox 2 [TtDevice(id=0..3)] Four independent p300c chips
Galaxy [TtDevice(id=0..31)] 32 Wormhole chips

Set TT_METAL_ARCH_NAME before activating the env if it isn't already set:

export TT_METAL_ARCH_NAME=blackhole   # p300c / TT-QuietBox 2 / p150
export TT_METAL_ARCH_NAME=wormhole_b0 # n150 / n300 / T3000 / Galaxy
source ~/tt-forge-venv/bin/activate

Run the official TT-Forge demos

The tt-forge repo has validated GPT-2, ALBERT, ResNet, and OPT demos using JAX/Flax and PyTorch/XLA:

git clone https://github.com/tenstorrent/tt-forge.git ~/tt-forge
cd ~/tt-forge/demos/tt-xla/nlp/jax
source ~/tt-forge-venv/bin/activate
pip install -r requirements.txt
python gpt_demo.py

Expected output:

Model Variant: GPT2Variant.BASE
Prompt: Gravity Gravity Gravity Gravity Gravity
Next token: ' Gravity' (id: 24532)
Probability: 0.9876

Other demos in ~/tt-forge/demos/tt-xla/:

Demo Path What it runs
GPT-2 nlp/jax/gpt_demo.py GPT-2 Base/Medium/Large/XL, next-token prediction
ALBERT nlp/jax/albert_demo.py ALBERT classification
OPT nlp/jax/opt_demo.py Meta OPT language model
ResNet cnn/ Image classification with JAX/Flax

▶ Clone and Run TT-Forge Demos
git clone https://github.com/tenstorrent/tt-forge.git ~/tt-forge 2>/dev/null || (cd ~/tt-forge && git pull origin main)


What you just ran

venv-forge (pre-installed)
  pjrt_plugin_tt ─── connects JAX/torch-xla to TT hardware
  jax 0.7.1      ─── framework, JIT, pmap
  torch-xla 2.9  ─── PyTorch XLA backend (TT-patched)

One activation command → tensors on silicon.

No new venv, no pip install, no Python version change, no library compilation.


Next steps