TT-NN Introduction

Welcome to TT-NN, a high-performance deep learning framework optimized for Tenstorrent’s AI accelerators. This tutorial will guide you through the fundamental concepts and operations needed to get started with TT-NN.

What You’ll Learn

  • Basic Setup: Device initialization and library importing

  • Tensor Management: Creating, moving, and manipulating tensors

  • PyTorch Integration: Seamless interoperability with PyTorch

  • Memory Optimization: Leveraging SRAM (L1) and DRAM for performance

  • Neural Network Operations: Building blocks for AI models

  • Advanced Features: Tensor sharding, compilation, and multi-device support

We recommend downloading and running this tutorial on your device! It’s available here.

1. Getting Started

TT-NN is implemented in C++ for optimal performance while providing Python bindings for ease of development and prototyping. This hybrid approach gives you the best of both worlds: high performance and developer productivity.

Importing the Library

Let’s start by importing TT-NN:

[ ]:
# Import the TT-NN library
import ttnn

# Display version information
print("TT-NN successfully imported!")

Device Initialization

Before performing any computations, we need to initialize a Tenstorrent device. The device ID (0) refers to the first available Tenstorrent device in your system:

[ ]:
# Initialize the first Tenstorrent device (device_id=0)
device = ttnn.open_device(device_id=0)

print(f"Device initialized successfully: {device}")
print(f"Device ID: {device.id()}")
print(f"Available compute cores: {device.compute_with_storage_grid_size()}")

2. Tensor Creation and Management

TT-NN tensors can exist in two locations:

  • Host (CPU): For data preparation and post-processing

  • Device (Tenstorrent hardware): For high-performance computation

Creating Host Tensors

Let’s start by creating a tensor on the host (CPU memory):

[ ]:
# Create a tensor filled with 1.0 values on the host (CPU)
# Shape: [10, 15] - 10 rows, 15 columns
host_tensor = ttnn.full([10, 15], 1.0)

print(f"Host tensor created:")
print(f"  Shape: {host_tensor.shape}")
print(f"  Data type: {host_tensor.dtype}")
print(f"  Device: {host_tensor.device()}")  # Should show None (host)
print(f"  Layout: {host_tensor.layout}")
print(f"  Memory config: {host_tensor.memory_config()}")

Moving Tensors to Device

To perform computations on Tenstorrent hardware, we need to transfer tensors from host to device:

[ ]:
# Transfer the host tensor to the device
device_tensor = ttnn.to_device(host_tensor, device)

print(f"Device tensor created:")
print(f"  Shape: {device_tensor.shape}")
print(f"  Device: {device_tensor.device()}")  # Should show the device ID
print(f"  Layout: {device_tensor.layout}")    # Same layout as host tensor
print(f"  Memory config: {device_tensor.memory_config()}")  # Default DRAM

Creating Tensors Directly on Device

For efficiency, you can also create tensors directly on the device without going through the host:

[ ]:
# Create a tensor with random values directly on the device
# This is more efficient as it avoids host->device transfer
device_tensor_2 = ttnn.rand([10, 15], device=device)

print(f"Direct device tensor created:")
print(f"  Shape: {device_tensor_2.shape}")
print(f"  Device: {device_tensor_2.device()}")
print(f"  Layout: {device_tensor_2.layout}")  # May default to different layout

3. PyTorch Interoperability

One of TT-NN’s key strengths is seamless integration with PyTorch, allowing you to leverage existing PyTorch code and models. You can easily convert between PyTorch tensors and TT-NN tensors.

[ ]:
# Import PyTorch for interoperability demonstrations
import torch

print(f"PyTorch version: {torch.__version__}")
print("Ready for PyTorch <-> TT-NN conversions!")
[ ]:
# Create a PyTorch tensor with random values
torch_tensor = torch.rand([10, 15])
print(f"Original PyTorch tensor shape: {torch_tensor.shape}")
print(f"Original PyTorch tensor dtype: {torch_tensor.dtype}")

# Convert PyTorch tensor to TT-NN tensor on host
host_ttnn_from_torch = ttnn.from_torch(torch_tensor)
print(f"\nTT-NN tensor from PyTorch (host):")
print(f"  Shape: {host_ttnn_from_torch.shape}")
print(f"  Layout: {host_ttnn_from_torch.layout}")
print(f"  Device: {host_ttnn_from_torch.device()}")

# Convert PyTorch tensor to TT-NN tensor directly on device with tile layout
# Tile layout is optimized for Tenstorrent hardware operations
device_ttnn_from_torch = ttnn.from_torch(
    torch_tensor,
    device=device,
    layout=ttnn.TILE_LAYOUT
)
print(f"\nTT-NN tensor from PyTorch (device, tiled):")
print(f"  Shape: {device_ttnn_from_torch.shape}")
print(f"  Layout: {device_ttnn_from_torch.layout}")
print(f"  Device: {device_ttnn_from_torch.device()}")
print(f"  Memory config: {device_ttnn_from_torch.memory_config()}")

Moving Tensors Back to Host

After performing computations on the device, you often need to transfer results back to the host for further processing or analysis:

[ ]:
# Create a device tensor for demonstration
device_tensor = ttnn.rand([10, 15], device=device)
print(f"Original device tensor: {device_tensor.device()}")

# Method 1: Transfer device tensor back to host using ttnn.from_device
host_tensor = ttnn.from_device(device_tensor)
print(f"Transferred to host using from_device(): {host_tensor.device()}")

# Method 2: Alternative syntax using .cpu() method (similar to PyTorch)
host_tensor_alt = device_tensor.cpu()
print(f"Transferred to host using .cpu(): {host_tensor_alt.device()}")

# Both methods produce equivalent results
print(f"Shapes match: {host_tensor.shape == host_tensor_alt.shape}")
print(f"Host tensor shape: {host_tensor.shape}")

Converting Back to PyTorch

TT-NN tensors can be seamlessly converted back to PyTorch tensors for further processing or integration with PyTorch-based pipelines:

[ ]:
# Convert TT-NN tensor (device or host) back to PyTorch tensor
# Note: Device tensors are automatically transferred to host during conversion
torch_tensor_result = ttnn.to_torch(device_tensor)

print(f"Converted back to PyTorch:")
print(f"  PyTorch tensor shape: {torch_tensor_result.shape}")
print(f"  PyTorch tensor dtype: {torch_tensor_result.dtype}")
print(f"  PyTorch tensor device: {torch_tensor_result.device}")

# Display tensor properties for comparison
print(f"\nTensor Properties Comparison:")
print(f"Host TT-NN tensor:")
print(f"  Shape: {host_tensor.shape}")
print(f"  Layout: {host_tensor.layout}")
print(f"  Data type: {host_tensor.dtype}")
print(f"  Memory config: {host_tensor.memory_config()}")
print(f"  Device: {host_tensor.device()}")

print(f"\nDevice TT-NN tensor:")
print(f"  Shape: {device_tensor.shape}")
print(f"  Layout: {device_tensor.layout}")
print(f"  Data type: {device_tensor.dtype}")
print(f"  Memory config: {device_tensor.memory_config()}")
print(f"  Device: {device_tensor.device()}")

4. Understanding Tensor Layouts

📖 Documentation: Tensor Layouts

TT-NN supports two primary tensor layouts that affect how data is stored in memory and how operations are performed:

Layout Types

🔲 ROW_MAJOR_LAYOUT - Traditional row-by-row data storage Row Major Layout

🟦 TILE_LAYOUT - Optimized 32×32 tile-based storage Tile Layout

Why Tile Layout Matters

Tenstorrent hardware is specifically optimized for tiled data layouts. Most high-performance operations require tensors in tile layout for efficient execution. When converting to tile layout:

  • Tensors are automatically padded to fill complete 32×32 tiles

  • This padding is handled transparently - you don’t need to worry about it

  • Operations run significantly faster on tiled data

Default Behavior

By default, most tensor creation functions use row-major layout, but this can vary:

[ ]:
# Check default layouts for different tensor creation methods
host_tensor = ttnn.full([3,4], 1.0)
device_tensor = ttnn.full([3,4], 1.0, device=device)

print(f"Host tensor layout: {host_tensor.layout}")        # ROW_MAJOR_LAYOUT
print(f"Device tensor layout: {device_tensor.layout}")    # ROW_MAJOR_LAYOUT

# Note: ttnn.full() uses ROW_MAJOR_LAYOUT by default for both host and device

Functions with Different Default Layouts

However, some operations create tensors directly in tile layout for performance reasons:

[ ]:
# ttnn.rand() defaults to TILE_LAYOUT for device tensors
rand_tensor = ttnn.rand([10,15], device=device)

print(f"Random tensor layout: {rand_tensor.layout}")  # TILE_LAYOUT
print(f"This is because ttnn.rand() optimizes for device operations")

Layout Preservation During Transfer

When you transfer tensors between host and device, the layout is preserved:

[ ]:
# Create a host tensor (row-major by default)
host_tensor = ttnn.full([3, 4], 1.0)
print(f"Host tensor layout: {host_tensor.layout}")

# Transfer to device - layout is preserved
device_tensor = ttnn.to_device(host_tensor, device)
print(f"Device tensor layout: {device_tensor.layout}")

# The layout remains ROW_MAJOR_LAYOUT even on device
print("Layout preserved during host->device transfer")

Converting Between Layouts

You can explicitly convert between layouts using ttnn.to_layout(). This is often necessary to optimize performance:

[ ]:
# Start with row-major layout
print(f"Original layout: {device_tensor.layout}")

# Convert to tile layout for optimized operations
device_tensor = ttnn.to_layout(device_tensor, ttnn.TILE_LAYOUT)
print(f"After conversion: {device_tensor.layout}")

When converting from PyTorch tensors, you can specify the desired layout:

[ ]:
torch_tensor = torch.rand([10,15])
print(ttnn.from_torch(torch_tensor).layout)
print(ttnn.from_torch(torch_tensor, device=device).layout)
print(ttnn.from_torch(torch_tensor, device=device, layout=ttnn.TILE_LAYOUT).layout)

5. Data Types and Precision

TT-NN supports various data types optimized for AI workloads, ranging from high precision (float32) to ultra-compact formats (bfloat4_b) that maximize throughput and memory efficiency.

Supported Data Types

TT-NN supports the following data types, each optimized for different use cases:

Data Type

Bits

Use Case

Trade-off

uint16

16

Integer operations

Standard integer precision

uint32

32

Integer operations

Higher integer precision

float32

32

High precision float

Standard accuracy, more memory

bfloat16

16

Neural networks

Good accuracy, 2x memory savings

bfloat8_b

8

Inference, large models

4x memory savings, reduced accuracy

bfloat4_b

4

Ultra-efficient inference

8x memory savings, lowest accuracy

Performance vs. Accuracy Trade-offs

  • Lower precision formats (bfloat8_b, bfloat4_b) provide:

    • Better memory bandwidth and computational efficiency

    • Faster operations due to reduced data movement

    • Reduced numerical accuracy - may impact model quality

  • Higher precision formats (float32, bfloat16) provide:

    • Higher accuracy for numerical computations

    • More memory usage and potentially slower operations

[ ]:
# Create a tensor with bfloat16 precision (common for neural networks)
x_bf16 = ttnn.rand([1000, 1000], device=device, dtype=ttnn.bfloat16)
print(f"BFloat16 tensor: {x_bf16.dtype}, Shape: {x_bf16.shape}")

# Convert to different data types using ttnn.typecast()
print("\n=== Data Type Conversions ===")

# Convert to float32 (higher precision)
x_float32 = ttnn.typecast(x_bf16, ttnn.float32)
print(f"Float32 tensor: {x_float32.dtype}")

# Convert to uint16 (integer type)
x_uint16 = ttnn.typecast(x_bf16, ttnn.uint16)
print(f"UInt16 tensor: {x_uint16.dtype}")

# Convert to bfloat8_b (reduced precision for efficiency)
x_bf8_b = ttnn.typecast(x_bf16, ttnn.bfloat8_b)
print(f"BFloat8_b tensor: {x_bf8_b.dtype}")

# Convert to bfloat4_b (ultra-low precision)
x_bf4_b = ttnn.typecast(x_bf16, ttnn.bfloat4_b)
print(f"BFloat4_b tensor: {x_bf4_b.dtype}")

print("\nTip: Use lower precision types for inference to maximize throughput!")

6. Basic Tensor Operations

TT-NN provides a comprehensive set of tensor operations similar to PyTorch, but optimized for Tenstorrent hardware. Most operations are performed on device tensors for maximum performance.

Important Operation Requirements

  • Device-only operations: Most TT-NN operations are only supported on device tensors, not host tensors

  • Layout considerations: Many operations perform better on TILE_LAYOUT tensors

  • Matrix multiplication: For advanced control over math fidelity and performance, see the Matrix Engine documentation

Creating Test Data

Let’s create some tensors for demonstrating operations:

[ ]:
# Create a range tensor from 0 to 99, then normalize it to [0, 1]
x = ttnn.arange(start=0, end=100, device=device, layout=ttnn.TILE_LAYOUT)
print(f"Created range tensor: shape={x.shape}, layout={x.layout}")

# Normalize to range [0, 1] by dividing by 100
x = ttnn.divide(x, 100)
print(f"Normalized tensor to [0, 1] range")

# Reshape to a row vector for operations
x = x.reshape([1, 100])
print(f"Reshaped to: {x.shape}")
print(f"Values range from ~0 to ~1")
[ ]:
# Create a second random tensor for binary operations
y = ttnn.rand([1, 100], device=device)
print(f"Created random tensor y: shape={y.shape}")
print(f"Ready for element-wise operations!")
[ ]:
# Arithmetic Operations (Element-wise)
print("=== Arithmetic Operations ===")

# Addition - both operators work
result_add = x + y  # Operator overloading
print(f"Addition (x + y): shape={result_add.shape}")

# Multiplication
result_mul = x * y
print(f"Multiplication (x * y): shape={result_mul.shape}")

# Subtraction
result_sub = x - y
print(f"Subtraction (x - y): shape={result_sub.shape}")

# Division - using function call
result_div = ttnn.divide(x, y)
print(f"Division ttnn.divide(x, y): shape={result_div.shape}")

print("\nAll arithmetic operations completed successfully!")
[ ]:
# Mathematical Functions (Unary operations)
print("=== Mathematical Functions ===")

# Trigonometric functions
sin_x = ttnn.sin(x)
cos_x = ttnn.cos(x)
print(f"sin(x) and cos(x): computed")

# Exponential and logarithmic functions
exp_x = ttnn.exp(x)    # e^x
log_x = ttnn.log(x)    # natural logarithm
print(f"exp(x) and log(x): computed")

# Power and root functions
sqrt_x = ttnn.sqrt(x)        # square root
pow_x = ttnn.pow(x, 2)       # x^2 (square)
print(f"sqrt(x) and pow(x, 2): computed")

print(f"\nAll mathematical functions applied to tensor of shape {x.shape}")
print("These functions work element-wise on the entire tensor")
[ ]:
# Data movement functions
ttnn.sort(y)
[ ]:
# Tensor manipulation functions
ttnn.concat([x, y], dim=1)

Tensor slicing is also supported:

[ ]:
x[:, 50:100]

The full set of supported operations is available in the TT-NN API documentation.

Neural Network Operations

TT-NN provides neural network operations as pure functions (similar to torch.nn.functional), giving you flexibility in structuring your model classes:

[ ]:
input_ids = ttnn.from_torch(
    torch.randint(0, 1000, (2, 32)), dtype=ttnn.uint32, device=device
)
emb_weight = ttnn.rand((1, 1, 1000, 512), dtype=ttnn.bfloat16, device=device)

x = ttnn.embedding(input_ids, emb_weight, layout=ttnn.TILE_LAYOUT)  # [2, 32, 512]
x = ttnn.reshape(x, (2, 1, 32, 512))

# LayerNorm
x = ttnn.layer_norm(x, epsilon=1e-5)

# Linear: 512 -> 2048 -> 512
w1 = ttnn.rand(
    (1, 1, 512, 2048), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device
)
x = ttnn.relu(ttnn.linear(x, w1))
w2 = ttnn.rand(
    (1, 1, 2048, 512), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device
)
x = ttnn.linear(x, w2)

For a comprehensive list of neural network operations, refer to the TT-NN API documentation.

7. Just-In-Time Compilation and Caching

TT-NN uses just-in-time (JIT) compilation to generate optimized kernels for Tenstorrent hardware. This means:

First Run vs. Subsequent Runs

  • First execution: Slow due to kernel compilation

  • Subsequent executions: Fast using cached compiled kernels

What Affects Compilation?

  • Tensor shapes: Different shapes trigger new compilation

  • Operation types: Each operation type needs compilation

  • Data types: Different precisions require different kernels

  • Memory layouts: ROW_MAJOR vs TILE_LAYOUT use different kernels

Let’s demonstrate this compilation behavior:

[ ]:
import time

# Create a test tensor
x = ttnn.rand([1000, 1000], device=device)
print(f"Testing compilation with tensor shape: {x.shape}")

# === FIRST EXECUTION (includes compilation time) ===
print("\n=== First Execution (Compilation + Execution) ===")
start = time.time()
y = ttnn.softmax(x, dim=1)
# IMPORTANT: ttnn.synchronize_device() ensures the operation completes
# Without it, we only measure dispatch time, not actual execution time
ttnn.synchronize_device(device)
first_time = time.time() - start
print(f"Time: {first_time:.4f} seconds (includes compilation)")

# === SECOND EXECUTION (cached, no compilation) ===
print("\n=== Second Execution (Cached) ===")
start = time.time()
y = ttnn.softmax(x, dim=1)
ttnn.synchronize_device(device)
cached_time = time.time() - start
print(f"Time: {cached_time:.4f} seconds (cached)")

# Show the speedup from caching
speedup = first_time / cached_time if cached_time > 0 else float('inf')
print(f"\nSpeedup from caching: {speedup:.1f}x faster!")
print(f"Compilation overhead: {(first_time - cached_time)*1000:.1f}ms")

The compilation cache is tied to compile-time parameters such as tensor shape. When these parameters change, a new compilation is triggered:

[ ]:
# Same operation, different shape
x = ttnn.rand([1337, 1337], device=device)
start = time.time()
y = ttnn.softmax(x, dim=1)
ttnn.synchronize_device(device)
end = time.time()
print(f"First iteration: {end - start} seconds")
start = time.time()
y = ttnn.softmax(x, dim=1)
ttnn.synchronize_device(device)
end = time.time()
print(f"Time taken: {end - start} seconds")

Direct SRAM (L1) control

TT-Metal and TT-NN provide explicit control over tensor placement in device memory hierarchy, allowing you to optimize data movement between slower DRAM and faster SRAM (L1 cache).

Available SRAM per device: - Wormhole n150: 108 MB - Wormhole n300: 192 MB
- Blackhole p100a: 180 MB - Blackhole p150a: 210 MB
[ ]:
dram_tensor = ttnn.rand([4096, 4096], device=device)
dram_tensor.memory_config()
[ ]:
sram_tensor = ttnn.to_memory_config(dram_tensor, ttnn.L1_MEMORY_CONFIG)
sram_tensor.memory_config()
[ ]:
# warmup, compilation
ttnn.sum(dram_tensor, dim=0)
ttnn.sum(sram_tensor, dim=0)
ttnn.synchronize_device(device)
start = time.time()
for _ in range(10):
    ttnn.sum(dram_tensor, dim=0)
ttnn.synchronize_device(device)
end = time.time()
print(f"DRAM Time taken: {end - start} seconds")
start = time.time()
for _ in range(10):
    ttnn.sum(sram_tensor, dim=0)
ttnn.synchronize_device(device)
end = time.time()
print(f"SRAM Time taken: {end - start} seconds")

Memory Management Best Practice:

When performing sequences of operations, manually deallocate intermediate tensors to free memory. This is particularly important for L1 memory due to its limited capacity:

[ ]:
ttnn.deallocate(sram_tensor)

Advanced: Tensor Sharding

For optimal performance, you can shard tensors across compute cores to minimize data movement. This keeps data closer to the cores processing it.

Learn more: - Tensor Sharding Documentation - Technical Report on Tensor Sharding

[ ]:
sharded_tensor = ttnn.to_memory_config(dram_tensor, ttnn.L1_MEMORY_CONFIG)
ttnn.sum(sharded_tensor, dim=0)
ttnn.synchronize_device(device)

start = time.time()
for _ in range(10):
    res = ttnn.sum(sharded_tensor, dim=0)
ttnn.synchronize_device(device)
end = time.time()

interleaved_l1_time = end - start
print(f"Interleaved L1 Time taken: {interleaved_l1_time * 1000} ms")
ttnn.deallocate(sharded_tensor)

sharded_config = ttnn.create_sharded_memory_config(
    shape=dram_tensor.shape,
    core_grid=ttnn.CoreGrid(x=8, y=8),
    strategy=ttnn.ShardStrategy.WIDTH,
)

sharded_tensor = ttnn.to_memory_config(dram_tensor, sharded_config)
ttnn.sum(sharded_tensor, dim=0)
ttnn.synchronize_device(device)

start = time.time()
for _ in range(10):
    res = ttnn.sum(sharded_tensor, dim=0)
ttnn.synchronize_device(device)
end = time.time()

width_sharded_time = end - start
print(f"Width sharded Time taken: {width_sharded_time * 1000} ms")
ttnn.deallocate(sharded_tensor)

sharded_config = ttnn.create_sharded_memory_config(
    shape=dram_tensor.shape,
    core_grid=ttnn.CoreGrid(x=8, y=8),
    strategy=ttnn.ShardStrategy.HEIGHT,
)
sharded_tensor = ttnn.to_memory_config(dram_tensor, sharded_config)
ttnn.sum(sharded_tensor, dim=0)
ttnn.synchronize_device(device)

start = time.time()
for _ in range(10):
    res = ttnn.sum(sharded_tensor, dim=0)
ttnn.synchronize_device(device)
end = time.time()

height_sharded_time = end - start
print(f"Height sharded Time taken: {height_sharded_time * 1000} ms")
ttnn.deallocate(sharded_tensor)

Preserving Intermediate Results in L1

Explicit L1 control allows you to keep intermediate results in fast memory without fusing operations:

[ ]:
x = ttnn.rand([32, 128], device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
[ ]:
w1 = ttnn.rand([128, 128], device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
w2 = ttnn.rand([128, 128], device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
[ ]:
x1 = ttnn.linear(x, w1, memory_config=ttnn.L1_MEMORY_CONFIG)
print(x1.memory_config())

x2 = ttnn.relu(x1) # automatically maintains L1 config
print(x2.memory_config())

x3 = ttnn.linear(x2, w2, memory_config=ttnn.L1_MEMORY_CONFIG)
print(x3.memory_config())

ttnn.deallocate(x1)
ttnn.deallocate(x2)
ttnn.deallocate(x3)

Inference Focus

TT-NN is optimized for inference workloads and does not include automatic differentiation (autograd).

For training support, see our separate training framework tt-train.

Development Tools

TT-NN includes comprehensive tooling for development and debugging:

Advanced Topics

8. Exercise: Implement Scaled Dot-Product Attention

Now let’s put your TT-NN knowledge to the test! Implement a composite version of Scaled Dot-Product Attention (the core operation in Transformers) using basic TT-NN operations.

Background

Scaled Dot-Product Attention is defined as:

SDPA(Q, K, V) = softmax((Q × K^T) / √d_k) × V

Where:

  • Q: Query matrix

  • K: Key matrix

  • V: Value matrix

  • d_k: Dimension of the key vectors (for scaling)

Your Task

Complete the composite_sdpa function below using basic TT-NN operations:

[ ]:
import math

def composite_sdpa(q, k, v, causal_mask, scale=None):
    """
    Implement Scaled Dot-Product Attention using basic TT-NN operations.

    Args:
        q: Query tensor [batch, num_heads, seq_len, head_dim]
        k: Key tensor [batch, num_heads, seq_len, head_dim]
        v: Value tensor [batch, num_heads, seq_len, head_dim]
        causal_mask: Mask tensor for autoregressive attention
        scale: Optional scaling factor (defaults to 1/sqrt(head_dim))

    Returns:
        Attention output tensor [batch, num_heads, seq_len, head_dim]
    """

    # TODO: Implement the following steps:

    # Step 1: Scale the queries (Q × scale)
    # If no scale provided, use 1/sqrt(head_dim)
    if scale is None:
        head_dim = q.shape[-1]
        scale = 1.0 / math.sqrt(head_dim)

    # YOUR CODE HERE: Scale the queries
    q_scaled = ...  # HINT: Use ttnn.multiply()

    # Step 2: Transpose the keys (K^T)
    # YOUR CODE HERE: Transpose the last two dimensions of k
    k_t = ...  # HINT: Use ttnn.permute()

    # Step 3: Compute attention scores (Q_scaled × K^T)
    # YOUR CODE HERE: Matrix multiply q_scaled and k_t
    attn_scores = ...  # HINT: Use ttnn.matmul()

    # Step 4: Apply causal mask (add mask to scores)
    # YOUR CODE HERE: Add the causal_mask to attention scores
    masked_scores = ...  # HINT: Use ttnn.add()

    # Step 5: Apply softmax along the last dimension
    # YOUR CODE HERE: Apply softmax to get attention weights
    attn_weights = ...  # HINT: Use ttnn.softmax()

    # Step 6: Apply attention weights to values (attn_weights × V)
    # YOUR CODE HERE: Matrix multiply attention weights and values
    output = ...  # HINT: Use ttnn.matmul()

    # return output # Replace once your implementation is complete
    return ttnn.rand([1, 32, 1024, 128], device=device)  # Placeholder return

print("SDPA function template ready!")
print("Replace the placeholder operations above to complete the implementation")
[ ]:
import time

batch, num_heads, seq_len, head_dim = 1, 32, 1024, 128
num_iterations, warmup_iterations = 50, 1

print(f"Config: B={batch}, H={num_heads}, S={seq_len}, D={head_dim}, Causal=True")

torch.manual_seed(42)
Q_torch = torch.randn(batch, num_heads, seq_len, head_dim)
K_torch = torch.randn(batch, num_heads, seq_len, head_dim)
V_torch = torch.randn(batch, num_heads, seq_len, head_dim)

Q_tt = ttnn.from_torch(Q_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
K_tt = ttnn.from_torch(K_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
V_tt = ttnn.from_torch(V_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

causal_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).unsqueeze(0).unsqueeze(0)
causal_mask_tt = ttnn.from_torch(causal_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

print("\n=== Accuracy Test ===")
output_composite = composite_sdpa(Q_tt, K_tt, V_tt, causal_mask_tt)
output_composite_torch = ttnn.to_torch(output_composite)[:, :, :seq_len, :head_dim]

output_optimized = ttnn.transformer.scaled_dot_product_attention(Q_tt, K_tt, V_tt, is_causal=True)
output_optimized_torch = ttnn.to_torch(output_optimized)[:, :, :seq_len, :head_dim]

output_torch = torch.nn.functional.scaled_dot_product_attention(Q_torch, K_torch, V_torch, is_causal=True)

pcc_composite = torch.corrcoef(torch.stack([output_composite_torch.flatten(), output_torch.flatten()]))[0, 1].item()
pcc_optimized = torch.corrcoef(torch.stack([output_optimized_torch.flatten(), output_torch.flatten()]))[0, 1].item()
rmse_composite = torch.sqrt(((output_composite_torch - output_torch) ** 2).mean()).item()
rmse_optimized = torch.sqrt(((output_optimized_torch - output_torch) ** 2).mean()).item()

print(f"Composite vs PyTorch:  PCC={pcc_composite:.6f}, RMSE={rmse_composite:.6f}")
print(f"Optimized vs PyTorch:  PCC={pcc_optimized:.6f}, RMSE={rmse_optimized:.6f}")
print("\n=== Speed Test ===")
print("Warming up (compiling kernels)...")
for _ in range(warmup_iterations):
    out = composite_sdpa(Q_tt, K_tt, V_tt, causal_mask_tt)
    out = ttnn.transformer.scaled_dot_product_attention(Q_tt, K_tt, V_tt, is_causal=True)

start = time.perf_counter()
for _ in range(num_iterations):
    output = composite_sdpa(Q_tt, K_tt, V_tt, causal_mask_tt)
ttnn.synchronize_device(device)
composite_time = (time.perf_counter() - start) / num_iterations * 1000

start = time.perf_counter()
for _ in range(num_iterations):
    output = ttnn.transformer.scaled_dot_product_attention(Q_tt, K_tt, V_tt, is_causal=True)
ttnn.synchronize_device(device)
optimized_time = (time.perf_counter() - start) / num_iterations * 1000

speedup = composite_time / optimized_time

print(f"Composite SDPA: {composite_time:.3f} ms")
print(f"Optimized SDPA: {optimized_time:.3f} ms")
print(f"Speedup:        {speedup:.2f}x")

Math Fidelity Control

TT-NN provides fine-grained control over computational precision for performance tuning. The matrix engine supports multiple math fidelity modes that trade accuracy for speed.

Available Fidelity Modes: - LoFi - Lowest precision, highest performance - HiFi2 - Medium precision with FP32 accumulation - HiFi3 - Higher precision
- HiFi4 - Highest precision with full FP32 accumulation

Additional resources: - Matrix Engine Technical Report - Data Format Documentation

[ ]:
import torch
import time

M, K, N = 2048, 2048, 2048
print(f"> Matrix dimensions: {M}x{K} @ {K}x{N}")

torch.manual_seed(42)
a = torch.randn((M, K), dtype=torch.bfloat16)
b = torch.randn((K, N), dtype=torch.bfloat16)
reference = torch.matmul(a.float(), b.float())

tt_a = ttnn.from_torch(a, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
tt_b = ttnn.from_torch(b, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

print("\n" + "-" * 80)
print(f"{'Fidelity':<10} {'Time (ms)':<12} {'Mean Error':<12}")
print("-" * 80)

# Test different fidelities
fidelities = [
    (ttnn.MathFidelity.LoFi, "LoFi"),
    (ttnn.MathFidelity.HiFi2, "HiFi2"),
    (ttnn.MathFidelity.HiFi3, "HiFi3"),
    (ttnn.MathFidelity.HiFi4, "HiFi4"),
]

for fidelity, name in fidelities:
    # Configure compute kernel
    # Note: Enable FP32 accumulation for HiFi2/HiFi4 to see accuracy benefits
    # With BF16 accumulation and large values, LSB corrections can introduce noise
    use_fp32_acc = (fidelity != ttnn.MathFidelity.LoFi)

    config = ttnn.WormholeComputeKernelConfig(
        math_fidelity=fidelity,
        math_approx_mode=False,
        fp32_dest_acc_en=use_fp32_acc,  # FP32 for HiFi2/HiFi4
        packer_l1_acc=use_fp32_acc,     # L1 accumulation for better precision
    )

    # Warm-up
    _ = ttnn.matmul(tt_a, tt_b, compute_kernel_config=config)

    # Time the operation
    start = time.time()
    for _ in range(50):
        result_tt = ttnn.matmul(tt_a, tt_b, compute_kernel_config=config)
    ttnn.synchronize_device(device)
    elapsed = (time.time() - start) / 50 * 1000  # Convert to ms

    # Get result
    result = ttnn.to_torch(result_tt).float()

    # Compute errors and PCC
    error = torch.abs(reference - result)
    mean_err = error.mean().item()

    print(f"{name:<10} {elapsed:>10.4f}   {mean_err:>10.8f}")

Metal Trace

Metal trace allows you to record and replay sequences of operations for improved performance:

# Begin recording operations
tid = ttnn.begin_trace_capture(device, cq_id=0)
output = run_model(input)
ttnn.end_trace_capture(device, tid, cq_id=0)

# Replay the traced operations
ttnn.execute_trace(device, tid, cq_id=0)

This is particularly useful for eliminating Python overhead in production inference.

Multi-device

TT-NN supports distributed computing across multiple devices using collective communication operations (CCL):

Example: Tensor Sharding Across Devices

# Open a 1x2 mesh of devices
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 2))

# Create a torch tensor
torch_tensor = torch.zeros(1, 1, 32, 64)
torch_tensor[..., 0:32] = 1.0
torch_tensor[..., 32:64] = 2.0

# Shard the tensor across devices along dimension 3
mesh_tensor = ttnn.from_torch(
    torch_tensor,
    layout=ttnn.TILE_LAYOUT,
    device=mesh_device,
    mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3),
)

# Perform collective operations
output_tensor = ttnn.all_gather(mesh_tensor, dim=3, num_links=1)

This enables efficient model parallelism and data parallelism across multiple Tenstorrent devices.