ttnn.broadcast

ttnn.broadcast(sender_coord: MeshCoordinate, cluster_axis: int, mesh_device: MeshDevice, *, num_links: int = 1, memory_config: ttnn.MemoryConfig = input tensor memory config, topology: ttnn.Topology = ttnn.Topology.Ring) ttnn.Tensor of the output on the mesh device.

Performs a broadcast operation from a sender device to all other mesh devices across a cluster axis.

Parameters:
  • input_tensor (ttnn.Tensor) –

  • sender_coord (MeshCoordinate) – Coordinate of the sender device in the mesh.

  • cluster_axis (int) – Provided a MeshTensor, the axis corresponding to MeshDevice to perform the operation on.

  • mesh_device (MeshDevice) – Device mesh to perform the operation on.

Mesh Tensor Programming Guide : https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md

Keyword Arguments:
  • num_links (int, optional) – Number of links to use for the all-broadcast operation. Defaults to 1.

  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to input tensor memory config.

  • topology (ttnn.Topology, optional) – The topology configuration to run the operation in. Valid options are Ring and Linear. Defaults to ttnn.Topology.Ring.

Returns:

ttnn.Tensor of the output on the mesh device.

Example

>>> sender_tensor = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16)
>>> num_devices = 4
>>> device_tensors = []
>>> for device_idx in range(num_devices):
        if device_idx == sender_coord_tuple[cluster_axis]:
            device_tensors.append(sender_tensor)
        else:
            device_tensors.append(torch.zeros_like(sender_tensor))
>>> mesh_tensor_torch = torch.cat(device_tensors, dim=-1)
>>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 4))
>>> sender_coord = MeshCoordinate((0, 0))
>>> mesh_mapper_config = ttnn.MeshMapperConfig(
        [ttnn.PlacementReplicate(), ttnn.PlacementShard(-1)], ttnn.MeshShape(1, num_devices)
    )
>>> ttnn_tensor = ttnn.from_torch(
                mesh_tensor_torch,
                dtype=input_dtype,
                device=mesh_device,
                layout=layout,
                memory_config=mem_config,
                mesh_mapper=ttnn.create_mesh_mapper(mesh_device,mesh_mapper_config))
>>> output = ttnn.broadcast(ttnn_tensor, sender_coord)