ttnn.indexed_fill

ttnn.indexed_fill(batch_id: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = None, dim: int = 0) ttnn.Tensor

Replaces batch of input in input_b denoted by batch_ids into input_a.

Parameters:
  • batch_id (ttnn.Tensor) – the input tensor.

  • input_tensor_a (ttnn.Tensor) – the input tensor.

  • input_tensor_b (ttnn.Tensor) – the input tensor.

Keyword Arguments:
  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

  • dim (int, optional) – Dimension value. Defaults to 0.

Returns:

ttnn.Tensor – the output tensor.

Example

# Define shapes for input tensors
input_a_shape = (32, 1, 1, 4)
input_b_shape = (6, 1, 1, 4)

# Create a tensors to perform indexed fill
batch_id = torch.randint(0, (32 - 1), (1, 1, 1, 6))
batch_id_ttnn = ttnn.Tensor(batch_id, ttnn.uint32).to(
    device, ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1)
)
input_tensor_a = ttnn.rand(input_a_shape, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
input_tensor_b = ttnn.rand(input_b_shape, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)

# Perform indexed fill
output_tensor = ttnn.indexed_fill(batch_id_ttnn, input_tensor_a, input_tensor_b)
logger.info("Indexed Fill Output Tensor Shape:", output_tensor.shape)