ttnn.transformer.chunked_flash_mla_prefill

ttnn.transformer.chunked_flash_mla_prefill(input_tensor_q: ttnn.Tensor, input_tensor_k: ttnn.Tensor, page_table_tensor: ttnn.Tensor, chunk_start_idx: int, head_dim_v: uint32_t, *, scale: float = None, memory_config: ttnn.MemoryConfig = None, program_config: SDPAProgramConfig = None, compute_kernel_config: ttnn.DeviceComputeKernelConfig = None) ttnn.Tensor

Chunked causal scaled dot product attention for processing long sequences in chunks. This variant allows processing of sequences longer than the maximum supported length by splitting the input into chunks and maintaining KV cache state. The KV cache is page-based, and the page table tensor is used to map the page indices to the corresponding KV cache indices.

Parameters:
  • input_tensor_q (ttnn.Tensor) – the input tensor. [b x nqh x s x dh]

  • input_tensor_k (ttnn.Tensor) – the input tensor. [b x nkv x s x dh]

  • page_table_tensor (ttnn.Tensor) – the page table tensor. [b x num_pages]

  • chunk_start_idx (int) – Absolute position in the sequence where this chunk starts. Must be a multiple of program_config.q_chunk_size.

  • head_dim_v (uint32_t) – the head dimension of V.

Keyword Arguments:
  • scale (float, optional) – Defaults to None.

  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

  • program_config (SDPAProgramConfig, optional) – Defaults to None.

  • compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional) – Defaults to None.

Returns:

ttnn.Tensor – the output tensor [b x nqh x s x dh].