ttnn.moe

ttnn.moe() ttnn.Tensor

Returns the weight of the zero-th MoE expert.

Note

This is equivalent to the following PyTorch code: val, ind = torch.topk(input_tensor + expert_mask_tensor, k) return torch.sum(torch.softmax(val+topk_mask_tensor, dim=-1)*(ind==0), dim=-1)

:param * input_tensor: Input Tensor for moe. :param * expert_mask_tensor: Expert Mask Tensor for mask to out invalid experts. :param * topk_mask_tensor: Topk Mask Tensor for mask to out everything except topk. :param * k: the number of top elements to look for

:keyword * memory_config: Memory Config of the output tensors :keyword * output_tensor: preallocated output tensors :kwtype * output_tensor: Optional[ttnn.Tensor]

Returns:

ttnn.Tensor – the output tensor.

Note

The input_tensor, expert_mask_tensor, and topk_mask_tensor must match the following data type and layout:

dtype

layout

BFLOAT16

TILE

The output tensor will be in TILE layout and BFLOAT16.

Memory Support:
  • Interleaved: DRAM and L1

Limitations:
  • Tensors must be 4D with shape [N, C, H, W], and must be located on the device.

  • For the input_tensor, N*C*H must be a multiple of 32. The last dimension must be a power of two and ≥64.

  • k must be exactly 32.

  • For the topk_mask_tensor, H must be 32 and W must match k (i.e. 32).

  • For the expert_mask_tensor, H must be 32 and W must match W of the input_tensor.

  • All of the shape validations are performed on padded shapes.

  • Sharding is not supported for this operation.

Example

N, C, H, W = 1, 1, 32, 64
k = 32

input_tensor = ttnn.rand([N, C, H, W], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
expert_mask = ttnn.zeros([N, C, 1, W], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
topE_mask = ttnn.zeros([N, C, 1, k], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

tensor_output = ttnn.moe(input_tensor, expert_mask, topE_mask, k)
logger.info(f"MOE result: {tensor_output}")