@@ -49,7 +49,10 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) ->
49
49
50
50
self ._stream = None # set in install
51
51
# Per-frame prepared tensor cache to avoid per-step device/dtype alignment and batch repeats
52
- self ._prepared_cache : Optional [Dict [str , Any ]] = None
52
+ self ._prepared_tensors : List [Optional [torch .Tensor ]] = []
53
+ self ._prepared_device : Optional [torch .device ] = None
54
+ self ._prepared_dtype : Optional [torch .dtype ] = None
55
+ self ._prepared_batch : Optional [int ] = None
53
56
self ._images_version : int = 0
54
57
55
58
# ---------- Public API (used by wrapper in a later step) ----------
@@ -66,8 +69,11 @@ def install(self, stream) -> None:
66
69
setattr (stream , 'controlnets' , self .controlnets )
67
70
setattr (stream , 'controlnet_scales' , self .controlnet_scales )
68
71
setattr (stream , 'preprocessors' , self .preprocessors )
69
- # Reset caches on install
70
- self ._prepared_cache = None
72
+ # Reset prepared tensors on install
73
+ self ._prepared_tensors = []
74
+ self ._prepared_device = None
75
+ self ._prepared_dtype = None
76
+ self ._prepared_batch = None
71
77
72
78
def add_controlnet (self , cfg : ControlNetConfig , control_image : Optional [Union [str , Any , torch .Tensor ]] = None ) -> None :
73
79
model = self ._load_pytorch_controlnet_model (cfg .model_id )
@@ -120,8 +126,8 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
120
126
self .controlnet_scales .append (float (cfg .conditioning_scale ))
121
127
self .preprocessors .append (preproc )
122
128
self .enabled_list .append (bool (cfg .enabled ))
123
- # Invalidate prepared cache and bump version when graph changes
124
- self ._prepared_cache = None
129
+ # Invalidate prepared tensors and bump version when graph changes
130
+ self ._prepared_tensors = []
125
131
self ._images_version += 1
126
132
127
133
def update_control_image_efficient (self , control_image : Union [str , Any , torch .Tensor ], index : Optional [int ] = None ) -> None :
@@ -154,9 +160,13 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
154
160
with self ._collections_lock :
155
161
if processed is not None and index < len (self .controlnet_images ):
156
162
self .controlnet_images [index ] = processed
157
- # Invalidate prepared cache and bump version for per-frame reuse
158
- self ._prepared_cache = None
163
+ # Invalidate prepared tensors and bump version for per-frame reuse
164
+ self ._prepared_tensors = []
159
165
self ._images_version += 1
166
+ # Pre-prepare tensors if we know the target specs
167
+ if self ._stream and hasattr (self ._stream , 'device' ) and hasattr (self ._stream , 'dtype' ):
168
+ # Use default batch size of 1 for now, will be adjusted on first use
169
+ self .prepare_frame_tensors (self ._stream .device , self ._stream .dtype , 1 )
160
170
return
161
171
162
172
# Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync)
@@ -178,8 +188,12 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
178
188
if img is not None and i < len (self .controlnet_images ):
179
189
self .controlnet_images [i ] = img
180
190
# Invalidate prepared cache and bump version after bulk update
181
- self ._prepared_cache = None
191
+ self ._prepared_tensors = []
182
192
self ._images_version += 1
193
+ # Pre-prepare tensors if we know the target specs
194
+ if self ._stream and hasattr (self ._stream , 'device' ) and hasattr (self ._stream , 'dtype' ):
195
+ # Use default batch size of 1 for now, will be adjusted on first use
196
+ self .prepare_frame_tensors (self ._stream .device , self ._stream .dtype , 1 )
183
197
184
198
def update_controlnet_scale (self , index : int , scale : float ) -> None :
185
199
with self ._collections_lock :
@@ -203,8 +217,8 @@ def remove_controlnet(self, index: int) -> None:
203
217
del self .preprocessors [index ]
204
218
if index < len (self .enabled_list ):
205
219
del self .enabled_list [index ]
206
- # Invalidate prepared cache and bump version
207
- self ._prepared_cache = None
220
+ # Invalidate prepared tensors and bump version
221
+ self ._prepared_tensors = []
208
222
self ._images_version += 1
209
223
210
224
def reorder_controlnets_by_model_ids (self , desired_model_ids : List [str ]) -> None :
@@ -260,6 +274,54 @@ def get_current_config(self) -> List[Dict[str, Any]]:
260
274
})
261
275
return cfg
262
276
277
+ def prepare_frame_tensors (self , device : torch .device , dtype : torch .dtype , batch_size : int ) -> None :
278
+ """Prepare control image tensors for the current frame.
279
+
280
+ This method is called once per frame to prepare all control images with the correct
281
+ device, dtype, and batch size. This avoids redundant operations during each denoising step.
282
+
283
+ Args:
284
+ device: Target device for tensors
285
+ dtype: Target dtype for tensors
286
+ batch_size: Target batch size
287
+ """
288
+ with self ._collections_lock :
289
+ # Check if we need to re-prepare tensors
290
+ cache_valid = (
291
+ self ._prepared_device == device and
292
+ self ._prepared_dtype == dtype and
293
+ self ._prepared_batch == batch_size and
294
+ len (self ._prepared_tensors ) == len (self .controlnet_images )
295
+ )
296
+
297
+ if cache_valid :
298
+ return
299
+
300
+ # Prepare tensors for current frame
301
+ self ._prepared_tensors = []
302
+ for img in self .controlnet_images :
303
+ if img is None :
304
+ self ._prepared_tensors .append (None )
305
+ continue
306
+
307
+ # Prepare tensor with correct batch size
308
+ prepared = img
309
+ if prepared .dim () == 4 and prepared .shape [0 ] != batch_size :
310
+ if prepared .shape [0 ] == 1 :
311
+ prepared = prepared .repeat (batch_size , 1 , 1 , 1 )
312
+ else :
313
+ repeat_factor = max (1 , batch_size // prepared .shape [0 ])
314
+ prepared = prepared .repeat (repeat_factor , 1 , 1 , 1 )[:batch_size ]
315
+
316
+ # Move to correct device and dtype
317
+ prepared = prepared .to (device = device , dtype = dtype )
318
+ self ._prepared_tensors .append (prepared )
319
+
320
+ # Update cache state
321
+ self ._prepared_device = device
322
+ self ._prepared_dtype = dtype
323
+ self ._prepared_batch = batch_size
324
+
263
325
# ---------- Internal helpers ----------
264
326
def build_unet_hook (self ) -> UnetHook :
265
327
def _unet_hook (ctx : StepCtx ) -> UnetKwargsDelta :
@@ -324,40 +386,15 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
324
386
down_samples_list : List [List [torch .Tensor ]] = []
325
387
mid_samples_list : List [torch .Tensor ] = []
326
388
327
- # Prepare control images once per frame for current device/dtype/batch
328
- try :
329
- main_batch = x_t .shape [0 ]
330
- cache_ok = (
331
- isinstance (self ._prepared_cache , dict )
332
- and self ._prepared_cache .get ('device' ) == x_t .device
333
- and self ._prepared_cache .get ('dtype' ) == x_t .dtype
334
- and self ._prepared_cache .get ('batch' ) == main_batch
335
- and self ._prepared_cache .get ('version' ) == self ._images_version
336
- )
337
- if not cache_ok :
338
- prepared : List [Optional [torch .Tensor ]] = [None ] * len (self .controlnet_images )
339
- for i , base_img in enumerate (self .controlnet_images ):
340
- if base_img is None :
341
- continue
342
- cur = base_img
343
- if cur .dim () == 4 and cur .shape [0 ] != main_batch :
344
- if cur .shape [0 ] == 1 :
345
- cur = cur .repeat (main_batch , 1 , 1 , 1 )
346
- else :
347
- repeat_factor = max (1 , main_batch // cur .shape [0 ])
348
- cur = cur .repeat (repeat_factor , 1 , 1 , 1 )
349
- cur = cur .to (device = x_t .device , dtype = x_t .dtype )
350
- prepared [i ] = cur
351
- self ._prepared_cache = {
352
- 'device' : x_t .device ,
353
- 'dtype' : x_t .dtype ,
354
- 'batch' : main_batch ,
355
- 'version' : self ._images_version ,
356
- 'prepared' : prepared ,
357
- }
358
- prepared_images : List [Optional [torch .Tensor ]] = self ._prepared_cache ['prepared' ] if self ._prepared_cache else [None ] * len (self .controlnet_images )
359
- except Exception :
360
- prepared_images = active_images # Fallback to per-step path if cache prep fails
389
+ # Ensure tensors are prepared for this frame
390
+ # This should have been called earlier, but we call it here as a safety net
391
+ if (self ._prepared_device != x_t .device or
392
+ self ._prepared_dtype != x_t .dtype or
393
+ self ._prepared_batch != x_t .shape [0 ]):
394
+ self .prepare_frame_tensors (x_t .device , x_t .dtype , x_t .shape [0 ])
395
+
396
+ # Use pre-prepared tensors
397
+ prepared_images = self ._prepared_tensors
361
398
362
399
for cn , img , scale , idx_i in zip (active_controlnets , active_images , active_scales , active_indices ):
363
400
# Swap to TRT engine if compiled and available for this model_id
@@ -368,8 +405,8 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
368
405
# Swapped to TRT engine
369
406
except Exception :
370
407
pass
371
- # Pull from prepared cache if available
372
- current_img = prepared_images [idx_i ] if 'prepared_images' in locals () and prepared_images and idx_i < len (prepared_images ) and prepared_images [ idx_i ] is not None else img
408
+ # Use pre- prepared tensor
409
+ current_img = prepared_images [idx_i ] if idx_i < len (prepared_images ) else img
373
410
if current_img is None :
374
411
continue
375
412
kwargs = base_kwargs .copy ()
0 commit comments