@@ -55,79 +55,69 @@ class ReducescatterOp
5555 return 0 ;
5656 }
5757
58- torch::Tensor run (torch::Tensor const & input , torch::optional<torch::List<int64_t >> sizes)
58+ std::vector< torch::Tensor> run_list (torch::TensorList input_list , torch::optional<torch::List<int64_t >> sizes)
5959 {
6060 TLLM_CHECK_WITH_INFO (mNcclComm .get () != nullptr , " mNcclComm should be initialized before used" );
61- auto stream = at::cuda::getCurrentCUDAStream (input.get_device ());
62- auto type = tensorrt_llm::runtime::TorchUtils::dataType (input.scalar_type ());
63- std::vector<int64_t > outputShape = input.sizes ().vec ();
61+ bool use_nccl_reducescatter = !sizes.has_value ()
62+ || std::all_of (sizes.value ().begin (), sizes.value ().end (),
63+ [&sizes](int64_t size) { return size == sizes.value ()[0 ]; });
64+ int groupRank = 0 ;
6465 if (sizes.has_value ())
6566 {
6667 auto rank = COMM_SESSION.getRank ();
67- int groupRank = 0 ;
6868 for (auto const & currentRank : mGroup )
6969 {
7070 if (rank == currentRank)
7171 break ;
7272 ++groupRank;
7373 }
7474 TLLM_CHECK (static_cast <size_t >(groupRank) < mGroup .size ());
75- outputShape[0 ] = sizes.value ()[groupRank];
76- }
77- else
78- {
79- outputShape[0 ] = outputShape[0 ] / mGroup .size ();
8075 }
81- auto output = torch::empty (outputShape, input.options ());
82- bool use_nccl_reducescatter = !sizes.has_value ()
83- || std::all_of (sizes.value ().begin (), sizes.value ().end (),
84- [&sizes](int64_t size) { return size == sizes.value ()[0 ]; });
85- if (use_nccl_reducescatter)
86- {
87- NCCLCHECK_THROW (ncclReduceScatter (input.data_ptr (), output.mutable_data_ptr (), output.numel (),
88- (*getDtypeMap ())[type], ncclSum, *mNcclComm , stream));
89- }
90- else
76+ std::vector<torch::Tensor> output_list;
77+ output_list.reserve (input_list.size ());
78+ ncclGroupStart ();
79+ for (auto const & input : input_list)
9180 {
92- size_t numel_base = std::accumulate (outputShape. cbegin () + 1 , outputShape. cend (), 1 , std::multiplies<>{} );
93- int64_t split_offset = 0 ;
94- ncclGroupStart ();
95- for ( int root = 0 ; root < static_cast < int >( mGroup . size ()); ++root )
81+ auto stream = at::cuda::getCurrentCUDAStream (input. get_device () );
82+ auto type = tensorrt_llm::runtime::TorchUtils::dataType (input. scalar_type ()) ;
83+ std::vector< int64_t > outputShape = input. sizes (). vec ();
84+ if (sizes. has_value () )
9685 {
97- auto split_size = sizes.value ()[root];
98- NCCLCHECK_THROW (
86+ outputShape[0 ] = sizes.value ()[groupRank];
87+ }
88+ else
89+ {
90+ outputShape[0 ] = outputShape[0 ] / mGroup .size ();
91+ }
92+ auto output = torch::empty (outputShape, input.options ());
93+ if (use_nccl_reducescatter)
94+ {
95+ ncclReduceScatter (input.data_ptr (), output.mutable_data_ptr (), output.numel (), (*getDtypeMap ())[type],
96+ ncclSum, *mNcclComm , stream);
97+ }
98+ else
99+ {
100+ size_t numel_base
101+ = std::accumulate (outputShape.cbegin () + 1 , outputShape.cend (), 1 , std::multiplies<>{});
102+ int64_t split_offset = 0 ;
103+ for (int root = 0 ; root < static_cast <int >(mGroup .size ()); ++root)
104+ {
105+ auto split_size = sizes.value ()[root];
99106 ncclReduce (input.index ({torch::indexing::Slice (split_offset, torch::indexing::None)}).data_ptr (),
100107 output.mutable_data_ptr (), numel_base * split_size, (*getDtypeMap ())[type], ncclSum, root,
101- *mNcclComm , stream));
102- split_offset += split_size;
108+ *mNcclComm , stream);
109+ split_offset += split_size;
110+ }
103111 }
104- ncclGroupEnd ( );
112+ output_list. push_back (output );
105113 }
106- return output;
114+ NCCLCHECK_THROW (ncclGroupEnd ());
115+ return output_list;
107116 }
108117
109- std::vector<torch::Tensor> run_list (
110- torch::TensorList input_list, torch::optional<torch::List<int64_t >> sizes) noexcept
118+ torch::Tensor run (torch::Tensor const & input, torch::optional<torch::List<int64_t >> sizes)
111119 {
112- std::vector<torch::Tensor> output_list;
113- output_list.reserve (input_list.size ());
114- bool use_nccl_reducescatter = !sizes.has_value ()
115- || std::all_of (sizes.value ().begin (), sizes.value ().end (),
116- [&sizes](int64_t size) { return size == sizes.value ()[0 ]; });
117- if (use_nccl_reducescatter)
118- {
119- ncclGroupStart ();
120- }
121- for (auto const & input : input_list)
122- {
123- auto output = run (input, sizes);
124- output_list.push_back (output);
125- }
126- if (use_nccl_reducescatter)
127- {
128- ncclGroupEnd ();
129- }
130- return output_list;
120+ return run_list ({input}, sizes)[0 ];
131121 }
132122
133123private:
0 commit comments