ttnn.transformer.joint_scaled_dot_product_attention
- ttnn.transformer.joint_scaled_dot_product_attention(input_tensor_q: ttnn.Tensor, input_tensor_k: ttnn.Tensor, input_tensor_v: ttnn.Tensor, joint_tensor_q: ttnn.Tensor, joint_tensor_k: ttnn.Tensor, joint_tensor_v: ttnn.Tensor, *, joint_strategy: str, scale: float = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None)
-
JointAttention operation that efficiently performs non-causal attention over two sets of query, key, and value tensors. Internally, these are concatenated in the sequence dimension (joint_strategy = “rear”), then attention is computed once. The output is split (“sliced”) into two parts: one for the original Q/K/V chunk, and one for the joint Q/K/V chunk.
This op handles optional padding via an attention mask to omit padded tokens from both the “original” and “joint” sequences.
- Parameters:
-
input_tensor_q (ttnn.Tensor) – Original queries [b x nh x N x dh].
input_tensor_k (ttnn.Tensor) – Original keys [b x nh x N x dh].
input_tensor_v (ttnn.Tensor) – Original values [b x nh x N x dh].
joint_tensor_q (ttnn.Tensor) – Joint queries [b x nh x L x dh].
joint_tensor_k (ttnn.Tensor) – Joint keys [b x nh x L x dh].
joint_tensor_v (ttnn.Tensor) – Joint values [b x nh x L x dh].
- Keyword Arguments:
-
joint_strategy (str) – Strategy for joint attention. Must be “rear”.
program_config (ttnn.SDPAProgramConfig) –
scale (float, optional) – Scale factor for QK^T. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – Defaults to None.
- Returns:
-
(ttnn.Tensor, ttnn.Tensor) – - The attention output for the original Q/K/V shape [b x nh x N x dh]. - The attention output for the joint Q/K/V shape [b x nh x L x dh].