ttnn.transformer.ring_joint_scaled_dot_product_attention
- ttnn.transformer.ring_joint_scaled_dot_product_attention(input_tensor_q: ttnn.Tensor, input_tensor_k: ttnn.Tensor, input_tensor_v: ttnn.Tensor, joint_tensor_q: ttnn.Tensor, joint_tensor_k: ttnn.Tensor, joint_tensor_v: ttnn.Tensor, *, persistent_output_buffer_k: ttnn.Tensor, persistent_output_buffer_v: ttnn.Tensor, joint_strategy: str, logical_n: int, program_config: ttnn.SDPAProgramConfig, scale: float = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None, dim: int, multi_device_global_semaphore: List[ttnn.GlobalSemaphore], num_links: int, cluster_axis: int, mesh_device: ttnn.MeshDevice, topology: ttnn.ccl.Topology, subdevice_id: tt.tt_metal.SubDeviceId | None = None, ccl_core_grid_offset: ttnn.CoreCoord)
-
RingJointAttention operation that efficiently performs non-causal attention over two sets of query, key, and value tensors, where the first set is sharded across devices in the sequence dimension. Internally, these are concatenated in the sequence dimension (joint_strategy = “rear”), then attention is computed once. The output is split (“sliced”) into two parts: one for the original Q/K/V chunk, and one for the joint Q/K/V chunk.
This op handles optional padding via an attention mask to omit padded tokens from both the “original” and “joint” sequences.
Since N must be divisible by the number of devices, the logical N must be passed in.
- Parameters:
-
input_tensor_q (ttnn.Tensor) – Original queries [b x nh x N/num_devices x dh].
input_tensor_k (ttnn.Tensor) – Original keys [b x nh x N/num_devices x dh].
input_tensor_v (ttnn.Tensor) – Original values [b x nh x N/num_devices x dh].
joint_tensor_q (ttnn.Tensor) – Joint queries [b x nh x L x dh].
joint_tensor_k (ttnn.Tensor) – Joint keys [b x nh x L x dh].
joint_tensor_v (ttnn.Tensor) – Joint values [b x nh x L x dh].
- Keyword Arguments:
-
persistent_output_buffer_k (ttnn.Tensor) – Persistent buffer for gathered K tensor.
persistent_output_buffer_v (ttnn.Tensor) – Persistent buffer for gathered V tensor.
joint_strategy (str) – Strategy for joint attention. Must be “rear”.
logical_n (int) – The logical sequence length N before sharding across devices.
program_config (ttnn.SDPAProgramConfig) – Program configuration for the operation.
scale (float, optional) – Scale factor for QK^T. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – Defaults to None.
dim (int) – Dimension along which to perform the ring all-gather operation.
multi_device_global_semaphore (List[ttnn.GlobalSemaphore]) – Global semaphores for multi-device synchronization.
num_links (int) – Number of communication links to use for ring all-gather.
cluster_axis (int) – Axis of the mesh device along which to perform the all-gather.
mesh_device (ttnn.MeshDevice) – Multi-device mesh for distributed computation.
topology (ttnn.ccl.Topology) – Communication topology (Ring or Linear).
subdevice_id (Optional[tt.tt_metal.SubDeviceId]) – Sub-device identifier. Defaults to None.
ccl_core_grid_offset (ttnn.CoreCoord) – Core grid offset for CCL operations.
- Returns:
-
(ttnn.Tensor, ttnn.Tensor, ttnn.Tensor) – - The attention output for the original Q/K/V shape [b x nh x N/num_devices x dh]. - The attention output for the joint Q/K/V shape [b x nh x L x dh]. - The final log-sum-exp of the operation. [b x nh x (N/num_devices + L) x 1]