ttnn.moe

ttnn.moe(input_tensor: ttnn.Tensor, expert_mask_tensor: ttnn.Tensor, topk_mask_tensor: ttnn.Tensor, k: number, *, memory_config: ttnn.MemoryConfig | None, output_tensor: ttnn.Tensor | None) ttnn.Tensor

Returns the weight of the zero-th MoE expert.

Note

This is equivalent to the following PyTorch code:

val, ind = torch.topk(input_tensor + expert_mask_tensor, k)
return torch.sum(torch.softmax(val+topk_mask_tensor, dim=-1)*(ind==0), dim=-1)
Parameters:
  • input_tensor (ttnn.Tensor) – Input Tensor for moe.

  • expert_mask_tensor (ttnn.Tensor) – Expert Mask Tensor to mask out invalid experts.

  • topk_mask_tensor (ttnn.Tensor) – Topk Mask Tensor to mask out everything except topk.

  • k (number) – the number of top elements to look for

Keyword Arguments:
  • memory_config (ttnn.MemoryConfig, optional) – Memory Config of the output tensors

  • output_tensor (Optional[ttnn.Tensor]) – preallocated output tensors

Returns:

ttnn.Tensor – the output tensor.

Note

The input_tensor, expert_mask_tensor, and topk_mask_tensor must match the following data type and layout:

dtype

layout

BFLOAT16

TILE

The output tensor will be in TILE layout and BFLOAT16.

Memory Support:
  • Interleaved: DRAM and L1

Limitations:
  • Tensors must be 4D with shape [N, C, H, W], and must be located on the device.

  • For the input_tensor, N*C*H must be a multiple of 32. The last dimension must be a power of two and ≥64.

  • k must be exactly 32.

  • For the topk_mask_tensor, H must be 32 and W must match k (i.e. 32).

  • For the expert_mask_tensor, H must be 32 and W must match W of the input_tensor.

  • All of the shape validations are performed on padded shapes.

  • Sharding is not supported for this operation.

Example

N, C, H, W = 1, 1, 32, 64
k = 32

input_tensor = ttnn.rand([N, C, H, W], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
expert_mask = ttnn.zeros([N, C, 1, W], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
topE_mask = ttnn.zeros([N, C, 1, k], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

tensor_output = ttnn.moe(input_tensor, expert_mask, topE_mask, k)
logger.info(f"MOE result: {tensor_output}")