ttnn.reduce_to_root

ttnn.reduce_to_root() ttnn.Tensor output_tensor_l

Reduce-to-root operation. Performs sdpa tree reduction across 4 devices and stores the output on the root device only.

Args:

input_tensor_l: the input tensor is a vector of values l of SDPA. input_tensor_s: the input tensor is a vector of state s of SDPA. input_tensor_m: the input tensor is a vector of state m of SDPA. root_coord (ttnn.MeshCoordinate): Coordinate of the root device. Should be (1,0) for 4 devices setup.

Keyword Args:

topology (ttnn.Topology): Fabric topology. output_tensor (ttnn.Tensor,optional): Optional output tensor. intermediate_tensor (ttnn.Tensor,optional): Optional intermediate tensor.

Returns:

ttnn.Tensor output_tensor_l

the output tensor for values.

ttnn.Tensor output_tensor_s: the output tensor for sum. ttnn.Tensor output_tensor_m: the output tensor for max.

Example:

>>> input_tensor_torch_l = torch.zeros((8,128), dtype=dtype)
>>> input_tensor_torch_s = torch.zeros((8,32), dtype=dtype)
>>> input_tensor_torch_m = torch.zeros((8,32), dtype=dtype)
>>> input_tensor_l = ttnn.from_torch(
>>>     input_tensor_torch_l, device=mesh_device, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0)
>>> )
>>> input_tensor_s = ttnn.from_torch(
>>>     input_tensor_torch_s, device=mesh_device, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0)
>>> )
>>> input_tensor_m = ttnn.from_torch(
>>>     input_tensor_torch_m, device=mesh_device, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0)
>>> )
>>> root_coord= ttnn.MeshCoordinate((1,0))
>>> output_tensor_l, output_tensor_s, output_tensor_m = ttnn.reduce_to_root(
        input_tensor_l,
        input_tensor_s,
        input_tensor_m,
        root_coord,
        scale_fp32=1.0,
        topology=ttnn.Topology.Linear)