Skip to content

Commit 4f0f17a

Browse files
authored
feat: Misc Opt for large scale EP (#5374)
Signed-off-by: Dongxu Yang <[email protected]>
1 parent 5d4ab47 commit 4f0f17a

File tree

9 files changed

+363
-77
lines changed

9 files changed

+363
-77
lines changed

cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceCommon.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ struct MoeLoadBalanceStatisticInfo
6363

6464
// rawDataWindowSize means the size of the raw data window.
6565
// e.g. how many steps of raw data are kept in the memory.
66-
int rawDataWindowSize = 1;
66+
// current we keep only the data in current iteration, previous should sum to expertLoadFactor.
67+
static constexpr int rawDataWindowSize = 1;
6768

6869
// decayFactor means the decay factor of the raw data per step.
69-
// e.g. if decayFactor is 0.9, then the raw data of expert i will be decayed by 0.9 for each step.
70-
float decayFactor = 0.9f;
70+
// e.g. if decayFactor is 0.95, then the raw data of expert i will be decayed by 0.95 for each step.
71+
float decayFactor = 0.95f;
7172
};
7273

7374
// The placement information for GPU

cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,19 @@ void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal)
128128
signal->stepAndOwner += MoeLoadBalanceSingleLayerSignal::kCPU;
129129
}
130130

131+
template <typename TYPE>
132+
__global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int* const enabled, int* expertTokenCount)
133+
{
134+
if (*enabled == 0)
135+
{
136+
return;
137+
}
138+
TYPE oldExpertTokenCount = {0};
139+
int* expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx.x;
140+
TYPE* typedExpertTokenCountPtr = reinterpret_cast<TYPE*>(expertTokenCountPtr);
141+
typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount;
142+
}
143+
131144
template <typename TYPE>
132145
__global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const enabled, int* expertTokenCount)
133146
{
@@ -151,8 +164,8 @@ __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const en
151164
typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount;
152165
}
153166

154-
__global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo,
155-
int totalEltCount, int* const enabled, int* const gatheredRawExpertIds)
167+
__global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertTokenCount, int totalEltCount,
168+
int* const enabled, int* const gatheredRawExpertIds)
156169
{
157170
extern __shared__ int sharedExpertCount[];
158171
if (*enabled == 0)
@@ -175,19 +188,19 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceS
175188
__syncthreads();
176189
for (int i = threadIdx.x; i < metaInfo.expertCount; i += blockDim.x)
177190
{
178-
atomicAdd_system(&statisticInfo.expertTokenCount[i], sharedExpertCount[i]);
191+
atomicAdd_system(&expertTokenCount[i], sharedExpertCount[i]);
179192
}
180193
}
181194

182-
__global__ void updateLoadFactorKernel(
183-
MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, int* const enabled)
195+
__global__ void updateLoadFactorKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo,
196+
int* expertTokenCountPtr, int* const enabled)
184197
{
185198
if (*enabled == 0)
186199
{
187200
return;
188201
}
189202
int expertIdx = blockIdx.x * blockDim.x + threadIdx.x;
190-
int expertTokenCount = statisticInfo.expertTokenCount[expertIdx];
203+
int expertTokenCount = expertTokenCountPtr[expertIdx];
191204
float* loadFactor = statisticInfo.expertLoadFactor;
192205
loadFactor[expertIdx] = loadFactor[expertIdx] * statisticInfo.decayFactor + expertTokenCount;
193206
}
@@ -233,16 +246,71 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
233246
}
234247
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
235248
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
236-
metaInfo, statisticInfo, totalEltCount, enabled, gatheredRawExpertIds);
249+
metaInfo, statisticInfo.expertTokenCount, totalEltCount, enabled, gatheredRawExpertIds);
237250
}
238251

239252
if (isLastStage)
240253
{
241254
// only last stage need update load factor.
242255
int threadCount = 128;
243256
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
244-
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(metaInfo, statisticInfo, enabled);
257+
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
258+
metaInfo, statisticInfo, statisticInfo.expertTokenCount, enabled);
259+
}
260+
}
261+
262+
void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int numTotalTokens,
263+
int* localExpertTokenCount, int* const enabled, bool isFirstStage, bool isLastStage, int* const localRawExpertIds,
264+
cudaStream_t stream)
265+
{
266+
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
267+
if (isFirstStage)
268+
{
269+
// shift window and zero expertTokenCount
270+
// only first stage need shift window.
271+
int threadCount = metaInfo.expertCount;
272+
auto* kernelFunc = zeroExpertTokenCountKernel<int>;
273+
if (threadCount % 4 == 0)
274+
{
275+
threadCount /= 4;
276+
kernelFunc = zeroExpertTokenCountKernel<int4>;
277+
}
278+
else if (threadCount % 2 == 0)
279+
{
280+
threadCount /= 2;
281+
kernelFunc = zeroExpertTokenCountKernel<int2>;
282+
}
283+
dim3 gridDim(1);
284+
dim3 blockDim(threadCount);
285+
void* args[]
286+
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&localExpertTokenCount)};
287+
TLLM_CHECK_WITH_INFO(
288+
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
289+
TLLM_CUDA_CHECK(cudaLaunchKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
245290
}
291+
292+
{
293+
// do the statistic into expertTokenCount and maybe also expertLoadFactor;
294+
int threadCount = 1024;
295+
int totalEltCount = numTotalTokens * metaInfo.topK;
296+
int blockCount = (totalEltCount + threadCount - 1) / threadCount;
297+
if (blockCount > smCount)
298+
{
299+
blockCount = smCount;
300+
}
301+
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
302+
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
303+
metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds);
304+
}
305+
}
306+
307+
void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo,
308+
int* globalExpertTokenCount, int* const enabled, cudaStream_t stream)
309+
{
310+
int threadCount = 128;
311+
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
312+
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
313+
metaInfo, statisticInfo, globalExpertTokenCount, enabled);
246314
}
247315

248316
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>

cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,32 @@ void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal);
7070
void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, int numTotalTokens,
7171
int* const enabled, bool isFirstStage, bool isLastStage, int* const gatheredRawExpertIds, cudaStream_t stream);
7272

73+
// @brief do the statistic based on local device's data
74+
//
75+
// This function is used to launch a kernel to do the statistic for local tokens.
76+
//
77+
// @param metaInfo: the meta info
78+
// @param numTotalTokens: the total number of tokens in localRawExpertIds
79+
// @param localExpertTokenCount: the token count that each expert has for local tokens.
80+
// @param enabled: flag on device memory to indicate if the statistic is enabled
81+
// @param isFirstStage: whether the current stage is the first stage (only first stage need shift window)
82+
// @param isLastStage: whether the current stage is the last stage (only last stage need update load factor)
83+
// @param localRawExpertIds: the gathered raw expert ids, should have shape [numTotalTokens, metaInfo.topK]
84+
void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int numTotalTokens,
85+
int* localExpertTokenCount, int* const enabled, bool isFirstStage, bool isLastStage, int* const localRawExpertIds,
86+
cudaStream_t stream);
87+
88+
// @brief update the statistic info based on global info
89+
//
90+
// This function is used to launch a kernel to update the statistic info per iteration.
91+
//
92+
// @param metaInfo: the meta info
93+
// @param statisticInfo: the statistic info
94+
// @param globalExpertTokenCount: the global expert token count, should have shape [metaInfo.expertCount]
95+
// @param enabled: flag on device memory to indicate if the statistic is enabled
96+
void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo,
97+
int* globalExpertTokenCount, int* const enabled, cudaStream_t stream);
98+
7399
// @brief compute the route
74100
//
75101
// This function is used to launch a kernel to compute the route based on the token selected experts and the placement

cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,60 @@ void moeLoadBalanceStatistic(torch::Tensor gatheredRawExpertIds, torch::Tensor e
8585
static_cast<bool>(isFirstStage), static_cast<bool>(isLastStage), gatheredRawExpertIds.data_ptr<int>(), stream);
8686
}
8787

88+
void moeHierarchicalStatisticLocalDevice(torch::Tensor localRawExpertIds, torch::Tensor localExpertTokenCount,
89+
torch::Tensor enabled, int64_t singleLayerLoadBalancerPtr, int64_t isFirstStage, int64_t isLastStage)
90+
{
91+
CHECK_INPUT(localRawExpertIds, torch::kInt32);
92+
CHECK_INPUT(localExpertTokenCount, torch::kInt32);
93+
CHECK_INPUT(enabled, torch::kInt32);
94+
TORCH_CHECK(localRawExpertIds.dim() == 2, "localRawExpertIds must be a 2D tensor");
95+
TORCH_CHECK(localExpertTokenCount.dim() == 1, "localExpertTokenCount must be a 1D tensor");
96+
int topK = localRawExpertIds.size(1);
97+
TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor");
98+
TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element");
99+
TORCH_CHECK(isFirstStage == 0 || isFirstStage == 1, "isFirstStage must be 0 or 1");
100+
TORCH_CHECK(isLastStage == 0 || isLastStage == 1, "isLastStage must be 0 or 1");
101+
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
102+
103+
auto* loadBalancer
104+
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
105+
auto stream = at::cuda::getCurrentCUDAStream();
106+
107+
tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo();
108+
109+
TORCH_CHECK(localExpertTokenCount.size(0) == metaInfo.expertCount, "localExpertTokenCount should have shape (%d,)",
110+
metaInfo.expertCount);
111+
TORCH_CHECK(topK == metaInfo.topK, "topK must be equal to metaInfo.topK");
112+
113+
int numTotalTokens = localRawExpertIds.size(0);
114+
115+
tensorrt_llm::kernels::moeHierarchicalStatisticLocalDevice(metaInfo, numTotalTokens,
116+
localExpertTokenCount.data_ptr<int>(), enabled.data_ptr<int>(), static_cast<bool>(isFirstStage),
117+
static_cast<bool>(isLastStage), localRawExpertIds.data_ptr<int>(), stream);
118+
}
119+
120+
void moeHierarchicalStatisticUpdate(
121+
torch::Tensor globalExpertTokenCount, torch::Tensor enabled, int64_t singleLayerLoadBalancerPtr)
122+
{
123+
CHECK_INPUT(globalExpertTokenCount, torch::kInt32);
124+
CHECK_INPUT(enabled, torch::kInt32);
125+
TORCH_CHECK(globalExpertTokenCount.dim() == 1, "globalExpertTokenCount must be a 1D tensor");
126+
TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor");
127+
TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element");
128+
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
129+
auto* loadBalancer
130+
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
131+
auto stream = at::cuda::getCurrentCUDAStream();
132+
133+
tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo();
134+
auto statisticInfo = loadBalancer->getStatisticInfo();
135+
136+
TORCH_CHECK(globalExpertTokenCount.size(0) == metaInfo.expertCount,
137+
"globalExpertTokenCount should have shape (%d,)", metaInfo.expertCount);
138+
tensorrt_llm::kernels::moeHierarchicalStatisticUpdate(
139+
metaInfo, *statisticInfo, globalExpertTokenCount.data_ptr<int>(), enabled.data_ptr<int>(), stream);
140+
}
141+
88142
torch::Tensor moeLoadBalanceRouting(
89143
torch::Tensor tokenSelectedExperts, bool offsetByEpRank, int64_t singleLayerLoadBalancerPtr)
90144
{
@@ -182,6 +236,31 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
182236
m.impl("moe_load_balance_statistic", &torch_ext::moeLoadBalanceStatistic);
183237
}
184238

239+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
240+
{
241+
m.def(
242+
"moe_hierarchical_statistic_local_device(Tensor local_raw_expert_ids, Tensor local_expert_token_count, Tensor "
243+
"enabled, int "
244+
"single_layer_load_balancer_ptr, int is_first_stage, int is_last_stage) -> ()");
245+
}
246+
247+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
248+
{
249+
m.impl("moe_hierarchical_statistic_local_device", &torch_ext::moeHierarchicalStatisticLocalDevice);
250+
}
251+
252+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
253+
{
254+
m.def(
255+
"moe_hierarchical_statistic_update(Tensor global_expert_token_count, Tensor enabled, int "
256+
"single_layer_load_balancer_ptr) -> ()");
257+
}
258+
259+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
260+
{
261+
m.impl("moe_hierarchical_statistic_update", &torch_ext::moeHierarchicalStatisticUpdate);
262+
}
263+
185264
TORCH_LIBRARY_FRAGMENT(trtllm, m)
186265
{
187266
m.def(

cpp/tests/kernels/moeLoadBalanceKernelTest.cpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ struct MoeLoadBalanceTestParam
110110
bool isFirstStage;
111111
bool isLastStage;
112112
float decayFactor;
113-
int rawDataWindowSize;
114113
};
115114

116115
class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam<MoeLoadBalanceTestParam>
@@ -126,14 +125,13 @@ class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam<MoeLoa
126125
mMetaInfo.epSize = param.epSize;
127126
mMetaInfo.slotCountPerRank = param.slotCountPerRank;
128127

129-
mStatisticInfo.rawDataWindowSize = param.rawDataWindowSize;
130128
mStatisticInfo.decayFactor = param.decayFactor;
131129

132130
ASSERT_EQ(cudaStreamCreate(&mStream), cudaSuccess);
133131

134132
// allocate device memory
135133
size_t expertLoadFactorSize = param.expertCount * sizeof(float);
136-
size_t expertTokenCountSize = param.expertCount * param.rawDataWindowSize * sizeof(int);
134+
size_t expertTokenCountSize = param.expertCount * mStatisticInfo.rawDataWindowSize * sizeof(int);
137135
size_t gatheredIdsSize = param.maxTokenCountPerRank * param.epSize * param.topK * sizeof(int);
138136

139137
ASSERT_EQ(cudaMalloc(&mDeviceEnabled, sizeof(int)), cudaSuccess);
@@ -147,9 +145,9 @@ class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam<MoeLoa
147145

148146
// allocate host memory for verification
149147
mExpectedLoadFactor.resize(param.expertCount, 0.0f);
150-
mExpectedExpertTokenCount.resize(param.expertCount * param.rawDataWindowSize);
148+
mExpectedExpertTokenCount.resize(param.expertCount * mStatisticInfo.rawDataWindowSize);
151149
mHostExpertLoadFactor.resize(param.expertCount);
152-
mHostExpertTokenCount.resize(param.expertCount * param.rawDataWindowSize);
150+
mHostExpertTokenCount.resize(param.expertCount * mStatisticInfo.rawDataWindowSize);
153151
mHostGatheredIds.resize(param.maxTokenCountPerRank * param.epSize * param.topK);
154152

155153
// initialize the random number generator
@@ -188,7 +186,7 @@ class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam<MoeLoa
188186
mExpectedExpertTokenCount = mHostExpertTokenCount;
189187
if (param.isFirstStage)
190188
{
191-
for (int windowIdx = param.rawDataWindowSize - 1; windowIdx >= 0; --windowIdx)
189+
for (int windowIdx = mStatisticInfo.rawDataWindowSize - 1; windowIdx >= 0; --windowIdx)
192190
{
193191
if (windowIdx > 0)
194192
{
@@ -305,7 +303,7 @@ TEST_P(MoeLoadBalanceStatisticKernelTest, TestStatistics)
305303
EXPECT_NEAR(mHostExpertLoadFactor[i], mExpectedLoadFactor[i], 1e-6)
306304
<< "Expert " << i << " load factor mismatch";
307305
}
308-
for (int i = 0; i < param.expertCount * param.rawDataWindowSize; ++i)
306+
for (int i = 0; i < param.expertCount * mStatisticInfo.rawDataWindowSize; ++i)
309307
{
310308
EXPECT_EQ(mHostExpertTokenCount[i], mExpectedExpertTokenCount[i]) << "Expert " << i << " token count mismatch";
311309
}
@@ -323,8 +321,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceStatisticKernelTests, MoeLoadBalanceStati
323321
/* maxTokenCountPerRank */ 128,
324322
/* isFirstStage */ true,
325323
/* isLastStage */ true,
326-
/* decayFactor */ 0.9f,
327-
/* rawDataWindowSize */ 3},
324+
/* decayFactor */ 0.9f},
328325
// large scale test scenarios
329326
MoeLoadBalanceTestParam{/* expertCount */ 64,
330327
/* topK */ 4,
@@ -334,8 +331,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceStatisticKernelTests, MoeLoadBalanceStati
334331
/* maxTokenCountPerRank */ 512,
335332
/* isFirstStage */ false,
336333
/* isLastStage */ true,
337-
/* decayFactor */ 0.95f,
338-
/* rawDataWindowSize */ 5} // can add more test scenarios
334+
/* decayFactor */ 0.95f} // can add more test scenarios
339335
));
340336

341337
class MoeLoadBalanceRouteKernelTest : public ::testing::TestWithParam<MoeLoadBalanceTestParam>
@@ -601,8 +597,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern
601597
/* maxTokenCountPerRank */ 128,
602598
/* isFirstStage */ true,
603599
/* isLastStage */ true,
604-
/* decayFactor */ 0.9f,
605-
/* rawDataWindowSize */ 3},
600+
/* decayFactor */ 0.9f},
606601
// large scale test scenarios
607602
MoeLoadBalanceTestParam{/* expertCount */ 256,
608603
/* topK */ 8,
@@ -612,8 +607,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern
612607
/* maxTokenCountPerRank */ 5000,
613608
/* isFirstStage */ false,
614609
/* isLastStage */ true,
615-
/* decayFactor */ 0.95f,
616-
/* rawDataWindowSize */ 5},
610+
/* decayFactor */ 0.95f},
617611
// edge case: single rank
618612
MoeLoadBalanceTestParam{/* expertCount */ 16,
619613
/* topK */ 2,
@@ -623,5 +617,4 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern
623617
/* maxTokenCountPerRank */ 64,
624618
/* isFirstStage */ true,
625619
/* isLastStage */ true,
626-
/* decayFactor */ 0.9f,
627-
/* rawDataWindowSize */ 1}));
620+
/* decayFactor */ 0.9f}));

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,22 @@ def _(single_layer_load_balancer_ptr: int):
263263
pass
264264

265265
@torch.library.register_fake("trtllm::moe_load_balance_statistic")
266-
def _(single_layer_load_balancer_ptr: int,
267-
gathered_raw_expert_ids: torch.Tensor, enabled: torch.Tensor,
268-
is_first_stage: bool, is_last_stage: bool):
266+
def _(gathered_raw_expert_ids: torch.Tensor, enabled: torch.Tensor,
267+
single_layer_load_balancer_ptr: int, is_first_stage: bool,
268+
is_last_stage: bool):
269+
pass
270+
271+
@torch.library.register_fake(
272+
"trtllm::moe_hierarchical_statistic_local_device")
273+
def _(local_raw_expert_ids: torch.Tensor,
274+
local_expert_token_count: torch.Tensor, enabled: torch.Tensor,
275+
single_layer_load_balancer_ptr: int, is_first_stage: bool,
276+
is_last_stage: bool):
277+
pass
278+
279+
@torch.library.register_fake("trtllm::moe_hierarchical_statistic_update")
280+
def _(global_expert_token_count: torch.Tensor, enabled: torch.Tensor,
281+
single_layer_load_balancer_ptr: int):
269282
pass
270283

271284
@torch.library.register_fake("trtllm::moe_load_balance_routing")

0 commit comments

Comments
 (0)