ttnn.broadcast
- ttnn.broadcast(sender_coord: MeshCoordinate, cluster_axis: int, mesh_device: MeshDevice, *, num_links: int = 1, memory_config: ttnn.MemoryConfig = input tensor memory config, topology: ttnn.Topology = ttnn.Topology.Ring) ttnn.Tensor of the output on the mesh device.
-
Performs a broadcast operation from a sender device to all other mesh devices across a cluster axis.
- Parameters:
-
input_tensor (ttnn.Tensor) –
sender_coord (MeshCoordinate) – Coordinate of the sender device in the mesh.
cluster_axis (int) – Provided a MeshTensor, the axis corresponding to MeshDevice to perform the operation on.
mesh_device (MeshDevice) – Device mesh to perform the operation on.
Mesh Tensor Programming Guide : https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md
- Keyword Arguments:
-
num_links (int, optional) – Number of links to use for the all-broadcast operation. Defaults to 1.
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to input tensor memory config.
topology (ttnn.Topology, optional) – The topology configuration to run the operation in. Valid options are Ring and Linear. Defaults to ttnn.Topology.Ring.
- Returns:
-
ttnn.Tensor of the output on the mesh device.
Example
>>> sender_tensor = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16) >>> num_devices = 4 >>> device_tensors = [] >>> for device_idx in range(num_devices): if device_idx == sender_coord_tuple[cluster_axis]: device_tensors.append(sender_tensor) else: device_tensors.append(torch.zeros_like(sender_tensor)) >>> mesh_tensor_torch = torch.cat(device_tensors, dim=-1) >>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 4)) >>> sender_coord = MeshCoordinate((0, 0)) >>> mesh_mapper_config = ttnn.MeshMapperConfig( [ttnn.PlacementReplicate(), ttnn.PlacementShard(-1)], ttnn.MeshShape(1, num_devices) ) >>> ttnn_tensor = ttnn.from_torch( mesh_tensor_torch, dtype=input_dtype, device=mesh_device, layout=layout, memory_config=mem_config, mesh_mapper=ttnn.create_mesh_mapper(mesh_device,mesh_mapper_config)) >>> output = ttnn.broadcast(ttnn_tensor, sender_coord)