ttnn.to_torch
- ttnn.to_torch(tensor: ttnn.Tensor, dtype: torch.dtype | None = None, *, torch_rank: int | None = None, mesh_composer: ttnn.MeshToTensor | None = None, device: ttnn.Device | None = None, cq_id: int | None = 0) torch.Tensor
-
Converts the ttnn.Tensor tensor into a torch.Tensor. It does not call to_layout for bfloat8_b or bfloat4_b as we now convert to tile layout during tensor.to_torch().
- Parameters:
-
tensor (ttnn.Tensor) – the input tensor.
dtype (torch.dtype, optional) – the desired torch data type of returned tensor. Defaults to None.
- Keyword Arguments:
-
torch_rank (int, optional) – Desired rank of the torch.Tensor. Defaults to None. Will use torch.squeeze operation to remove dimensions until the desired rank is reached. If not possible, the operation will raise an error.
mesh_composer (ttnn.MeshToTensor, optional) – The desired ttnn mesh composer. Defaults to None.
device (ttnn.Device, optional) – The ttnn device of the input tensor. Defaults to None.
cq_id (int, optional) – The command queue ID to use. Defaults to 0.
- Returns:
-
torch.Tensor – The converted torch tensor.
Example
>>> ttnn_tensor = ttnn.from_torch(torch.randn((2,3)), dtype=ttnn.bfloat16) >>> torch_tensor = ttnn.to_torch(ttnn_tensor) >>> print(torch_tensor) tensor([[-0.3008, -0.8438, 0.3242], [ 0.9023, -0.5820, 0.5312]], dtype=torch.bfloat16)