ttnn.transformer.split_query_key_value_and_split_heads
- ttnn.transformer.split_query_key_value_and_split_heads(input_tensor: ttnn.Tensor, kv_input_tensor: ttnn.Tensor = None, *, num_heads: int, num_kv_heads: int | None = None, transpose_key: bool = true, memory_config: ttnn.MemoryConfig | None = None) Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor]
-
Splits
input_tensor
of shape[batch_size, sequence_size, 3 * hidden_size]
into 3 tensors (Query, Key, Value) of shape[batch_size, sequence_size, hidden_size]
. Then, reshapes and permutes the output tensors, to make them ready for computing attention scores.If
kv_input_tensor
is passed in, theninput_tensor
of shape[batch_size, sequence_size, hidden_size]
is only used for Query, andkv_input_tensor
of shape[batch_size, sequence_size, 2 * hidden_size]
is used for Key and Value.For the sharded implementation, the input query, key and value are expected to be concatenated such that the heads are interleaved (q1 k1 v1…qn kn vn).
Equivalent pytorch code:
if kv_input_tensor is not None: input_tensor = torch.cat([input_tensor, kv_input_tensor], dim=-1) if num_kv_heads is None: num_kv_heads = num_heads batch_size, sequence_size, hidden_size = input_tensor.shape # Subtract head sizes for key and value head_size = (hidden_size) // (num_heads + num_kv_heads * 2) tensor = torch.reshape(input_tensor, (batch_size, sequence_size, num_heads + num_kv_heads * 2, head_size)) query, key, value = ( tensor[..., :num_heads, :], tensor[..., num_heads:num_heads + num_kv_heads, :], tensor[..., num_heads + num_kv_heads:, :], ) query = torch.reshape(query, (batch_size, sequence_size, num_heads, head_size)) key = torch.reshape(key, (batch_size, sequence_size, num_kv_heads, head_size)) value = torch.reshape(value, (batch_size, sequence_size, num_kv_heads, head_size)) query = torch.permute(query, (0, 2, 1, 3)).contiguous().clone() key = torch.permute(key, (0, 2, 1, 3)).contiguous().clone() value = torch.permute(value, (0, 2, 1, 3)).contiguous().clone() if transpose_key: key = torch.permute(key, (0, 1, 3, 2)).contiguous().clone() return query, key, value
- Parameters:
-
input_tensor (ttnn.Tensor) – Input Tensor for Query, Key and Value. If
kv_input_tensor
is not None, theninput_tensor
is only used for Query.kv_input_tensor (ttnn.Tensor) – Input Tensor for Key and Value. If passed in,
input_tensor
has to be used only for Query. Defaults to None.
- Keyword Arguments:
-
num_heads (int) – num heads to split into.
num_kv_heads (int, optional) – num heads of Key and num heads of Value. If not passed in, then
num_kv_heads
is set tonum_heads
. Defaults to None.transpose_key (bool) – Whether to transpose the Key tensor on the last two dimensions. Defaults to true
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.
- Returns:
-
Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor] – the output tensor.