ttnn.all_broadcast

ttnn.all_broadcast(input_tensor: ttnn.Tensor, *, cluster_axis: int = None, subdevice_id: ttnn.SubDeviceId | None, memory_config: ttnn.MemoryConfig = input tensor memory config, num_links: int = 1, topology: ttnn.Topology = ttnn.Topology.Linear) List[ttnn.Tensor]

All-broadcast operation across devices. This operation broadcasts data from all devices to all other devices in the mesh, returning a vector of tensors where each tensor contains the data from a corresponding device in the mesh. The output tensors are identical across all devices along the cluster axis if specified, otherwise they are identical across all devices in the mesh.

Parameters:

input_tensor (ttnn.Tensor) – Input tensor to be broadcast.

Keyword Arguments:
  • cluster_axis (int, optional) – The axis on the mesh device to broadcast across. Defaults to None.

  • subdevice_id (ttnn.SubDeviceId, optional) – Subdevice id for worker cores.

  • memory_config (ttnn.MemoryConfig, optional) – Output memory configuration. Defaults to input tensor memory config.

  • num_links (int, optional) – The number of links to use for the all-broadcast operation. Defaults to 1.

  • topology (ttnn.Topology, optional) – Fabric topology. Defaults to ttnn.Topology.Linear.

Returns:

List[ttnn.Tensor] – A list of tensors, one from each device, where each tensor has the same shape as the input.

Example

>>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 8))
>>> input_tensor = ttnn.from_torch(
                torch.rand([1, 1, 32, 256]),
                dtype=ttnn.bfloat16,
                device=mesh_device,
                layout=ttnn.TILE_LAYOUT,
                mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device))
>>> output = ttnn.all_broadcast(input_tensor)
>>> # output is a list of 8 tensors, each with shape [1, 1, 32, 256]