Multi-Head Attention

Multi-Head Attention is an important part of all Transformer-based models. This tutorial shows how to write and optimize Multi-Head Attention.

[ ]:
import os
import time
import torch
import ttnn
from loguru import logger

torch.manual_seed(0)

device_id = 0
device = ttnn.open_device(device_id=device_id)

Write Multi-Head Attention with TT-NN

Multi-head can be implemented in torch using the following six operations:

  1. torch.matmul

  2. torch.add

  3. torch.reshape

  4. torch.permute

  5. torch.mul

  6. torch.softmax

TT-NN provides the same APIs for these operations. Multi-head Attention can be implemented similarly. Be mindful of the tensor layout except when using TT-NN.

[ ]:
def multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    num_heads,
):
    fallback_reshape = ttnn.get_fallback_function(ttnn.reshape)

    batch_size, sequence_size, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT)
    query = fallback_reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT)
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT)
    key = fallback_reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT)
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT)
    value = fallback_reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT)
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    attention_scores += attention_mask
    attention_probs = ttnn.softmax(attention_scores, dim=-1, numeric_stable=False)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.to_layout(context_layer, layout=ttnn.ROW_MAJOR_LAYOUT)
    context_layer = fallback_reshape(context_layer, (batch_size, sequence_size, hidden_size))
    context_layer = ttnn.to_layout(context_layer, layout=ttnn.TILE_LAYOUT)

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

Now that the model is written, create input tensors to run and test it.

Configuration

[ ]:
batch_size = 6
sequence_size = 384
num_heads = 16
head_size = 64
hidden_size = num_heads * head_size

Initialize activations and weights with Torch:

[ ]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)
torch_attention_mask = torch.randn((batch_size, 1, 1, sequence_size), dtype=torch.bfloat16)
torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)

Convert activations and weights to TT-NN:

[ ]:
hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device)
attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device)
query_weight = ttnn.from_torch(torch_query_weight, layout=ttnn.TILE_LAYOUT, device=device)
query_bias = ttnn.from_torch(torch_query_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
key_weight = ttnn.from_torch(torch_key_weight, layout=ttnn.TILE_LAYOUT, device=device)
key_bias = ttnn.from_torch(torch_key_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
value_weight = ttnn.from_torch(torch_value_weight, layout=ttnn.TILE_LAYOUT, device=device)
value_bias = ttnn.from_torch(torch_value_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.from_torch(torch_output_weight, layout=ttnn.TILE_LAYOUT, device=device)
output_bias = ttnn.from_torch(torch_output_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)

Run the first iteration of Multi-Head Attention:

[ ]:
start = time.time()
multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start
[ ]:
logger.info(f"Multi-head attention ran in {duration} seconds for the first iteration")

Run a subsequent iteration of Multi-Head Attention:

The first iteration of Multi-Head Attention will take several seconds to process. As TT-NN configures and compiles device code for operations on-the-fly, most execution time is spent compiling.

Fortunately, configuration and compiled device code are stored in a program cache. This means that subsequent iterations will be significantly faster, about two orders of magnitude faster than the first iteration.

[ ]:
start = time.time()
output = multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start
[ ]:
logger.info(f"Multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")

Write an optimized version of Multi-Head Attention:

The optimized version of Multi-Head Attention can be written by:

  • Tilizing all of the tensors ahead of time.

  • Using more performant matmuls that fuse bias and specify the number of cores they execute on.

  • Putting every tensor into L1.

  • Using bfloat8_b data_type.

  • Using custom ttnn.transformer operations instead of ttnn.permute and ttnn.reshape operations.

ttnn.deallocate calls are required, otherwise cores on the device will run out of the L1 memory.

[ ]:
def optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    fused_qkv_weight,
    fused_qkv_bias,
    self_output_weight,
    self_output_bias,
    *,
    num_heads,
    num_cores_x=12,
):
    batch_size, _, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads

    hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)

    fused_qkv_output = ttnn.linear(
        hidden_states,
        fused_qkv_weight,
        bias=fused_qkv_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )

    (
        query,
        key,
        value,
    ) = ttnn.transformer.split_query_key_value_and_split_heads(
        fused_qkv_output,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        num_heads=num_heads,
    )
    ttnn.deallocate(fused_qkv_output)

    attention_scores = ttnn.matmul(
        query,
        key,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(query)
    ttnn.deallocate(key)

    attention_probs = ttnn.transformer.attention_softmax_(attention_scores, attention_mask=attention_mask, head_size=head_size)

    context_layer = ttnn.matmul(
        attention_probs,
        value,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(attention_probs)

    context_layer_after_concatenate_heads = ttnn.transformer.concatenate_heads(
        context_layer,
        memory_config=ttnn.L1_MEMORY_CONFIG,
    )
    ttnn.deallocate(context_layer)

    self_output = ttnn.linear(
        context_layer_after_concatenate_heads,
        self_output_weight,
        bias=self_output_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(context_layer_after_concatenate_heads)

    return self_output

Pre-process the parameters of the optimized model:

  1. Fuse QKV weights and biases.

  2. Reshape and tilize for optimized operations using preprocess_linear_weight and preprocess_linear_bias.

  3. Move to device.

[ ]:
from ttnn.model_preprocessing import (
    preprocess_linear_bias,
    preprocess_linear_weight,
)

torch_qkv_weight = torch.cat([torch_query_weight, torch_key_weight, torch_value_weight], dim=-1)
torch_qkv_bias = torch.cat([torch_query_bias, torch_key_bias, torch_value_bias], dim=-1)

qkv_weight = preprocess_linear_weight(torch_qkv_weight.T, dtype=ttnn.bfloat16)
qkv_bias = preprocess_linear_bias(torch_qkv_bias, dtype=ttnn.bfloat16)
output_weight = preprocess_linear_weight(torch_output_weight.T, dtype=ttnn.bfloat16)
output_bias = preprocess_linear_bias(torch_output_bias, dtype=ttnn.bfloat16)

qkv_weight = ttnn.to_device(qkv_weight, device)
qkv_bias = ttnn.to_device(qkv_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.to_device(output_weight, device)
output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)

Run the first iteration of optimized Multi-Head Attention:

[ ]:
start = time.time()
hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)
optimized_output = optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start
[ ]:
logger.info(f"Optimized multi-head attention ran in {duration} seconds for the first iteration")

Run a subsequent iteration of optimized Multi-Head Attention:

[ ]:
start = time.time()
optimized_output = optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start
[ ]:
logger.info(f"Optimized multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")

Note that the optimized multi-head attention is two orders of magnitude faster than the initial version.

Check that the output of the optimized version matches the output of the original implementation:

[ ]:
torch_output = ttnn.to_torch(output)
torch_optimized_output = ttnn.to_torch(optimized_output)

assert torch.allclose(torch_output, torch_optimized_output)

Close the device:

[ ]:
ttnn.close_device(device)

Full Example and Output

Lets put everything together in a complete example that can be run directly:

ttnn_multihead_attention.py

Running the following script to generate output:

$ python3 $TT_METAL_HOME/ttnn/tutorials/basic_python/ttnn_multihead_attention.py
2025-07-07 13:06:38.768 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:06:38.769 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:06:38.776 | info     |          Device | Opening user mode device driver (tt_cluster.cpp:190)
2025-07-07 13:06:38.776 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:06:38.777 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:06:38.783 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:06:38.784 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:06:38.790 | info     |   SiliconDriver | Harvesting mask for chip 0 is 0x100 (NOC0: 0x100, simulated harvesting mask: 0x0). (cluster.cpp:282)
2025-07-07 13:06:38.887 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)
2025-07-07 13:06:38.931 | info     |   SiliconDriver | Opening local chip ids/pci ids: {0}/[7] and remote chip ids {} (cluster.cpp:147)
2025-07-07 13:06:38.942 | info     |   SiliconDriver | Software version 6.0.0, Ethernet FW version 6.14.0 (Device 0) (cluster.cpp:1039)
2025-07-07 13:06:39.027 | info     |           Metal | AI CLK for device 0 is:   1000 MHz (metal_context.cpp:128)
2025-07-07 13:06:39.603 | info     |           Metal | Initializing device 0. Program cache is enabled (device.cpp:428)
2025-07-07 13:06:39.605 | warning  |           Metal | Unable to bind worker thread to CPU Core. May see performance degradation. Error Code: 22 (hardware_command_queue.cpp:74)
2025-07-07 13:06:51.001 | INFO     | __main__:main:132 - Multi-head attention ran in 9.265338897705078 seconds for the first iteration
2025-07-07 13:06:51.056 | INFO     | __main__:main:151 - Multi-head attention ran in 0.05480194091796875 seconds for the subsequent iteration because of the program cache
2025-07-07 13:06:55.363 | INFO     | __main__:main:259 - Optimized multi-head attention ran in 4.2866740226745605 seconds for the first iteration
2025-07-07 13:06:55.366 | INFO     | __main__:main:274 - Optimized multi-head attention ran in 0.002416849136352539 seconds for the subsequent iteration because of the program cache
2025-07-07 13:06:55.417 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)
2025-07-07 13:06:55.418 | info     |           Metal | Closing mesh device 0 (mesh_device.cpp:488)
2025-07-07 13:06:55.418 | info     |           Metal | Closing device 0 (device.cpp:468)
2025-07-07 13:06:55.418 | info     |           Metal | Disabling and clearing program cache on device 0 (device.cpp:783)
2025-07-07 13:06:55.460 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)