ttnn.embedding_bw

ttnn.embedding_bw(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, output_gradient_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = input tensor memory config, output_tensor: ttnn.Tensor = None, dtype: ttnn.DataType = None) ttnn.Tensor

Returns the input gradients of the output gradients tensor with respect to the input indices.

Parameters:
  • input_tensor (ttnn.Tensor) – the input indices tensor.

  • weight (ttnn.Tensor) – the embeddings tensor that corresponds to the indices tensor. This tensor is only used to extract the vocabulary size.

  • output_gradient_tensor (ttnn.Tensor) – the output gradient tensor from the previous backwards op.

Keyword Arguments:
  • memory_config (ttnn.MemoryConfig, optional) – Memory configuration for the operation. Defaults to input tensor memory config.

  • output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.

  • dtype (ttnn.DataType, optional) – the data type for the output tensor. Defaults to None.

Returns:

ttnn.Tensor – the output tensor.

Note

The input and the output gradient tensors must have the same datatype.

Example

# Create input, weights, and gradient tensors for embedding backward
batch_size, seq_len, embedding_dim, num_embeddings = 2, 1024, 4096, 3200
input_shape = (batch_size, seq_len)
input_index = torch.randint(0, num_embeddings, input_shape)
input_tensor = ttnn.from_torch(input_index, dtype=ttnn.uint32, device=device)

# Create weights tensor
weights_shape = (num_embeddings, embedding_dim)
weights = torch.randn(weights_shape, requires_grad=True)
weights_ttnn = ttnn.from_torch(weights, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

# Create gradient tensor
grad_shape = (1, 1, batch_size * seq_len, embedding_dim)
grad_data = torch.randn(grad_shape, requires_grad=True)
grad_tensor = ttnn.from_torch(grad_data, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

# Compute gradients for embedding operation
output = ttnn.embedding_bw(input_tensor, weights_ttnn, grad_tensor, dtype=ttnn.bfloat16)
logger.info(f"Embedding backward result: {output}")