ttnn.transformer.joint_scaled_dot_product_attention

ttnn.transformer.joint_scaled_dot_product_attention(input_tensor_q: ttnn.Tensor, input_tensor_k: ttnn.Tensor, input_tensor_v: ttnn.Tensor, joint_tensor_q: ttnn.Tensor, joint_tensor_k: ttnn.Tensor, joint_tensor_v: ttnn.Tensor, *, joint_strategy: str, scale: float = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None)

JointAttention operation that efficiently performs non-causal attention over two sets of query, key, and value tensors. Internally, these are concatenated in the sequence dimension (joint_strategy = “rear”), then attention is computed once. The output is split (“sliced”) into two parts: one for the original Q/K/V chunk, and one for the joint Q/K/V chunk.

This op handles optional padding via an attention mask to omit padded tokens from both the “original” and “joint” sequences.

Parameters:
  • input_tensor_q (ttnn.Tensor) – Original queries [b x nh x N x dh].

  • input_tensor_k (ttnn.Tensor) – Original keys [b x nh x N x dh].

  • input_tensor_v (ttnn.Tensor) – Original values [b x nh x N x dh].

  • joint_tensor_q (ttnn.Tensor) – Joint queries [b x nh x L x dh].

  • joint_tensor_k (ttnn.Tensor) – Joint keys [b x nh x L x dh].

  • joint_tensor_v (ttnn.Tensor) – Joint values [b x nh x L x dh].

Keyword Arguments:
  • joint_strategy (str) – Strategy for joint attention. Must be “rear”.

  • program_config (ttnn.SDPAProgramConfig) –

  • scale (float, optional) – Scale factor for QK^T. Defaults to None.

  • compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – Defaults to None.

Returns:

(ttnn.Tensor, ttnn.Tensor) – - The attention output for the original Q/K/V shape [b x nh x N x dh]. - The attention output for the joint Q/K/V shape [b x nh x L x dh].