ttnn.to_torch

ttnn.to_torch(tensor: ttnn.Tensor, dtype: torch.dtype | None = None, *, torch_rank: int | None = None, mesh_composer: ttnn.MeshToTensor | None = None, device: ttnn.Device | None = None, cq_id: int | None = 0) torch.Tensor

Converts the ttnn.Tensor tensor into a torch.Tensor. It does not call to_layout for bfloat8_b or bfloat4_b as we now convert to tile layout during tensor.to_torch().

Parameters:
  • tensor (ttnn.Tensor) – the input tensor.

  • dtype (torch.dtype, optional) – the desired torch data type of returned tensor. Defaults to None.

Keyword Arguments:
  • torch_rank (int, optional) – Desired rank of the torch.Tensor. Defaults to None. Will use torch.squeeze operation to remove dimensions until the desired rank is reached. If not possible, the operation will raise an error.

  • mesh_composer (ttnn.MeshToTensor, optional) – The desired ttnn mesh composer. Defaults to None.

  • device (ttnn.Device, optional) – The ttnn device of the input tensor. Defaults to None.

  • cq_id (int, optional) – The command queue ID to use. Defaults to 0.

Returns:

torch.Tensor – The converted torch tensor.

Example

>>> ttnn_tensor = ttnn.from_torch(torch.randn((2,3)), dtype=ttnn.bfloat16)
>>> torch_tensor = ttnn.to_torch(ttnn_tensor)
>>> print(torch_tensor)
tensor([[-0.3008, -0.8438,  0.3242],
        [ 0.9023, -0.5820,  0.5312]], dtype=torch.bfloat16)