Multi-Head Attention

Multi-Head Attention is an important part of all Transformer-based models. This tutorial will show how to write it and how to then optimize it.

[1]:
import os

[2]:
import time
import torch
import ttnn

torch.manual_seed(0)

device_id = 0
dispatch_core_type = ttnn.device.DispatchCoreType.ETH
if "grayskull" in os.environ.get("ARCH_NAME"):
    dispatch_core_type = ttnn.device.DispatchCoreType.WORKER
device = ttnn.open_device(device_id=device_id, l1_small_size=8192, dispatch_core_config=ttnn.device.DispatchCoreConfig(dispatch_core_type))
2024-07-11 18:14:54.821 | DEBUG    | ttnn:<module>:136 - Initial ttnn.CONFIG:
{'cache_path': PosixPath('/home/ubuntu/.cache/ttnn'),
 'comparison_mode_pcc': 0.9999,
 'enable_comparison_mode': False,
 'enable_detailed_buffer_report': False,
 'enable_detailed_tensor_report': False,
 'enable_fast_runtime_mode': True,
 'enable_graph_report': False,
 'enable_logging': False,
 'enable_model_cache': False,
 'model_cache_path': PosixPath('/home/ubuntu/.cache/ttnn/models'),
 'report_name': None,
 'root_report_path': PosixPath('generated/ttnn/reports'),
 'throw_exception_on_fallback': False,
 'tmp_dir': PosixPath('/tmp/ttnn')}
2024-07-11 18:14:54.907 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.logical_xor be migrated to C++?
2024-07-11 18:14:54.908 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.xlogy be migrated to C++?
2024-07-11 18:14:54.908 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.maximum be migrated to C++?
2024-07-11 18:14:54.909 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.minimum be migrated to C++?
2024-07-11 18:14:54.909 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.atan2 be migrated to C++?
2024-07-11 18:14:54.910 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.hypot be migrated to C++?
2024-07-11 18:14:54.911 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.nextafter be migrated to C++?
2024-07-11 18:14:54.912 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.polyval be migrated to C++?
2024-07-11 18:14:54.912 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.isclose be migrated to C++?
2024-07-11 18:14:54.914 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.all_gather be migrated to C++?
2024-07-11 18:14:54.915 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.pearson_correlation_coefficient be migrated to C++?
2024-07-11 18:14:54.919 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.conv2d be migrated to C++?
2024-07-11 18:14:54.921 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.reshape be migrated to C++?
2024-07-11 18:14:54.921 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.unsqueeze_to_4D be migrated to C++?
2024-07-11 18:14:54.922 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.squeeze be migrated to C++?
2024-07-11 18:14:54.923 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.from_torch be migrated to C++?
2024-07-11 18:14:54.923 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.to_torch be migrated to C++?
2024-07-11 18:14:54.924 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.to_device be migrated to C++?
2024-07-11 18:14:54.925 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.from_device be migrated to C++?
2024-07-11 18:14:54.926 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.allocate_tensor_on_device be migrated to C++?
2024-07-11 18:14:54.926 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.copy_host_to_device_tensor be migrated to C++?
2024-07-11 18:14:54.927 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.deallocate be migrated to C++?
2024-07-11 18:14:54.928 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.clone be migrated to C++?
2024-07-11 18:14:54.929 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.reallocate be migrated to C++?
2024-07-11 18:14:54.929 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.load_tensor be migrated to C++?
2024-07-11 18:14:54.930 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.dump_tensor be migrated to C++?
2024-07-11 18:14:54.931 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.as_tensor be migrated to C++?
2024-07-11 18:14:54.934 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.arange be migrated to C++?
2024-07-11 18:14:54.936 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.mse_loss be migrated to C++?
2024-07-11 18:14:54.936 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.l1_loss be migrated to C++?
2024-07-11 18:14:54.938 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.matmul be migrated to C++?
2024-07-11 18:14:54.939 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.linear be migrated to C++?
2024-07-11 18:14:54.941 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.mac be migrated to C++?
2024-07-11 18:14:54.942 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.addcmul be migrated to C++?
2024-07-11 18:14:54.942 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.addcdiv be migrated to C++?
2024-07-11 18:14:54.943 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.lerp be migrated to C++?
2024-07-11 18:14:54.948 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.logit be migrated to C++?
2024-07-11 18:14:54.949 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.polygamma be migrated to C++?
2024-07-11 18:14:54.949 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.hardshrink be migrated to C++?
2024-07-11 18:14:54.950 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.celu be migrated to C++?
2024-07-11 18:14:54.951 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.softshrink be migrated to C++?
2024-07-11 18:14:54.952 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.clip be migrated to C++?
2024-07-11 18:14:54.952 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.threshold be migrated to C++?
2024-07-11 18:14:54.953 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.glu be migrated to C++?
2024-07-11 18:14:54.954 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.reglu be migrated to C++?
2024-07-11 18:14:54.955 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.swiglu be migrated to C++?
2024-07-11 18:14:54.955 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.geglu be migrated to C++?
2024-07-11 18:14:54.958 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.matmul be migrated to C++?
2024-07-11 18:14:54.958 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.linear be migrated to C++?
2024-07-11 18:14:54.960 | WARNING  | ttnn.decorators:operation_decorator:758 - Should ttnn.conv2d be migrated to C++?
                 Device | INFO     | Opening user mode device driver

2024-07-11 18:14:54.976 | INFO     | SiliconDriver   - Detected 1 PCI device : {0}
2024-07-11 18:14:54.989 | WARNING  | SiliconDriver   - init_detect_tt_device_numanodes(): Could not determine NumaNodeSet for TT device (physical_device_id: 0 pci_bus_id: 0000:07:00.0)
2024-07-11 18:14:54.989 | WARNING  | SiliconDriver   - Could not find NumaNodeSet for TT Device (physical_device_id: 0 pci_bus_id: 0000:07:00.0)
2024-07-11 18:14:54.990 | WARNING  | SiliconDriver   - bind_area_memory_nodeset(): Unable to determine TT Device to NumaNode mapping for physical_device_id: 0. Skipping membind.
---- ttSiliconDevice::init_hugepage: bind_area_to_memory_nodeset() failed (physical_device_id: 0 ch: 0). Hugepage allocation is not on NumaNode matching TT Device. Side-Effect is decreased Device->Host perf (Issue #893).
2024-07-11 18:14:55.014 | INFO     | SiliconDriver   - Software version 6.0.0, Ethernet FW version 6.9.0 (Device 0)
                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   800 MHz

Enable program cache

[3]:
device.enable_program_cache()
                  Metal | INFO     | Enabling program cache on device 0

Write Multi-Head Attention using ttnn

Multi-head can be implemented in torch using just 6 operations:

  1. torch.matmul

  2. torch.add

  3. torch.reshape

  4. torch.permute

  5. torch.mul

  6. torch.softmax

ttnn provides the exact same APIs to do that and therefore multi-head attention can be implemented in a very similar fashion. Except, when using ttnn, the user should be mindful of the tensor layout.

[4]:
def multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    num_heads,
):
    fallback_reshape = ttnn.get_fallback_function(ttnn.reshape)

    batch_size, sequence_size, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT)
    query = fallback_reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT)
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT)
    key = fallback_reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT)
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT)
    value = fallback_reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT)
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    attention_scores += attention_mask
    attention_probs = ttnn.softmax(attention_scores, dim=-1)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.to_layout(context_layer, layout=ttnn.ROW_MAJOR_LAYOUT)
    context_layer = fallback_reshape(context_layer, (batch_size, sequence_size, hidden_size))
    context_layer = ttnn.to_layout(context_layer, layout=ttnn.TILE_LAYOUT)

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

Now that the model is written, let’s create input tensors to run it and test it

Configuration

[5]:
batch_size = 8
sequence_size = 384
num_heads = 16
head_size = 64
hidden_size = num_heads * head_size

Initialize activations and weights using torch

[6]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)
torch_attention_mask = torch.randn((batch_size, 1, 1, sequence_size), dtype=torch.bfloat16)
torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)

Convert activations and weights to ttnn

[7]:
hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device)
attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device)
query_weight = ttnn.from_torch(torch_query_weight, layout=ttnn.TILE_LAYOUT, device=device)
query_bias = ttnn.from_torch(torch_query_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
key_weight = ttnn.from_torch(torch_key_weight, layout=ttnn.TILE_LAYOUT, device=device)
key_bias = ttnn.from_torch(torch_key_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
value_weight = ttnn.from_torch(torch_value_weight, layout=ttnn.TILE_LAYOUT, device=device)
value_bias = ttnn.from_torch(torch_value_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.from_torch(torch_output_weight, layout=ttnn.TILE_LAYOUT, device=device)
output_bias = ttnn.from_torch(torch_output_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)

Run the first iteration of Multi-Head Attention

[8]:
start = time.time()
multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start
[9]:
print(f"Multi-head attention ran in {duration} seconds for the first iteration")
Multi-head attention ran in 8.00607705116272 seconds for the first iteration

Run a subsequent iteration of Multi-Head Attention

[10]:
start = time.time()
output = multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start
[11]:
print(f"Multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")
Multi-head attention ran in 0.250946044921875 seconds for the subsequent iteration because of the program cache

Write optimized version of Multi-Head Attention

Optimized version of the multi-head attention can be written by:

  • Tilizing all of the tensors ahead of time

  • Using more performant matmuls that fuse bias and specify the number of cores they execute on

  • Putting every tensor into L1

  • Using bfloat8_b data_type

  • Using custom ttnn.transformer operations instead of ttnn.permute and ttnn.reshape

ttnn.deallocate calls are needed because otherwise, the cores on the device will run out of the L1 memory

[12]:
def optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    fused_qkv_weight,
    fused_qkv_bias,
    self_output_weight,
    self_output_bias,
    *,
    num_heads,
    num_cores_x=12,
):
    batch_size, _, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads

    hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)

    fused_qkv_output = ttnn.linear(
        hidden_states,
        fused_qkv_weight,
        bias=fused_qkv_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )

    (
        query,
        key,
        value,
    ) = ttnn.transformer.split_query_key_value_and_split_heads(
        fused_qkv_output,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        num_heads=num_heads,
    )
    ttnn.deallocate(fused_qkv_output)

    attention_scores = ttnn.matmul(
        query,
        key,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(query)
    ttnn.deallocate(key)

    attention_probs = ttnn.transformer.attention_softmax_(attention_scores, attention_mask=attention_mask, head_size=head_size)

    context_layer = ttnn.matmul(
        attention_probs,
        value,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(attention_probs)

    context_layer_after_concatenate_heads = ttnn.transformer.concatenate_heads(
        context_layer,
        memory_config=ttnn.L1_MEMORY_CONFIG,
    )
    ttnn.deallocate(context_layer)

    self_output = ttnn.linear(
        context_layer_after_concatenate_heads,
        self_output_weight,
        bias=self_output_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(context_layer_after_concatenate_heads)

    return self_output

Pre-process the parameters of the optimized model

  1. Fuse QKV weights and biases

  2. Reshape and tilize for the optimized operations using preprocess_linear_weight and preprocess_linear_bias

  3. Move to device

[13]:
from ttnn.model_preprocessing import (
    preprocess_linear_bias,
    preprocess_linear_weight,
)

torch_qkv_weight = torch.cat([torch_query_weight, torch_key_weight, torch_value_weight], dim=-1)
torch_qkv_bias = torch.cat([torch_query_bias, torch_key_bias, torch_value_bias], dim=-1)

qkv_weight = preprocess_linear_weight(torch_qkv_weight.T, dtype=ttnn.bfloat16)
qkv_bias = preprocess_linear_bias(torch_qkv_bias, dtype=ttnn.bfloat16)
output_weight = preprocess_linear_weight(torch_output_weight.T, dtype=ttnn.bfloat16)
output_bias = preprocess_linear_bias(torch_output_bias, dtype=ttnn.bfloat16)

qkv_weight = ttnn.to_device(qkv_weight, device)
qkv_bias = ttnn.to_device(qkv_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.to_device(output_weight, device)
output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)

Run the first iteration of the optimized Multi-Head Attention

[14]:
start = time.time()
hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)
optimized_output = optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start
[15]:
print(f"Optimized multi-head attention ran in {duration} seconds for the first iteration")
Optimized multi-head attention ran in 4.474989175796509 seconds for the first iteration

Run a subsequent iteration of the optimized Multi-Head Attention

[16]:
start = time.time()
optimized_output = optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start
[17]:
print(f"Optimized multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")
Optimized multi-head attention ran in 0.020017147064208984 seconds for the subsequent iteration because of the program cache

Note that the optimized multi-head attention is 2 orders of magnitude faster than the initial version

Check that the output of the optimized version matches the output of the original implementation

[18]:
torch_output = ttnn.to_torch(output)
torch_optimized_output = ttnn.to_torch(optimized_output)

assert torch.allclose(torch_output, torch_optimized_output)

Close the device

[19]:
ttnn.close_device(device)
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0