ttnn.transformer.concatenate_heads
- ttnn.transformer.concatenate_heads(input_tensor: ttnn.Tensor, *, memory_config=None) ttnn.Tensor
-
Takes in a tensor of shape
[batch_size, num_heads, sequence_size, head_size]
, concatenates heads back along the width dimension and returns the tensor of shape[batch_size, sequence_size, num_heads * head_size]
- Parameters:
-
input_tensor (ttnn.Tensor) – the input tensor.
- Keyword Arguments:
-
memory_config – Memory Config of the output tensor, if None then it gets set to input_tensor.memory_config(). Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.