ttnn.reduce_scatter
- ttnn.reduce_scatter(input_tensor: ttnn.Tensor, dim: int, *, cluster_axis: int = None, subdevice_id: ttnn.SubDeviceId | None, memory_config: ttnn.MemoryConfig | None, output_tensor: ttnn.Tensor | None, num_links: int = None, topology: ttnn.Topology = None) ttnn.Tensor
-
Reduce-scatter operation across devices along a selected dimension and optional cluster axis. This operation reduces the mesh tensor across the devices in the mesh, along the specified dimension. It then scatters the reduced tensor back to the devices in the mesh, along the same dimension. When cluster axis is specified, we reduce and scatter along the cluster axis. When it is not specified, then we reduce and scatter across all devices in the mesh. When the layout is row-major or the scatter breaks apart tiles, we use the composite reduce-scatter implementation that falls back to all-broadcast.
- Parameters:
-
input_tensor (ttnn.Tensor) – Input tensor to be reduced and scattered.
dim (int) – Dimension along which to reduce.
- Keyword Arguments:
-
cluster_axis (int, optional) – The cluster axis to reduce across. Defaults to None.
subdevice_id (ttnn.SubDeviceId, optional) – Subdevice id for worker cores.
memory_config (ttnn.MemoryConfig, optional) – Output memory configuration.
output_tensor (ttnn.Tensor, optional) – Preallocated output tensor.
num_links (int, optional) – The number of links to use for the reduce-scatter operation. Defaults to None, for which the number of links is determined automatically.
topology (ttnn.Topology, optional) – Fabric topology. Defaults to None.
- Returns:
-
ttnn.Tensor – The reduced and scattered tensor, with output_shape = input_shape for all the unspecified dimensions, and output_shape[dim] = input_shape[dim] / num_devices, where num_devices is the number of devices along the cluster_axis if specified, else the total number of devices along the mesh.
Example
>>> # ttnn_tensor shape is [1, 8, 32, 256] >>> # num_devices along cluster_axis is 8 >>> output = ttnn.reduce_scatter(ttnn_tensor, dim=1, cluster_axis=1) >>> print(output.shape) [1, 1, 32, 256]