ttnn.moe
- ttnn.moe(input_tensor: ttnn.Tensor, expert_mask_tensor: ttnn.Tensor, topk_mask_tensor: ttnn.Tensor, k: number, *, memory_config: ttnn.MemoryConfig | None, output_tensor: ttnn.Tensor | None) ttnn.Tensor
-
Returns the weight of the zero-th MoE expert.
Note
This is equivalent to the following PyTorch code:
val, ind = torch.topk(input_tensor + expert_mask_tensor, k) return torch.sum(torch.softmax(val+topk_mask_tensor, dim=-1)*(ind==0), dim=-1)
- Parameters:
-
input_tensor (ttnn.Tensor) – Input Tensor for moe.
expert_mask_tensor (ttnn.Tensor) – Expert Mask Tensor to mask out invalid experts.
topk_mask_tensor (ttnn.Tensor) – Topk Mask Tensor to mask out everything except topk.
k (number) – the number of top elements to look for
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory Config of the output tensors
output_tensor (Optional[ttnn.Tensor]) – preallocated output tensors
- Returns:
-
ttnn.Tensor – the output tensor.
Note
The
input_tensor,expert_mask_tensor, andtopk_mask_tensormust match the following data type and layout:dtype
layout
BFLOAT16
TILE
The output tensor will be in TILE layout and BFLOAT16.
- Memory Support:
-
Interleaved: DRAM and L1
- Limitations:
-
Tensors must be 4D with shape [N, C, H, W], and must be located on the device.
For the
input_tensor, N*C*H must be a multiple of 32. The last dimension must be a power of two and ≥64.kmust be exactly 32.For the
topk_mask_tensor, H must be 32 and W must matchk(i.e. 32).For the
expert_mask_tensor, H must be 32 and W must match W of theinput_tensor.All of the shape validations are performed on padded shapes.
Sharding is not supported for this operation.
Example
N, C, H, W = 1, 1, 32, 64 k = 32 input_tensor = ttnn.rand([N, C, H, W], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) expert_mask = ttnn.zeros([N, C, 1, W], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) topE_mask = ttnn.zeros([N, C, 1, k], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) tensor_output = ttnn.moe(input_tensor, expert_mask, topE_mask, k) logger.info(f"MOE result: {tensor_output}")