Skip to content

Commit b0fcdd2

Browse files
cache controlimage per frame
1 parent 06fd221 commit b0fcdd2

File tree

1 file changed

+67
-15
lines changed

1 file changed

+67
-15
lines changed

src/streamdiffusion/modules/controlnet_module.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) ->
4747
self._preprocessing_orchestrator: Optional[PreprocessingOrchestrator] = None
4848

4949
self._stream = None # set in install
50+
# Per-frame prepared tensor cache to avoid per-step device/dtype alignment and batch repeats
51+
self._prepared_cache: Optional[Dict[str, Any]] = None
52+
self._images_version: int = 0
5053

5154
# ---------- Public API (used by wrapper in a later step) ----------
5255
def install(self, stream) -> None:
@@ -63,6 +66,8 @@ def install(self, stream) -> None:
6366
setattr(stream, 'controlnets', self.controlnets)
6467
setattr(stream, 'controlnet_scales', self.controlnet_scales)
6568
setattr(stream, 'preprocessors', self.preprocessors)
69+
# Reset caches on install
70+
self._prepared_cache = None
6671

6772
def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None:
6873
model = self._load_pytorch_controlnet_model(cfg.model_id)
@@ -93,6 +98,18 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
9398
except Exception:
9499
pass
95100

101+
# Align preprocessor target size with stream resolution once (avoid double-resize later)
102+
try:
103+
if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict):
104+
preproc.params['image_width'] = int(self._stream.width)
105+
preproc.params['image_height'] = int(self._stream.height)
106+
if hasattr(preproc, 'image_width'):
107+
setattr(preproc, 'image_width', int(self._stream.width))
108+
if hasattr(preproc, 'image_height'):
109+
setattr(preproc, 'image_height', int(self._stream.height))
110+
except Exception:
111+
pass
112+
96113
image_tensor: Optional[torch.Tensor] = None
97114
if control_image is not None and self._preprocessing_orchestrator is not None:
98115
image_tensor = self._prepare_control_image(control_image, preproc)
@@ -103,6 +120,9 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
103120
self.controlnet_scales.append(float(cfg.conditioning_scale))
104121
self.preprocessors.append(preproc)
105122
self.enabled_list.append(bool(cfg.enabled))
123+
# Invalidate prepared cache and bump version when graph changes
124+
self._prepared_cache = None
125+
self._images_version += 1
106126

107127
def update_control_image_efficient(self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None) -> None:
108128
if self._preprocessing_orchestrator is None:
@@ -134,6 +154,9 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
134154
with self._collections_lock:
135155
if processed is not None and index < len(self.controlnet_images):
136156
self.controlnet_images[index] = processed
157+
# Invalidate prepared cache and bump version for per-frame reuse
158+
self._prepared_cache = None
159+
self._images_version += 1
137160
return
138161

139162
# Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync)
@@ -154,6 +177,9 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
154177
for i, img in enumerate(processed_images):
155178
if img is not None and i < len(self.controlnet_images):
156179
self.controlnet_images[i] = img
180+
# Invalidate prepared cache and bump version after bulk update
181+
self._prepared_cache = None
182+
self._images_version += 1
157183

158184
def update_controlnet_scale(self, index: int, scale: float) -> None:
159185
with self._collections_lock:
@@ -177,6 +203,9 @@ def remove_controlnet(self, index: int) -> None:
177203
del self.preprocessors[index]
178204
if index < len(self.enabled_list):
179205
del self.enabled_list[index]
206+
# Invalidate prepared cache and bump version
207+
self._prepared_cache = None
208+
self._images_version += 1
180209

181210
def reorder_controlnets_by_model_ids(self, desired_model_ids: List[str]) -> None:
182211
"""Reorder internal collections to match the desired model_id order.
@@ -295,7 +324,42 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
295324
down_samples_list: List[List[torch.Tensor]] = []
296325
mid_samples_list: List[torch.Tensor] = []
297326

298-
for cn, img, scale in zip(active_controlnets, active_images, active_scales):
327+
# Prepare control images once per frame for current device/dtype/batch
328+
try:
329+
main_batch = x_t.shape[0]
330+
cache_ok = (
331+
isinstance(self._prepared_cache, dict)
332+
and self._prepared_cache.get('device') == x_t.device
333+
and self._prepared_cache.get('dtype') == x_t.dtype
334+
and self._prepared_cache.get('batch') == main_batch
335+
and self._prepared_cache.get('version') == self._images_version
336+
)
337+
if not cache_ok:
338+
prepared: List[Optional[torch.Tensor]] = [None] * len(self.controlnet_images)
339+
for i, base_img in enumerate(self.controlnet_images):
340+
if base_img is None:
341+
continue
342+
cur = base_img
343+
if cur.dim() == 4 and cur.shape[0] != main_batch:
344+
if cur.shape[0] == 1:
345+
cur = cur.repeat(main_batch, 1, 1, 1)
346+
else:
347+
repeat_factor = max(1, main_batch // cur.shape[0])
348+
cur = cur.repeat(repeat_factor, 1, 1, 1)
349+
cur = cur.to(device=x_t.device, dtype=x_t.dtype)
350+
prepared[i] = cur
351+
self._prepared_cache = {
352+
'device': x_t.device,
353+
'dtype': x_t.dtype,
354+
'batch': main_batch,
355+
'version': self._images_version,
356+
'prepared': prepared,
357+
}
358+
prepared_images: List[Optional[torch.Tensor]] = self._prepared_cache['prepared'] if self._prepared_cache else [None] * len(self.controlnet_images)
359+
except Exception:
360+
prepared_images = active_images # Fallback to per-step path if cache prep fails
361+
362+
for cn, img, scale, idx_i in zip(active_controlnets, active_images, active_scales, active_indices):
299363
# Swap to TRT engine if compiled and available for this model_id
300364
try:
301365
model_id = getattr(cn, 'model_id', None)
@@ -304,22 +368,10 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
304368
# Swapped to TRT engine
305369
except Exception:
306370
pass
307-
current_img = img
371+
# Pull from prepared cache if available
372+
current_img = prepared_images[idx_i] if 'prepared_images' in locals() and prepared_images and idx_i < len(prepared_images) and prepared_images[idx_i] is not None else img
308373
if current_img is None:
309374
continue
310-
# Ensure control image batch matches latent batch for TRT engines
311-
try:
312-
main_batch = x_t.shape[0]
313-
if current_img.dim() == 4 and current_img.shape[0] != main_batch:
314-
if current_img.shape[0] == 1:
315-
current_img = current_img.repeat(main_batch, 1, 1, 1)
316-
else:
317-
repeat_factor = max(1, main_batch // current_img.shape[0])
318-
current_img = current_img.repeat(repeat_factor, 1, 1, 1)
319-
# Align device/dtype with latent for engine inputs
320-
current_img = current_img.to(device=x_t.device, dtype=x_t.dtype)
321-
except Exception:
322-
pass
323375
kwargs = base_kwargs.copy()
324376
kwargs['controlnet_cond'] = current_img
325377
kwargs['conditioning_scale'] = float(scale)

0 commit comments

Comments
 (0)