ttnn.embedding_bw
- ttnn.embedding_bw = FastOperation(python_fully_qualified_name='ttnn.embedding_bw', function=<ttnn._ttnn.operations.embedding_backward.embedding_bw_t object>, preprocess_golden_function_inputs=None, golden_function=None, postprocess_golden_function_outputs=None, is_cpp_operation=True, is_experimental=False)
-
Returns the input gradients of the output gradients tensor with respect to the input indices.
- Args:
-
input_tensor (ttnn.Tensor): the input indices tensor. weight (ttnn.Tensor): the embeddings tensor that corresponds to the indices tensor. This tensor is only used to extract the vocabulary size. output_gradient_tensor (ttnn.Tensor): the output gradient tensor from the previous backwards op.
- Keyword args:
-
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to input tensor memory config. output_tensor (ttnn.Tensor, optional): Preallocated output tensor. Defaults to None. queue_id (int, optional): command queue id. Defaults to 0. dtype (ttnn.DataType, optional): the data type for the output tensor. Defaults to None.
- Returns:
-
ttnn.Tensor: the output tensor.
- Note:
-
The input and the output gradient tensors must have the same datatype.
- Example:
-
>>> device_id = 0 >>> device = ttnn.open_device(device_id=device_id) >>> batch_size, seq_len, embedding_dim, num_embeddings = 2, 1024, 4096, 3200
>>> input_shape = (batch_size, seq_len) >>> input_index = torch.randint(0, num_embeddings, input_shape) >>> input_tensor = ttnn.from_torch(input_index, dtype=ttnn.uint32, device=device)
>>> weights_shape = (num_embeddings, embedding_dim) >>> weights = torch.randn(weights_shape, requires_grad=True) >>> weights_ttnn = ttnn.from_torch(weights, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> grad_shape = (1, 1, batch_size * seq_len, embedding_dim) >>> grad_data = torch.randn(grad_shape, requires_grad=True) >>> grad_tensor = ttnn.from_torch(grad_data, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> output = ttnn.embedding_bw(input_tensor, weights_ttnn, grad_tensor, dtype=ttnn.bfloat16)