ttnn.moe_routing_remap

ttnn.moe_routing_remap(routing_weights_tensor: ttnn.Tensor, non_zero_weight_size: integer, expert_parallel_size: integer, cluster_axis: integer, *, memory_config: ttnn.MemoryConfig = None, optional_output_tensor: ttnn.Tensor | None) ttnn.Tensor

Remap MoE routing weights to local device routing weights. Partitions groups of non-zero weights, which may be non-uniformly distributed, and maps them to devices along the specified cluster axis.

for example:,

non_zero_weight_size=8, expert_parallel_size=4, Total experts= 32 routing_weights_tensor (1,total_experts): [0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 3, 0, 0, 4, 5, 0, 0, 0, 6,0, 0, 0, 7, 0, 8, 0, 0, 0, 0, 0, 0, 0]

Each device will have 2 non-zero values in their output.

cluster device 0: [0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] cluster device 1: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] cluster device 2: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 6,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] cluster device 3: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 7, 0, 8, 0, 0, 0, 0, 0, 0, 0]

Parameters:
  • routing_weights_tensor (ttnn.Tensor) – tensor of weights for selected experts, replicated on all devices [1, total_experts]

  • non_zero_weight_size (integer) – Total number of selected experts, non-zero weights in routing_weights_tensor.

  • expert_parallel_size (integer) – Number of devices in a cluster.

  • cluster_axis (integer) – Device mesh axis of cluster, 0: columns, 1: rows.

Keyword Arguments:
  • memory_config (ttnn.MemoryConfig, optional) – Optional memory configuration for the output. Defaults to None.

  • optional_output_tensor (ttnn.Tensor, optional) – Optional output buffer.

Returns:

ttnn.Tensor – Tensor containing the device partitioned local weights, [devices/devices, total_experts]