|
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "moeTopKFuncs.cuh" |
| 18 | +#include "tensorrt_llm/common/cudaTypeUtils.cuh" |
| 19 | +#include "tensorrt_llm/common/envUtils.h" |
| 20 | +#include "tensorrt_llm/kernels/archCondition.h" |
| 21 | +#include "tensorrt_llm/kernels/customMoeRoutingKernels.h" |
| 22 | +#include <climits> // For INT_MAX |
| 23 | +#include <cooperative_groups.h> |
| 24 | +#include <cooperative_groups/reduce.h> |
| 25 | +#include <cub/cub.cuh> |
| 26 | +#include <cuda/std/limits> // For numeric_limits |
| 27 | +#include <math.h> |
| 28 | + |
| 29 | +namespace cg = cooperative_groups; |
| 30 | +using namespace tensorrt_llm::common; |
| 31 | + |
| 32 | +namespace tensorrt_llm::kernels |
| 33 | +{ |
| 34 | + |
| 35 | +static constexpr int BLOCK_SIZE = 1024; |
| 36 | +static constexpr int WARP_SIZE = 32; |
| 37 | +static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; |
| 38 | + |
| 39 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 40 | + |
| 41 | +template <typename T> |
| 42 | +__device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score, int32_t laneIdx, int32_t NumTopExperts) |
| 43 | +{ |
| 44 | + T maxScore = T{-INFINITY}; |
| 45 | + if (laneIdx < NumTopExperts) |
| 46 | + { |
| 47 | + maxScore = score >= maxScore ? score : maxScore; |
| 48 | + } |
| 49 | + maxScore = cg::reduce(warp, maxScore, cg::greater<T>()); |
| 50 | + |
| 51 | + float sumScore{0.f}; |
| 52 | + float newScore; |
| 53 | + // Get the summation of scores for each token |
| 54 | + if (laneIdx < NumTopExperts) |
| 55 | + { |
| 56 | + newScore = static_cast<float>(score) - static_cast<float>(maxScore); |
| 57 | + newScore = static_cast<float>(exp(newScore)); |
| 58 | + sumScore += newScore; |
| 59 | + } |
| 60 | + sumScore = cg::reduce(warp, sumScore, cg::plus<float>()); |
| 61 | + |
| 62 | + if (laneIdx < NumTopExperts) |
| 63 | + { |
| 64 | + score = static_cast<T>(newScore / sumScore); |
| 65 | + } |
| 66 | + |
| 67 | + return score; |
| 68 | +} |
| 69 | + |
| 70 | +template <typename DataType, int VecSize> |
| 71 | +__device__ void calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, DataType (&scores)[VecSize]) |
| 72 | +{ |
| 73 | + DataType maxScore = DataType{-INFINITY}; |
| 74 | + DataType sumScore = DataType{0.f}; |
| 75 | + |
| 76 | + // Get the max score for each token |
| 77 | +#pragma unroll |
| 78 | + for (int i = 0; i < VecSize; ++i) |
| 79 | + { |
| 80 | + maxScore = scores[i] >= maxScore ? scores[i] : maxScore; |
| 81 | + } |
| 82 | + maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>()); |
| 83 | + |
| 84 | + // Get the summation of scores for each token |
| 85 | +#pragma unroll |
| 86 | + for (int i = 0; i < VecSize; ++i) |
| 87 | + { |
| 88 | + scores[i] = static_cast<DataType>(exp(scores[i] - maxScore)); |
| 89 | + sumScore += scores[i]; |
| 90 | + } |
| 91 | + sumScore = cg::reduce(warp, sumScore, cg::plus<DataType>()); |
| 92 | + |
| 93 | + // Normalize the scores |
| 94 | +#pragma unroll |
| 95 | + for (int i = 0; i < VecSize; ++i) |
| 96 | + { |
| 97 | + scores[i] = static_cast<DataType>(scores[i] / sumScore); |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 102 | + |
| 103 | +template <typename InputT, typename OutputT, typename IdxT, int MaxNumExperts, int MaxNumTopExperts, |
| 104 | + bool DoSoftmaxBeforeTopK> |
| 105 | +__global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, |
| 106 | + int32_t const numTokens, int32_t const numExperts, int32_t const topK) |
| 107 | +{ |
| 108 | + using BaseType = std::conditional_t<DoSoftmaxBeforeTopK, float, InputT>; |
| 109 | + uint32_t const blockRank = blockIdx.x; |
| 110 | + uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x; |
| 111 | + uint32_t const warpIdx = tIdx / WARP_SIZE; |
| 112 | + uint32_t const laneIdx = tIdx % WARP_SIZE; |
| 113 | + uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK; |
| 114 | + auto block = cg::this_thread_block(); |
| 115 | + auto warp = cg::tiled_partition<WARP_SIZE>(block); |
| 116 | + |
| 117 | + BaseType minScore = BaseType{-INFINITY}; |
| 118 | + for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum) |
| 119 | + { |
| 120 | + auto scoreOffset = tokenId * numExperts; |
| 121 | + auto outputOffset = tokenId * topK; |
| 122 | + |
| 123 | + BaseType inputScore[MaxNumExperts / WARP_SIZE]; |
| 124 | + IdxT inputIndex[MaxNumExperts / WARP_SIZE]; |
| 125 | + |
| 126 | + BaseType warpTopKScore[MaxNumTopExperts]; |
| 127 | + IdxT warpTopKExpertIdx[MaxNumTopExperts]; |
| 128 | + |
| 129 | + // Load scores and indices for this warp |
| 130 | + for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i) |
| 131 | + { |
| 132 | + auto expertIdx = i * WARP_SIZE + laneIdx; |
| 133 | + inputScore[i] |
| 134 | + = expertIdx < numExperts ? static_cast<BaseType>(routerLogits[scoreOffset + expertIdx]) : minScore; |
| 135 | + inputIndex[i] = expertIdx; |
| 136 | + } |
| 137 | + |
| 138 | + if constexpr (DoSoftmaxBeforeTopK) |
| 139 | + { |
| 140 | + calcSoftmax(warp, inputScore); |
| 141 | + } |
| 142 | + // Reduce topK scores and indices for this warp |
| 143 | + reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); |
| 144 | + |
| 145 | + // Normalize the scores |
| 146 | + if constexpr (DoSoftmaxBeforeTopK) |
| 147 | + { |
| 148 | + if (laneIdx < topK) |
| 149 | + { |
| 150 | + topkValues[outputOffset + laneIdx] = static_cast<OutputT>(warpTopKScore[laneIdx]); |
| 151 | + topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; |
| 152 | + } |
| 153 | + } |
| 154 | + else |
| 155 | + { |
| 156 | + auto softmaxScore = calcSoftmax(warp, |
| 157 | + laneIdx < topK ? static_cast<float>(warpTopKScore[laneIdx]) : static_cast<float>(minScore), laneIdx, |
| 158 | + topK); |
| 159 | + if (laneIdx < topK) |
| 160 | + { |
| 161 | + topkValues[outputOffset + laneIdx] = static_cast<OutputT>(softmaxScore); |
| 162 | + topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; |
| 163 | + } |
| 164 | + } |
| 165 | + } // end for tokenId |
| 166 | +} |
| 167 | + |
| 168 | +int nextPowerOfTwo(int num) |
| 169 | +{ |
| 170 | + if (num <= 0) |
| 171 | + { |
| 172 | + return 1; // Handle invalid input |
| 173 | + } |
| 174 | + int power = 1; |
| 175 | + while (power < num) |
| 176 | + { |
| 177 | + // Check for overflow before shifting |
| 178 | + if (power > INT_MAX / 2) |
| 179 | + { |
| 180 | + return power; |
| 181 | + } |
| 182 | + power <<= 1; |
| 183 | + } |
| 184 | + return power; |
| 185 | +} |
| 186 | + |
| 187 | +#define CASE(MAX_NUM_EXPERTS) \ |
| 188 | + case MAX_NUM_EXPERTS: \ |
| 189 | + switch (maxNumTopExperts) \ |
| 190 | + { \ |
| 191 | + case 1: \ |
| 192 | + kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 1, DoSoftmaxBeforeTopK>; \ |
| 193 | + break; \ |
| 194 | + case 2: \ |
| 195 | + kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 2, DoSoftmaxBeforeTopK>; \ |
| 196 | + break; \ |
| 197 | + case 4: \ |
| 198 | + kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 4, DoSoftmaxBeforeTopK>; \ |
| 199 | + break; \ |
| 200 | + case 8: \ |
| 201 | + kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 8, DoSoftmaxBeforeTopK>; \ |
| 202 | + break; \ |
| 203 | + default: kernelInstance = nullptr; break; \ |
| 204 | + } \ |
| 205 | + break; |
| 206 | + |
| 207 | +template <typename InputT, typename OutputT, typename IdxT, bool DoSoftmaxBeforeTopK> |
| 208 | +void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, |
| 209 | + int64_t const numExperts, int64_t const topK, cudaStream_t const stream) |
| 210 | +{ |
| 211 | + |
| 212 | + const uint32_t maxNumBlocks = 1024; |
| 213 | + const uint32_t numBlocks = std::min(static_cast<uint32_t>((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks); |
| 214 | + |
| 215 | + uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts); |
| 216 | + uint32_t maxNumTopExperts = nextPowerOfTwo(topK); |
| 217 | + |
| 218 | + auto* kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, 128, 8, DoSoftmaxBeforeTopK>; |
| 219 | + |
| 220 | + switch (maxNumExperts) |
| 221 | + { |
| 222 | + CASE(32) |
| 223 | + CASE(64) |
| 224 | + CASE(96) |
| 225 | + CASE(128) |
| 226 | + default: kernelInstance = nullptr; break; |
| 227 | + } |
| 228 | + |
| 229 | + if (kernelInstance == nullptr) |
| 230 | + { |
| 231 | + TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance."); |
| 232 | + } |
| 233 | + |
| 234 | + dim3 renormMoeRoutingGridDim(numBlocks); |
| 235 | + dim3 renormMoeRoutingBlockDim(BLOCK_SIZE); |
| 236 | + cudaLaunchConfig_t config; |
| 237 | + config.gridDim = renormMoeRoutingGridDim; |
| 238 | + config.blockDim = renormMoeRoutingBlockDim; |
| 239 | + config.dynamicSmemBytes = 0; |
| 240 | + config.stream = stream; |
| 241 | + cudaLaunchAttribute attrs[1]; |
| 242 | + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; |
| 243 | + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); |
| 244 | + config.numAttrs = 1; |
| 245 | + config.attrs = attrs; |
| 246 | + cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast<int32_t>(numTokens), |
| 247 | + static_cast<int32_t>(numExperts), static_cast<int32_t>(topK)); |
| 248 | + sync_check_cuda_error(stream); |
| 249 | +} |
| 250 | + |
| 251 | +#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT, DoSoftmaxBeforeTopK) \ |
| 252 | + template void invokeRenormMoeRouting<InputT, OutputT, IdxT, DoSoftmaxBeforeTopK>(InputT * routerLogits, \ |
| 253 | + OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \ |
| 254 | + int64_t const topK, cudaStream_t const stream); |
| 255 | + |
| 256 | +INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, false); |
| 257 | +INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, false); |
| 258 | +#ifdef ENABLE_BF16 |
| 259 | +INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, false); |
| 260 | +#endif |
| 261 | + |
| 262 | +INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, true); |
| 263 | +INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, true); |
| 264 | +#ifdef ENABLE_BF16 |
| 265 | +INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, true); |
| 266 | +#endif |
| 267 | + |
| 268 | +} // namespace tensorrt_llm::kernels |
0 commit comments