ttnn.mesh_partition

ttnn.mesh_partition(input_tensor: ttnn.Tensor, dim: int, cluster_axis: int = None, *, memory_config: ttnn.MemoryConfig = None) ttnn.Tensor

Partitions the input tensor across the mesh such that each device has the i/num_devices-th partition of the input tensor along the specified dimension. This is the inverse of all_gather. When cluster axis is specified, we partition along the cluster axis.

Parameters:
  • input_tensor (ttnn.Tensor) – the input tensor.

  • dim (int) – the dimension to partition along.

  • cluster_axis (int, optional) – the cluster axis on the mesh. Defaults to None.

Keyword Arguments:

memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.

Returns:

ttnn.Tensor – The partitioned tensor, with output_shape = input_shape for all the unspecified dimensions, and output_shape[dim] = input_shape[dim] / num_devices, where num_devices is the number of devices along the cluster_axis if specified, else the total number of devices along the mesh.

Example

>>> tensor = ttnn.mesh_partition(
                tt_input_tensors_list[i],
                dim,
                cluster_axis=1,
                memory_config=output_mem_config)