ttnn.indexed_fill
- ttnn.indexed_fill(batch_id: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = None, dim: int = 0) ttnn.Tensor
-
Replaces batch of input in input_b denoted by batch_ids into input_a.
- Parameters:
-
batch_id (ttnn.Tensor) – the input tensor.
input_tensor_a (ttnn.Tensor) – the input tensor.
input_tensor_b (ttnn.Tensor) – the input tensor.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.
dim (int, optional) – Dimension value. Defaults to 0.
- Returns:
-
ttnn.Tensor – the output tensor.
Example
# Define shapes for input tensors input_a_shape = (32, 1, 1, 4) input_b_shape = (6, 1, 1, 4) # Create a tensors to perform indexed fill batch_id = torch.randint(0, (32 - 1), (1, 1, 1, 6)) batch_id_ttnn = ttnn.Tensor(batch_id, ttnn.uint32).to( device, ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1) ) input_tensor_a = ttnn.rand(input_a_shape, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) input_tensor_b = ttnn.rand(input_b_shape, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) # Perform indexed fill output_tensor = ttnn.indexed_fill(batch_id_ttnn, input_tensor_a, input_tensor_b) logger.info("Indexed Fill Output Tensor Shape:", output_tensor.shape)