JAX and PyTorch/XLA on Tenstorrent
The venv-forge / tt-forge-venv environment ships JAX, torch-xla, and the TT PJRT plugin
pre-installed. There is no installation step — just activate and start computing.
QB2 users: All four p300c chips appear as TT devices (
jax.devices()returns four entries).pmapdistributes work across them automatically.
Activate the environment
source ~/tt-forge-venv/bin/activate
Can't find
~/tt-forge-venv? Developer images put the forge env at/opt/venv-forgeand symlink it to~/tt-forge-venvautomatically. If you're on a system where only one path exists, create the link yourself:# /opt/venv-forge exists but ~/tt-forge-venv doesn't: ln -s /opt/venv-forge ~/tt-forge-venv # ~/tt-forge-venv exists but /opt/venv-forge doesn't (needs sudo): sudo ln -s ~/tt-forge-venv /opt/venv-forge
Note: The PJRT plugin requires
tt_torchto be imported beforejaxso the TT shared libraries are loaded first. The verify command handles this automatically. If you callimport jaxwithout importingtt_torchfirst, JAX will fall back to CPU.
source ~/tt-forge-venv/bin/activate && python3 -c "\n${JAX_DEVICE_CHECK_PY}\n"Expected output:
TT devices: [TtDevice(id=0)] # N150 / p300c
# or
TT devices: [TtDevice(id=0), TtDevice(id=1), TtDevice(id=2), TtDevice(id=3)] # QB2
tt-smi
JAX — 30 seconds to tensor on silicon
JAX dispatches to TT hardware automatically via the PJRT plugin. No device placement code needed:
import jax
import jax.numpy as jnp
# Create arrays — they live on TT hardware
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))
# This runs on your TT chip
c = a @ b
print(c.shape) # (1024, 1024)
print(c.devices()) # {TtDevice(id=0)}
print(c[0, 0]) # 1024.0
source ~/tt-forge-venv/bin/activate && python3 -c \
JIT compilation
@jax.jit compiles the function into an XLA program the first time it runs,
then caches it. Subsequent calls hit the compiled path:
import jax
import jax.numpy as jnp
@jax.jit
def scaled_matmul(A, B, scale):
return scale * (A @ B)
A = jnp.ones((256, 256))
B = jnp.ones((256, 256))
# First call: compiles + runs
result = scaled_matmul(A, B, 2.0)
# Subsequent calls: cached compiled kernel, fast
result = scaled_matmul(A, B, 3.0)
print(result[0, 0]) # 768.0
Transformer attention on TT hardware
A minimal multi-head self-attention block — the core of every modern LLM:
import jax
import jax.numpy as jnp
def attention(Q, K, V):
"""Scaled dot-product attention."""
d_k = Q.shape[-1]
scores = Q @ K.T / jnp.sqrt(d_k)
weights = jax.nn.softmax(scores, axis=-1)
return weights @ V
attention_jit = jax.jit(attention)
seq_len, d_model = 64, 128
Q = jnp.ones((seq_len, d_model))
K = jnp.ones((seq_len, d_model))
V = jnp.ones((seq_len, d_model))
out = attention_jit(Q, K, V)
print(out.shape) # (64, 128)
print(out.devices()) # {TtDevice(id=0)}
Multi-device with pmap (QB2 / N300 / T3K)
jax.pmap maps a function over a leading batch dimension, one slice per device.
On QB2 with four p300c chips this uses all four in parallel:
import jax
import jax.numpy as jnp
devices = jax.devices()
n = len(devices)
print(f"Running across {n} TT device(s)")
# Replicate computation across all chips
@jax.pmap
def matmul_per_device(A):
return A @ A.T
# Leading axis = number of devices
A = jnp.ones((n, 512, 512))
result = matmul_per_device(A)
print(result.shape) # (4, 512, 512) on QB2
print(result.sharding) # shows per-device placement
source ~/tt-forge-venv/bin/activate && python3 -c \
PyTorch/XLA — PyTorch models on TT silicon
torch-xla is also pre-installed. Use xm.xla_device() to get the TT device
and .to(device) to place tensors there — standard PyTorch idiom:
import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()
print(f"TT device: {device}")
# PyTorch tensors on TT hardware
x = torch.randn(256, 256).to(device)
y = torch.randn(256, 256).to(device)
z = x @ y
xm.mark_step() # flush the XLA graph
print(z.shape) # torch.Size([256, 256])
print(z.device) # xla:0
source ~/tt-forge-venv/bin/activate && python3 -c \
PyTorch model inference
import torch
import torch_xla.core.xla_model as xm
import torchvision.models as models
device = xm.xla_device()
# Standard torchvision model — no code changes needed
model = models.mobilenet_v2(weights="DEFAULT").to(device)
model.eval()
x = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
output = model(x)
xm.mark_step()
print(output.shape) # torch.Size([1, 1000])
Note: torch-xla (without
forge.compile()) runs models via the XLA JIT path. For the full tt-forge-fe" target="_blank" rel="noreferrer">TT-Forge compiler pipeline with MLIR optimization, see the TT-Forge Image Classification lesson.
Hardware configuration
Wormhole and Blackhole chips are configured identically at the JAX API level.
jax.devices() returns one entry per chip, regardless of board type.
| Hardware | jax.devices() |
Notes |
|---|---|---|
| N150 | [TtDevice(id=0)] |
Single Wormhole chip |
| N300 | [TtDevice(id=0), TtDevice(id=1)] |
Two Wormhole chips |
| T3K | [TtDevice(id=0..7)] |
Eight Wormhole chips |
| p300c | [TtDevice(id=0)] |
Single Blackhole chip |
| QB2 | [TtDevice(id=0..3)] |
Four independent p300c chips |
| Galaxy | [TtDevice(id=0..31)] |
32 Wormhole chips |
Set TT_METAL_ARCH_NAME before activating the env if it isn't already set:
export TT_METAL_ARCH_NAME=blackhole # p300c / QB2 / P150
export TT_METAL_ARCH_NAME=wormhole_b0 # N150 / N300 / T3K / Galaxy
source ~/tt-forge-venv/bin/activate
Run the official tt-forge demos
The tt-forge repo has validated GPT-2, ALBERT, ResNet, and OPT demos
using JAX/Flax and PyTorch/XLA:
git clone https://github.com/tenstorrent/tt-forge.git ~/tt-forge
cd ~/tt-forge/demos/tt-xla/nlp/jax
source ~/tt-forge-venv/bin/activate
pip install -r requirements.txt
python gpt_demo.py
Expected output:
Model Variant: GPT2Variant.BASE
Prompt: Gravity Gravity Gravity Gravity Gravity
Next token: ' Gravity' (id: 24532)
Probability: 0.9876
Other demos in ~/tt-forge/demos/tt-xla/:
| Demo | Path | What it runs |
|---|---|---|
| GPT-2 | nlp/jax/gpt_demo.py |
GPT-2 Base/Medium/Large/XL, next-token prediction |
| ALBERT | nlp/jax/albert_demo.py |
ALBERT classification |
| OPT | nlp/jax/opt_demo.py |
Meta OPT language model |
| ResNet | cnn/ |
Image classification with JAX/Flax |
git clone https://github.com/tenstorrent/tt-forge.git ~/tt-forge 2>/dev/null || (cd ~/tt-forge && git pull origin main)
What you just ran
venv-forge (pre-installed)
pjrt_plugin_tt ─── connects JAX/torch-xla to TT hardware
jax 0.7.1 ─── framework, JIT, pmap
torch-xla 2.9 ─── PyTorch XLA backend (TT-patched)
One activation command → tensors on silicon.
No new venv, no pip install, no Python version change, no library compilation.
Next steps
- TT-Forge Image Classification → —
forge.compile()for PyTorch models via the full MLIR compiler pipeline - vLLM Production → — LLM serving (Qwen3, Llama)
- JAX documentation — comprehensive JAX tutorials
- tt-forge demos — validated JAX and PyTorch/XLA examples