ttnn.all_gather
- ttnn.all_gather(input_tensor: ttnn.Tensor, dim: int, *, cluster_axis: int = None, subdevice_id: ttnn.SubDeviceId | None, memory_config: ttnn.MemoryConfig | None, output_tensor: ttnn.Tensor | None, num_links: int = None, topology: ttnn.Topology = None) ttnn.Tensor
-
All-gather operation across devices along a selected dimension and optional cluster axis. If cluster axis is specified then we gather across the cluster axis, resulting in identical tensor shards across all devices along the cluster axis. If it is not specified, then we gather across all devices in the mesh. All-gather is a collective operation that gathers data from all devices into a new output tensor, concatenated along the specified dim. When cluster_axis is specified, each of the non-cluster_axis dimensions are performing independent all-gathers along the devices on the cluster axis. When the layout is row-major or we have tile padding on the gather dim, we use the composite all-gather implementation that falls back to all-broadcast.
- Parameters:
-
input_tensor (ttnn.Tensor) – Input tensor to be gathered.
dim (int) – Dimension along which to gather.
- Keyword Arguments:
-
cluster_axis (int, optional) – The axis on the mesh device to gather across. Defaults to None.
subdevice_id (ttnn.SubDeviceId, optional) – Subdevice id for worker cores.
memory_config (ttnn.MemoryConfig, optional) – Output memory configuration.
output_tensor (ttnn.Tensor, optional) – Preallocated output tensor.
num_links (int, optional) – The number of links to use for the all-gather operation. Defaults to None, for which the number of links is determined automatically.
topology (ttnn.Topology, optional) – Fabric topology. Defaults to None.
- Returns:
-
ttnn.Tensor – The gathered tensor, with output_shape = input_shape for all the unspecified dimensions, and output_shape[dim] = input_shape[dim] * num_devices, where num_devices is the number of devices along the cluster_axis if specified, else the total number of devices along the mesh.
Example
>>> full_tensor = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16) >>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 8)) >>> ttnn_tensor = ttnn.from_torch( full_tensor, dtype=input_dtype, device=mesh_device, layout=layout, memory_config=mem_config, mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(1, 8), dims=(-1, -2))) >>> output = ttnn.all_gather(ttnn_tensor, dim=0) >>> print(output.shape) [8, 1, 32, 256]