@@ -40,14 +40,42 @@ def __init__(self, engine: "PyTorchModelEngine"):
4040 self .max_beam_width = engine .max_beam_width
4141 self .spec_config = engine .spec_config
4242
43+ self .max_possible_draft_len = (self .spec_config .max_draft_len
44+ if self .enable_spec_decode else 0 )
45+
4346 self .graphs : Dict [Tuple [int , int ], torch .cuda .CUDAGraph ] = {}
44- self .static_inputs : Dict [Tuple [int , int ], Dict [str , torch .Tensor ]] = {}
4547 self .graph_outputs : Dict [Tuple [int , int ],
4648 Callable [[], Optional [torch .Tensor ]]] = {}
4749 self .graph_metadata : Dict [Tuple [int , int ], Dict [str , Any ]] = {}
4850 self .memory_pool = engine ._cuda_graph_mem_pool
4951 self .padding_dummy_request : Optional ["Request" ] = None
5052
53+ self .shared_static_tensors : Dict [str , torch .Tensor ] = {}
54+ if self .enabled :
55+ self ._create_shared_static_tensors ()
56+
57+ def _create_shared_static_tensors (self ):
58+ """Allocates static tensors sized for the largest possible batch."""
59+ engine = self ._get_engine ()
60+
61+ token_per_request = self .max_possible_draft_len + 1
62+ max_total_tokens = (self .max_supported_batch_size *
63+ self .max_beam_width * token_per_request )
64+ max_total_tokens = min (max_total_tokens , engine .max_num_tokens )
65+
66+ self .shared_static_tensors = {
67+ "input_ids" :
68+ torch .ones ((max_total_tokens , ), device = "cuda" , dtype = torch .int32 ),
69+ "position_ids" :
70+ torch .zeros ((1 , max_total_tokens ), device = "cuda" ,
71+ dtype = torch .int32 ),
72+ }
73+ if engine .use_mrope :
74+ self .shared_static_tensors ["mrope_position_deltas" ] = torch .zeros (
75+ (self .max_supported_batch_size , 1 ),
76+ device = "cuda" ,
77+ dtype = torch .int32 )
78+
5179 @property
5280 def enable_spec_decode (self ):
5381 return self ._get_engine ().is_spec_decode
@@ -139,38 +167,32 @@ def needs_capture(self, batch_size: int):
139167 def capture (self , batch_size : int , forward_fn : Callable ,
140168 initial_inputs : Dict [str , Any ]):
141169 """Captures the forward pass for a given batch size."""
142- engine = self ._get_engine ()
143170 key = (batch_size , self .draft_len )
144- spec_metadata = initial_inputs .get ("spec_metadata" , None )
145171 # [CUDA graph spec decode padding]
146172 # We pad input IDs/position IDs to the maximum draft length (token per request).
147173 # We're forced to do this because we cannot reallocate inputs over many graph runs.
148- token_per_request = spec_metadata .max_draft_len + 1 if spec_metadata is not None else 1
174+ token_per_request = self .max_possible_draft_len + 1
175+ num_tokens_for_capture = (batch_size * self .max_beam_width *
176+ token_per_request )
149177
150- static_tensors = {
178+ sliced_static_tensors = {
151179 "input_ids" :
152- torch .ones ((batch_size * self .max_beam_width * token_per_request , ),
153- device = "cuda" ,
154- dtype = torch .int32 ),
180+ self .shared_static_tensors ["input_ids" ][:num_tokens_for_capture ],
155181 "position_ids" :
156- torch .zeros ((
157- 1 ,
158- batch_size * self .max_beam_width * token_per_request ,
159- ),
160- device = "cuda" ,
161- dtype = torch .int32 ),
182+ self .shared_static_tensors ["position_ids" ]
183+ [:, :num_tokens_for_capture ],
162184 }
163- if engine .use_mrope :
164- static_tensors ["mrope_position_deltas" ] = torch .zeros (
165- (batch_size , 1 ), device = "cuda" , dtype = torch .int32 )
166- self .static_inputs [key ] = static_tensors
185+ if "mrope_position_deltas" in self .shared_static_tensors :
186+ sliced_static_tensors ["mrope_position_deltas" ] = \
187+ self .shared_static_tensors ["mrope_position_deltas" ][:batch_size ]
167188
189+ # Use the sliced tensors for capture
168190 capture_inputs = initial_inputs .copy ()
169- capture_inputs .update (static_tensors )
191+ capture_inputs .update (sliced_static_tensors )
170192
171193 self .graph_metadata [key ] = {
172194 "attn_metadata" : initial_inputs ["attn_metadata" ],
173- "spec_metadata" : spec_metadata ,
195+ "spec_metadata" : initial_inputs . get ( " spec_metadata" , None ) ,
174196 }
175197
176198 # We have to do warm up runs to initialize PyTorch's
@@ -198,7 +220,7 @@ def replay(self, batch_size: int,
198220 assert current_inputs .get (
199221 "spec_metadata" ) is stored_meta ["spec_metadata" ]
200222
201- static_tensors = self .static_inputs [ key ]
223+ static_tensors = self .shared_static_tensors
202224
203225 input_ids = current_inputs ["input_ids" ]
204226 seqlen = input_ids .shape [0 ]
@@ -301,7 +323,6 @@ def clear(self):
301323 for graph in self .graphs .values ():
302324 graph .reset ()
303325 self .graphs .clear ()
304- self .static_inputs .clear ()
305326 self .graph_outputs .clear ()
306327 self .graph_metadata .clear ()
307328 self .padding_dummy_request = None
0 commit comments