99
1010import tensorrt_llm
1111import tensorrt_llm .bindings .internal .runtime as _tbr
12- from tensorrt_llm ._torch .pyexecutor .cuda_graph_runner import is_graph_capturing
1312from tensorrt_llm .logger import logger
1413from tensorrt_llm .mapping import Mapping
1514
1615from ...distributed import AllReduce
1716from ...utils import EventType
17+ from ..multi_stream_utils import do_multi_stream
1818
1919
2020def _tensor_to_weight (t : torch .Tensor ) -> _tbr .MoeWeight :
@@ -472,7 +472,7 @@ def start_wait_gpu_stage(self):
472472 assert self .func_called_count ["start_wait_gpu_stage" ] == 0
473473 self .func_called_count ["start_wait_gpu_stage" ] += 1
474474 if self .updates_enabled :
475- if is_graph_capturing ():
475+ if do_multi_stream ():
476476 self .event_dict [EventType .Main ].record ()
477477 with torch .cuda .stream (self .aux_stream ):
478478 self .event_dict [EventType .Main ].wait ()
@@ -491,7 +491,7 @@ def done_wait_gpu_stage(self):
491491 assert self .func_called_count ["done_wait_gpu_stage" ] == 0
492492 self .func_called_count ["done_wait_gpu_stage" ] += 1
493493 if self .updates_enabled :
494- if is_graph_capturing ():
494+ if do_multi_stream ():
495495 self .event_dict [EventType .MoeBalancer ].wait ()
496496
497497 def start_set_cpu_stage (self ):
@@ -502,7 +502,7 @@ def start_set_cpu_stage(self):
502502 assert self .func_called_count ["start_set_cpu_stage" ] == 0
503503 self .func_called_count ["start_set_cpu_stage" ] += 1
504504 if self .updates_enabled :
505- if is_graph_capturing ():
505+ if do_multi_stream ():
506506 self .event_dict [EventType .Main ].record ()
507507 with torch .cuda .stream (self .aux_stream ):
508508 self .event_dict [EventType .Main ].wait ()
@@ -522,7 +522,7 @@ def done_set_cpu_stage(self):
522522 self .func_called_count [name ] = 0
523523 self .statistic_flag_tensor = None
524524 if self .updates_enabled :
525- if is_graph_capturing ():
525+ if do_multi_stream ():
526526 self .event_dict [EventType .MoeBalancer ].wait ()
527527
528528 def update_local_statistic (self , local_raw_expert_ids : torch .Tensor ,
@@ -544,7 +544,7 @@ def update_local_statistic(self, local_raw_expert_ids: torch.Tensor,
544544 (self .expert_count , ),
545545 dtype = torch .int32 ,
546546 device = torch .device ('cuda' ))
547- if is_graph_capturing ():
547+ if do_multi_stream ():
548548 self .event_dict [EventType .Main ].record ()
549549 with torch .cuda .stream (self .aux_stream ):
550550 self .event_dict [EventType .Main ].wait ()
@@ -569,7 +569,7 @@ def get_local_statistic_tensor(self) -> Optional[torch.Tensor]:
569569 assert self .func_called_count ["update_local_statistic" ] > 0
570570 self .func_called_count ["get_local_statistic_tensor" ] += 1
571571 if self .updates_enabled :
572- if is_graph_capturing ():
572+ if do_multi_stream ():
573573 with torch .cuda .stream (self .aux_stream ):
574574 self .event_dict [EventType .MoeBalancer ].record ()
575575 self .event_dict [EventType .MoeBalancer ].wait ()
@@ -598,7 +598,7 @@ def _update_statistic():
598598 self .single_layer_load_balancer_ptr )
599599
600600 if self .updates_enabled :
601- if is_graph_capturing ():
601+ if do_multi_stream ():
602602 self .event_dict [EventType .Main ].record ()
603603 with torch .cuda .stream (self .aux_stream ):
604604 self .event_dict [EventType .Main ].wait ()
@@ -636,7 +636,7 @@ def _update_statistic():
636636 if self .updates_enabled :
637637 self .update_local_statistic (local_raw_expert_ids , is_first_stage ,
638638 is_last_stage )
639- if is_graph_capturing ():
639+ if do_multi_stream ():
640640 with torch .cuda .stream (self .aux_stream ):
641641 _update_statistic ()
642642 else :
@@ -660,7 +660,7 @@ def update_statistic_with_global_ids(self,
660660 assert self .func_called_count ["update_statistic_with_local_ids" ] == 0
661661 self .func_called_count ["update_statistic_with_global_ids" ] += 1
662662 if self .updates_enabled :
663- if is_graph_capturing ():
663+ if do_multi_stream ():
664664 self .event_dict [EventType .Main ].record ()
665665 with torch .cuda .stream (self .aux_stream ):
666666 self .event_dict [EventType .Main ].wait ()
@@ -851,8 +851,8 @@ def set_warm_up_iter_count(self, iter_count: int):
851851 """
852852 self .load_balancer_impl .set_warm_up_iter_count (iter_count )
853853
854- def set_next_iter_info (self , enable_statistic : Optional [bool ],
855- enable_update_weights : Optional [bool ]):
854+ def set_iter_info (self , enable_statistic : Optional [bool ],
855+ enable_update_weights : Optional [bool ]):
856856 if enable_statistic is not None :
857857 self .enable_statistic = enable_statistic
858858 if enable_update_weights is not None :
@@ -998,8 +998,8 @@ def __enter__(self):
998998 """
999999 if self .moe_load_balancer is not None and not self .moe_load_balancer .is_static_routing (
10001000 ):
1001- self .moe_load_balancer .set_next_iter_info (self .enable_statistic ,
1002- self .enable_updates )
1001+ self .moe_load_balancer .set_iter_info (self .enable_statistic ,
1002+ self .enable_updates )
10031003 self .moe_load_balancer .start_iter ()
10041004 return self
10051005
0 commit comments