ttnn.embedding_bw

ttnn.embedding_bw = FastOperation(python_fully_qualified_name='ttnn.embedding_bw', function=<ttnn._ttnn.operations.embedding_backward.embedding_bw_t object>, preprocess_golden_function_inputs=None, golden_function=None, postprocess_golden_function_outputs=None, is_cpp_operation=True, is_experimental=False)

Returns the input gradients of the output gradients tensor with respect to the input indices.

Args:

input_tensor (ttnn.Tensor): the input indices tensor. weight (ttnn.Tensor): the embeddings tensor that corresponds to the indices tensor. This tensor is only used to extract the vocabulary size. output_gradient_tensor (ttnn.Tensor): the output gradient tensor from the previous backwards op.

Keyword args:

memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to input tensor memory config. output_tensor (ttnn.Tensor, optional): Preallocated output tensor. Defaults to None. queue_id (int, optional): command queue id. Defaults to 0. dtype (ttnn.DataType, optional): the data type for the output tensor. Defaults to None.

Returns:

ttnn.Tensor: the output tensor.

Note:

The input and the output gradient tensors must have the same datatype.

Example:
>>> device_id = 0
>>> device = ttnn.open_device(device_id=device_id)
>>> batch_size, seq_len, embedding_dim, num_embeddings = 2, 1024, 4096, 3200
>>> input_shape = (batch_size, seq_len)
>>> input_index = torch.randint(0, num_embeddings, input_shape)
>>> input_tensor = ttnn.from_torch(input_index, dtype=ttnn.uint32, device=device)
>>> weights_shape = (num_embeddings, embedding_dim)
>>> weights = torch.randn(weights_shape, requires_grad=True)
>>> weights_ttnn = ttnn.from_torch(weights, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> grad_shape = (1, 1, batch_size * seq_len, embedding_dim)
>>> grad_data = torch.randn(grad_shape, requires_grad=True)
>>> grad_tensor = ttnn.from_torch(grad_data, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> output = ttnn.embedding_bw(input_tensor, weights_ttnn, grad_tensor, dtype=ttnn.bfloat16)