ttnn.embedding
- ttnn.embedding(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig | None = input tensor memory config, output_tensor: ttnn.Tensor | None = None, queue_id: int | None = 0, padding_idx: int | None, layout: ttnn.Layout = ttnn.ROW_MAJOR_LAYOUT, embeddings_type: ttnn.EmbeddingsType = ttnn._ttnn.operations.embedding.EmbeddingsType.GENERIC, dtype: ttnn.DataType | None = None) ttnn.Tensor
-
Retrieves word embeddings using input_tensor. The input_tensor is a list of indices, and the embedding matrix, and the output is the corresponding word embeddings.
- Parameters:
-
input_tensor (ttnn.Tensor) – the input indices tensor.
weight (ttnn.Tensor) – the embeddings tensor that corresponds to the indices tensor.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to input tensor memory config.
output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.
queue_id (int, optional) – command queue id. Defaults to 0.
padding_idx (int, optional) – the padding token. Default to None.
layout (ttnn.Layout) – the layout of the output tensor. Defaults to ttnn.ROW_MAJOR_LAYOUT.
embeddings_type (ttnn.EmbeddingsType) – the type of embeddings. Defaults to ttnn._ttnn.operations.embedding.EmbeddingsType.GENERIC.
dtype (ttnn.DataType, optional) – the data type for the output tensor. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor of layout == layout or layout of the weights tensor.
Example
>>> device_id = 0 >>> device = ttnn.open_device(device_id=device_id) >>> tensor = ttnn.to_device(ttnn.from_torch(torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]), dtype=ttnn.uint32), device=device) >>> # an embedding matrix containing 10 tensors of size 4 >>> weight = ttnn.to_device(ttnn.from_torch(torch.rand(10, 4), dtype=ttnn.bfloat16), device=device) >>> output = ttnn.embedding(tensor, weight) ttnn.Tensor([ [[1, 0.106445, 0.988281, 0.59375], [0.212891, 0.964844, 0.199219, 0.996094], [3.78362e-38, 0, 7.89785e-39, 0], [8.04479e-38, 0, 1.25815e-38, 0]], [[2.71833e-38, 0, 3.59995e-38, 0], [7.60398e-38, 0, 1.83671e-38, 0], [2.22242e-38, 0, 1.88263e-38, 0], [1.35917e-38, 0, 4.49994e-39, 0]]], dtype=bfloat16)