ttnn.reduce_to_root
- ttnn.reduce_to_root() ttnn.Tensor output_tensor_l
-
Reduce-to-root operation. Performs sdpa tree reduction across 4 devices and stores the output on the root device only.
- Args:
-
input_tensor_l: the input tensor is a vector of values l of SDPA. input_tensor_s: the input tensor is a vector of state s of SDPA. input_tensor_m: the input tensor is a vector of state m of SDPA. root_coord (ttnn.MeshCoordinate): Coordinate of the root device. Should be (1,0) for 4 devices setup.
- Keyword Args:
-
topology (ttnn.Topology): Fabric topology. output_tensor (ttnn.Tensor,optional): Optional output tensor. intermediate_tensor (ttnn.Tensor,optional): Optional intermediate tensor.
- Returns:
-
ttnn.Tensor output_tensor_l –
- the output tensor for values.
-
ttnn.Tensor output_tensor_s: the output tensor for sum. ttnn.Tensor output_tensor_m: the output tensor for max.
Example:
>>> input_tensor_torch_l = torch.zeros((8,128), dtype=dtype) >>> input_tensor_torch_s = torch.zeros((8,32), dtype=dtype) >>> input_tensor_torch_m = torch.zeros((8,32), dtype=dtype)
>>> input_tensor_l = ttnn.from_torch( >>> input_tensor_torch_l, device=mesh_device, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0) >>> ) >>> input_tensor_s = ttnn.from_torch( >>> input_tensor_torch_s, device=mesh_device, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0) >>> ) >>> input_tensor_m = ttnn.from_torch( >>> input_tensor_torch_m, device=mesh_device, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0) >>> ) >>> root_coord= ttnn.MeshCoordinate((1,0)) >>> output_tensor_l, output_tensor_s, output_tensor_m = ttnn.reduce_to_root( input_tensor_l, input_tensor_s, input_tensor_m, root_coord, scale_fp32=1.0, topology=ttnn.Topology.Linear)