11from dataclasses import dataclass
22from functools import lru_cache
3- from typing import List , Optional , Tuple
3+ from typing import List , Optional , Tuple , Union
44
55import torch
66
7- from tensorrt_llm ._torch .utils import (fp4_utils ,
7+ from tensorrt_llm ._torch .utils import (Fp4QuantizedTensor , fp4_utils ,
88 get_last_power_of_2_num_tokens_buckets ,
99 last_positive_power_of_2 ,
1010 next_positive_power_of_2 )
@@ -269,6 +269,31 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor,
269269 return kernel_runner (inputs , tactic = best_tactic )
270270
271271
272+ def fp4_block_scale_fake_output_without_finalize (
273+ hidden_states : Union [torch .Tensor , Fp4QuantizedTensor ],
274+ num_experts : int ,
275+ top_k : int ,
276+ routing_bias : Optional [torch .Tensor ],
277+ ):
278+ num_tokens = hidden_states .shape [0 ]
279+ hidden_size = hidden_states .shape [1 ] * (2 if isinstance (
280+ hidden_states , Fp4QuantizedTensor ) else 1 )
281+
282+ tile_tokens_dim = calculate_tile_tokens_dim (num_tokens , num_experts , top_k )
283+
284+ expanded_row_count = num_tokens * top_k
285+ max_padding_required = (tile_tokens_dim - 1 ) * num_experts
286+ max_num_padded_tokens = fp4_utils .pad_up (
287+ expanded_row_count + max_padding_required , tile_tokens_dim )
288+ wt_dtype = routing_bias .dtype if routing_bias is not None else torch .bfloat16
289+ return [
290+ hidden_states .new_empty ((max_num_padded_tokens , hidden_size ),
291+ dtype = torch .bfloat16 ),
292+ hidden_states .new_empty ((num_tokens , top_k ), dtype = wt_dtype ),
293+ hidden_states .new_empty ((num_tokens , top_k ), dtype = torch .int32 )
294+ ]
295+
296+
272297@fp4_block_scale_moe_runner .register_fake
273298def _ (
274299 routing_logits ,
@@ -293,27 +318,20 @@ def _(
293318 routing_method_type ,
294319 do_finalize ,
295320) -> List [torch .Tensor ]:
296- num_tokens = hidden_states .shape [0 ]
297- hidden_size = hidden_states .shape [1 ] * 2
298321 if do_finalize :
322+ num_tokens = hidden_states .shape [0 ]
323+ hidden_size = hidden_states .shape [1 ] * 2
299324 return [
300325 hidden_states .new_empty ((num_tokens , hidden_size ),
301326 dtype = torch .bfloat16 )
302327 ]
303328
304- tile_tokens_dim = calculate_tile_tokens_dim (num_tokens , num_experts , top_k )
305-
306- expanded_row_count = num_tokens * top_k
307- max_padding_required = (tile_tokens_dim - 1 ) * num_experts
308- max_num_padded_tokens = fp4_utils .pad_up (
309- expanded_row_count + max_padding_required , tile_tokens_dim )
310- wt_dtype = routing_bias .dtype if routing_bias is not None else torch .bfloat16
311- return [
312- hidden_states .new_empty ((max_num_padded_tokens , hidden_size ),
313- dtype = torch .bfloat16 ),
314- hidden_states .new_empty ((num_tokens , top_k ), dtype = wt_dtype ),
315- hidden_states .new_empty ((num_tokens , top_k ), dtype = torch .int32 )
316- ]
329+ return fp4_block_scale_fake_output_without_finalize (
330+ hidden_states ,
331+ num_experts ,
332+ top_k ,
333+ routing_bias ,
334+ )
317335
318336
319337@dataclass (frozen = True )
0 commit comments