MLP inference with TT-NN
In this example we will combine insight from the previous examples, and use TT-NN with PyTorch to perform a simple MLP inference task. This will demonstrate how to use TT-NN for tensor operations and model inference.
Lets create the example file,
ttnn_mlp_inference_mnist.py
Import the necessary libraries
Amongst these, torchvision is used to load the MNIST dataset, and ttnn is used for tensor operations on the Tenstorrent device.
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import ttnn
from loguru import logger
Open Tenstorrent device
Create necessary device on which we will run our program.
# Open Tenstorrent device
device = ttnn.open_device(device_id=0)
Load MNIST Test Data
Load and convert the MNIST 28x28 grayscale images to tensors and normalize them. Subsequently, lets create a DataLoader to iterate through the dataset. This will allow us to perform inference on each image in the dataset.
# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)
Load Pretrained MLP Weights
Load the pretrained MLP weights from a file.
# Pretrained weights
weights = torch.load("mlp_mnist_weights.pt")
W1 = weights["W1"]
b1 = weights["b1"]
W2 = weights["W2"]
b2 = weights["b2"]
W3 = weights["W3"]
b3 = weights["b3"]
Basic accuracy tracking, inference, loop, and image flattening
Loop through the first 5 images in the data set, and convert the image from 1x28x28 to 1x784 by flattening it. This is done to match the input shape of the MLP model.
correct = 0
total = 0
for i, (image, label) in enumerate(testloader):
if i >= 5:
break
image = image.view(1, -1).to(torch.float32)
Convert to TT-NN Tensor
Convert the PyTorch tensor to TT-NN format with bfloat16 data type and TILE_LAYOUT. This is necessary for efficient computation on the Tenstorrent device.
# Input tensor
image_tt = ttnn.from_torch(image, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
image_tt = ttnn.to_layout(image_tt, ttnn.TILE_LAYOUT)
Layer 1 (Linear + ReLU)
Transposed weights are used to match TT-NN’s expected shape. Bias reshaped to 1x128 for broadcasting, and compute output 1.
# Layer 1
W1_tt = ttnn.from_torch(W1.T, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
W1_tt = ttnn.to_layout(W1_tt, ttnn.TILE_LAYOUT)
b1_tt = ttnn.from_torch(b1.view(1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
b1_tt = ttnn.to_layout(b1_tt, ttnn.TILE_LAYOUT)
out1 = ttnn.linear(image_tt, W1_tt, bias=b1_tt)
out1 = ttnn.relu(out1)
Layer 2 (Linear + ReLU)
Same pattern as Layer 1, but with different weights and biases.
# Layer 2
W2_tt = ttnn.from_torch(W2.T, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
W2_tt = ttnn.to_layout(W2_tt, ttnn.TILE_LAYOUT)
b2_tt = ttnn.from_torch(b2.view(1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
b2_tt = ttnn.to_layout(b2_tt, ttnn.TILE_LAYOUT)
out2 = ttnn.linear(out1, W2_tt, bias=b2_tt)
out2 = ttnn.relu(out2)
Layer 3 (Output Layer)
Final layer with 10 output (for digits 0-9). No ReLU activation here, as this is the output layer.
# Layer 3
W3_tt = ttnn.from_torch(W3.T, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
W3_tt = ttnn.to_layout(W3_tt, ttnn.TILE_LAYOUT)
b3_tt = ttnn.from_torch(b3.view(1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
b3_tt = ttnn.to_layout(b3_tt, ttnn.TILE_LAYOUT)
out3 = ttnn.linear(out2, W3_tt, bias=b3_tt)
Convert Back to PyTorch and sum results
Final layer with 10 output (for digits 0-9). No ReLU activation here, as this is the output layer.
# Convert result back to torch
prediction = ttnn.to_torch(out3)
predicted_label = torch.argmax(prediction, dim=1).item()
correct += predicted_label == label.item()
total += 1
logger.info(f"Sample {i+1}: Predicted={predicted_label}, Actual={label.item()}")
Full example and output
Lets put everything together in a complete example that can be run directly. This example will open a Tenstorrent device, create two tensors, perform the addition, and log the output tensor.
You can run the provided train_and_export_mlp.py
script to generate the weights to a file named mlp_mnist_weights.pt
.
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import ttnn
from loguru import logger
def main():
# Open Tenstorrent device
device = ttnn.open_device(device_id=0)
try:
logger.info("\n--- MLP Inference Using TT-NN on MNIST ---")
# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)
# Pretrained weights
weights = torch.load("mlp_mnist_weights.pt")
W1 = weights["W1"]
b1 = weights["b1"]
W2 = weights["W2"]
b2 = weights["b2"]
W3 = weights["W3"]
b3 = weights["b3"]
"""
Random weights for MLP - will not predict correctly
torch.manual_seed(0)
W1 = torch.randn((128, 28 * 28), dtype=torch.float32)
b1 = torch.randn((128,), dtype=torch.float32)
W2 = torch.randn((64, 128), dtype=torch.float32)
b2 = torch.randn((64,), dtype=torch.float32)
W3 = torch.randn((10, 64), dtype=torch.float32)
b3 = torch.randn((10,), dtype=torch.float32)
"""
correct = 0
total = 0
for i, (image, label) in enumerate(testloader):
if i >= 5:
break
image = image.view(1, -1).to(torch.float32)
# Input tensor
image_tt = ttnn.from_torch(image, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
image_tt = ttnn.to_layout(image_tt, ttnn.TILE_LAYOUT)
# Layer 1
W1_tt = ttnn.from_torch(W1.T, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
W1_tt = ttnn.to_layout(W1_tt, ttnn.TILE_LAYOUT)
b1_tt = ttnn.from_torch(b1.view(1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
b1_tt = ttnn.to_layout(b1_tt, ttnn.TILE_LAYOUT)
out1 = ttnn.linear(image_tt, W1_tt, bias=b1_tt)
out1 = ttnn.relu(out1)
# Layer 2
W2_tt = ttnn.from_torch(W2.T, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
W2_tt = ttnn.to_layout(W2_tt, ttnn.TILE_LAYOUT)
b2_tt = ttnn.from_torch(b2.view(1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
b2_tt = ttnn.to_layout(b2_tt, ttnn.TILE_LAYOUT)
out2 = ttnn.linear(out1, W2_tt, bias=b2_tt)
out2 = ttnn.relu(out2)
# Layer 3
W3_tt = ttnn.from_torch(W3.T, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
W3_tt = ttnn.to_layout(W3_tt, ttnn.TILE_LAYOUT)
b3_tt = ttnn.from_torch(b3.view(1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
b3_tt = ttnn.to_layout(b3_tt, ttnn.TILE_LAYOUT)
out3 = ttnn.linear(out2, W3_tt, bias=b3_tt)
# Convert result back to torch
prediction = ttnn.to_torch(out3)
predicted_label = torch.argmax(prediction, dim=1).item()
correct += predicted_label == label.item()
total += 1
logger.info(f"Sample {i+1}: Predicted={predicted_label}, Actual={label.item()}")
logger.info(f"\nTT-NN MLP Inference Accuracy: {correct}/{total} = {100.0 * correct / total:.2f}%")
finally:
ttnn.close_device(device)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from loguru import logger
# Define MLP model
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(128, 64)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(x.size(0), -1) # flatten
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x
def main():
# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Train model
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(5):
total_loss = 0
for images, labels in trainloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
logger.info(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
# Save weights
weights = {
"W1": model.fc1.weight.detach().clone(), # [128, 784]
"b1": model.fc1.bias.detach().clone(), # [128]
"W2": model.fc2.weight.detach().clone(), # [64, 128]
"b2": model.fc2.bias.detach().clone(), # [64]
"W3": model.fc3.weight.detach().clone(), # [10, 64]
"b3": model.fc3.bias.detach().clone(), # [10]
}
torch.save(weights, "mlp_mnist_weights.pt")
logger.info("Weights saved to mlp_mnist_weights.pt")
if __name__ == "__main__":
main()
Running this script will output the input tensors and the result of their addition, which should be a tensor filled with 3s. As shown below
2025-06-23 09:51:47.723 | INFO | __main__:main:17 -
--- MLP Inference Using TT-NN on MNIST ---
2025-06-23 09:52:10.480 | INFO | __main__:main:85 - Sample 1: Predicted=7, Actual=7
2025-06-23 09:52:10.491 | INFO | __main__:main:85 - Sample 2: Predicted=2, Actual=2
2025-06-23 09:52:10.499 | INFO | __main__:main:85 - Sample 3: Predicted=1, Actual=1
2025-06-23 09:52:10.506 | INFO | __main__:main:85 - Sample 4: Predicted=0, Actual=0
2025-06-23 09:52:10.514 | INFO | __main__:main:85 - Sample 5: Predicted=4, Actual=4
2025-06-23 09:52:10.514 | INFO | __main__:main:87 -
TT-NN MLP Inference Accuracy: 5/5 = 100.00%