ttnn.transformer.concatenate_heads
- ttnn.transformer.concatenate_heads = FastOperation(python_fully_qualified_name='ttnn.transformer.concatenate_heads', function=<ttnn._ttnn.operations.transformer.concatenate_heads_t object>, preprocess_golden_function_inputs=<function default_preprocess_golden_function_inputs>, golden_function=<function _golden_function>, postprocess_golden_function_outputs=<function default_postprocess_golden_function_outputs>, is_cpp_operation=True, is_experimental=False)
-
Takes in a tensor of shape
[batch_size, num_heads, sequence_size, head_size]
, concatenates heads back along the width dimension and returns the tensor of shape[batch_size, sequence_size, num_heads * head_size]
- Args:
-
input_tensor (ttnn.Tensor): the input tensor.
- Keyword Args:
-
memory_config: Memory Config of the output tensor, if None then it gets set to input_tensor.memory_config(). Defaults to None.
- Returns:
-
ttnn.Tensor: the output tensor.