ttnn.transformer.chunked_scaled_dot_product_attention

ttnn.transformer.chunked_scaled_dot_product_attention(input_tensor_q: ttnn.Tensor, input_tensor_k: ttnn.Tensor, input_tensor_v: ttnn.Tensor, page_table_tensor: ttnn.Tensor, chunk_start_idx: int | None, chunk_start_idx_tensor: ttnn.Tensor | None, *, scale: float = None, memory_config: ttnn.MemoryConfig = None, program_config: SDPAProgramConfig = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None) ttnn.Tensor

Chunked causal scaled dot product attention for paged KV cache and long sequences. Processes one Q chunk at a time; K/V are provided as paged cache. The page table maps virtual block indices to physical blocks. Two calling conventions:

Legacy (chunk_start_idx as int): Pass chunk_start_idx (integer). The offset is fixed at dispatch time. Use when iterating chunks from Python and passing a new scalar each call. Program is cached per (config, chunk_start_idx) for the first chunk; later chunks reuse when possible.

Flexible (chunk_start_idx_tensor): Pass chunk_start_idx_tensor (ttnn.Tensor of shape [1], dtype int32) on device. The kernel reads the start index from device memory at runtime. Use for:

  • Trace capture/replay: capture one SDPA call, then replay with different chunk_start_idx by updating the tensor on device (no recompile). One program handles variable prefix lengths by updating the tensor each step.

The program is compiled once (fixed max page table size); the trace key does not include the runtime offset.

Parameters:
  • input_tensor_q (ttnn.Tensor) – Q chunk. [b x nqh x chunk_s x dh]

  • input_tensor_k (ttnn.Tensor) – Paged K cache. [max_blocks x nkv x block_s x dh]

  • input_tensor_v (ttnn.Tensor) – Paged V cache. [max_blocks x nkv x block_s x dh]

  • page_table_tensor (ttnn.Tensor) – Page table. [b x num_pages], int32.

  • chunk_start_idx (int, optional) – Legacy: absolute sequence index for this chunk. Must be a multiple of program_config.q_chunk_size. Must be a multiple of program_config.k_chunk_size (workaround for https://github.com/tenstorrent/tt-metal/issues/35225) Omit when using chunk_start_idx_tensor.

  • chunk_start_idx_tensor (ttnn.Tensor, optional) – Flexible: device tensor [1] int32 holding the chunk start index; read at runtime. Use for trace or prefix caching. Must be a multiple of program_config.q_chunk_size. Must be a multiple of program_config.k_chunk_size (workaround for https://github.com/tenstorrent/tt-metal/issues/35225)

Keyword Arguments:
  • scale (float, optional) – Defaults to None.

  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

  • program_config (SDPAProgramConfig, optional) – Defaults to None.

  • compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – Defaults to None.

Returns:

ttnn.Tensor – the output tensor [b x nqh x s x dh].