TT-NN Tracer and BERT Model Visualization Tutorial

This tutorial demonstrates how to use TT-NN’s tracer functionality to visualize tensor operations and computational graphs. We’ll explore: 1. Basic tensor operation tracing with PyTorch tensors 2. TT-NN tensor operations with reshaping 3. Tracing a BERT self-attention layer 4. Running and tracing a full BERT model for question answering

The tracer is a powerful debugging and optimization tool that helps understand how operations are executed on Tenstorrent hardware.

Import Libraries

Disable fast runtime mode to ensure all operations are properly traced. Fast runtime mode may skip some operations for performance, which we don’t want when debugging.

[ ]:
import os
from pathlib import Path
os.environ["TTNN_CONFIG_OVERRIDES"] = "{\"enable_fast_runtime_mode\": false}"

import torch
import transformers

import ttnn
from ttnn.tracer import trace, visualize

Set program config

Suppress transformer library warnings for cleaner output.

[ ]:
transformers.logging.set_verbosity_error()

Example 1: Tracing PyTorch Operations

The tracer context manager captures all operations performed within its scope.

This example shows how basic PyTorch operations are tracked.

[ ]:
with trace():
    # Create a random integer tensor of shape (1, 64) with values between 0-99
    tensor = torch.randint(0, 100, (1, 64))
    # Apply exponential function element-wise
    # This demonstrates how mathematical operations are captured
    tensor = torch.exp(tensor)

# Visualize the computational graph of the traced operations
# This will show the flow from random tensor creation to exp operation
visualize(tensor)

Example 2: Tracing TT-NN Tensor Operations

This example demonstrates tracing operations that involve TT-NN tensors.

[ ]:
with trace():
    # Create a PyTorch tensor with shape (4, 64)
    tensor = torch.randint(0, 100, (4, 64))

    # Convert PyTorch tensor to TT-NN format
    # This operation moves data to the TT-NN representation
    tensor = ttnn.from_torch(tensor)

    # Reshape the tensor from (4, 64) to (2, 4, 32)
    # This demonstrates how reshape operations are handled in TT-NN
    tensor = ttnn.reshape(tensor, (2, 4, 32))

    # Convert back to PyTorch for visualization
    tensor = ttnn.to_torch(tensor)

# Visualize the graph showing PyTorch → TT-NN → reshape → PyTorch conversion
visualize(tensor)

Model and Config downloading

We define three functions to download the weights and configuration from Hugging Face.

For practical purposes, we can also specify a TTNN_TUTORIALS_MODELS_CLIP_PATH environment variable to avoid downloading the model. If it is defined, then model and configuration will be loaded from the location indicated by TTNN_TUTORIALS_MODELS_CLIP_PATH.

[ ]:
def download_google_bert_model_and_config(
    model_name: str,
) -> tuple[transformers.models.bert.modeling_bert.BertSelfOutput, transformers.BertConfig]:
    model_location = model_name  # By default, download from Hugging Face

    # If TTNN_TUTORIALS_MODELS_TRACER_PATH is set, use it as the cache directory to avoid requests to Hugging Face
    cache_dir = os.getenv("TTNN_TUTORIALS_MODELS_TRACER_PATH")
    if cache_dir is not None:
        model_location = Path(cache_dir) / Path("config_google_bert.json")

    # Load model weights (download if cache_dir was not set)
    config = transformers.BertConfig.from_pretrained(model_location)
    model = transformers.models.bert.modeling_bert.BertSelfOutput(config).eval()

    return model, config


def download_ttnn_bert_config(model_name: str) -> transformers.BertConfig:
    config_location = model_name  # By default, download from Hugging Face

    # If TTNN_TUTORIALS_MODELS_TRACER_PATH is set, use it as the cache directory to avoid requests to Hugging Face
    cache_dir = os.getenv("TTNN_TUTORIALS_MODELS_TRACER_PATH")
    if cache_dir is not None:
        config_location = Path(cache_dir) / Path("config_ttnn_bert.json")

    # Load config (download if cache_dir was not set)
    config = transformers.BertConfig.from_pretrained(config_location)

    return config


def download_ttnn_bert_model(model_name: str, config: transformers.BertConfig) -> transformers.BertForQuestionAnswering:
    model_location = model_name  # By default, download from Hugging Face

    # If TTNN_TUTORIALS_MODELS_TRACER_PATH is set, use it as the cache directory to avoid requests to Hugging Face
    cache_dir = os.getenv("TTNN_TUTORIALS_MODELS_TRACER_PATH")
    if cache_dir is not None:
        model_location = Path(cache_dir)

    # Load model weights (download if cache_dir was not set)
    model = transformers.BertForQuestionAnswering.from_pretrained(model_location, config=config).eval()

    return model

Example 3: Tracing a BERT Layer

Load a small BERT configuration for demonstration. This is a tiny BERT model with only 4 layers, 256 hidden dimensions, and 4 attention heads.

[ ]:
model_name = "google/bert_uncased_L-4_H-256_A-4"
model, config = download_google_bert_model_and_config(model_name)

# Trace the BERT self-output layer operations
with trace():
    # Create dummy inputs matching the expected dimensions
    # hidden_states: output from self-attention (batch=1, seq_len=64, hidden_size=256)
    hidden_states = torch.rand((1, 64, config.hidden_size))
    # input_tensor: residual connection input
    input_tensor = torch.rand((1, 64, config.hidden_size))

    # Run the layer forward pass
    hidden_states = model(hidden_states, input_tensor)

    # Convert output to TT-NN format for visualization
    output = ttnn.from_torch(hidden_states)

# Visualize the BERT layer computation graph
visualize(output)

Example 4: Trace models written using ttnn

[ ]:
# Configure the dispatch core type based on the architecture
# ETH cores are used on newer architectures, WORKER cores on Grayskull
dispatch_core_type = ttnn.device.DispatchCoreType.ETH
if os.environ.get("ARCH_NAME") and "grayskull" in os.environ.get("ARCH_NAME"):
    dispatch_core_type = ttnn.device.DispatchCoreType.WORKER

# Open device with custom configuration
# - l1_small_size: Set L1 memory allocation to 8KB for small tensors
# - dispatch_core_config: Configure which cores handle dispatch operations
device = ttnn.open_device(
    device_id=0,
    l1_small_size=8192,
    dispatch_core_config=ttnn.device.DispatchCoreConfig(dispatch_core_type)
)
[ ]:
from models.demos.bert.tt import ttnn_bert
from models.demos.bert.tt import ttnn_optimized_bert
from ttnn.model_preprocessing import preprocess_model_parameters


def ttnn_bert(bert):
    """
    Run and trace a complete BERT model for question answering.

    Args:
        bert: Either ttnn_bert or ttnn_optimized_bert module
    """
    # Use a larger BERT model fine-tuned for question answering
    model_name = "phiyodr/bert-large-finetuned-squad2"
    config = download_ttnn_bert_config(model_name)

    # Limit to 1 layer for faster execution in this demo
    # Full BERT-large has 24 layers
    config.num_hidden_layers = 1

    # Set batch size and sequence length for input
    batch_size = 8
    sequence_size = 384  # Standard for question answering tasks

    # ===== Model Parameter Preprocessing =====
    # Convert model parameters to TT-NN format and optimize for device
    # This includes weight packing, layout conversion, and memory placement
    model = download_ttnn_bert_model(model_name, config)
    parameters = preprocess_model_parameters(
        initialize_model=lambda: model,
        custom_preprocessor=bert.custom_preprocessor,
        device=device,
    )

    # ===== Trace BERT Inference =====
    with trace():
        # Create dummy input tensors
        # input_ids: Token IDs from vocabulary
        input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32)

        # token_type_ids: Segment IDs (0 for question, 1 for context in QA)
        torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)

        # position_ids: Position embeddings (usually just 0 to sequence_length-1)
        torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)

        # attention_mask: Mask for padding tokens (only for optimized version)
        # Shape differs between regular and optimized BERT implementations
        torch_attention_mask = torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None

        # Preprocess inputs for TT-NN format
        # This converts PyTorch tensors to device tensors with appropriate layout
        ttnn_bert_inputs = bert.preprocess_inputs(
            input_ids,
            torch_token_type_ids,
            torch_position_ids,
            torch_attention_mask,
            device=device,
        )

        # Run BERT model for question answering
        # Returns start and end logits for answer span prediction
        output = bert.bert_for_question_answering(
            config,
            *ttnn_bert_inputs,
            parameters=parameters,
        )

        # Move output back from device to host for visualization
        output = ttnn.from_device(output)

    # Visualize the complete BERT computation graph
    return visualize(output)


# Run the optimized BERT implementation
# This version includes TT-NN specific optimizations for better performance
ttnn_bert(ttnn_optimized_bert)

Close the device

[ ]:
ttnn.close_device(device)