ttnn.transformer.split_query_key_value_and_split_heads

ttnn.transformer.split_query_key_value_and_split_heads(input_tensor: ttnn.Tensor, kv_input_tensor: ttnn.Tensor = None, *, num_heads: int, num_kv_heads: int | None = None, transpose_key: bool = true, memory_config: ttnn.MemoryConfig | None = None) Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor]

Splits input_tensor of shape [batch_size, sequence_size, 3 * hidden_size] into 3 tensors (Query, Key, Value) of shape [batch_size, sequence_size, hidden_size]. Then, reshapes and permutes the output tensors, to make them ready for computing attention scores.

If kv_input_tensor is passed in, then input_tensor of shape [batch_size, sequence_size, hidden_size] is only used for Query, and kv_input_tensor of shape [batch_size, sequence_size, 2 * hidden_size] is used for Key and Value.

For the sharded implementation, the input query, key and value are expected to be concatenated such that the heads are interleaved (q1 k1 v1…qn kn vn).

Equivalent pytorch code:

if kv_input_tensor is not None:
    input_tensor = torch.cat([input_tensor, kv_input_tensor], dim=-1)

if num_kv_heads is None:
    num_kv_heads = num_heads

batch_size, sequence_size, hidden_size = input_tensor.shape
# Subtract head sizes for key and value
head_size = (hidden_size) // (num_heads + num_kv_heads * 2)
tensor = torch.reshape(input_tensor, (batch_size, sequence_size, num_heads + num_kv_heads * 2, head_size))
query, key, value = (
    tensor[..., :num_heads, :],
    tensor[..., num_heads:num_heads + num_kv_heads, :],
    tensor[..., num_heads + num_kv_heads:, :],
)

query = torch.reshape(query, (batch_size, sequence_size, num_heads, head_size))
key = torch.reshape(key, (batch_size, sequence_size, num_kv_heads, head_size))
value = torch.reshape(value, (batch_size, sequence_size, num_kv_heads, head_size))

query = torch.permute(query, (0, 2, 1, 3)).contiguous().clone()
key = torch.permute(key, (0, 2, 1, 3)).contiguous().clone()
value = torch.permute(value, (0, 2, 1, 3)).contiguous().clone()
if transpose_key:
    key = torch.permute(key, (0, 1, 3, 2)).contiguous().clone()

return query, key, value
Parameters:
  • input_tensor (ttnn.Tensor) – Input Tensor for Query, Key and Value. If kv_input_tensor is not None, then input_tensor is only used for Query.

  • kv_input_tensor (ttnn.Tensor) – Input Tensor for Key and Value. If passed in, input_tensor has to be used only for Query. Defaults to None.

Keyword Arguments:
  • num_heads (int) – num heads to split into.

  • num_kv_heads (int, optional) – num heads of Key and num heads of Value. If not passed in, then num_kv_heads is set to num_heads. Defaults to None.

  • transpose_key (bool) – Whether to transpose the Key tensor on the last two dimensions. Defaults to true

  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

Returns:

Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor] – the output tensor.