ttnn.moe_routing_remap
- ttnn.moe_routing_remap(routing_weights_tensor: ttnn.Tensor, non_zero_weight_size: integer, expert_parallel_size: integer, cluster_axis: integer, *, memory_config: ttnn.MemoryConfig = None, optional_output_tensor: ttnn.Tensor | None) ttnn.Tensor
-
Remap MoE routing weights to local device routing weights. Partitions groups of non-zero weights, which may be non-uniformly distributed, and maps them to devices along the specified cluster axis.
for example:,
non_zero_weight_size=8, expert_parallel_size=4, Total experts= 32 routing_weights_tensor (1,total_experts): [0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 3, 0, 0, 4, 5, 0, 0, 0, 6,0, 0, 0, 7, 0, 8, 0, 0, 0, 0, 0, 0, 0]
Each device will have 2 non-zero values in their output.
cluster device 0: [0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] cluster device 1: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] cluster device 2: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 6,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] cluster device 3: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 7, 0, 8, 0, 0, 0, 0, 0, 0, 0]
- Parameters:
-
routing_weights_tensor (ttnn.Tensor) – tensor of weights for selected experts, replicated on all devices [1, total_experts]
non_zero_weight_size (integer) – Total number of selected experts, non-zero weights in routing_weights_tensor.
expert_parallel_size (integer) – Number of devices in a cluster.
cluster_axis (integer) – Device mesh axis of cluster, 0: columns, 1: rows.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Optional memory configuration for the output. Defaults to None.
optional_output_tensor (ttnn.Tensor, optional) – Optional output buffer.
- Returns:
-
ttnn.Tensor – Tensor containing the device partitioned local weights, [devices/devices, total_experts]