11import bisect
22import contextlib
3- import weakref
3+ from dataclasses import dataclass
44from typing import TYPE_CHECKING , Any , Callable , Dict , Optional , Tuple
55
66import torch
1616from .scheduler import ScheduledRequests
1717
1818if TYPE_CHECKING :
19- from .model_engine import PyTorchModelEngine
19+ from ..distributed import MPIDist
20+ from ..mapping import Mapping
21+ from ..speculative import DecodingBaseConfig
2022
2123# A large prime number used for dummy request IDs to avoid collisions
2224CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64 ) - 1
2325
2426
27+ @dataclass
28+ class CUDAGraphRunnerConfig :
29+ """Configuration for the CUDAGraphRunner, passed from the ModelEngine."""
30+ use_cuda_graph : bool
31+ cuda_graph_padding_enabled : bool
32+ cuda_graph_batch_sizes : list [int ]
33+ max_cuda_graph_batch_size : int
34+ max_beam_width : int
35+ max_num_tokens : int
36+ spec_config : Optional ["DecodingBaseConfig" ]
37+ cuda_graph_mem_pool : Any
38+ use_mrope : bool
39+ original_max_draft_len : int
40+ is_draft_model : bool
41+ enable_attention_dp : bool
42+ batch_size : int
43+ mapping : Optional ["Mapping" ]
44+ dist : Optional ["MPIDist" ]
45+ kv_cache_manager_key : Any
46+
47+
2548class CUDAGraphRunner :
2649 """
2750 Manages the lifecycle and execution of CUDA graphs for the model engine.
@@ -32,23 +55,22 @@ class CUDAGraphRunner:
3255 """
3356 WARMUP_STEPS = 2
3457
35- def __init__ (self , engine : "PyTorchModelEngine" ):
36- self .engine_ref = weakref . ref ( engine )
58+ def __init__ (self , config : CUDAGraphRunnerConfig ):
59+ self .config = config
3760
38- # High-level configuration
39- config = engine .pytorch_backend_config
61+ # High-level configuration from the config object
4062 self .enabled = config .use_cuda_graph
4163 self .padding_enabled = config .cuda_graph_padding_enabled
42- self .supported_batch_sizes = engine . _cuda_graph_batch_sizes
43- self .max_supported_batch_size = engine . _max_cuda_graph_batch_size
44- self .max_beam_width = engine .max_beam_width
45- self .spec_config = engine .spec_config
64+ self .supported_batch_sizes = config . cuda_graph_batch_sizes
65+ self .max_supported_batch_size = config . max_cuda_graph_batch_size
66+ self .max_beam_width = config .max_beam_width
67+ self .spec_config = config .spec_config
4668
4769 self .graphs : Dict [Tuple [int , int , int ], torch .cuda .CUDAGraph ] = {}
4870 self .graph_outputs : Dict [Tuple [int , int , int ],
4971 Callable [[], Optional [torch .Tensor ]]] = {}
5072 self .graph_metadata : Dict [Tuple [int , int , int ], Dict [str , Any ]] = {}
51- self .memory_pool = engine . _cuda_graph_mem_pool
73+ self .memory_pool = config . cuda_graph_mem_pool
5274 self .padding_dummy_request : Optional ["Request" ] = None
5375
5476 self .shared_static_tensors : Dict [str , torch .Tensor ] = {}
@@ -58,12 +80,11 @@ def __init__(self, engine: "PyTorchModelEngine"):
5880
5981 def _create_shared_static_tensors (self ):
6082 """Allocates static tensors sized for the largest possible batch."""
61- engine = self ._get_engine ()
62-
63- token_per_request = self .max_possible_draft_len + 1
83+ max_draft_len = self .config .original_max_draft_len if self .config .is_spec_decode else 0
84+ token_per_request = max_draft_len + 1
6485 max_total_tokens = (self .max_supported_batch_size *
6586 self .max_beam_width * token_per_request )
66- max_total_tokens = min (max_total_tokens , engine .max_num_tokens )
87+ max_total_tokens = min (max_total_tokens , self . config .max_num_tokens )
6788
6889 self .shared_static_tensors = {
6990 "input_ids" :
@@ -72,7 +93,7 @@ def _create_shared_static_tensors(self):
7293 torch .zeros ((1 , max_total_tokens ), device = "cuda" ,
7394 dtype = torch .int32 ),
7495 }
75- if engine .use_mrope :
96+ if self . config .use_mrope :
7697 self .shared_static_tensors ["position_ids" ] = torch .zeros (
7798 (3 , 1 , max_total_tokens ), device = "cuda" , dtype = torch .int32 )
7899 self .shared_static_tensors ["multimodal_params" ] = [
@@ -86,55 +107,31 @@ def _create_shared_static_tensors(self):
86107 }) for _ in range (max_total_tokens )
87108 ]
88109
89- @property
90- def enable_spec_decode (self ):
91- return self ._get_engine ().enable_spec_decode
92-
93- @property
94- def max_possible_draft_len (self ):
95- engine = self ._get_engine ()
96- return (engine .original_max_draft_len if self .enable_spec_decode else 0 )
97-
98110 def get_graph_key (
99111 self ,
100112 batch_size ,
113+ enable_spec_decode : bool ,
101114 spec_resource_manager : Optional [BaseResourceManager ] = None ):
102- engine = self ._get_engine ()
103- if engine .is_draft_model and spec_resource_manager is not None and isinstance (
115+ if self .config .is_draft_model and spec_resource_manager is not None and isinstance (
104116 spec_resource_manager , Eagle3ResourceManager ):
105- draft_len = engine .original_max_draft_len if spec_resource_manager .is_first_draft else 0
117+ draft_len = self . config .original_max_draft_len if spec_resource_manager .is_first_draft else 0
106118 key = (batch_size , draft_len , spec_resource_manager .is_first_draft )
107119 else :
108- draft_len = self .spec_config .max_draft_len if self . enable_spec_decode else 0
120+ draft_len = self .spec_config .max_draft_len if enable_spec_decode else 0
109121 key = (batch_size , draft_len , False )
110122 return key
111123
112- @property
113- def spec_metadata (self ):
114- return self ._get_engine ().spec_metadata
115-
116- @property
117- def draft_tokens_cuda (self ):
118- return self ._get_engine ().draft_tokens_cuda
119-
120- @property
121- def attn_metadata (self ):
122- return self ._get_engine ().attn_metadata
123-
124124 def __del__ (self ):
125125 self .clear ()
126126
127- def _get_engine (self ) -> "PyTorchModelEngine" :
128- """Safely dereferences the weak reference to the engine."""
129- engine = self .engine_ref ()
130- if engine is None :
131- raise RuntimeError (
132- "The parent PyTorchModelEngine has been garbage collected." )
133- return engine
134-
135127 def maybe_get_cuda_graph (
136128 self ,
137129 batch : ScheduledRequests ,
130+ iter_counter : int ,
131+ enable_spec_decode : bool ,
132+ attn_metadata : Any ,
133+ spec_metadata : Optional [Any ],
134+ draft_tokens_cuda : torch .Tensor ,
138135 spec_resource_manager : Optional [BaseResourceManager ] = None ):
139136 """
140137 Determines if the current batch can be run with a CUDA graph.
@@ -145,17 +142,14 @@ def maybe_get_cuda_graph(
145142 - The spec_metadata for the graph, if applicable.
146143 - The key for the graph.
147144 """
148- engine = self ._get_engine ()
149-
150145 # disable when doing statistic
151- if hasattr (engine , 'iter_counter' ) and ExpertStatistic .set_iter (
152- engine .iter_counter ):
146+ if ExpertStatistic .set_iter (iter_counter ):
153147 return False , None , None , None
154148
155149 can_run_cuda_graph = batch .can_run_cuda_graph
156150 batch_size = batch .batch_size
157- if self .enabled and engine . enable_attention_dp and engine .mapping .tp_size > 1 :
158- all_can_graph_batch = engine .dist .tp_allgather (
151+ if self .enabled and self . config . enable_attention_dp and self . config .mapping .tp_size > 1 :
152+ all_can_graph_batch = self . config .dist .tp_allgather (
159153 [can_run_cuda_graph , batch_size ])
160154 is_all_gen_only = all (all_can_graph [0 ]
161155 for all_can_graph in all_can_graph_batch )
@@ -168,7 +162,8 @@ def maybe_get_cuda_graph(
168162
169163 if not self .enabled or not can_run_cuda_graph :
170164 return False , None , None , None
171- key = self .get_graph_key (batch_size , spec_resource_manager )
165+ key = self .get_graph_key (batch_size , enable_spec_decode ,
166+ spec_resource_manager )
172167
173168 if key in self .graphs :
174169 return True , self .graph_metadata [key ][
@@ -178,29 +173,28 @@ def maybe_get_cuda_graph(
178173 return False , None , None , None
179174
180175 num_sequences_in_batch = batch_size * self .max_beam_width
181- attn_metadata = self . attn_metadata .create_cuda_graph_metadata (
176+ graph_attn_metadata = attn_metadata .create_cuda_graph_metadata (
182177 num_sequences_in_batch , False , key [1 ], self .cuda_graph_meta_buffers )
183- assert attn_metadata .is_cuda_graph
178+ assert graph_attn_metadata .is_cuda_graph
184179
185- if self . enable_spec_decode :
186- spec_metadata = self . spec_metadata .create_cuda_graph_metadata (
180+ if enable_spec_decode :
181+ graph_spec_metadata = spec_metadata .create_cuda_graph_metadata (
187182 num_sequences_in_batch )
188- spec_metadata .draft_tokens = self . draft_tokens_cuda
183+ graph_spec_metadata .draft_tokens = draft_tokens_cuda
189184 else :
190- spec_metadata = None
191- return True , attn_metadata , spec_metadata , key
185+ graph_spec_metadata = None
186+ return True , graph_attn_metadata , graph_spec_metadata , key
192187
193188 def needs_capture (self , key : Tuple [int , int , int ]):
194-
195189 return key not in self .graph_outputs
196190
197191 def capture (self ,
198192 key : Tuple [int , int , int ],
199193 forward_fn : Callable ,
200194 initial_inputs : Dict [str , Any ],
195+ enable_spec_decode : bool ,
201196 postprocess_fn : Optional [Callable ] = None ):
202197 """Captures the forward pass for a given batch size."""
203- engine = self ._get_engine ()
204198 batch_size = key [0 ]
205199 # [CUDA graph spec decode padding]
206200 # We pad input IDs/position IDs to the maximum draft length (token per request).
@@ -217,7 +211,7 @@ def capture(self,
217211 self .shared_static_tensors ["position_ids" ]
218212 [:, :num_tokens_for_capture ],
219213 }
220- if engine .use_mrope :
214+ if self . config .use_mrope :
221215 sliced_static_tensors ["position_ids" ] = self .shared_static_tensors [
222216 "position_ids" ][:, :, :num_tokens_for_capture ],
223217 sliced_static_tensors [
@@ -235,12 +229,10 @@ def capture(self,
235229 def _setup_spec_decoding_and_forward (key : Tuple [int , int , int ],
236230 forward_fn : Callable ,
237231 capture_inputs : Dict [str , Any ]):
238- engine = self ._get_engine ()
239- # for the first inference of draft model, we need to set the use_spec_decoding to True when capture the graph for multiple runs.
240232 is_first_draft = key [2 ]
241- needs_kv_cache_recompute = True if engine . enable_spec_decode and engine .spec_config .spec_dec_mode .needs_kv_cache_recompute (
233+ needs_kv_cache_recompute = True if enable_spec_decode and self . config .spec_config .spec_dec_mode .needs_kv_cache_recompute (
242234 ) else False
243- if is_first_draft and engine .is_draft_model and needs_kv_cache_recompute :
235+ if is_first_draft and self . config .is_draft_model and needs_kv_cache_recompute :
244236 capture_inputs ['attn_metadata' ].use_spec_decoding = True
245237 return forward_fn (capture_inputs )
246238
@@ -268,7 +260,6 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
268260 def replay (self , key : Tuple [int , int , int ],
269261 current_inputs : Dict [str , Any ]) -> Optional [torch .Tensor ]:
270262 """Replays a previously captured graph."""
271- engine = self ._get_engine ()
272263 stored_meta = self .graph_metadata [key ]
273264 assert current_inputs ["attn_metadata" ] is stored_meta ["attn_metadata" ]
274265 if stored_meta ["spec_metadata" ] is not None :
@@ -282,7 +273,7 @@ def replay(self, key: Tuple[int, int, int],
282273 static_tensors ["input_ids" ][:seqlen ].copy_ (input_ids )
283274
284275 position_ids = current_inputs ["position_ids" ]
285- if engine .use_mrope and current_inputs .get (
276+ if self . config .use_mrope and current_inputs .get (
286277 'multimodal_params' ) is not None :
287278 static_tensors ["position_ids" ][:, :, :seqlen ].copy_ (position_ids )
288279 for i , multimodal_param in enumerate (
@@ -302,16 +293,16 @@ def replay(self, key: Tuple[int, int, int],
302293 return output_ref
303294
304295 def _get_padded_batch (self , batch : ScheduledRequests ,
305- resource_manager : ResourceManager ) -> int :
306- engine = self . _get_engine ()
296+ resource_manager : ResourceManager ,
297+ runtime_draft_len : int ) -> int :
307298 kv_cache_manager = resource_manager .get_resource_manager (
308- engine .kv_cache_manager_key )
299+ self . config .kv_cache_manager_key )
309300 can_run_cuda_graph = batch .can_run_cuda_graph
310301 batch_size = batch .batch_size
311302 new_batch_size = batch_size
312303
313- if self .enabled and engine . enable_attention_dp and engine .mapping .tp_size > 1 :
314- graph_batch_size = engine .dist .tp_allgather (
304+ if self .enabled and self . config . enable_attention_dp and self . config .mapping .tp_size > 1 :
305+ graph_batch_size = self . config .dist .tp_allgather (
315306 [can_run_cuda_graph , batch_size ])
316307 all_can_graph = all (graph_batch [0 ]
317308 for graph_batch in graph_batch_size )
@@ -329,7 +320,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
329320 return 0
330321
331322 padding_size = padded_batch_size - batch_size
332- if padding_size + batch .batch_size > engine .batch_size :
323+ if padding_size + batch .batch_size > self . config .batch_size :
333324 return 0
334325
335326 # No padding if it would create too many concurrent requests.
@@ -344,9 +335,9 @@ def _get_padded_batch(self, batch: ScheduledRequests,
344335 self .padding_dummy_request = kv_cache_manager .add_dummy_requests (
345336 [CUDA_GRAPH_DUMMY_REQUEST_ID ],
346337 is_gen = True ,
347- max_num_draft_tokens = engine . runtime_draft_len ,
348- use_mrope = engine .use_mrope ,
349- max_beam_width = engine .max_beam_width )[0 ]
338+ max_num_draft_tokens = runtime_draft_len ,
339+ use_mrope = self . config .use_mrope ,
340+ max_beam_width = self . config .max_beam_width )[0 ]
350341 self .padding_dummy_request .is_cuda_graph_dummy = True
351342 spec_res_mgr = resource_manager .get_resource_manager (
352343 ResourceManagerType .SPEC_RESOURCE_MANAGER )
@@ -368,11 +359,11 @@ def _round_up_batch_size(self, batch_size: int) -> int:
368359
369360 @contextlib .contextmanager
370361 def pad_batch (self , scheduled_requests : ScheduledRequests ,
371- resource_manager : ResourceManager ):
362+ resource_manager : ResourceManager , runtime_draft_len : int ):
372363 """Context manager to pad a batch to a graph-compatible size."""
373-
374364 padding_size = self ._get_padded_batch (scheduled_requests ,
375- resource_manager )
365+ resource_manager ,
366+ runtime_draft_len )
376367 try :
377368 yield scheduled_requests
378369 finally :
0 commit comments