ttnn.from_torch

ttnn.from_torch(tensor: torch.Tensor | None, dtype: ttnn.DataType = None, *, spec: ttnn.TensorSpec = None, tile: ttnn.Tile = None, pad_value: float = None, layout: ttnn.Layout = ttnn.ROW_MAJOR_LAYOUT, device: ttnn.MeshDevice = None, memory_config: ttnn.MemoryConfig = None, mesh_mapper: ttnn.TensorToMesh = None, cq_id: int = 0) ttnn.Tensor | None

Converts the torch.Tensor tensor into a ttnn.Tensor. If tensor is None, the function returns None.

For bfloat8_b or bfloat4_b format, the function itself is called twice, first call runs in bfloat16 format, and calls to_layout to convert from row_major layout to tile layout (for padding purpose in case input is not tile padded). Second call runs in desired format and does not call to_layout for bfloat8_b or bfloat4_b as we now convert to tile layout during tensor creation (ttnn.Tensor).

Parameters:
  • tensor (torch.Tensor | None) – the input tensor. If tensor is None, the function returns None.

  • dtype (ttnn.DataType, optional) – the desired ttnn data type. Defaults to None.

Keyword Arguments:
  • spec (ttnn.TensorSpec, optional) – the desired ttnn tensor spec. Defaults to None.

  • tile (ttnn.Tile, optional) – the desired tiling configuration for the tensor. Defaults to None.

  • pad_value (float, optional) – the desired padding value for tiling. Only used if layout is TILE_LAYOUT. Defaults to None.

  • layout (ttnn.Layout, optional) – the desired ttnn layout. Defaults to ttnn.ROW_MAJOR_LAYOUT.

  • device (ttnn.MeshDevice, optional) – the desired ttnn device. Defaults to None.

  • memory_config (ttnn.MemoryConfig, optional) – The desired ttnn memory configuration. Defaults to None.

  • mesh_mapper (ttnn.TensorToMesh, optional) – The desired ttnn mesh mapper. Defaults to None.

  • cq_id (int, optional) – The command queue ID to use. Defaults to 0.

Returns:

ttnn.Tensor | None – A ttnn.Tensor created from the input torch.Tensor, or None if tensor is None.

Example

# Create a Torch tensor and convert it to a TT-NN tensor
torch_tensor = torch.randn((4, 5), dtype=torch.bfloat16)
ttnn_tensor = ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16, device=device)

logger.info("TT-NN tensor shape", ttnn_tensor.shape)