Skip to content

Commit 9428414

Browse files
authored
opensource: Opensource MOE MXFP8-MXFP4 implementation (#5222)
Signed-off-by: Daniel Stokes <[email protected]>
1 parent e9cd810 commit 9428414

File tree

55 files changed

+1492
-4690
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1492
-4690
lines changed

3rdparty/cutlass

Submodule cutlass updated 257 files

cpp/include/tensorrt_llm/common/cudaUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,10 +1339,10 @@ struct ConstExprWrapper
13391339
};
13401340

13411341
template <int VALUE>
1342-
using Int = ConstExprWrapper<int, VALUE>;
1342+
using ConstInt = ConstExprWrapper<int, VALUE>;
13431343

13441344
template <bool VALUE>
1345-
using Bool = ConstExprWrapper<bool, VALUE>;
1345+
using ConstBool = ConstExprWrapper<bool, VALUE>;
13461346

13471347
template <typename T>
13481348
struct TmaDescType;

cpp/micro_benchmarks/CMakeLists.txt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,5 @@ function(add_benchmark test_name test_src)
5050
add_dependencies(micro_benchmarks ${test_name})
5151
endfunction()
5252

53-
# currently only support internal-cutlass lib version
5453
add_benchmark(mixtureOfExpertsBackendBenchmark
5554
mixtureOfExpertsBackendBenchmarkLauncher.cu)
56-
# Temporary opend-sourced version. Will be daleted when open-sourced moe_gemm
57-
# support MXFP4
58-
if(USING_OSS_CUTLASS_MOE_GEMM)
59-
add_benchmark(mixtureOfExpertsBackendBenchmarkOss
60-
mixtureOfExpertsBackendBenchmarkLauncherOss.cu)
61-
endif()

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@
2020

2121
#include <nlohmann/json.hpp>
2222

23+
#ifdef USING_OSS_CUTLASS_MOE_GEMM
24+
#include "tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h"
25+
#else
2326
#include "moe_kernels.h"
27+
#endif
28+
29+
#include "tensorrt_llm/kernels/cutlass_kernels/include/cutlass_kernel_selector.h"
30+
2431
#include "tensorrt_llm/common/cudaUtils.h"
2532
#include "tensorrt_llm/common/memoryUtils.h"
2633
#include "tensorrt_llm/common/nvtxUtils.h"
@@ -42,6 +49,12 @@ using namespace tensorrt_llm::common;
4249
using namespace tensorrt_llm::runtime;
4350
using namespace tensorrt_llm::cutlass_extensions;
4451

52+
using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
53+
using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput;
54+
using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner;
55+
using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
56+
using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation;
57+
4558
static BufferManager::CudaStreamPtr streamPtr;
4659
static std::unique_ptr<BufferManager> bufferManager;
4760
static int deviceCount;
@@ -485,7 +498,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
485498
bool mIsGated = false;
486499
int mGatedMultiplier = 1;
487500

488-
tensorrt_llm::ActivationType mActType = tensorrt_llm::ActivationType::Relu;
501+
ActivationType mActType = ActivationType::Relu;
489502

490503
QuantParams mQuantParams{};
491504
bool mUseLora = false;
@@ -650,9 +663,15 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
650663
"Tactic Profiling GEMM " + std::to_string(static_cast<int>(gemm_to_profile)));
651664

652665
GemmProfilerBackend profiler;
666+
#ifdef USING_OSS_CUTLASS_MOE_GEMM
667+
profiler.init(mMoERunner, gemm_to_profile, typeToDtypeID<DataType>(), typeToDtypeID<WeightType>(),
668+
typeToDtypeID<OutputType>(), mNumExperts, mK, mHiddenSize, mInterSize, mGroupSize, mActType, mUseBias,
669+
mUseLora, /*min_latency_mode=*/false, /*need_weights=*/true, parallelism_config, /*enable_alltoall=*/false);
670+
#else
653671
profiler.init(mMoERunner, gemm_to_profile, typeToDtypeID<DataType>(), typeToDtypeID<WeightType>(),
654672
typeToDtypeID<OutputType>(), mNumExperts, mK, mHiddenSize, mInterSize, mGroupSize, mActType, mUseBias,
655673
mUseLora, /*min_latency_mode=*/false, /*need_weights=*/true, parallelism_config);
674+
#endif
656675
auto workspace_size = profiler.getWorkspaceSize(mTotalTokens);
657676
auto workspace = bufferManager->gpu(workspace_size);
658677

@@ -760,11 +779,19 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
760779
{
761780
auto stream = streamPtr->get();
762781
MoeMinLatencyParams min_latency_params;
782+
#ifdef USING_OSS_CUTLASS_MOE_GEMM
763783
mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExperts, mUseFinalScale ? mScaleProbs : nullptr,
764784
mExpertWeight1, mExpertBias1, mActType, mExpertWeight2, mExpertBias2, mQuantParams, mTotalTokens,
765785
mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap,
766-
parallelism_config, mUseLora, mLoraParams,
786+
parallelism_config, /*enable_alltoall=*/false, mUseLora, mLoraParams,
767787
/*use_deepseek_fp8_block_scale=*/false, /*min_latency_mode=*/false, min_latency_params, stream);
788+
#else
789+
mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExperts, mUseFinalScale ? mScaleProbs : nullptr,
790+
mExpertWeight1, mExpertBias1, mActType, mExpertWeight2, mExpertBias2, mQuantParams, mTotalTokens,
791+
mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap,
792+
parallelism_config, mUseLora, mLoraParams, /*use_deepseek_fp8_block_scale=*/false,
793+
/*min_latency_mode=*/false, min_latency_params, stream);
794+
#endif
768795
}
769796

770797
void runBenchmark(benchmark::State& state);

0 commit comments

Comments
 (0)