ttnn.all_broadcast
- ttnn.all_broadcast(input_tensor: ttnn.Tensor, *, cluster_axis: int = None, subdevice_id: ttnn.SubDeviceId | None, memory_config: ttnn.MemoryConfig = input tensor memory config, num_links: int = 1, topology: ttnn.Topology = ttnn.Topology.Linear) List[ttnn.Tensor]
-
All-broadcast operation across devices. This operation broadcasts data from all devices to all other devices in the mesh, returning a vector of tensors where each tensor contains the data from a corresponding device in the mesh. The output tensors are identical across all devices along the cluster axis if specified, otherwise they are identical across all devices in the mesh.
- Parameters:
-
input_tensor (ttnn.Tensor) – Input tensor to be broadcast.
- Keyword Arguments:
-
cluster_axis (int, optional) – The axis on the mesh device to broadcast across. Defaults to None.
subdevice_id (ttnn.SubDeviceId, optional) – Subdevice id for worker cores.
memory_config (ttnn.MemoryConfig, optional) – Output memory configuration. Defaults to input tensor memory config.
num_links (int, optional) – The number of links to use for the all-broadcast operation. Defaults to 1.
topology (ttnn.Topology, optional) – Fabric topology. Defaults to ttnn.Topology.Linear.
- Returns:
-
List[ttnn.Tensor] – A list of tensors, one from each device, where each tensor has the same shape as the input.
Example
>>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 8)) >>> input_tensor = ttnn.from_torch( torch.rand([1, 1, 32, 256]), dtype=ttnn.bfloat16, device=mesh_device, layout=ttnn.TILE_LAYOUT, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device)) >>> output = ttnn.all_broadcast(input_tensor) >>> # output is a list of 8 tensors, each with shape [1, 1, 32, 256]