Skip to content

Commit d49374b

Browse files
authored
[TRTLLM-7408][feat] Wrap MOE with custom op. (#7277)
Signed-off-by: Jin Li <[email protected]>
1 parent a0e1604 commit d49374b

File tree

17 files changed

+272
-118
lines changed

17 files changed

+272
-118
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,7 @@ class AttentionMetadata:
121121
default_factory=AttentionRuntimeFeatures)
122122

123123
# The number of tokens in each rank.
124-
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
125-
default=None,
126-
repr=False)
127-
all_rank_num_tokens: Optional[List[int]]
128-
# The max number of tokens among all ranks.
129-
all_rank_max_num_tokens: Optional[int] = None
124+
all_rank_num_tokens: Optional[List[int]] = None
130125

131126
# These fields are set when changing seq_lens and _num_contexts to avoid computation
132127
# during execution. If the calculation happens during execution, torch compile treats it
@@ -167,16 +162,6 @@ def on_update(self):
167162
elif self._seq_lens is not None:
168163
self._num_tokens = self._seq_lens.sum().item()
169164

170-
@property
171-
def all_rank_num_tokens(self) -> Optional[List[int]]:
172-
return self._all_rank_num_tokens
173-
174-
@all_rank_num_tokens.setter
175-
def all_rank_num_tokens(self, value: Optional[List[int]]):
176-
value = value if value is not AttentionMetadata.all_rank_num_tokens else None
177-
self._all_rank_num_tokens = value
178-
self.all_rank_max_num_tokens = max(value) if value is not None else None
179-
180165
@property
181166
def seq_lens(self) -> Optional[torch.Tensor]:
182167
return self._seq_lens

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from dataclasses import dataclass
22
from functools import lru_cache
3-
from typing import List, Optional, Tuple
3+
from typing import List, Optional, Tuple, Union
44

55
import torch
66

7-
from tensorrt_llm._torch.utils import (fp4_utils,
7+
from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils,
88
get_last_power_of_2_num_tokens_buckets,
99
last_positive_power_of_2,
1010
next_positive_power_of_2)
@@ -269,6 +269,31 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor,
269269
return kernel_runner(inputs, tactic=best_tactic)
270270

271271

272+
def fp4_block_scale_fake_output_without_finalize(
273+
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
274+
num_experts: int,
275+
top_k: int,
276+
routing_bias: Optional[torch.Tensor],
277+
):
278+
num_tokens = hidden_states.shape[0]
279+
hidden_size = hidden_states.shape[1] * (2 if isinstance(
280+
hidden_states, Fp4QuantizedTensor) else 1)
281+
282+
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
283+
284+
expanded_row_count = num_tokens * top_k
285+
max_padding_required = (tile_tokens_dim - 1) * num_experts
286+
max_num_padded_tokens = fp4_utils.pad_up(
287+
expanded_row_count + max_padding_required, tile_tokens_dim)
288+
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
289+
return [
290+
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
291+
dtype=torch.bfloat16),
292+
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
293+
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
294+
]
295+
296+
272297
@fp4_block_scale_moe_runner.register_fake
273298
def _(
274299
routing_logits,
@@ -293,27 +318,20 @@ def _(
293318
routing_method_type,
294319
do_finalize,
295320
) -> List[torch.Tensor]:
296-
num_tokens = hidden_states.shape[0]
297-
hidden_size = hidden_states.shape[1] * 2
298321
if do_finalize:
322+
num_tokens = hidden_states.shape[0]
323+
hidden_size = hidden_states.shape[1] * 2
299324
return [
300325
hidden_states.new_empty((num_tokens, hidden_size),
301326
dtype=torch.bfloat16)
302327
]
303328

304-
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
305-
306-
expanded_row_count = num_tokens * top_k
307-
max_padding_required = (tile_tokens_dim - 1) * num_experts
308-
max_num_padded_tokens = fp4_utils.pad_up(
309-
expanded_row_count + max_padding_required, tile_tokens_dim)
310-
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
311-
return [
312-
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
313-
dtype=torch.bfloat16),
314-
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
315-
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
316-
]
329+
return fp4_block_scale_fake_output_without_finalize(
330+
hidden_states,
331+
num_experts,
332+
top_k,
333+
routing_bias,
334+
)
317335

318336

319337
@dataclass(frozen=True)

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,7 @@ def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig:
548548
f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config)
549549

550550
def compute_routed_output(self, hidden_states, hidden_states_fp4,
551-
all_rank_num_tokens, all_rank_max_num_tokens,
552-
do_finalize):
551+
all_rank_num_tokens, do_finalize):
553552
# max-throughput
554553
use_dp_padding = False
555554
if self.use_dp and self.mapping.tp_size > 1:
@@ -568,7 +567,6 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
568567
do_finalize=do_finalize,
569568
output_dtype=hidden_states.dtype,
570569
all_rank_num_tokens=all_rank_num_tokens,
571-
all_rank_max_num_tokens=all_rank_max_num_tokens,
572570
use_dp_padding=use_dp_padding,
573571
)
574572

@@ -579,7 +577,6 @@ def forward(
579577
hidden_states: torch.Tensor,
580578
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
581579
all_rank_num_tokens: Optional[list[int]] = None,
582-
all_rank_max_num_tokens: Optional[int] = None,
583580
final_all_reduce_params: Optional[AllReduceParams] = None,
584581
do_finalize: Optional[bool] = True,
585582
) -> torch.Tensor:
@@ -598,7 +595,6 @@ def _compute_routed_output():
598595
routed_output = self.compute_routed_output(hidden_states,
599596
hidden_states_fp4,
600597
all_rank_num_tokens,
601-
all_rank_max_num_tokens,
602598
do_finalize)
603599
return routed_output
604600

@@ -840,7 +836,6 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
840836
hidden_states,
841837
hidden_states_fp4,
842838
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
843-
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
844839
final_all_reduce_params=AllReduceParams(
845840
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
846841
or self.mapping.tp_size == 1)),
@@ -1028,7 +1023,6 @@ def forward(
10281023
embed_tokens: Embedding,
10291024
attn_metadata: AttentionMetadata,
10301025
all_rank_num_tokens: Optional[List[int]] = None,
1031-
all_rank_max_num_tokens: Optional[int] = None,
10321026
**kwargs,
10331027
) -> torch.Tensor:
10341028

@@ -1087,7 +1081,6 @@ def norm_hidden():
10871081
hidden_states = self.mlp(
10881082
hidden_states,
10891083
all_rank_num_tokens=all_rank_num_tokens,
1090-
all_rank_max_num_tokens=all_rank_max_num_tokens,
10911084
final_all_reduce_params=AllReduceParams(
10921085
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
10931086
or self.mapping.tp_size == 1)),

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def forward_attn_dp(
258258

259259
# Get attention_dp parameters
260260
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
261-
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
262261

263262
if self.mapping.tp_size > 1 and all_rank_num_tokens is not None:
264263
if (isinstance(self.experts, (TRTLLMGenFusedMoE, TritonFusedMoE))):
@@ -276,12 +275,10 @@ def forward_attn_dp(
276275

277276
# Let CutlassFusedMoE handle allgather internally
278277
# Pass the normalized tensor (t) as input to experts, not x
279-
expert_output = self.experts(
280-
x=t,
281-
router_logits=g,
282-
all_rank_num_tokens=all_rank_num_tokens,
283-
all_rank_max_num_tokens=all_rank_max_num_tokens,
284-
use_dp_padding=False)
278+
expert_output = self.experts(x=t,
279+
router_logits=g,
280+
all_rank_num_tokens=all_rank_num_tokens,
281+
use_dp_padding=False)
285282

286283
expert_output = expert_output.view(orig_shape)
287284
return expert_output, residual

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -315,32 +315,27 @@ def __init__(
315315
self.aux_stream = aux_stream
316316

317317
def compute_routed_output(self, hidden_states, all_rank_num_tokens,
318-
all_rank_max_num_tokens,
319318
cutlass_min_latency_mode):
320319
router_logits = self.router(hidden_states)
321-
routed_output = self.experts(
322-
hidden_states,
323-
router_logits,
324-
do_finalize=not cutlass_min_latency_mode,
325-
all_rank_num_tokens=all_rank_num_tokens,
326-
all_rank_max_num_tokens=all_rank_max_num_tokens,
327-
use_dp_padding=False)
320+
routed_output = self.experts(hidden_states,
321+
router_logits,
322+
do_finalize=not cutlass_min_latency_mode,
323+
all_rank_num_tokens=all_rank_num_tokens,
324+
use_dp_padding=False)
328325
return routed_output
329326

330327
def forward(
331328
self,
332329
hidden_states: torch.Tensor,
333330
all_rank_num_tokens=None,
334-
all_rank_max_num_tokens=None,
335331
final_all_reduce_params: Optional[AllReduceParams] = None,
336332
cutlass_min_latency_mode: Optional[bool] = False,
337333
) -> torch.Tensor:
338334
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
339335
# This design is mainly for low latency use case. Need to improve for max throughput use case.
340336
fn0 = lambda: self.shared_expert(hidden_states)
341337
fn1 = lambda: self.compute_routed_output(
342-
hidden_states, all_rank_num_tokens, all_rank_max_num_tokens,
343-
cutlass_min_latency_mode)
338+
hidden_states, all_rank_num_tokens, cutlass_min_latency_mode)
344339
shared_output, routed_output = maybe_execute_in_parallel(
345340
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
346341
if cutlass_min_latency_mode:
@@ -542,7 +537,6 @@ def forward(
542537
hidden_states = self.feed_forward(
543538
hidden_states,
544539
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
545-
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
546540
final_all_reduce_params=AllReduceParams(
547541
enable_allreduce=not self.disable_feed_forward_allreduce),
548542
cutlass_min_latency_mode=cutlass_min_latency_mode,

tensorrt_llm/_torch/models/modeling_mixtral.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,11 @@ def forward(
6262
attn_metadata: AttentionMetadata,
6363
) -> torch.Tensor:
6464
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
65-
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
6665
router_logits = self.gate(hidden_states)
6766
final_hidden_states = self.experts(
6867
hidden_states,
6968
router_logits,
7069
all_rank_num_tokens=all_rank_num_tokens,
71-
all_rank_max_num_tokens=all_rank_max_num_tokens,
7270
use_dp_padding=False)
7371
return final_hidden_states
7472

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def forward(
127127
hidden_states = hidden_states.view(-1, self.hidden_dim)
128128
use_dp_padding = False
129129
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
130-
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
131130

132131
if not do_finalize:
133132
assert not self.enable_attention_dp
@@ -144,7 +143,6 @@ def forward(
144143
hidden_states,
145144
router_logits,
146145
all_rank_num_tokens=all_rank_num_tokens,
147-
all_rank_max_num_tokens=all_rank_max_num_tokens,
148146
use_dp_padding=use_dp_padding,
149147
do_finalize=do_finalize,
150148
)

tensorrt_llm/_torch/models/modeling_qwen_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,11 @@ def forward(
8484
hidden_states = hidden_states.view(-1, self.hidden_dim)
8585

8686
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
87-
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
8887
router_logits = self.gate(hidden_states)
8988
final_hidden_states = self.experts(
9089
hidden_states,
9190
router_logits,
9291
all_rank_num_tokens=all_rank_num_tokens,
93-
all_rank_max_num_tokens=all_rank_max_num_tokens,
9492
use_dp_padding=False)
9593

9694
shared_expert_output = self.shared_expert(hidden_states)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
swiglu_alpha=swiglu_alpha,
8686
swiglu_beta=swiglu_beta,
8787
swiglu_limit=swiglu_limit,
88+
layer_idx=layer_idx,
8889
)
8990

9091
# Store original hidden size before any potential padding
@@ -96,8 +97,6 @@ def __init__(
9697
self.intermediate_size_per_partition = (
9798
(self.intermediate_size_per_partition + 127) // 128) * 128
9899

99-
self.layer_idx = layer_idx
100-
101100
self.num_slots = self.num_experts
102101
self.expert_size_per_partition = self.num_experts // self.ep_size
103102
self.initial_global_assignments = [
@@ -449,15 +448,16 @@ def split_chunk(self, split_token_num: int, split_num_chunks: int):
449448
split_num_chunks - val_mod)
450449
return split_chunk_size_list
451450

452-
def forward(
451+
def forward_impl(
453452
self,
454453
x: Union[torch.Tensor, Fp4QuantizedTensor],
455454
router_logits: torch.Tensor,
455+
*,
456456
do_finalize: bool = True, # used by other MoE backends
457457
output_dtype: Optional[torch.dtype] = None,
458458
all_rank_num_tokens: Optional[List[int]] = None,
459-
all_rank_max_num_tokens: Optional[int] = None,
460459
use_dp_padding: Optional[bool] = None,
460+
**kwargs,
461461
) -> torch.Tensor:
462462
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
463463
if self.use_dp and self.parallel_size > 1:
@@ -472,7 +472,7 @@ def forward(
472472
1) // self.moe_max_num_tokens
473473

474474
if use_dp_padding:
475-
all_rank_num_tokens_padded = [all_rank_max_num_tokens
475+
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
476476
] * len(all_rank_num_tokens)
477477
else:
478478
all_rank_num_tokens_padded = all_rank_num_tokens

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,15 +637,16 @@ def forward_chunk(
637637

638638
return final_hidden_states
639639

640-
def forward(
640+
def forward_impl(
641641
self,
642642
x: Union[torch.Tensor, Fp4QuantizedTensor],
643643
router_logits: torch.Tensor,
644+
*,
644645
do_finalize: bool = True, # used by other MoE backends
645646
output_dtype: Optional[torch.dtype] = None,
646647
all_rank_num_tokens: Optional[List[int]] = None,
647-
all_rank_max_num_tokens: Optional[int] = None,
648648
use_dp_padding: Optional[bool] = None,
649+
**kwargs,
649650
) -> torch.Tensor:
650651
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
651652
if self.use_dp and self.parallel_size > 1:
@@ -663,7 +664,7 @@ def forward(
663664
1) // self.moe_max_num_tokens
664665

665666
if use_dp_padding:
666-
all_rank_num_tokens_padded = [all_rank_max_num_tokens
667+
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
667668
] * len(all_rank_num_tokens)
668669
else:
669670
all_rank_num_tokens_padded = all_rank_num_tokens

0 commit comments

Comments
 (0)