ttnn.embedding_bw
- ttnn.embedding_bw(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, output_gradient_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig | None = input tensor memory config, output_tensor: ttnn.Tensor | None = None, queue_id: int | None = 0, dtype: ttnn.DataType | None = None) ttnn.Tensor
-
Returns the input gradients of the output gradients tensor with respect to the input indices.
- Parameters:
-
input_tensor (ttnn.Tensor) – the input indices tensor.
weight (ttnn.Tensor) – the embeddings tensor that corresponds to the indices tensor. This tensor is only used to extract the vocabulary size.
output_gradient_tensor (ttnn.Tensor) – the output gradient tensor from the previous backwards op.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to input tensor memory config.
output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.
queue_id (int, optional) – command queue id. Defaults to 0.
dtype (ttnn.DataType, optional) – the data type for the output tensor. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.
Note
The input and the output gradient tensors must have the same datatype.
Example
>>> device_id = 0 >>> device = ttnn.open_device(device_id=device_id) >>> batch_size, seq_len, embedding_dim, num_embeddings = 2, 1024, 4096, 3200
>>> input_shape = (batch_size, seq_len) >>> input_index = torch.randint(0, num_embeddings, input_shape) >>> input_tensor = ttnn.from_torch(input_index, dtype=ttnn.uint32, device=device)
>>> weights_shape = (num_embeddings, embedding_dim) >>> weights = torch.randn(weights_shape, requires_grad=True) >>> weights_ttnn = ttnn.from_torch(weights, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> grad_shape = (1, 1, batch_size * seq_len, embedding_dim) >>> grad_data = torch.randn(grad_shape, requires_grad=True) >>> grad_tensor = ttnn.from_torch(grad_data, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> output = ttnn.embedding_bw(input_tensor, weights_ttnn, grad_tensor, dtype=ttnn.bfloat16)