@@ -284,6 +284,15 @@ def deepgemm_fp8_group_blockwise_gemm(
284284 return
285285
286286
287+ def set_strides (workspace : torch .Tensor , g : int , m : int , k : int ):
288+ workspace = workspace [0 :g * m * k ]
289+ workspace = workspace .as_strided (
290+ size = (g , m , k ),
291+ stride = (m * k , k , 1 ),
292+ )
293+ return workspace
294+
295+
287296class DeepGemmFusedMoE (CutlassFusedMoE ):
288297 """
289298 Python Flow of Fused Mixture of Experts (MoE) Layer.
@@ -337,28 +346,26 @@ def __init__(
337346 )
338347
339348 def get_workspace (self , m_max : int , group_size : int ):
340- hidden_size_0 = max (self .hidden_size , self .w3_w1_weight .shape [1 ] // 2 )
341- workspace_0 = torch .empty (
342- (self .expert_size_per_partition * m_max * hidden_size_0 ),
343- dtype = torch .float8_e4m3fn ,
344- device = 'cuda' )
345-
346- max (self .w3_w1_weight .shape [1 ], self .w2_weight .shape [1 ])
349+ hidden_size = self .hidden_size
350+ intermediate_size = self .intermediate_size
351+ num_experts = self .expert_size_per_partition
352+
353+ # create workspace
354+ fp8_dim = max (hidden_size , intermediate_size )
355+ workspace_0 = torch .empty ((num_experts * m_max * fp8_dim ),
356+ dtype = torch .float8_e4m3fn ,
357+ device = 'cuda' )
347358 workspace_1 = torch .empty (
348- (self . expert_size_per_partition * m_max * self . hidden_size ),
359+ (num_experts * m_max * max ( intermediate_size * 2 , hidden_size ) ),
349360 dtype = torch .bfloat16 ,
350361 device = 'cuda' )
351362
352- alignment = 4
353- scale_dim = (self .hidden_size + group_size - 1 ) // group_size
354- padded_dim_size = (scale_dim + alignment - 1 ) // alignment * alignment
355- padded_col_size = (m_max + alignment - 1 ) // alignment * alignment
356- scale_k = (self .w3_w1_weight .shape [1 ] // 2 + group_size -
357- 1 ) // group_size
358- scale_k_padded = (scale_k + alignment - 1 ) // alignment * alignment
359- row_size = max (padded_dim_size // 4 , scale_k_padded // 4 )
363+ # create workspace for scaling factors
364+ m_padded = fp8_utils .align (m_max , 4 )
365+ scale_k = fp8_utils .ceil_div (fp8_dim , group_size )
366+ scale_k_padded = fp8_utils .align (scale_k , 4 )
360367 workspace_sf = torch .empty (
361- (self . expert_size_per_partition * row_size * padded_col_size ),
368+ (num_experts * ( scale_k_padded // 4 ) * m_padded ),
362369 dtype = torch .int32 ,
363370 device = 'cuda' )
364371
@@ -468,30 +475,20 @@ def forward_chunk(
468475 expected_m = (token_selected_experts .numel () +
469476 self .expert_size_per_partition -
470477 1 ) // self .expert_size_per_partition
471- # prepare workspace
472- m_max = (x .shape [0 ] + 127 ) // 128 * 128
473- act_input_fp8 = workspace ["workspace_0" ][0 :self .
474- expert_size_per_partition *
475- m_max * self .hidden_size ]
476- # act_input_fp8.view(self.expert_size_per_partition, m_max, self.hidden_size)
477- act_input_fp8 = act_input_fp8 .as_strided (
478- size = (self .expert_size_per_partition , m_max , self .hidden_size ),
479- stride = (m_max * self .hidden_size , self .hidden_size , 1 ),
480- )
481- alignment = 4
482- scale_dim = (self .hidden_size + 128 - 1 ) // 128
483- padded_dim_size = (scale_dim + alignment - 1 ) // alignment * alignment
484- padded_col_size = (m_max + alignment - 1 ) // alignment * alignment
485- act_input_sf = workspace ["workspace_sf" ][0 :self .
486- expert_size_per_partition *
487- padded_dim_size // 4 *
488- padded_col_size ]
489- # act_input_sf.view(self.expert_size_per_partition, padded_dim_size // 4, padded_col_size)
490- act_input_sf = act_input_sf .as_strided (
491- size = (self .expert_size_per_partition , padded_dim_size // 4 ,
492- padded_col_size ),
493- stride = (padded_dim_size // 4 * padded_col_size , padded_col_size , 1 ),
494- )
478+
479+ # padding and quantization
480+ m_max = fp8_utils .align (x .shape [0 ], 128 )
481+ act_input_fp8 = set_strides (workspace ["workspace_0" ],
482+ self .expert_size_per_partition , m_max ,
483+ self .hidden_size )
484+
485+ m_padded = fp8_utils .align (m_max , 4 )
486+ scale_k = fp8_utils .ceil_div (self .hidden_size , 128 )
487+ scale_k_padded = fp8_utils .align (scale_k , 4 )
488+ act_input_sf = set_strides (workspace ["workspace_sf" ],
489+ self .expert_size_per_partition ,
490+ scale_k_padded // 4 , m_padded )
491+
495492 act_input_sf = masked_index_copy_group_quant_fp8 (
496493 act_input_fp8 ,
497494 act_input_sf ,
@@ -500,16 +497,11 @@ def forward_chunk(
500497 token_to_expert_map ,
501498 group_size = 128 )
502499
503- # prepare workspace
504- h1 = workspace ["workspace_1" ][0 :self .expert_size_per_partition * m_max *
505- self .w3_w1_weight .shape [1 ]]
506- # h1.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1])
507- h1 = h1 .as_strided (
508- size = (self .expert_size_per_partition , m_max ,
509- self .w3_w1_weight .shape [1 ]),
510- stride = (m_max * self .w3_w1_weight .shape [1 ],
511- self .w3_w1_weight .shape [1 ], 1 ),
512- )
500+ # grouped gemm 1
501+ h1 = set_strides (workspace ["workspace_1" ],
502+ self .expert_size_per_partition , m_max ,
503+ self .intermediate_size * 2 )
504+
513505 deepgemm_fp8_group_blockwise_gemm (
514506 d = h1 ,
515507 a = act_input_fp8 ,
@@ -520,54 +512,41 @@ def forward_chunk(
520512 expected_m = expected_m ,
521513 )
522514
523- # prepare workspace
524- h2 = workspace ["workspace_0" ][0 :self .expert_size_per_partition * m_max *
525- self .w3_w1_weight .shape [1 ] // 2 ]
526- # h2.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1] // 2)
527- h2 = h2 .as_strided (
528- size = (self .expert_size_per_partition , m_max ,
529- self .w3_w1_weight .shape [1 ] // 2 ),
530- stride = (m_max * self .w3_w1_weight .shape [1 ] // 2 ,
531- self .w3_w1_weight .shape [1 ] // 2 , 1 ),
532- )
533- scale_k = (self .w3_w1_weight .shape [1 ] // 2 + 128 - 1 ) // 128
534- scale_k_padded = (scale_k + alignment - 1 ) // alignment * alignment
535- h2_sf = workspace ["workspace_sf" ][0 :self .expert_size_per_partition *
536- scale_k_padded // 4 * padded_col_size ]
537- # h2_sf.view(self.expert_size_per_partition, scale_k_padded // 4, padded_col_size)
538- h2_sf = h2_sf .as_strided (
539- size = (self .expert_size_per_partition , scale_k_padded // 4 ,
540- padded_col_size ),
541- stride = (scale_k_padded // 4 * padded_col_size , padded_col_size , 1 ),
542- )
515+ # activation and quantization
516+ act_input_fp8 = set_strides (workspace ["workspace_0" ],
517+ self .expert_size_per_partition , m_max ,
518+ self .intermediate_size )
519+
520+ scale_k = fp8_utils .ceil_div (self .intermediate_size , 128 )
521+ scale_k_padded = fp8_utils .align (scale_k , 4 )
522+ act_input_sf = set_strides (workspace ["workspace_sf" ],
523+ self .expert_size_per_partition ,
524+ scale_k_padded // 4 , m_padded )
525+
543526 act_input_sf = fp8_utils .silu_and_mul_masked_post_quant_fwd (
544- output = h2 ,
545- output_scale = h2_sf ,
527+ output = act_input_fp8 ,
528+ output_scale = act_input_sf ,
546529 input = h1 ,
547530 quant_group_size = 128 ,
548531 masked_m = masked_m ,
549532 scale_ue8m0 = True )
550533
551- # prepare workspace
552- h3 = workspace ["workspace_1" ][0 :self .expert_size_per_partition * m_max *
553- self .w2_weight .shape [1 ]]
554- # h3.view(self.expert_size_per_partition, m_max, self.w2_weight.shape[1])
555- h3 = h3 .as_strided (
556- size = (self .expert_size_per_partition , m_max ,
557- self .w2_weight .shape [1 ]),
558- stride = (m_max * self .w2_weight .shape [1 ], self .w2_weight .shape [1 ],
559- 1 ),
560- )
534+ # grouped gemm 2
535+ h3 = set_strides (workspace ["workspace_1" ],
536+ self .expert_size_per_partition , m_max ,
537+ self .hidden_size )
538+
561539 deepgemm_fp8_group_blockwise_gemm (
562540 d = h3 ,
563- a = h2 ,
541+ a = act_input_fp8 ,
564542 b = self .w2_weight ,
565543 sfa = act_input_sf ,
566544 sfb = self .quant_scales [1 ],
567545 masked_m = masked_m ,
568546 expected_m = expected_m ,
569547 )
570548
549+ # gather and finalize
571550 triton_masked_index_gather (permuted_data_tensor , h3 ,
572551 expert_first_token_offset_tensor ,
573552 token_to_expert_map )
@@ -626,7 +605,7 @@ def forward(
626605 num_rows = x .shape [0 ]
627606 if self .use_dp :
628607 num_rows = sum (all_rank_num_tokens_padded )
629- m_max = (num_rows + 127 ) // 128 * 128
608+ m_max = fp8_utils . align (num_rows , 128 )
630609 workspace = self .get_workspace (m_max , 128 )
631610 outputs = self .forward_chunk (
632611 x ,
@@ -656,11 +635,11 @@ def forward(
656635 # create workspace
657636 chunk_size_0 = sum (all_rank_num_tokens_list [0 ]
658637 ) if self .use_dp else chunk_size_list [0 ]
659- workspace_0 = self .get_workspace ((chunk_size_0 + 127 ) // 128 * 128 ,
660- 128 )
661638 chunk_size_1 = sum (all_rank_num_tokens_list [1 ]
662639 ) if self .use_dp else chunk_size_list [1 ]
663- workspace_1 = self .get_workspace ((chunk_size_1 + 127 ) // 128 * 128 ,
640+ workspace_0 = self .get_workspace (fp8_utils .align (chunk_size_0 , 128 ),
641+ 128 )
642+ workspace_1 = self .get_workspace (fp8_utils .align (chunk_size_1 , 128 ),
664643 128 )
665644
666645 x_list = x .split (chunk_size_list )
0 commit comments