Skip to content

Commit 446fd13

Browse files
yuxianqlfr-0531
authored andcommitted
Update default_moe_max_num_tokens.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent b7c4180 commit 446fd13

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,10 @@ def __init__(
112112

113113
max_num_tokens = model_config.max_num_tokens
114114
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
115-
if self.use_dp:
116-
max_num_tokens *= model_config.mapping.world_size
117-
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
115+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
116+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
118117
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
119-
if self.moe_max_num_tokens < max_num_tokens:
118+
if self.moe_max_num_tokens < moe_max_num_tokens:
120119
self.aux_stream = aux_stream_dict[
121120
AuxStreamType.
122121
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,16 @@ def __init__(
330330
apply_router_weight_on_input: bool = False,
331331
layer_idx: Optional[int] = None,
332332
):
333+
if model_config.moe_max_num_tokens is None:
334+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
335+
# The default moe_max_num_tokens is calculated from the following formula:
336+
# max_isl = 8196, max_batch_size = 1024, mtp = 0
337+
# max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
338+
# moe_max_num_tokens = max_num_tokens * 2 = 18688
339+
# It can avoid OOM for 8k/1k cases.
340+
default_moe_max_num_tokens = 18688
341+
if moe_max_num_tokens > default_moe_max_num_tokens:
342+
model_config.moe_max_num_tokens = default_moe_max_num_tokens
333343

334344
super().__init__(
335345
routing_method=routing_method,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,8 @@ def __init__(
8383

8484
max_num_tokens = model_config.max_num_tokens
8585
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
86-
if self.use_dp:
87-
max_num_tokens *= model_config.mapping.world_size
88-
self.moe_max_num_tokens = (model_config.moe_max_num_tokens
89-
if model_config.moe_max_num_tokens
90-
is not None else max_num_tokens)
86+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
87+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
9188

9289
self._weights_created = False
9390
if not model_config.skip_create_weights_in_init:

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,10 @@ def __init__(
152152

153153
max_num_tokens = model_config.max_num_tokens
154154
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
155-
max_num_tokens *= model_config.mapping.world_size
156-
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
155+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
156+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
157157
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
158-
if self.moe_max_num_tokens < max_num_tokens:
158+
if self.moe_max_num_tokens < moe_max_num_tokens:
159159
self.aux_stream = aux_stream_dict[
160160
AuxStreamType.
161161
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(

0 commit comments

Comments
 (0)