ttnn.all_reduce

ttnn.all_reduce(input_tensor: ttnn.Tensor, *, cluster_axis: int = None, subdevice_id: ttnn.SubDeviceId | None, memory_config: ttnn.MemoryConfig | None, num_links: int = None, topology: ttnn.Topology = None) ttnn.Tensor

All-reduce operation across devices with Sum reduction. If cluster axis is specified, the all-reduce is performed on tensor shards along the cluster axis, resulting in identical tensor shards across all devices along the cluster axis. If it is not specified, then we reduce across all devices in the mesh. All-reduce is a collective operation that reduces data from all devices using the Sum operation and returns the result to all devices.

Parameters:

input_tensor (ttnn.Tensor) – Input tensor to be reduced.

Keyword Arguments:
  • cluster_axis (int, optional) – The axis on the mesh device to reduce across. Defaults to None.

  • subdevice_id (ttnn.SubDeviceId, optional) – Subdevice id for worker cores.

  • memory_config (ttnn.MemoryConfig, optional) – Output memory configuration.

  • num_links (int, optional) – Number of links to use for the all_reduce operation. Defaults to None.

  • topology (ttnn.Topology, optional) – Fabric topology. Defaults to None.

Returns:

ttnn.Tensor – The reduced tensor with the same shape as the input tensor. The output tensor is identical across all devices along the cluster axis if specified, otherwise it is identical across all devices in the mesh.

Example

>>> full_tensor = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16)
>>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 8))
>>> ttnn_tensor = ttnn.from_torch(
                full_tensor,
                dtype=input_dtype,
                device=mesh_device,
                layout=layout,
                memory_config=mem_config,
                mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(1, 8), dims=(-1, -2)))
>>> output = ttnn.all_reduce(ttnn_tensor)
>>> print(output.shape)
[1, 1, 32, 256]