File tree Expand file tree Collapse file tree 4 files changed +18
-12
lines changed
tensorrt_llm/_torch/modules/fused_moe Expand file tree Collapse file tree 4 files changed +18
-12
lines changed Original file line number Diff line number Diff line change @@ -112,11 +112,10 @@ def __init__(
112112
113113 max_num_tokens = model_config .max_num_tokens
114114 # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
115- if self .use_dp :
116- max_num_tokens *= model_config .mapping .world_size
117- self .moe_max_num_tokens = model_config .moe_max_num_tokens or max_num_tokens
115+ moe_max_num_tokens = model_config .max_num_tokens * model_config .mapping .dp_size
116+ self .moe_max_num_tokens = model_config .moe_max_num_tokens or moe_max_num_tokens
118117 # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
119- if self .moe_max_num_tokens < max_num_tokens :
118+ if self .moe_max_num_tokens < moe_max_num_tokens :
120119 self .aux_stream = aux_stream_dict [
121120 AuxStreamType .
122121 MoeChunkingOverlap ] if aux_stream_dict is not None else torch .cuda .Stream (
Original file line number Diff line number Diff line change @@ -330,6 +330,16 @@ def __init__(
330330 apply_router_weight_on_input : bool = False ,
331331 layer_idx : Optional [int ] = None ,
332332 ):
333+ if model_config .moe_max_num_tokens is None :
334+ moe_max_num_tokens = model_config .max_num_tokens * model_config .mapping .dp_size
335+ # The default moe_max_num_tokens is calculated from the following formula:
336+ # max_isl = 8196, max_batch_size = 1024, mtp = 0
337+ # max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
338+ # moe_max_num_tokens = max_num_tokens * 2 = 18688
339+ # It can avoid OOM for 8k/1k cases.
340+ default_moe_max_num_tokens = 18688
341+ if moe_max_num_tokens > default_moe_max_num_tokens :
342+ model_config .moe_max_num_tokens = default_moe_max_num_tokens
333343
334344 super ().__init__ (
335345 routing_method = routing_method ,
Original file line number Diff line number Diff line change @@ -83,11 +83,8 @@ def __init__(
8383
8484 max_num_tokens = model_config .max_num_tokens
8585 # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
86- if self .use_dp :
87- max_num_tokens *= model_config .mapping .world_size
88- self .moe_max_num_tokens = (model_config .moe_max_num_tokens
89- if model_config .moe_max_num_tokens
90- is not None else max_num_tokens )
86+ moe_max_num_tokens = model_config .max_num_tokens * model_config .mapping .dp_size
87+ self .moe_max_num_tokens = model_config .moe_max_num_tokens or moe_max_num_tokens
9188
9289 self ._weights_created = False
9390 if not model_config .skip_create_weights_in_init :
Original file line number Diff line number Diff line change @@ -152,10 +152,10 @@ def __init__(
152152
153153 max_num_tokens = model_config .max_num_tokens
154154 # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
155- max_num_tokens *= model_config .mapping .world_size
156- self .moe_max_num_tokens = model_config .moe_max_num_tokens if model_config . moe_max_num_tokens is not None else max_num_tokens
155+ moe_max_num_tokens = model_config . max_num_tokens * model_config .mapping .dp_size
156+ self .moe_max_num_tokens = model_config .moe_max_num_tokens or moe_max_num_tokens
157157 # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
158- if self .moe_max_num_tokens < max_num_tokens :
158+ if self .moe_max_num_tokens < moe_max_num_tokens :
159159 self .aux_stream = aux_stream_dict [
160160 AuxStreamType .
161161 MoeChunkingOverlap ] if aux_stream_dict is not None else torch .cuda .Stream (
You can’t perform that action at this time.
0 commit comments