ttnn.update_cache

ttnn.update_cache() None

Updates the cache tensor in place with the values from input at the specified update_idx. When cache has batch less than 32, input is assumed to have batch padded to 32 and [batch_offset:batch_offset+batch] from dim[-2] of input is used to update the cache.

:param * cache: The cache tensor to be written to. :type * cache: ttnn.Tensor :param * input: The token tensor to be written to the cache. :type * input: ttnn.Tensor :param * update_index: The index into the cache tensor. :type * update_index: int :param * batch_offset: The batch_offset into the cache tensor. Default = 0 . :type * batch_offset: int

:keyword * compute_kernel_config Optional[DeviceComputeKernelConfig]:

Example

>>> tensor1 = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> tensor2 = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = ttnn.update_cache(tensor1, tensor2, update_index)