@@ -47,6 +47,9 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) ->
47
47
self ._preprocessing_orchestrator : Optional [PreprocessingOrchestrator ] = None
48
48
49
49
self ._stream = None # set in install
50
+ # Per-frame prepared tensor cache to avoid per-step device/dtype alignment and batch repeats
51
+ self ._prepared_cache : Optional [Dict [str , Any ]] = None
52
+ self ._images_version : int = 0
50
53
51
54
# ---------- Public API (used by wrapper in a later step) ----------
52
55
def install (self , stream ) -> None :
@@ -63,6 +66,8 @@ def install(self, stream) -> None:
63
66
setattr (stream , 'controlnets' , self .controlnets )
64
67
setattr (stream , 'controlnet_scales' , self .controlnet_scales )
65
68
setattr (stream , 'preprocessors' , self .preprocessors )
69
+ # Reset caches on install
70
+ self ._prepared_cache = None
66
71
67
72
def add_controlnet (self , cfg : ControlNetConfig , control_image : Optional [Union [str , Any , torch .Tensor ]] = None ) -> None :
68
73
model = self ._load_pytorch_controlnet_model (cfg .model_id )
@@ -93,6 +98,18 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
93
98
except Exception :
94
99
pass
95
100
101
+ # Align preprocessor target size with stream resolution once (avoid double-resize later)
102
+ try :
103
+ if hasattr (preproc , 'params' ) and isinstance (getattr (preproc , 'params' ), dict ):
104
+ preproc .params ['image_width' ] = int (self ._stream .width )
105
+ preproc .params ['image_height' ] = int (self ._stream .height )
106
+ if hasattr (preproc , 'image_width' ):
107
+ setattr (preproc , 'image_width' , int (self ._stream .width ))
108
+ if hasattr (preproc , 'image_height' ):
109
+ setattr (preproc , 'image_height' , int (self ._stream .height ))
110
+ except Exception :
111
+ pass
112
+
96
113
image_tensor : Optional [torch .Tensor ] = None
97
114
if control_image is not None and self ._preprocessing_orchestrator is not None :
98
115
image_tensor = self ._prepare_control_image (control_image , preproc )
@@ -103,6 +120,9 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
103
120
self .controlnet_scales .append (float (cfg .conditioning_scale ))
104
121
self .preprocessors .append (preproc )
105
122
self .enabled_list .append (bool (cfg .enabled ))
123
+ # Invalidate prepared cache and bump version when graph changes
124
+ self ._prepared_cache = None
125
+ self ._images_version += 1
106
126
107
127
def update_control_image_efficient (self , control_image : Union [str , Any , torch .Tensor ], index : Optional [int ] = None ) -> None :
108
128
if self ._preprocessing_orchestrator is None :
@@ -134,6 +154,9 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
134
154
with self ._collections_lock :
135
155
if processed is not None and index < len (self .controlnet_images ):
136
156
self .controlnet_images [index ] = processed
157
+ # Invalidate prepared cache and bump version for per-frame reuse
158
+ self ._prepared_cache = None
159
+ self ._images_version += 1
137
160
return
138
161
139
162
# Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync)
@@ -154,6 +177,9 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
154
177
for i , img in enumerate (processed_images ):
155
178
if img is not None and i < len (self .controlnet_images ):
156
179
self .controlnet_images [i ] = img
180
+ # Invalidate prepared cache and bump version after bulk update
181
+ self ._prepared_cache = None
182
+ self ._images_version += 1
157
183
158
184
def update_controlnet_scale (self , index : int , scale : float ) -> None :
159
185
with self ._collections_lock :
@@ -177,6 +203,9 @@ def remove_controlnet(self, index: int) -> None:
177
203
del self .preprocessors [index ]
178
204
if index < len (self .enabled_list ):
179
205
del self .enabled_list [index ]
206
+ # Invalidate prepared cache and bump version
207
+ self ._prepared_cache = None
208
+ self ._images_version += 1
180
209
181
210
def reorder_controlnets_by_model_ids (self , desired_model_ids : List [str ]) -> None :
182
211
"""Reorder internal collections to match the desired model_id order.
@@ -295,7 +324,42 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
295
324
down_samples_list : List [List [torch .Tensor ]] = []
296
325
mid_samples_list : List [torch .Tensor ] = []
297
326
298
- for cn , img , scale in zip (active_controlnets , active_images , active_scales ):
327
+ # Prepare control images once per frame for current device/dtype/batch
328
+ try :
329
+ main_batch = x_t .shape [0 ]
330
+ cache_ok = (
331
+ isinstance (self ._prepared_cache , dict )
332
+ and self ._prepared_cache .get ('device' ) == x_t .device
333
+ and self ._prepared_cache .get ('dtype' ) == x_t .dtype
334
+ and self ._prepared_cache .get ('batch' ) == main_batch
335
+ and self ._prepared_cache .get ('version' ) == self ._images_version
336
+ )
337
+ if not cache_ok :
338
+ prepared : List [Optional [torch .Tensor ]] = [None ] * len (self .controlnet_images )
339
+ for i , base_img in enumerate (self .controlnet_images ):
340
+ if base_img is None :
341
+ continue
342
+ cur = base_img
343
+ if cur .dim () == 4 and cur .shape [0 ] != main_batch :
344
+ if cur .shape [0 ] == 1 :
345
+ cur = cur .repeat (main_batch , 1 , 1 , 1 )
346
+ else :
347
+ repeat_factor = max (1 , main_batch // cur .shape [0 ])
348
+ cur = cur .repeat (repeat_factor , 1 , 1 , 1 )
349
+ cur = cur .to (device = x_t .device , dtype = x_t .dtype )
350
+ prepared [i ] = cur
351
+ self ._prepared_cache = {
352
+ 'device' : x_t .device ,
353
+ 'dtype' : x_t .dtype ,
354
+ 'batch' : main_batch ,
355
+ 'version' : self ._images_version ,
356
+ 'prepared' : prepared ,
357
+ }
358
+ prepared_images : List [Optional [torch .Tensor ]] = self ._prepared_cache ['prepared' ] if self ._prepared_cache else [None ] * len (self .controlnet_images )
359
+ except Exception :
360
+ prepared_images = active_images # Fallback to per-step path if cache prep fails
361
+
362
+ for cn , img , scale , idx_i in zip (active_controlnets , active_images , active_scales , active_indices ):
299
363
# Swap to TRT engine if compiled and available for this model_id
300
364
try :
301
365
model_id = getattr (cn , 'model_id' , None )
@@ -304,22 +368,10 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
304
368
# Swapped to TRT engine
305
369
except Exception :
306
370
pass
307
- current_img = img
371
+ # Pull from prepared cache if available
372
+ current_img = prepared_images [idx_i ] if 'prepared_images' in locals () and prepared_images and idx_i < len (prepared_images ) and prepared_images [idx_i ] is not None else img
308
373
if current_img is None :
309
374
continue
310
- # Ensure control image batch matches latent batch for TRT engines
311
- try :
312
- main_batch = x_t .shape [0 ]
313
- if current_img .dim () == 4 and current_img .shape [0 ] != main_batch :
314
- if current_img .shape [0 ] == 1 :
315
- current_img = current_img .repeat (main_batch , 1 , 1 , 1 )
316
- else :
317
- repeat_factor = max (1 , main_batch // current_img .shape [0 ])
318
- current_img = current_img .repeat (repeat_factor , 1 , 1 , 1 )
319
- # Align device/dtype with latent for engine inputs
320
- current_img = current_img .to (device = x_t .device , dtype = x_t .dtype )
321
- except Exception :
322
- pass
323
375
kwargs = base_kwargs .copy ()
324
376
kwargs ['controlnet_cond' ] = current_img
325
377
kwargs ['conditioning_scale' ] = float (scale )
0 commit comments