ttnn.transformer.ring_distributed_scaled_dot_product_attention
- ttnn.transformer.ring_distributed_scaled_dot_product_attention(input_tensor_q: ttnn.Tensor, input_tensor_k: ttnn.Tensor, input_tensor_v: ttnn.Tensor, ring_size: uint32_t, ring_id: uint32_t | None, *, scale: float = None, memory_config: ttnn.MemoryConfig = None, program_config: SDPAProgramConfig = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None, page_table: ttnn.Tensor = None, chunk_start_idx: int | None, queue_id: int = 0) ttnn.Tensor
-
Ring-distributed causal scaled dot product attention for multi-device execution. This optimization distributes query computation across multiple devices in a ring topology, with each device computing only a subset of queries to reduce redundant computation caused by causal masking. Each device gets two query chunks (one early, one late) to balance computational load.
This operation is CAUSAL-ONLY and generates causal masks internally for each device’s non-contiguous query assignments. Custom attention masks are not supported.
Note: This operation outputs results contiguously for the device’s assigned queries. Model-level code must perform all-gather and reshuffling to restore sequence order.
- Parameters:
-
input_tensor_q (ttnn.Tensor) – the input tensor. [b x nqh x s x dh] The sequence length ‘s’ must be divisible by 2*ring_size. Additionally, for proper tile alignment, ‘s’ should be divisible by TILE_HEIGHT * 2 * ring_size (typically 256 for ring_size=4).
input_tensor_k (ttnn.Tensor) – the input tensor. [b x nkv x s x dh] When using paged KV cache (page_table is provided), this represents paged KV cache blocks with shape [max_num_blocks x nkv x block_size x dh], where block_size is the page block size.
input_tensor_v (ttnn.Tensor) – the input tensor. [b x nkv x s x dh] When using paged KV cache (page_table is provided), this represents paged KV cache blocks with shape [max_num_blocks x nkv x block_size x dh], where block_size is the page block size.
ring_size (uint32_t) – Number of devices in the ring topology.
ring_id (uint32_t, optional) – This device’s position in the ring (0 to ring_size-1). If None, automatically infers from device coordinate. Defaults to None.
- Keyword Arguments:
-
scale (float, optional) – Attention scaling factor. Defaults to None.
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.
program_config (SDPAProgramConfig, optional) – Program configuration. Defaults to None.
compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – Compute kernel configuration. Defaults to None.
page_table (ttnn.Tensor, optional) – Page table tensor for paged KV cache access [b x num_pages]. Defaults to None.
chunk_start_idx (int, optional) – Absolute position in the sequence where this chunk starts (for prefix caching). Must be a multiple of program_config.q_chunk_size. Defaults to None.
queue_id (int, optional) – command queue id. Defaults to 0.
- Returns:
-
ttnn.Tensor – the output tensor with results for this device’s assigned queries [b x nqh x local_s x dh].