ttnn.transformer.chunked_scaled_dot_product_attention

ttnn.transformer.chunked_scaled_dot_product_attention(input_tensor_q: ttnn.Tensor, input_tensor_k: ttnn.Tensor, input_tensor_v: ttnn.Tensor, page_table_tensor: ttnn.Tensor, chunk_start_idx: int, *, scale: float = None, memory_config: ttnn.MemoryConfig = None, program_config: SDPAProgramConfig = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None) ttnn.Tensor

Chunked causal scaled dot product attention for processing long sequences in chunks. This variant allows processing of sequences longer than the maximum supported length by splitting the input into chunks and maintaining KV cache state. The KV cache is page-based, and the page table tensor is used to map the page indices to the corresponding KV cache indices.

Parameters:
  • input_tensor_q (ttnn.Tensor) – the input tensor. [b x nqh x s x dh]

  • input_tensor_k (ttnn.Tensor) – the input tensor. [b x nkv x s x dh]

  • input_tensor_v (ttnn.Tensor) – the input tensor. [b x nkv x s x dh]

  • page_table_tensor (ttnn.Tensor) – the page table tensor. [b x num_pages]

  • chunk_start_idx (int) – Absolute position in the sequence where this chunk starts. Must be a multiple of program_config.q_chunk_size. Must be a multiple of program_config.k_chunk_size (workaround for https://github.com/tenstorrent/tt-metal/issues/35225)

Keyword Arguments:
  • scale (float, optional) – Defaults to None.

  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

  • program_config (SDPAProgramConfig, optional) – Defaults to None.

  • compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – Defaults to None.

Returns:

ttnn.Tensor – the output tensor [b x nqh x s x dh].