ttnn.moe
- ttnn.moe() ttnn.Tensor
-
Returns the weight of the zero-th MoE expert.
Note
This is equivalent to the following PyTorch code: val, ind = torch.topk(input_tensor + expert_mask_tensor, k) return torch.sum(torch.softmax(val+topk_mask_tensor, dim=-1)*(ind==0), dim=-1)
:param *
input_tensor: Input Tensor for moe. :param *expert_mask_tensor: Expert Mask Tensor for mask to out invalid experts. :param *topk_mask_tensor: Topk Mask Tensor for mask to out everything except topk. :param *k: the number of top elements to look for:keyword *
memory_config: Memory Config of the output tensors :keyword *output_tensor: preallocated output tensors :kwtype *output_tensor: Optional[ttnn.Tensor]- Returns:
-
ttnn.Tensor – the output tensor.
Note
The
input_tensor,expert_mask_tensor, andtopk_mask_tensormust match the following data type and layout:dtype
layout
BFLOAT16
TILE
The output tensor will be in TILE layout and BFLOAT16.
- Memory Support:
-
Interleaved: DRAM and L1
- Limitations:
-
Tensors must be 4D with shape [N, C, H, W], and must be located on the device.
For the
input_tensor, N*C*H must be a multiple of 32. The last dimension must be a power of two and ≥64.kmust be exactly 32.For the
topk_mask_tensor, H must be 32 and W must matchk(i.e. 32).For the
expert_mask_tensor, H must be 32 and W must match W of theinput_tensor.All of the shape validations are performed on padded shapes.
Sharding is not supported for this operation.
Example
N, C, H, W = 1, 1, 32, 64 k = 32 input_tensor = ttnn.rand([N, C, H, W], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) expert_mask = ttnn.zeros([N, C, 1, W], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) topE_mask = ttnn.zeros([N, C, 1, k], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) tensor_output = ttnn.moe(input_tensor, expert_mask, topE_mask, k) logger.info(f"MOE result: {tensor_output}")