ttnn.transformer.concatenate_heads

ttnn.transformer.concatenate_heads(input_tensor: ttnn.Tensor, *, memory_config=None) ttnn.Tensor

Takes in a tensor of shape [batch_size, num_heads, sequence_size, head_size], concatenates heads back along the width dimension and returns the tensor of shape [batch_size, sequence_size, num_heads * head_size]

Parameters:

input_tensor (ttnn.Tensor) – the input tensor.

Keyword Arguments:

memory_config – Memory Config of the output tensor, if None then it gets set to input_tensor.memory_config(). Defaults to None.

Returns:

ttnn.Tensor – the output tensor.