Skip to content

Commit b7c4180

Browse files
committed
code clean.
Signed-off-by: Fanrong Li <[email protected]>
1 parent 22665c2 commit b7c4180

File tree

1 file changed

+67
-88
lines changed

1 file changed

+67
-88
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 67 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,15 @@ def deepgemm_fp8_group_blockwise_gemm(
284284
return
285285

286286

287+
def set_strides(workspace: torch.Tensor, g: int, m: int, k: int):
288+
workspace = workspace[0:g * m * k]
289+
workspace = workspace.as_strided(
290+
size=(g, m, k),
291+
stride=(m * k, k, 1),
292+
)
293+
return workspace
294+
295+
287296
class DeepGemmFusedMoE(CutlassFusedMoE):
288297
"""
289298
Python Flow of Fused Mixture of Experts (MoE) Layer.
@@ -337,28 +346,26 @@ def __init__(
337346
)
338347

339348
def get_workspace(self, m_max: int, group_size: int):
340-
hidden_size_0 = max(self.hidden_size, self.w3_w1_weight.shape[1] // 2)
341-
workspace_0 = torch.empty(
342-
(self.expert_size_per_partition * m_max * hidden_size_0),
343-
dtype=torch.float8_e4m3fn,
344-
device='cuda')
345-
346-
max(self.w3_w1_weight.shape[1], self.w2_weight.shape[1])
349+
hidden_size = self.hidden_size
350+
intermediate_size = self.intermediate_size
351+
num_experts = self.expert_size_per_partition
352+
353+
# create workspace
354+
fp8_dim = max(hidden_size, intermediate_size)
355+
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
356+
dtype=torch.float8_e4m3fn,
357+
device='cuda')
347358
workspace_1 = torch.empty(
348-
(self.expert_size_per_partition * m_max * self.hidden_size),
359+
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
349360
dtype=torch.bfloat16,
350361
device='cuda')
351362

352-
alignment = 4
353-
scale_dim = (self.hidden_size + group_size - 1) // group_size
354-
padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment
355-
padded_col_size = (m_max + alignment - 1) // alignment * alignment
356-
scale_k = (self.w3_w1_weight.shape[1] // 2 + group_size -
357-
1) // group_size
358-
scale_k_padded = (scale_k + alignment - 1) // alignment * alignment
359-
row_size = max(padded_dim_size // 4, scale_k_padded // 4)
363+
# create workspace for scaling factors
364+
m_padded = fp8_utils.align(m_max, 4)
365+
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
366+
scale_k_padded = fp8_utils.align(scale_k, 4)
360367
workspace_sf = torch.empty(
361-
(self.expert_size_per_partition * row_size * padded_col_size),
368+
(num_experts * (scale_k_padded // 4) * m_padded),
362369
dtype=torch.int32,
363370
device='cuda')
364371

@@ -468,30 +475,20 @@ def forward_chunk(
468475
expected_m = (token_selected_experts.numel() +
469476
self.expert_size_per_partition -
470477
1) // self.expert_size_per_partition
471-
# prepare workspace
472-
m_max = (x.shape[0] + 127) // 128 * 128
473-
act_input_fp8 = workspace["workspace_0"][0:self.
474-
expert_size_per_partition *
475-
m_max * self.hidden_size]
476-
# act_input_fp8.view(self.expert_size_per_partition, m_max, self.hidden_size)
477-
act_input_fp8 = act_input_fp8.as_strided(
478-
size=(self.expert_size_per_partition, m_max, self.hidden_size),
479-
stride=(m_max * self.hidden_size, self.hidden_size, 1),
480-
)
481-
alignment = 4
482-
scale_dim = (self.hidden_size + 128 - 1) // 128
483-
padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment
484-
padded_col_size = (m_max + alignment - 1) // alignment * alignment
485-
act_input_sf = workspace["workspace_sf"][0:self.
486-
expert_size_per_partition *
487-
padded_dim_size // 4 *
488-
padded_col_size]
489-
# act_input_sf.view(self.expert_size_per_partition, padded_dim_size // 4, padded_col_size)
490-
act_input_sf = act_input_sf.as_strided(
491-
size=(self.expert_size_per_partition, padded_dim_size // 4,
492-
padded_col_size),
493-
stride=(padded_dim_size // 4 * padded_col_size, padded_col_size, 1),
494-
)
478+
479+
# padding and quantization
480+
m_max = fp8_utils.align(x.shape[0], 128)
481+
act_input_fp8 = set_strides(workspace["workspace_0"],
482+
self.expert_size_per_partition, m_max,
483+
self.hidden_size)
484+
485+
m_padded = fp8_utils.align(m_max, 4)
486+
scale_k = fp8_utils.ceil_div(self.hidden_size, 128)
487+
scale_k_padded = fp8_utils.align(scale_k, 4)
488+
act_input_sf = set_strides(workspace["workspace_sf"],
489+
self.expert_size_per_partition,
490+
scale_k_padded // 4, m_padded)
491+
495492
act_input_sf = masked_index_copy_group_quant_fp8(
496493
act_input_fp8,
497494
act_input_sf,
@@ -500,16 +497,11 @@ def forward_chunk(
500497
token_to_expert_map,
501498
group_size=128)
502499

503-
# prepare workspace
504-
h1 = workspace["workspace_1"][0:self.expert_size_per_partition * m_max *
505-
self.w3_w1_weight.shape[1]]
506-
# h1.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1])
507-
h1 = h1.as_strided(
508-
size=(self.expert_size_per_partition, m_max,
509-
self.w3_w1_weight.shape[1]),
510-
stride=(m_max * self.w3_w1_weight.shape[1],
511-
self.w3_w1_weight.shape[1], 1),
512-
)
500+
# grouped gemm 1
501+
h1 = set_strides(workspace["workspace_1"],
502+
self.expert_size_per_partition, m_max,
503+
self.intermediate_size * 2)
504+
513505
deepgemm_fp8_group_blockwise_gemm(
514506
d=h1,
515507
a=act_input_fp8,
@@ -520,54 +512,41 @@ def forward_chunk(
520512
expected_m=expected_m,
521513
)
522514

523-
# prepare workspace
524-
h2 = workspace["workspace_0"][0:self.expert_size_per_partition * m_max *
525-
self.w3_w1_weight.shape[1] // 2]
526-
# h2.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1] // 2)
527-
h2 = h2.as_strided(
528-
size=(self.expert_size_per_partition, m_max,
529-
self.w3_w1_weight.shape[1] // 2),
530-
stride=(m_max * self.w3_w1_weight.shape[1] // 2,
531-
self.w3_w1_weight.shape[1] // 2, 1),
532-
)
533-
scale_k = (self.w3_w1_weight.shape[1] // 2 + 128 - 1) // 128
534-
scale_k_padded = (scale_k + alignment - 1) // alignment * alignment
535-
h2_sf = workspace["workspace_sf"][0:self.expert_size_per_partition *
536-
scale_k_padded // 4 * padded_col_size]
537-
# h2_sf.view(self.expert_size_per_partition, scale_k_padded // 4, padded_col_size)
538-
h2_sf = h2_sf.as_strided(
539-
size=(self.expert_size_per_partition, scale_k_padded // 4,
540-
padded_col_size),
541-
stride=(scale_k_padded // 4 * padded_col_size, padded_col_size, 1),
542-
)
515+
# activation and quantization
516+
act_input_fp8 = set_strides(workspace["workspace_0"],
517+
self.expert_size_per_partition, m_max,
518+
self.intermediate_size)
519+
520+
scale_k = fp8_utils.ceil_div(self.intermediate_size, 128)
521+
scale_k_padded = fp8_utils.align(scale_k, 4)
522+
act_input_sf = set_strides(workspace["workspace_sf"],
523+
self.expert_size_per_partition,
524+
scale_k_padded // 4, m_padded)
525+
543526
act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
544-
output=h2,
545-
output_scale=h2_sf,
527+
output=act_input_fp8,
528+
output_scale=act_input_sf,
546529
input=h1,
547530
quant_group_size=128,
548531
masked_m=masked_m,
549532
scale_ue8m0=True)
550533

551-
# prepare workspace
552-
h3 = workspace["workspace_1"][0:self.expert_size_per_partition * m_max *
553-
self.w2_weight.shape[1]]
554-
# h3.view(self.expert_size_per_partition, m_max, self.w2_weight.shape[1])
555-
h3 = h3.as_strided(
556-
size=(self.expert_size_per_partition, m_max,
557-
self.w2_weight.shape[1]),
558-
stride=(m_max * self.w2_weight.shape[1], self.w2_weight.shape[1],
559-
1),
560-
)
534+
# grouped gemm 2
535+
h3 = set_strides(workspace["workspace_1"],
536+
self.expert_size_per_partition, m_max,
537+
self.hidden_size)
538+
561539
deepgemm_fp8_group_blockwise_gemm(
562540
d=h3,
563-
a=h2,
541+
a=act_input_fp8,
564542
b=self.w2_weight,
565543
sfa=act_input_sf,
566544
sfb=self.quant_scales[1],
567545
masked_m=masked_m,
568546
expected_m=expected_m,
569547
)
570548

549+
# gather and finalize
571550
triton_masked_index_gather(permuted_data_tensor, h3,
572551
expert_first_token_offset_tensor,
573552
token_to_expert_map)
@@ -626,7 +605,7 @@ def forward(
626605
num_rows = x.shape[0]
627606
if self.use_dp:
628607
num_rows = sum(all_rank_num_tokens_padded)
629-
m_max = (num_rows + 127) // 128 * 128
608+
m_max = fp8_utils.align(num_rows, 128)
630609
workspace = self.get_workspace(m_max, 128)
631610
outputs = self.forward_chunk(
632611
x,
@@ -656,11 +635,11 @@ def forward(
656635
# create workspace
657636
chunk_size_0 = sum(all_rank_num_tokens_list[0]
658637
) if self.use_dp else chunk_size_list[0]
659-
workspace_0 = self.get_workspace((chunk_size_0 + 127) // 128 * 128,
660-
128)
661638
chunk_size_1 = sum(all_rank_num_tokens_list[1]
662639
) if self.use_dp else chunk_size_list[1]
663-
workspace_1 = self.get_workspace((chunk_size_1 + 127) // 128 * 128,
640+
workspace_0 = self.get_workspace(fp8_utils.align(chunk_size_0, 128),
641+
128)
642+
workspace_1 = self.get_workspace(fp8_utils.align(chunk_size_1, 128),
664643
128)
665644

666645
x_list = x.split(chunk_size_list)

0 commit comments

Comments
 (0)