ttnn.embedding_bw
- ttnn.embedding_bw(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, output_gradient_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = input tensor memory config, output_tensor: ttnn.Tensor = None, dtype: ttnn.DataType = None) ttnn.Tensor
-
Returns the input gradients of the output gradients tensor with respect to the input indices.
- Parameters:
-
input_tensor (ttnn.Tensor) – the input indices tensor.
weight (ttnn.Tensor) – the embeddings tensor that corresponds to the indices tensor. This tensor is only used to extract the vocabulary size.
output_gradient_tensor (ttnn.Tensor) – the output gradient tensor from the previous backwards op.
- Keyword Arguments:
-
memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to input tensor memory config.
output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.
dtype (ttnn.DataType, optional) – the data type for the output tensor. Defaults to None.
- Returns:
-
ttnn.Tensor – the output tensor.
Note
The input and the output gradient tensors must have the same datatype.
Example
# Create input, weights, and gradient tensors for embedding backward batch_size, seq_len, embedding_dim, num_embeddings = 2, 1024, 4096, 3200 input_shape = (batch_size, seq_len) input_index = torch.randint(0, num_embeddings, input_shape) input_tensor = ttnn.from_torch(input_index, dtype=ttnn.uint32, device=device) # Create weights tensor weights_shape = (num_embeddings, embedding_dim) weights = torch.randn(weights_shape, requires_grad=True) weights_ttnn = ttnn.from_torch(weights, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) # Create gradient tensor grad_shape = (1, 1, batch_size * seq_len, embedding_dim) grad_data = torch.randn(grad_shape, requires_grad=True) grad_tensor = ttnn.from_torch(grad_data, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) # Compute gradients for embedding operation output = ttnn.embedding_bw(input_tensor, weights_ttnn, grad_tensor, dtype=ttnn.bfloat16) logger.info(f"Embedding backward result: {output}")