ttnn.embedding

ttnn.embedding(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig | None = input tensor memory config, output_tensor: ttnn.Tensor | None = None, queue_id: int | None = 0, padding_idx: int | None, layout: ttnn.Layout = ttnn.ROW_MAJOR_LAYOUT, embeddings_type: ttnn.EmbeddingsType = ttnn._ttnn.operations.embedding.EmbeddingsType.GENERIC, dtype: ttnn.DataType | None = None) ttnn.Tensor

Retrieves word embeddings using input_tensor. The input_tensor is a list of indices, and the embedding matrix, and the output is the corresponding word embeddings.

Parameters:
  • input_tensor (ttnn.Tensor) – the input indices tensor.

  • weight (ttnn.Tensor) – the embeddings tensor that corresponds to the indices tensor.

Keyword Arguments:
  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to input tensor memory config.

  • output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.

  • queue_id (int, optional) – command queue id. Defaults to 0.

  • padding_idx (int, optional) – the padding token. Default to None.

  • layout (ttnn.Layout) – the layout of the output tensor. Defaults to ttnn.ROW_MAJOR_LAYOUT.

  • embeddings_type (ttnn.EmbeddingsType) – the type of embeddings. Defaults to ttnn._ttnn.operations.embedding.EmbeddingsType.GENERIC.

  • dtype (ttnn.DataType, optional) – the data type for the output tensor. Defaults to None.

Returns:

ttnn.Tensor – the output tensor of layout == layout or layout of the weights tensor.

Example

>>> device_id = 0
>>> device = ttnn.open_device(device_id=device_id)
>>> tensor = ttnn.to_device(ttnn.from_torch(torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]), dtype=ttnn.uint32), device=device)
>>> # an embedding matrix containing 10 tensors of size 4
>>> weight = ttnn.to_device(ttnn.from_torch(torch.rand(10, 4), dtype=ttnn.bfloat16), device=device)
>>> output = ttnn.embedding(tensor, weight)
ttnn.Tensor([ [[1, 0.106445, 0.988281, 0.59375],
    [0.212891, 0.964844, 0.199219, 0.996094],
    [3.78362e-38, 0, 7.89785e-39, 0],
    [8.04479e-38, 0, 1.25815e-38, 0]],
[[2.71833e-38, 0, 3.59995e-38, 0],
    [7.60398e-38, 0, 1.83671e-38, 0],
    [2.22242e-38, 0, 1.88263e-38, 0],
    [1.35917e-38, 0, 4.49994e-39, 0]]], dtype=bfloat16)