ttnn.mesh_partition
- ttnn.mesh_partition(input_tensor: ttnn.Tensor, dim: int, cluster_axis: int = None, *, memory_config: ttnn.MemoryConfig = None) ttnn.Tensor
-
Partitions the input tensor across the mesh such that each device has the i/num_devices-th partition of the input tensor along the specified dimension. This is the inverse of all_gather. When cluster axis is specified, we partition along the cluster axis.
- Parameters:
-
input_tensor (ttnn.Tensor) – the input tensor.
dim (int) – the dimension to partition along.
cluster_axis (int, optional) – the cluster axis on the mesh. Defaults to None.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to None.
- Returns:
-
ttnn.Tensor – The partitioned tensor, with output_shape = input_shape for all the unspecified dimensions, and output_shape[dim] = input_shape[dim] / num_devices, where num_devices is the number of devices along the cluster_axis if specified, else the total number of devices along the mesh.
Example
>>> tensor = ttnn.mesh_partition( tt_input_tensors_list[i], dim, cluster_axis=1, memory_config=output_mem_config)