Tutorial: Generate Python Code from Your Model
Learn how to convert PyTorch and JAX models into standalone Python code using TT-XLA’s code generation feature.
What You’ll Learn
By the end of this tutorial, you’ll be able to:
Generate Python code from PyTorch/JAX models using TT-XLA
Execute the generated code on Tenstorrent hardware
Inspect and understand the TT-NN API calls in the generated code
Time to complete: ~15 minutes
What is Code Generation?
Code generation (also called “EmitPy” or powered by “TT-Alchemist”) transforms your high-level model into human-readable Python source code that directly calls the TT-NN library. This lets you inspect, modify, and deploy models without the full TT-XLA runtime.
For complete conceptual overview and all options, see the Code Generation Guide.
Prerequisites
Before starting, ensure you have:
Step-by-Step Guide
Step 1: Reserve Hardware and Start Docker Container
Reserve Tenstorrent hardware with the TT-XLA Docker image:
ird reserve --docker-image ghcr.io/tenstorrent/tt-xla/tt-xla-ird-ubuntu-24-04:latest [additional ird options]
Tip: The
[additional ird options]should include your typical IRD configuration like architecture, number of chips, etc.
Step 2: Clone and Setup TT-XLA
Inside your Docker container, clone the repository:
git clone https://github.com/tenstorrent/tt-xla.git
cd tt-xla
Initialize submodules (required for dependencies):
git submodule update --init --recursive
Expected output: Git will download all third-party dependencies.
Step 3: Build TT-XLA
Set up the build environment and compile the project:
# Activate the Python virtual environment
source venv/activate
# Configure the build
cmake -G Ninja -B build
# Build the project (this may take 10-15 minutes)
cmake --build build
Debug Build: Add
-DCMAKE_BUILD_TYPE=Debugto the cmake configure command if you need debug symbols.
Step 4: Run Code Generation Example
Choose your framework and run the example:
PyTorch:
python examples/pytorch/codegen/python/custom_module.py
JAX:
python examples/jax/codegen/python/custom_module.py
What Happens During Code Generation
Both examples configure TT-XLA with these options:
codegen_py(
forward, graphdef, state, x, export_path="model"
)
where forward is the model’s forward function that TT-XLA will compile and lower to TT-NN ops. In the JAX example, it is defined as:
def forward(graphdef, state, x):
model = nnx.merge(graphdef, state)
return model(x)
and graphdef/state come from nnx.split(model), while x is a representative input used to trace/compile the computation.
The process will:
✅ Compile your model through the TT-XLA pipeline
✅ Generate Python source code in the
model/directory
Generated Files
Check the model/ directory for your generated code:
ls -la model/
You should see:
main.py- Generated Python code with TT-NN API callsrun- Executable shell script to run the generated codetensors/- Directory with exported model input and parameter tensorsirs/- VHLO, SHLO, TTIR, TTNN intermediate representations (debugging)
Step 5: Generate the optimized code
You can specify different optimization options to produce more performant code. For example, supply the following set of options to produce optimized code.
# Any compile options you could specify when executing the model normally can also be used with codegen.
extra_options = {
"optimization_level": 0, # Levels 0, 1, and 2 are supported
}
codegen_py(
forward, graphdef, state, x, export_path="model", compiler_options=extra_options
)
For other optimizer options, see TT-XLA Optimizer Docs.
Step 6: Exporting model input and parameter tensors
By default, model input and parameter tensors are exported to export_path/tensors/.
If you don’t need to dump these tensors, set the codegen_py parameter export_tensors=False. The generated code will use ttnn.ones for input and parameter tensors instead.
codegen_py(
forward, graphdef, state, x, export_path="model", export_tensors=False
)
Step 7: Execute the Generated Code
Navigate to the model directory and run the execution script:
cd model
./run
What the run Script Does
The script automatically:
Sets up the Python environment with TT-NN dependencies
Configures Tenstorrent hardware settings
Executes the generated Python code
Displays inference results
Expected output: You should see inference results printed to the console, showing your model running successfully on Tenstorrent hardware.
Next Steps
Now that you’ve successfully generated and executed code:
Inspect the Generated Code
Open model/main.py to see how your PyTorch/JAX operations map to TT-NN API calls:
cat model/main.py
Look for patterns like:
Tensor allocation and initialization
Operation implementations (matrix multiply, activation functions, etc.)
Memory management and device synchronization
Customize the Generated Code
Try modifying operations in main.py:
Change tensor shapes or data types
Add print statements to debug intermediate values
Optimize memory layouts or operation ordering
Generate C++ Code
Want C++ instead of Python? Change the backend:
# Any compile options you could specify when executing the model normally can also be used with codegen.
extra_options = {
# "optimization_level": 0, # Levels 0, 1, and 2 are supported
}
codegen_cpp(
forward, graphdef, state, x, export_path="model", compiler_options=extra_options
)
The generated C++ code is fully standalone and can be integrated into existing C++ projects.
Generate ResNet TTNN code using the following example
Learn More
Code Generation Guide - Complete reference for all options and use cases
PyTorch Example Source - Full example code
JAX Example Source - Full example code
Summary
What you accomplished:
✅ Built TT-XLA from source in Docker
✅ Generated Python code from a PyTorch/JAX model
✅ Executed the generated code on Tenstorrent hardware
✅ Learned where to find and inspect the generated code
Key takeaways:
Code generation creates inspectable implementations
The process intentionally terminates after generation (current limitation)
Generated code can be modified and optimized for your use case
Need help? Visit the TT-XLA Issues page or check the Code Generation Guide for more details.