1111
1212from ...distributed import allgather
1313from ...model_config import ModelConfig
14- from ...utils import AuxStreamType , Fp4QuantizedTensor
14+ from ...utils import AuxStreamType , EventType , Fp4QuantizedTensor
1515from .fused_moe_cutlass import CutlassFusedMoE
1616from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ,
1717 MoEWeightLoadingMode , UnquantizedFusedMoEMethod )
@@ -88,6 +88,7 @@ def _masked_index_copy_group_quant_fp8(
8888
8989def masked_index_copy_group_quant_fp8 (
9090 output : torch .Tensor ,
91+ output_s : torch .Tensor ,
9192 input : torch .Tensor ,
9293 start_offsets : torch .Tensor ,
9394 row_indices : torch .Tensor ,
@@ -108,14 +109,10 @@ def masked_index_copy_group_quant_fp8(
108109 col_size = output .shape [1 ]
109110 dim_size = output .shape [2 ]
110111
111- # create padded output_s
112112 alignment = 4
113113 scale_dim = (dim_size + group_size - 1 ) // group_size
114114 padded_dim_size = (scale_dim + alignment - 1 ) // alignment * alignment
115115 padded_col_size = (col_size + alignment - 1 ) // alignment * alignment
116- output_s = torch .zeros ((row_size , padded_dim_size // 4 , padded_col_size ),
117- dtype = torch .int32 ,
118- device = 'cuda' )
119116
120117 # get block/grid/stage/warp
121118 num_groups = (dim_size + group_size - 1 ) // group_size
@@ -247,17 +244,14 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
247244
248245@nvtx_range ("[DG]" )
249246def deepgemm_fp8_group_blockwise_gemm (
247+ d : torch .Tensor ,
250248 a : torch .Tensor ,
251249 b : torch .Tensor ,
252250 sfa : torch .Tensor ,
253251 sfb : torch .Tensor ,
254252 masked_m : torch .Tensor ,
255253 expected_m : int ,
256254) -> torch .Tensor :
257- d = torch .empty ((a .shape [0 ], a .shape [1 ], b .shape [1 ]),
258- device = b .device ,
259- dtype = torch .bfloat16 )
260-
261255 # NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
262256 assert a .stride (- 1 ) == 1
263257 assert b .stride (- 1 ) == 1
@@ -287,7 +281,16 @@ def deepgemm_fp8_group_blockwise_gemm(
287281 masked_m ,
288282 expected_m ,
289283 disable_ue8m0_cast = True )
290- return d
284+ return
285+
286+
287+ def set_strides (workspace : torch .Tensor , g : int , m : int , k : int ):
288+ workspace = workspace [0 :g * m * k ]
289+ workspace = workspace .as_strided (
290+ size = (g , m , k ),
291+ stride = (m * k , k , 1 ),
292+ )
293+ return workspace
291294
292295
293296class DeepGemmFusedMoE (CutlassFusedMoE ):
@@ -327,6 +330,18 @@ def __init__(
327330 apply_router_weight_on_input : bool = False ,
328331 layer_idx : Optional [int ] = None ,
329332 ):
333+ if model_config .moe_max_num_tokens is None :
334+ moe_max_num_tokens = model_config .max_num_tokens * model_config .mapping .dp_size
335+ # The default moe_max_num_tokens is calculated from the following formula:
336+ # max_isl = 8196, max_batch_size = 1024, mtp = 0
337+ # max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
338+ # moe_max_num_tokens = max_num_tokens * 2 = 18688
339+ # It can avoid OOM for 8k/1k cases.
340+ default_moe_max_num_tokens = 18688
341+ if moe_max_num_tokens > default_moe_max_num_tokens :
342+ model_config ._frozen = False
343+ model_config .moe_max_num_tokens = default_moe_max_num_tokens
344+ model_config ._frozen = True
330345
331346 super ().__init__ (
332347 routing_method = routing_method ,
@@ -342,6 +357,37 @@ def __init__(
342357 layer_idx = layer_idx ,
343358 )
344359
360+ def get_workspace (self , m_max : int , group_size : int ):
361+ hidden_size = self .hidden_size
362+ intermediate_size = self .intermediate_size
363+ num_experts = self .expert_size_per_partition
364+
365+ # create workspace
366+ fp8_dim = max (hidden_size , intermediate_size )
367+ workspace_0 = torch .empty ((num_experts * m_max * fp8_dim ),
368+ dtype = torch .float8_e4m3fn ,
369+ device = 'cuda' )
370+ workspace_1 = torch .empty (
371+ (num_experts * m_max * max (intermediate_size * 2 , hidden_size )),
372+ dtype = torch .bfloat16 ,
373+ device = 'cuda' )
374+
375+ # create workspace for scaling factors
376+ m_padded = fp8_utils .align (m_max , 4 )
377+ scale_k = fp8_utils .ceil_div (fp8_dim , group_size )
378+ scale_k_padded = fp8_utils .align (scale_k , 4 )
379+ workspace_sf = torch .empty (
380+ (num_experts * (scale_k_padded // 4 ) * m_padded ),
381+ dtype = torch .int32 ,
382+ device = 'cuda' )
383+
384+ workspace = {
385+ "workspace_0" : workspace_0 ,
386+ "workspace_1" : workspace_1 ,
387+ "workspace_sf" : workspace_sf ,
388+ }
389+ return workspace
390+
345391 def _get_quant_method (self ):
346392 if self .quant_config is not None and self .quant_config .layer_quant_mode .has_any_quant (
347393 exclude_kv_cache = True ):
@@ -362,6 +408,7 @@ def forward_chunk(
362408 output_dtype : Optional [torch .dtype ] = None ,
363409 all_rank_num_tokens : Optional [List [int ]] = None ,
364410 use_dp_padding : Optional [bool ] = None ,
411+ workspace : Optional [dict ] = None ,
365412 ) -> torch .Tensor :
366413 if isinstance (x , Fp4QuantizedTensor ):
367414 assert output_dtype is not None
@@ -437,32 +484,72 @@ def forward_chunk(
437484 masked_m , token_to_expert_map = preprocess_after_permute (
438485 expert_first_token_offset_tensor , permuted_data_tensor )
439486
440- m_max = (x .shape [0 ] + 127 ) // 128 * 128
441487 expected_m = (token_selected_experts .numel () +
442488 self .expert_size_per_partition -
443489 1 ) // self .expert_size_per_partition
444- act_input_fp8 = torch .empty (
445- (self .expert_size_per_partition , m_max , self .hidden_size ),
446- dtype = torch .float8_e4m3fn ,
447- device = 'cuda' )
490+
491+ # padding and quantization
492+ m_max = fp8_utils .align (x .shape [0 ], 128 )
493+ act_input_fp8 = set_strides (workspace ["workspace_0" ],
494+ self .expert_size_per_partition , m_max ,
495+ self .hidden_size )
496+
497+ m_padded = fp8_utils .align (m_max , 4 )
498+ scale_k = fp8_utils .ceil_div (self .hidden_size , 128 )
499+ scale_k_padded = fp8_utils .align (scale_k , 4 )
500+ act_input_sf = set_strides (workspace ["workspace_sf" ],
501+ self .expert_size_per_partition ,
502+ scale_k_padded // 4 , m_padded )
503+
448504 act_input_sf = masked_index_copy_group_quant_fp8 (
449505 act_input_fp8 ,
506+ act_input_sf ,
450507 permuted_data_tensor ,
451508 expert_first_token_offset_tensor ,
452509 token_to_expert_map ,
453510 group_size = 128 )
454511
455- h1 = deepgemm_fp8_group_blockwise_gemm (
512+ # grouped gemm 1
513+ h1 = set_strides (workspace ["workspace_1" ],
514+ self .expert_size_per_partition , m_max ,
515+ self .intermediate_size * 2 )
516+
517+ deepgemm_fp8_group_blockwise_gemm (
518+ d = h1 ,
456519 a = act_input_fp8 ,
457520 b = self .w3_w1_weight ,
458521 sfa = act_input_sf ,
459522 sfb = self .quant_scales [0 ],
460523 masked_m = masked_m ,
461524 expected_m = expected_m ,
462525 )
463- act_input_fp8 , act_input_sf = fp8_utils .silu_and_mul_masked_post_quant_fwd (
464- input = h1 , quant_group_size = 128 , masked_m = masked_m , scale_ue8m0 = True )
465- h3 = deepgemm_fp8_group_blockwise_gemm (
526+
527+ # activation and quantization
528+ act_input_fp8 = set_strides (workspace ["workspace_0" ],
529+ self .expert_size_per_partition , m_max ,
530+ self .intermediate_size )
531+
532+ scale_k = fp8_utils .ceil_div (self .intermediate_size , 128 )
533+ scale_k_padded = fp8_utils .align (scale_k , 4 )
534+ act_input_sf = set_strides (workspace ["workspace_sf" ],
535+ self .expert_size_per_partition ,
536+ scale_k_padded // 4 , m_padded )
537+
538+ act_input_sf = fp8_utils .silu_and_mul_masked_post_quant_fwd (
539+ output = act_input_fp8 ,
540+ output_scale = act_input_sf ,
541+ input = h1 ,
542+ quant_group_size = 128 ,
543+ masked_m = masked_m ,
544+ scale_ue8m0 = True )
545+
546+ # grouped gemm 2
547+ h3 = set_strides (workspace ["workspace_1" ],
548+ self .expert_size_per_partition , m_max ,
549+ self .hidden_size )
550+
551+ deepgemm_fp8_group_blockwise_gemm (
552+ d = h3 ,
466553 a = act_input_fp8 ,
467554 b = self .w2_weight ,
468555 sfa = act_input_sf ,
@@ -471,6 +558,7 @@ def forward_chunk(
471558 expected_m = expected_m ,
472559 )
473560
561+ # gather and finalize
474562 triton_masked_index_gather (permuted_data_tensor , h3 ,
475563 expert_first_token_offset_tensor ,
476564 token_to_expert_map )
@@ -495,3 +583,137 @@ def forward_chunk(
495583 )
496584
497585 return final_hidden_states
586+
587+ def forward (
588+ self ,
589+ x : Union [torch .Tensor , Fp4QuantizedTensor ],
590+ router_logits : torch .Tensor ,
591+ do_finalize : bool = True , # used by other MoE backends
592+ output_dtype : Optional [torch .dtype ] = None ,
593+ all_rank_num_tokens : Optional [List [int ]] = None ,
594+ all_rank_max_num_tokens : Optional [int ] = None ,
595+ use_dp_padding : Optional [bool ] = None ,
596+ ) -> torch .Tensor :
597+ assert do_finalize , "CutlassFusedMoE does not support do_finalize=False"
598+ if self .use_dp and self .parallel_size > 1 :
599+ assert all_rank_num_tokens is not None
600+ assert use_dp_padding is not None
601+ num_rows = sum (all_rank_num_tokens )
602+ else :
603+ num_rows = x .shape [0 ]
604+
605+ # In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks.
606+ # Because we will use two streams in chunked moe and preallocate two workspaces.
607+ num_chunks = 1
608+ if num_rows > self .moe_max_num_tokens * 2 :
609+ num_chunks = (num_rows + self .moe_max_num_tokens -
610+ 1 ) // self .moe_max_num_tokens
611+
612+ if use_dp_padding :
613+ all_rank_num_tokens_padded = [all_rank_max_num_tokens
614+ ] * len (all_rank_num_tokens )
615+ else :
616+ all_rank_num_tokens_padded = all_rank_num_tokens
617+
618+ if num_chunks == 1 :
619+ # create workspace
620+ num_rows = x .shape [0 ]
621+ if self .use_dp :
622+ num_rows = sum (all_rank_num_tokens_padded )
623+ m_max = fp8_utils .align (num_rows , 128 )
624+ workspace = self .get_workspace (m_max , 128 )
625+ outputs = self .forward_chunk (
626+ x ,
627+ router_logits ,
628+ output_dtype ,
629+ all_rank_num_tokens = all_rank_num_tokens_padded ,
630+ use_dp_padding = use_dp_padding ,
631+ workspace = workspace )
632+ outputs = self .reducescatter_or_allreduce (
633+ outputs ,
634+ all_rank_num_tokens = all_rank_num_tokens_padded ,
635+ use_dp_padding = use_dp_padding )
636+ else :
637+ if self .use_dp :
638+ all_rank_chunk_size_list = [
639+ self .split_chunk (val , num_chunks )
640+ for val in all_rank_num_tokens_padded
641+ ]
642+ all_rank_num_tokens_list = [[
643+ val [idx_chunk ] for val in all_rank_chunk_size_list
644+ ] for idx_chunk in range (num_chunks )]
645+ chunk_size_list = all_rank_chunk_size_list [self .rank ]
646+ else :
647+ all_rank_num_tokens_list = [None ] * num_chunks
648+ chunk_size_list = self .split_chunk (x .shape [0 ], num_chunks )
649+
650+ # create workspace
651+ chunk_size_0 = sum (all_rank_num_tokens_list [0 ]
652+ ) if self .use_dp else chunk_size_list [0 ]
653+ chunk_size_1 = sum (all_rank_num_tokens_list [1 ]
654+ ) if self .use_dp else chunk_size_list [1 ]
655+ workspace_0 = self .get_workspace (fp8_utils .align (chunk_size_0 , 128 ),
656+ 128 )
657+ workspace_1 = self .get_workspace (fp8_utils .align (chunk_size_1 , 128 ),
658+ 128 )
659+
660+ x_list = x .split (chunk_size_list )
661+ router_logits_list = router_logits .split (chunk_size_list )
662+
663+ self .event_dict [EventType .Main ].record ()
664+ with torch .cuda .stream (self .aux_stream ):
665+ self .event_dict [EventType .Main ].wait ()
666+
667+ def _forward_chunk (x_ , router_logits_ , idx , workspace ):
668+ return self .forward_chunk (
669+ x_ ,
670+ router_logits_ ,
671+ all_rank_num_tokens = all_rank_num_tokens_list [idx ]
672+ if self .use_dp else None ,
673+ use_dp_padding = use_dp_padding ,
674+ workspace = workspace )
675+
676+ def _reducescatter_or_allreduce (x_ , idx ):
677+ return self .reducescatter_or_allreduce (
678+ x_ ,
679+ all_rank_num_tokens = all_rank_num_tokens_list [idx ],
680+ use_dp_padding = use_dp_padding )
681+
682+ outputs_list = []
683+ # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
684+ for idx_chunk , (x , router_logits ) in enumerate (
685+ zip (x_list , router_logits_list )):
686+
687+ if idx_chunk % 2 == 0 :
688+ with torch .cuda .stream (self .aux_stream ):
689+ outputs = _forward_chunk (x , router_logits , idx_chunk ,
690+ workspace_0 )
691+ if idx_chunk > 0 :
692+ outputs_list [- 1 ] = _reducescatter_or_allreduce (
693+ outputs_list [- 1 ], idx_chunk - 1 )
694+ else :
695+ outputs = _forward_chunk (x , router_logits , idx_chunk ,
696+ workspace_1 )
697+ with torch .cuda .stream (self .aux_stream ):
698+ outputs_list [- 1 ] = _reducescatter_or_allreduce (
699+ outputs_list [- 1 ], idx_chunk - 1 )
700+
701+ outputs_list .append (outputs )
702+
703+ if num_chunks % 2 == 0 :
704+ outputs_list [- 1 ] = _reducescatter_or_allreduce (
705+ outputs_list [- 1 ], - 1 )
706+ else :
707+ with torch .cuda .stream (self .aux_stream ):
708+ outputs_list [- 1 ] = _reducescatter_or_allreduce (
709+ outputs_list [- 1 ], - 1 )
710+ with torch .cuda .stream (self .aux_stream ):
711+ self .event_dict [EventType .MoeChunkingOverlap ].record ()
712+ self .event_dict [EventType .MoeChunkingOverlap ].wait ()
713+
714+ outputs = torch .cat (outputs_list )
715+
716+ if self .use_dp and self .parallel_size > 1 :
717+ rank = self .mapping .tp_rank
718+ outputs = outputs [:all_rank_num_tokens [rank ]]
719+ return outputs
0 commit comments