ttnn.embedding
- ttnn.embedding = Operation(python_fully_qualified_name='ttnn.embedding', function=<ttnn._ttnn.operations.embedding.embedding_t object>, preprocess_golden_function_inputs=<function default_preprocess_golden_function_inputs>, golden_function=<function _golden_function>, postprocess_golden_function_outputs=<function default_postprocess_golden_function_outputs>, is_cpp_operation=True, is_experimental=False)
-
Retrieves word embeddings using input_tensor. The input_tensor is a list of indices, and the embedding matrix, and the output is the corresponding word embeddings.
- Parameters:
-
input_tensor (ttnn.Tensor) – the input indices tensor.
weight (ttnn.Tensor) – the embeddings tensor that corresponds to the indices tensor.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to input tensor memory config.
output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.
queue_id (int, optional) – command queue id. Defaults to 0.
padding_idx (int, optional) – the padding token. Default to None.
layout (ttnn.Layout) – the layout of the output tensor. Defaults to ttnn.ROW_MAJOR_LAYOUT.
embeddings_type (ttnn.EmbeddingsType) – the type of embeddings. Defaults to ttnn._ttnn.operations.embedding.EmbeddingsType.GENERIC.
dtype (ttnn.DataType, optional) – the data type for the output tensor. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor of layout == layout or layout of the weights tensor.
Example
>>> device_id = 0 >>> device = ttnn.open_device(device_id=device_id) >>> tensor = ttnn.to_device(ttnn.from_torch(torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]), dtype=ttnn.uint32), device=device) >>> # an embedding matrix containing 10 tensors of size 4 >>> weight = ttnn.to_device(ttnn.from_torch(torch.rand(10, 4), dtype=ttnn.bfloat16), device=device) >>> output = ttnn.embedding(tensor, weight) ttnn.Tensor([ [[1, 0.106445, 0.988281, 0.59375], [0.212891, 0.964844, 0.199219, 0.996094], [3.78362e-38, 0, 7.89785e-39, 0], [8.04479e-38, 0, 1.25815e-38, 0]], [[2.71833e-38, 0, 3.59995e-38, 0], [7.60398e-38, 0, 1.83671e-38, 0], [2.22242e-38, 0, 1.88263e-38, 0], [1.35917e-38, 0, 4.49994e-39, 0]]], dtype=bfloat16)