Skip to content

Commit a875e50

Browse files
authored
[https://nvbugs/5392414] [fix] For release 1.0 cherry pick. Add customized default routing method (NVIDIA#7068)
Signed-off-by: Christina Zhang <[email protected]>
1 parent caf73f5 commit a875e50

File tree

11 files changed

+742
-433
lines changed

11 files changed

+742
-433
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "moeTopKFuncs.cuh"
18+
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
19+
#include "tensorrt_llm/common/envUtils.h"
20+
#include "tensorrt_llm/kernels/archCondition.h"
21+
#include "tensorrt_llm/kernels/customMoeRoutingKernels.h"
22+
#include <climits> // For INT_MAX
23+
#include <cooperative_groups.h>
24+
#include <cooperative_groups/reduce.h>
25+
#include <cub/cub.cuh>
26+
#include <cuda/std/limits> // For numeric_limits
27+
#include <math.h>
28+
29+
namespace cg = cooperative_groups;
30+
using namespace tensorrt_llm::common;
31+
32+
namespace tensorrt_llm::kernels
33+
{
34+
35+
static constexpr int BLOCK_SIZE = 1024;
36+
static constexpr int WARP_SIZE = 32;
37+
static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
38+
39+
////////////////////////////////////////////////////////////////////////////////////////////////////
40+
41+
template <typename T>
42+
__device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score, int32_t laneIdx, int32_t NumTopExperts)
43+
{
44+
T maxScore = T{-INFINITY};
45+
if (laneIdx < NumTopExperts)
46+
{
47+
maxScore = score >= maxScore ? score : maxScore;
48+
}
49+
maxScore = cg::reduce(warp, maxScore, cg::greater<T>());
50+
51+
float sumScore{0.f};
52+
float newScore;
53+
// Get the summation of scores for each token
54+
if (laneIdx < NumTopExperts)
55+
{
56+
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
57+
newScore = static_cast<float>(exp(newScore));
58+
sumScore += newScore;
59+
}
60+
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
61+
62+
if (laneIdx < NumTopExperts)
63+
{
64+
score = static_cast<T>(newScore / sumScore);
65+
}
66+
67+
return score;
68+
}
69+
70+
template <typename DataType, int VecSize>
71+
__device__ void calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, DataType (&scores)[VecSize])
72+
{
73+
DataType maxScore = DataType{-INFINITY};
74+
DataType sumScore = DataType{0.f};
75+
76+
// Get the max score for each token
77+
#pragma unroll
78+
for (int i = 0; i < VecSize; ++i)
79+
{
80+
maxScore = scores[i] >= maxScore ? scores[i] : maxScore;
81+
}
82+
maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>());
83+
84+
// Get the summation of scores for each token
85+
#pragma unroll
86+
for (int i = 0; i < VecSize; ++i)
87+
{
88+
scores[i] = static_cast<DataType>(exp(scores[i] - maxScore));
89+
sumScore += scores[i];
90+
}
91+
sumScore = cg::reduce(warp, sumScore, cg::plus<DataType>());
92+
93+
// Normalize the scores
94+
#pragma unroll
95+
for (int i = 0; i < VecSize; ++i)
96+
{
97+
scores[i] = static_cast<DataType>(scores[i] / sumScore);
98+
}
99+
}
100+
101+
////////////////////////////////////////////////////////////////////////////////////////////////////
102+
103+
template <typename InputT, typename OutputT, typename IdxT, int MaxNumExperts, int MaxNumTopExperts,
104+
bool DoSoftmaxBeforeTopK>
105+
__global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices,
106+
int32_t const numTokens, int32_t const numExperts, int32_t const topK)
107+
{
108+
using BaseType = std::conditional_t<DoSoftmaxBeforeTopK, float, InputT>;
109+
uint32_t const blockRank = blockIdx.x;
110+
uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x;
111+
uint32_t const warpIdx = tIdx / WARP_SIZE;
112+
uint32_t const laneIdx = tIdx % WARP_SIZE;
113+
uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK;
114+
auto block = cg::this_thread_block();
115+
auto warp = cg::tiled_partition<WARP_SIZE>(block);
116+
117+
BaseType minScore = BaseType{-INFINITY};
118+
for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
119+
{
120+
auto scoreOffset = tokenId * numExperts;
121+
auto outputOffset = tokenId * topK;
122+
123+
BaseType inputScore[MaxNumExperts / WARP_SIZE];
124+
IdxT inputIndex[MaxNumExperts / WARP_SIZE];
125+
126+
BaseType warpTopKScore[MaxNumTopExperts];
127+
IdxT warpTopKExpertIdx[MaxNumTopExperts];
128+
129+
// Load scores and indices for this warp
130+
for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i)
131+
{
132+
auto expertIdx = i * WARP_SIZE + laneIdx;
133+
inputScore[i]
134+
= expertIdx < numExperts ? static_cast<BaseType>(routerLogits[scoreOffset + expertIdx]) : minScore;
135+
inputIndex[i] = expertIdx;
136+
}
137+
138+
if constexpr (DoSoftmaxBeforeTopK)
139+
{
140+
calcSoftmax(warp, inputScore);
141+
}
142+
// Reduce topK scores and indices for this warp
143+
reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
144+
145+
// Normalize the scores
146+
if constexpr (DoSoftmaxBeforeTopK)
147+
{
148+
if (laneIdx < topK)
149+
{
150+
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(warpTopKScore[laneIdx]);
151+
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
152+
}
153+
}
154+
else
155+
{
156+
auto softmaxScore = calcSoftmax(warp,
157+
laneIdx < topK ? static_cast<float>(warpTopKScore[laneIdx]) : static_cast<float>(minScore), laneIdx,
158+
topK);
159+
if (laneIdx < topK)
160+
{
161+
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(softmaxScore);
162+
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
163+
}
164+
}
165+
} // end for tokenId
166+
}
167+
168+
int nextPowerOfTwo(int num)
169+
{
170+
if (num <= 0)
171+
{
172+
return 1; // Handle invalid input
173+
}
174+
int power = 1;
175+
while (power < num)
176+
{
177+
// Check for overflow before shifting
178+
if (power > INT_MAX / 2)
179+
{
180+
return power;
181+
}
182+
power <<= 1;
183+
}
184+
return power;
185+
}
186+
187+
#define CASE(MAX_NUM_EXPERTS) \
188+
case MAX_NUM_EXPERTS: \
189+
switch (maxNumTopExperts) \
190+
{ \
191+
case 1: \
192+
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 1, DoSoftmaxBeforeTopK>; \
193+
break; \
194+
case 2: \
195+
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 2, DoSoftmaxBeforeTopK>; \
196+
break; \
197+
case 4: \
198+
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 4, DoSoftmaxBeforeTopK>; \
199+
break; \
200+
case 8: \
201+
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 8, DoSoftmaxBeforeTopK>; \
202+
break; \
203+
default: kernelInstance = nullptr; break; \
204+
} \
205+
break;
206+
207+
template <typename InputT, typename OutputT, typename IdxT, bool DoSoftmaxBeforeTopK>
208+
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
209+
int64_t const numExperts, int64_t const topK, cudaStream_t const stream)
210+
{
211+
212+
const uint32_t maxNumBlocks = 1024;
213+
const uint32_t numBlocks = std::min(static_cast<uint32_t>((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks);
214+
215+
uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts);
216+
uint32_t maxNumTopExperts = nextPowerOfTwo(topK);
217+
218+
auto* kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, 128, 8, DoSoftmaxBeforeTopK>;
219+
220+
switch (maxNumExperts)
221+
{
222+
CASE(32)
223+
CASE(64)
224+
CASE(96)
225+
CASE(128)
226+
default: kernelInstance = nullptr; break;
227+
}
228+
229+
if (kernelInstance == nullptr)
230+
{
231+
TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance.");
232+
}
233+
234+
dim3 renormMoeRoutingGridDim(numBlocks);
235+
dim3 renormMoeRoutingBlockDim(BLOCK_SIZE);
236+
cudaLaunchConfig_t config;
237+
config.gridDim = renormMoeRoutingGridDim;
238+
config.blockDim = renormMoeRoutingBlockDim;
239+
config.dynamicSmemBytes = 0;
240+
config.stream = stream;
241+
cudaLaunchAttribute attrs[1];
242+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
243+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
244+
config.numAttrs = 1;
245+
config.attrs = attrs;
246+
cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast<int32_t>(numTokens),
247+
static_cast<int32_t>(numExperts), static_cast<int32_t>(topK));
248+
sync_check_cuda_error(stream);
249+
}
250+
251+
#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT, DoSoftmaxBeforeTopK) \
252+
template void invokeRenormMoeRouting<InputT, OutputT, IdxT, DoSoftmaxBeforeTopK>(InputT * routerLogits, \
253+
OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \
254+
int64_t const topK, cudaStream_t const stream);
255+
256+
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, false);
257+
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, false);
258+
#ifdef ENABLE_BF16
259+
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, false);
260+
#endif
261+
262+
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, true);
263+
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, true);
264+
#ifdef ENABLE_BF16
265+
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, true);
266+
#endif
267+
268+
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h renamed to cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@
2323

2424
namespace tensorrt_llm::kernels
2525
{
26-
template <typename InputT, typename OutputT, typename IdxT>
26+
template <typename InputT, typename OutputT, typename IdxT, bool DoSoftmaxBeforeTopK>
2727
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
2828
int64_t const numExperts, int64_t const topK, cudaStream_t const stream);
2929
} // namespace tensorrt_llm::kernels

0 commit comments

Comments
 (0)