ttnn.typecast
- ttnn.typecast(input_tensor: ttnn.Tensor, dtype: ttnn.DataType, *, memory_config: ttnn.MemoryConfig | None, output_tensor: ttnn.Tensor | None) ttnn.Tensor
-
Performs typecast on elements of a tensor on the host or device to the desired dtype.
- Parameters:
-
input_tensor (ttnn.Tensor) – input tensor to be typecast (can be on the host or device).
dtype (ttnn.DataType) – data type to cast the tensor elements to.
- Keyword Arguments:
-
memory_config (Optional[ttnn.MemoryConfig]) – Memory configuration for the operation.
output_tensor (Optional[ttnn.Tensor]) – Preallocated tensor to store the output.
- Returns:
-
ttnn.Tensor – The tensor with the updated data type. The output tensor will be in the same layout as the input tensor and have the given data type.
Note
This operations supports tensors according to the following data types and layout:
input_tensor dtype - layout
BFLOAT16, BFLOAT8_B, BFLOAT4_B, FLOAT32, UINT32, INT32, UINT16, UINT8 - TILE
BFLOAT16, FLOAT32, UINT32, INT32, UINT16, UINT8 - ROW_MAJOR
- Memory Support:
-
Interleaved: DRAM and L1
Height, Width, and Block Sharded: DRAM and L1
- Limitations:
-
ND Sharded tensors are not supported.
If preallocated output tensor is used, it must match the input tensor’s shape and layout.
Example
# Create a TT-NN tensor (on host or device) and typecast it to a different data type tensor = ttnn.typecast(ttnn.rand((10, 3, 32, 32), dtype=ttnn.bfloat16, device=device), dtype=ttnn.uint16) assert tensor.dtype == ttnn.uint16 assert tensor.shape == (10, 3, 32, 32) logger.info("TT-NN tensor shape after typecasting to uint16", tensor.shape)