Skip to content

Commit 3a31a67

Browse files
cleaner controlnet caching
1 parent 5cf7b3e commit 3a31a67

File tree

2 files changed

+85
-46
lines changed

2 files changed

+85
-46
lines changed

examples/controlnet/controlnet_video_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def main():
229229
print("main: Video processing completed successfully!")
230230
return 0
231231
except Exception as e:
232+
import traceback
232233
print(f"main: Error during processing: {e}")
234+
print(f"main: Traceback:\n{''.join(traceback.format_tb(e.__traceback__))}")
233235
return 1
234236

235237

src/streamdiffusion/modules/controlnet_module.py

Lines changed: 83 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) ->
4949

5050
self._stream = None # set in install
5151
# Per-frame prepared tensor cache to avoid per-step device/dtype alignment and batch repeats
52-
self._prepared_cache: Optional[Dict[str, Any]] = None
52+
self._prepared_tensors: List[Optional[torch.Tensor]] = []
53+
self._prepared_device: Optional[torch.device] = None
54+
self._prepared_dtype: Optional[torch.dtype] = None
55+
self._prepared_batch: Optional[int] = None
5356
self._images_version: int = 0
5457

5558
# ---------- Public API (used by wrapper in a later step) ----------
@@ -66,8 +69,11 @@ def install(self, stream) -> None:
6669
setattr(stream, 'controlnets', self.controlnets)
6770
setattr(stream, 'controlnet_scales', self.controlnet_scales)
6871
setattr(stream, 'preprocessors', self.preprocessors)
69-
# Reset caches on install
70-
self._prepared_cache = None
72+
# Reset prepared tensors on install
73+
self._prepared_tensors = []
74+
self._prepared_device = None
75+
self._prepared_dtype = None
76+
self._prepared_batch = None
7177

7278
def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None:
7379
model = self._load_pytorch_controlnet_model(cfg.model_id)
@@ -120,8 +126,8 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
120126
self.controlnet_scales.append(float(cfg.conditioning_scale))
121127
self.preprocessors.append(preproc)
122128
self.enabled_list.append(bool(cfg.enabled))
123-
# Invalidate prepared cache and bump version when graph changes
124-
self._prepared_cache = None
129+
# Invalidate prepared tensors and bump version when graph changes
130+
self._prepared_tensors = []
125131
self._images_version += 1
126132

127133
def update_control_image_efficient(self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None) -> None:
@@ -154,9 +160,13 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
154160
with self._collections_lock:
155161
if processed is not None and index < len(self.controlnet_images):
156162
self.controlnet_images[index] = processed
157-
# Invalidate prepared cache and bump version for per-frame reuse
158-
self._prepared_cache = None
163+
# Invalidate prepared tensors and bump version for per-frame reuse
164+
self._prepared_tensors = []
159165
self._images_version += 1
166+
# Pre-prepare tensors if we know the target specs
167+
if self._stream and hasattr(self._stream, 'device') and hasattr(self._stream, 'dtype'):
168+
# Use default batch size of 1 for now, will be adjusted on first use
169+
self.prepare_frame_tensors(self._stream.device, self._stream.dtype, 1)
160170
return
161171

162172
# Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync)
@@ -178,8 +188,12 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
178188
if img is not None and i < len(self.controlnet_images):
179189
self.controlnet_images[i] = img
180190
# Invalidate prepared cache and bump version after bulk update
181-
self._prepared_cache = None
191+
self._prepared_tensors = []
182192
self._images_version += 1
193+
# Pre-prepare tensors if we know the target specs
194+
if self._stream and hasattr(self._stream, 'device') and hasattr(self._stream, 'dtype'):
195+
# Use default batch size of 1 for now, will be adjusted on first use
196+
self.prepare_frame_tensors(self._stream.device, self._stream.dtype, 1)
183197

184198
def update_controlnet_scale(self, index: int, scale: float) -> None:
185199
with self._collections_lock:
@@ -203,8 +217,8 @@ def remove_controlnet(self, index: int) -> None:
203217
del self.preprocessors[index]
204218
if index < len(self.enabled_list):
205219
del self.enabled_list[index]
206-
# Invalidate prepared cache and bump version
207-
self._prepared_cache = None
220+
# Invalidate prepared tensors and bump version
221+
self._prepared_tensors = []
208222
self._images_version += 1
209223

210224
def reorder_controlnets_by_model_ids(self, desired_model_ids: List[str]) -> None:
@@ -260,6 +274,54 @@ def get_current_config(self) -> List[Dict[str, Any]]:
260274
})
261275
return cfg
262276

277+
def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_size: int) -> None:
278+
"""Prepare control image tensors for the current frame.
279+
280+
This method is called once per frame to prepare all control images with the correct
281+
device, dtype, and batch size. This avoids redundant operations during each denoising step.
282+
283+
Args:
284+
device: Target device for tensors
285+
dtype: Target dtype for tensors
286+
batch_size: Target batch size
287+
"""
288+
with self._collections_lock:
289+
# Check if we need to re-prepare tensors
290+
cache_valid = (
291+
self._prepared_device == device and
292+
self._prepared_dtype == dtype and
293+
self._prepared_batch == batch_size and
294+
len(self._prepared_tensors) == len(self.controlnet_images)
295+
)
296+
297+
if cache_valid:
298+
return
299+
300+
# Prepare tensors for current frame
301+
self._prepared_tensors = []
302+
for img in self.controlnet_images:
303+
if img is None:
304+
self._prepared_tensors.append(None)
305+
continue
306+
307+
# Prepare tensor with correct batch size
308+
prepared = img
309+
if prepared.dim() == 4 and prepared.shape[0] != batch_size:
310+
if prepared.shape[0] == 1:
311+
prepared = prepared.repeat(batch_size, 1, 1, 1)
312+
else:
313+
repeat_factor = max(1, batch_size // prepared.shape[0])
314+
prepared = prepared.repeat(repeat_factor, 1, 1, 1)[:batch_size]
315+
316+
# Move to correct device and dtype
317+
prepared = prepared.to(device=device, dtype=dtype)
318+
self._prepared_tensors.append(prepared)
319+
320+
# Update cache state
321+
self._prepared_device = device
322+
self._prepared_dtype = dtype
323+
self._prepared_batch = batch_size
324+
263325
# ---------- Internal helpers ----------
264326
def build_unet_hook(self) -> UnetHook:
265327
def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
@@ -324,40 +386,15 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
324386
down_samples_list: List[List[torch.Tensor]] = []
325387
mid_samples_list: List[torch.Tensor] = []
326388

327-
# Prepare control images once per frame for current device/dtype/batch
328-
try:
329-
main_batch = x_t.shape[0]
330-
cache_ok = (
331-
isinstance(self._prepared_cache, dict)
332-
and self._prepared_cache.get('device') == x_t.device
333-
and self._prepared_cache.get('dtype') == x_t.dtype
334-
and self._prepared_cache.get('batch') == main_batch
335-
and self._prepared_cache.get('version') == self._images_version
336-
)
337-
if not cache_ok:
338-
prepared: List[Optional[torch.Tensor]] = [None] * len(self.controlnet_images)
339-
for i, base_img in enumerate(self.controlnet_images):
340-
if base_img is None:
341-
continue
342-
cur = base_img
343-
if cur.dim() == 4 and cur.shape[0] != main_batch:
344-
if cur.shape[0] == 1:
345-
cur = cur.repeat(main_batch, 1, 1, 1)
346-
else:
347-
repeat_factor = max(1, main_batch // cur.shape[0])
348-
cur = cur.repeat(repeat_factor, 1, 1, 1)
349-
cur = cur.to(device=x_t.device, dtype=x_t.dtype)
350-
prepared[i] = cur
351-
self._prepared_cache = {
352-
'device': x_t.device,
353-
'dtype': x_t.dtype,
354-
'batch': main_batch,
355-
'version': self._images_version,
356-
'prepared': prepared,
357-
}
358-
prepared_images: List[Optional[torch.Tensor]] = self._prepared_cache['prepared'] if self._prepared_cache else [None] * len(self.controlnet_images)
359-
except Exception:
360-
prepared_images = active_images # Fallback to per-step path if cache prep fails
389+
# Ensure tensors are prepared for this frame
390+
# This should have been called earlier, but we call it here as a safety net
391+
if (self._prepared_device != x_t.device or
392+
self._prepared_dtype != x_t.dtype or
393+
self._prepared_batch != x_t.shape[0]):
394+
self.prepare_frame_tensors(x_t.device, x_t.dtype, x_t.shape[0])
395+
396+
# Use pre-prepared tensors
397+
prepared_images = self._prepared_tensors
361398

362399
for cn, img, scale, idx_i in zip(active_controlnets, active_images, active_scales, active_indices):
363400
# Swap to TRT engine if compiled and available for this model_id
@@ -368,8 +405,8 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
368405
# Swapped to TRT engine
369406
except Exception:
370407
pass
371-
# Pull from prepared cache if available
372-
current_img = prepared_images[idx_i] if 'prepared_images' in locals() and prepared_images and idx_i < len(prepared_images) and prepared_images[idx_i] is not None else img
408+
# Use pre-prepared tensor
409+
current_img = prepared_images[idx_i] if idx_i < len(prepared_images) else img
373410
if current_img is None:
374411
continue
375412
kwargs = base_kwargs.copy()

0 commit comments

Comments
 (0)