Multi-Head Attention
Multi-Head Attention is an important part of all Transformer-based models. This tutorial will show how to write it and how to then optimize it.
[1]:
import os
[2]:
import time
import torch
import ttnn
torch.manual_seed(0)
device_id = 0
dispatch_core_type = ttnn.device.DispatchCoreType.ETH
if "grayskull" in os.environ.get("ARCH_NAME"):
dispatch_core_type = ttnn.device.DispatchCoreType.WORKER
device = ttnn.open_device(device_id=device_id, l1_small_size=8192, dispatch_core_config=ttnn.device.DispatchCoreConfig(dispatch_core_type))
2024-07-11 18:14:54.821 | DEBUG | ttnn:<module>:136 - Initial ttnn.CONFIG:
{'cache_path': PosixPath('/home/ubuntu/.cache/ttnn'),
'comparison_mode_pcc': 0.9999,
'enable_comparison_mode': False,
'enable_detailed_buffer_report': False,
'enable_detailed_tensor_report': False,
'enable_fast_runtime_mode': True,
'enable_graph_report': False,
'enable_logging': False,
'enable_model_cache': False,
'model_cache_path': PosixPath('/home/ubuntu/.cache/ttnn/models'),
'report_name': None,
'root_report_path': PosixPath('generated/ttnn/reports'),
'throw_exception_on_fallback': False,
'tmp_dir': PosixPath('/tmp/ttnn')}
2024-07-11 18:14:54.907 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.logical_xor be migrated to C++?
2024-07-11 18:14:54.908 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.xlogy be migrated to C++?
2024-07-11 18:14:54.908 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.maximum be migrated to C++?
2024-07-11 18:14:54.909 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.minimum be migrated to C++?
2024-07-11 18:14:54.909 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.atan2 be migrated to C++?
2024-07-11 18:14:54.910 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.hypot be migrated to C++?
2024-07-11 18:14:54.911 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.nextafter be migrated to C++?
2024-07-11 18:14:54.912 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.polyval be migrated to C++?
2024-07-11 18:14:54.912 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.isclose be migrated to C++?
2024-07-11 18:14:54.914 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.all_gather be migrated to C++?
2024-07-11 18:14:54.915 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.pearson_correlation_coefficient be migrated to C++?
2024-07-11 18:14:54.919 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.conv2d be migrated to C++?
2024-07-11 18:14:54.921 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.reshape be migrated to C++?
2024-07-11 18:14:54.921 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.unsqueeze_to_4D be migrated to C++?
2024-07-11 18:14:54.922 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.squeeze be migrated to C++?
2024-07-11 18:14:54.923 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.from_torch be migrated to C++?
2024-07-11 18:14:54.923 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.to_torch be migrated to C++?
2024-07-11 18:14:54.924 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.to_device be migrated to C++?
2024-07-11 18:14:54.925 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.from_device be migrated to C++?
2024-07-11 18:14:54.926 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.allocate_tensor_on_device be migrated to C++?
2024-07-11 18:14:54.926 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.copy_host_to_device_tensor be migrated to C++?
2024-07-11 18:14:54.927 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.deallocate be migrated to C++?
2024-07-11 18:14:54.928 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.clone be migrated to C++?
2024-07-11 18:14:54.929 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.reallocate be migrated to C++?
2024-07-11 18:14:54.929 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.load_tensor be migrated to C++?
2024-07-11 18:14:54.930 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.dump_tensor be migrated to C++?
2024-07-11 18:14:54.931 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.as_tensor be migrated to C++?
2024-07-11 18:14:54.934 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.arange be migrated to C++?
2024-07-11 18:14:54.936 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.mse_loss be migrated to C++?
2024-07-11 18:14:54.936 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.l1_loss be migrated to C++?
2024-07-11 18:14:54.938 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.matmul be migrated to C++?
2024-07-11 18:14:54.939 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.linear be migrated to C++?
2024-07-11 18:14:54.941 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.mac be migrated to C++?
2024-07-11 18:14:54.942 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.addcmul be migrated to C++?
2024-07-11 18:14:54.942 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.addcdiv be migrated to C++?
2024-07-11 18:14:54.943 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.lerp be migrated to C++?
2024-07-11 18:14:54.948 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.logit be migrated to C++?
2024-07-11 18:14:54.949 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.polygamma be migrated to C++?
2024-07-11 18:14:54.949 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.hardshrink be migrated to C++?
2024-07-11 18:14:54.950 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.celu be migrated to C++?
2024-07-11 18:14:54.951 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.softshrink be migrated to C++?
2024-07-11 18:14:54.952 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.clip be migrated to C++?
2024-07-11 18:14:54.952 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.threshold be migrated to C++?
2024-07-11 18:14:54.953 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.glu be migrated to C++?
2024-07-11 18:14:54.954 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.reglu be migrated to C++?
2024-07-11 18:14:54.955 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.swiglu be migrated to C++?
2024-07-11 18:14:54.955 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.geglu be migrated to C++?
2024-07-11 18:14:54.958 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.matmul be migrated to C++?
2024-07-11 18:14:54.958 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.linear be migrated to C++?
2024-07-11 18:14:54.960 | WARNING | ttnn.decorators:operation_decorator:758 - Should ttnn.conv2d be migrated to C++?
Device | INFO | Opening user mode device driver
2024-07-11 18:14:54.976 | INFO | SiliconDriver - Detected 1 PCI device : {0}
2024-07-11 18:14:54.989 | WARNING | SiliconDriver - init_detect_tt_device_numanodes(): Could not determine NumaNodeSet for TT device (physical_device_id: 0 pci_bus_id: 0000:07:00.0)
2024-07-11 18:14:54.989 | WARNING | SiliconDriver - Could not find NumaNodeSet for TT Device (physical_device_id: 0 pci_bus_id: 0000:07:00.0)
2024-07-11 18:14:54.990 | WARNING | SiliconDriver - bind_area_memory_nodeset(): Unable to determine TT Device to NumaNode mapping for physical_device_id: 0. Skipping membind.
---- ttSiliconDevice::init_hugepage: bind_area_to_memory_nodeset() failed (physical_device_id: 0 ch: 0). Hugepage allocation is not on NumaNode matching TT Device. Side-Effect is decreased Device->Host perf (Issue #893).
2024-07-11 18:14:55.014 | INFO | SiliconDriver - Software version 6.0.0, Ethernet FW version 6.9.0 (Device 0)
Metal | INFO | Initializing device 0. Program cache is NOT enabled
Metal | INFO | AI CLK for device 0 is: 800 MHz
Enable program cache
[3]:
device.enable_program_cache()
Metal | INFO | Enabling program cache on device 0
Write Multi-Head Attention using ttnn
Multi-head can be implemented in torch
using just 6 operations:
torch.matmul
torch.add
torch.reshape
torch.permute
torch.mul
torch.softmax
ttnn
provides the exact same APIs to do that and therefore multi-head attention can be implemented in a very similar fashion. Except, when using ttnn
, the user should be mindful of the tensor layout.
[4]:
def multi_head_attention(
hidden_states,
attention_mask,
query_weight,
query_bias,
key_weight,
key_bias,
value_weight,
value_bias,
output_weight,
output_bias,
*,
num_heads,
):
fallback_reshape = ttnn.get_fallback_function(ttnn.reshape)
batch_size, sequence_size, hidden_size = hidden_states.shape
head_size = hidden_size // num_heads
query = hidden_states @ query_weight
query = query + query_bias
query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT)
query = fallback_reshape(query, (batch_size, sequence_size, num_heads, head_size))
query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT)
query = ttnn.permute(query, (0, 2, 1, 3))
key = hidden_states @ key_weight
key = key + key_bias
key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT)
key = fallback_reshape(key, (batch_size, sequence_size, num_heads, head_size))
key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT)
key = ttnn.permute(key, (0, 2, 3, 1))
value = hidden_states @ value_weight
value = value + value_bias
value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT)
value = fallback_reshape(value, (batch_size, sequence_size, num_heads, head_size))
value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT)
value = ttnn.permute(value, (0, 2, 1, 3))
attention_scores = query @ key
attention_scores = attention_scores * (1 / (head_size**0.5))
attention_scores += attention_mask
attention_probs = ttnn.softmax(attention_scores, dim=-1)
context_layer = attention_probs @ value
context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
context_layer = ttnn.to_layout(context_layer, layout=ttnn.ROW_MAJOR_LAYOUT)
context_layer = fallback_reshape(context_layer, (batch_size, sequence_size, hidden_size))
context_layer = ttnn.to_layout(context_layer, layout=ttnn.TILE_LAYOUT)
self_output = context_layer @ output_weight
self_output = self_output + output_bias
return self_output
Now that the model is written, let’s create input tensors to run it and test it
Configuration
[5]:
batch_size = 8
sequence_size = 384
num_heads = 16
head_size = 64
hidden_size = num_heads * head_size
Initialize activations and weights using torch
[6]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)
torch_attention_mask = torch.randn((batch_size, 1, 1, sequence_size), dtype=torch.bfloat16)
torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
Convert activations and weights to ttnn
[7]:
hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device)
attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device)
query_weight = ttnn.from_torch(torch_query_weight, layout=ttnn.TILE_LAYOUT, device=device)
query_bias = ttnn.from_torch(torch_query_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
key_weight = ttnn.from_torch(torch_key_weight, layout=ttnn.TILE_LAYOUT, device=device)
key_bias = ttnn.from_torch(torch_key_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
value_weight = ttnn.from_torch(torch_value_weight, layout=ttnn.TILE_LAYOUT, device=device)
value_bias = ttnn.from_torch(torch_value_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.from_torch(torch_output_weight, layout=ttnn.TILE_LAYOUT, device=device)
output_bias = ttnn.from_torch(torch_output_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
Run the first iteration of Multi-Head Attention
[8]:
start = time.time()
multi_head_attention(
hidden_states,
attention_mask,
query_weight,
query_bias,
key_weight,
key_bias,
value_weight,
value_bias,
output_weight,
output_bias,
num_heads=num_heads,
)
end = time.time()
duration = end - start
[9]:
print(f"Multi-head attention ran in {duration} seconds for the first iteration")
Multi-head attention ran in 8.00607705116272 seconds for the first iteration
Run a subsequent iteration of Multi-Head Attention
[10]:
start = time.time()
output = multi_head_attention(
hidden_states,
attention_mask,
query_weight,
query_bias,
key_weight,
key_bias,
value_weight,
value_bias,
output_weight,
output_bias,
num_heads=num_heads,
)
end = time.time()
duration = end - start
[11]:
print(f"Multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")
Multi-head attention ran in 0.250946044921875 seconds for the subsequent iteration because of the program cache
Write optimized version of Multi-Head Attention
Optimized version of the multi-head attention can be written by:
Tilizing all of the tensors ahead of time
Using more performant matmuls that fuse bias and specify the number of cores they execute on
Putting every tensor into L1
Using bfloat8_b data_type
Using custom
ttnn.transformer
operations instead ofttnn.permute
andttnn.reshape
ttnn.deallocate
calls are needed because otherwise, the cores on the device will run out of the L1 memory
[12]:
def optimized_multi_head_attention(
hidden_states,
attention_mask,
fused_qkv_weight,
fused_qkv_bias,
self_output_weight,
self_output_bias,
*,
num_heads,
num_cores_x=12,
):
batch_size, _, hidden_size = hidden_states.shape
head_size = hidden_size // num_heads
hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)
fused_qkv_output = ttnn.linear(
hidden_states,
fused_qkv_weight,
bias=fused_qkv_bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.bfloat8_b,
core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
)
(
query,
key,
value,
) = ttnn.transformer.split_query_key_value_and_split_heads(
fused_qkv_output,
memory_config=ttnn.L1_MEMORY_CONFIG,
num_heads=num_heads,
)
ttnn.deallocate(fused_qkv_output)
attention_scores = ttnn.matmul(
query,
key,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.bfloat16,
core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
)
ttnn.deallocate(query)
ttnn.deallocate(key)
attention_probs = ttnn.transformer.attention_softmax_(attention_scores, attention_mask=attention_mask, head_size=head_size)
context_layer = ttnn.matmul(
attention_probs,
value,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.bfloat8_b,
core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
)
ttnn.deallocate(attention_probs)
context_layer_after_concatenate_heads = ttnn.transformer.concatenate_heads(
context_layer,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
ttnn.deallocate(context_layer)
self_output = ttnn.linear(
context_layer_after_concatenate_heads,
self_output_weight,
bias=self_output_bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.bfloat16,
core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
)
ttnn.deallocate(context_layer_after_concatenate_heads)
return self_output
Pre-process the parameters of the optimized model
Fuse QKV weights and biases
Reshape and tilize for the optimized operations using preprocess_linear_weight and preprocess_linear_bias
Move to device
[13]:
from ttnn.model_preprocessing import (
preprocess_linear_bias,
preprocess_linear_weight,
)
torch_qkv_weight = torch.cat([torch_query_weight, torch_key_weight, torch_value_weight], dim=-1)
torch_qkv_bias = torch.cat([torch_query_bias, torch_key_bias, torch_value_bias], dim=-1)
qkv_weight = preprocess_linear_weight(torch_qkv_weight.T, dtype=ttnn.bfloat16)
qkv_bias = preprocess_linear_bias(torch_qkv_bias, dtype=ttnn.bfloat16)
output_weight = preprocess_linear_weight(torch_output_weight.T, dtype=ttnn.bfloat16)
output_bias = preprocess_linear_bias(torch_output_bias, dtype=ttnn.bfloat16)
qkv_weight = ttnn.to_device(qkv_weight, device)
qkv_bias = ttnn.to_device(qkv_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.to_device(output_weight, device)
output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
Run the first iteration of the optimized Multi-Head Attention
[14]:
start = time.time()
hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)
optimized_output = optimized_multi_head_attention(
hidden_states,
attention_mask,
qkv_weight,
qkv_bias,
output_weight,
output_bias,
num_heads=num_heads,
)
end = time.time()
duration = end - start
[15]:
print(f"Optimized multi-head attention ran in {duration} seconds for the first iteration")
Optimized multi-head attention ran in 4.474989175796509 seconds for the first iteration
Run a subsequent iteration of the optimized Multi-Head Attention
[16]:
start = time.time()
optimized_output = optimized_multi_head_attention(
hidden_states,
attention_mask,
qkv_weight,
qkv_bias,
output_weight,
output_bias,
num_heads=num_heads,
)
end = time.time()
duration = end - start
[17]:
print(f"Optimized multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")
Optimized multi-head attention ran in 0.020017147064208984 seconds for the subsequent iteration because of the program cache
Note that the optimized multi-head attention is 2 orders of magnitude faster than the initial version
Check that the output of the optimized version matches the output of the original implementation
[18]:
torch_output = ttnn.to_torch(output)
torch_optimized_output = ttnn.to_torch(optimized_output)
assert torch.allclose(torch_output, torch_optimized_output)
Close the device
[19]:
ttnn.close_device(device)
Metal | INFO | Closing device 0
Metal | INFO | Disabling and clearing program cache on device 0