Skip to content

Commit 9d8abf0

Browse files
reuse prepro orchestrator
1 parent b0fcdd2 commit 9d8abf0

File tree

4 files changed

+44
-14
lines changed

4 files changed

+44
-14
lines changed

src/streamdiffusion/modules/controlnet_module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from streamdiffusion.preprocessing.preprocessing_orchestrator import (
1313
PreprocessingOrchestrator,
1414
)
15+
from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser
1516

1617

1718
@dataclass
@@ -23,7 +24,7 @@ class ControlNetConfig:
2324
preprocessor_params: Optional[Dict[str, Any]] = None
2425

2526

26-
class ControlNetModule:
27+
class ControlNetModule(OrchestratorUser):
2728
"""ControlNet module that provides a UNet hook for residual conditioning.
2829
2930
Responsibilities in this step (3):
@@ -57,9 +58,8 @@ def install(self, stream) -> None:
5758
self.device = stream.device
5859
self.dtype = stream.dtype
5960
if self._preprocessing_orchestrator is None:
60-
self._preprocessing_orchestrator = PreprocessingOrchestrator(
61-
device=self.device, dtype=self.dtype, max_workers=4
62-
)
61+
# Enforce shared orchestrator via base helper (raises if missing)
62+
self.attach_orchestrator(stream)
6363
# Register UNet hook
6464
stream.unet_hooks.append(self.build_unet_hook())
6565
# Expose controlnet collections so existing updater can find them

src/streamdiffusion/modules/ipadapter_module.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from streamdiffusion.hooks import EmbedsCtx, EmbeddingHook, StepCtx, UnetKwargsDelta, UnetHook
88
import os
9+
from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser
910

1011

1112
@dataclass
@@ -23,7 +24,7 @@ class IPAdapterConfig:
2324
scale: float = 1.0
2425

2526

26-
class IPAdapterModule:
27+
class IPAdapterModule(OrchestratorUser):
2728
"""IP-Adapter embedding hook provider.
2829
2930
Produces an embedding hook that concatenates cached image tokens (from
@@ -98,6 +99,9 @@ def install(self, stream) -> None:
9899
logger = __import__('logging').getLogger(__name__)
99100
style_key = self.config.style_image_key or "ipadapter_main"
100101

102+
# Attach shared orchestrator to ensure consistent reuse across modules
103+
self.attach_orchestrator(stream)
104+
101105
# Validate required paths
102106
if not self.config.ipadapter_model_path or not self.config.image_encoder_path:
103107
raise ValueError("IPAdapterModule.install: ipadapter_model_path and image_encoder_path are required")
@@ -128,7 +132,7 @@ def install(self, stream) -> None:
128132
)
129133
self.ipadapter = ipadapter
130134

131-
# Register embedding preprocessor for this style key
135+
# Register embedding preprocessor for this style key
132136
embedding_preprocessor = IPAdapterEmbeddingPreprocessor(
133137
ipadapter=ipadapter,
134138
device=stream.device,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional
4+
5+
from .preprocessing_orchestrator import PreprocessingOrchestrator
6+
7+
8+
class OrchestratorUser:
9+
"""
10+
Minimal base class to attach a shared PreprocessingOrchestrator from the stream.
11+
No convenience methods; strictly enforces presence of a shared orchestrator on stream.
12+
"""
13+
14+
_preprocessing_orchestrator: Optional[PreprocessingOrchestrator] = None
15+
16+
def attach_orchestrator(self, stream) -> None:
17+
orchestrator = getattr(stream, 'preprocessing_orchestrator', None)
18+
if orchestrator is None:
19+
# Lazy-create on stream once, on first user that needs it
20+
orchestrator = PreprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4)
21+
setattr(stream, 'preprocessing_orchestrator', orchestrator)
22+
self._preprocessing_orchestrator = orchestrator
23+
24+

src/streamdiffusion/stream_parameter_updater.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import logging
88
logger = logging.getLogger(__name__)
9+
from .preprocessing.orchestrator_user import OrchestratorUser
910

1011
class CacheStats:
1112
"""Helper class to track cache statistics"""
@@ -20,7 +21,7 @@ def record_miss(self):
2021
self.misses += 1
2122

2223

23-
class StreamParameterUpdater:
24+
class StreamParameterUpdater(OrchestratorUser):
2425
def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True):
2526
self.stream = stream_diffusion
2627
self.wrapper = wrapper # Reference to wrapper for accessing pipeline structure
@@ -40,11 +41,15 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo
4041
self._seed_cache_stats = CacheStats()
4142

4243

44+
# Attach shared orchestrator once (lazy-creates on stream if absent)
45+
self.attach_orchestrator(self.stream)
46+
4347
# IPAdapter embedding preprocessing
4448
self._embedding_preprocessors = []
4549
self._embedding_cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
4650
self._current_style_images: Dict[str, Any] = {}
47-
self._embedding_orchestrator = None
51+
# Use the shared orchestrator attached via OrchestratorUser
52+
self._embedding_orchestrator = self._preprocessing_orchestrator
4853
def get_cache_info(self) -> Dict:
4954
"""Get cache statistics for monitoring performance."""
5055
total_requests = self._prompt_cache_stats.hits + self._prompt_cache_stats.misses
@@ -100,12 +105,9 @@ def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: st
100105
style_image_key: Unique key for the style image this preprocessor handles
101106
"""
102107
if self._embedding_orchestrator is None:
103-
from .preprocessing.preprocessing_orchestrator import PreprocessingOrchestrator
104-
self._embedding_orchestrator = PreprocessingOrchestrator(
105-
device=self.stream.device,
106-
dtype=self.stream.dtype,
107-
max_workers=4
108-
)
108+
# Ensure orchestrator is present
109+
self.attach_orchestrator(self.stream)
110+
self._embedding_orchestrator = self._preprocessing_orchestrator
109111

110112
self._embedding_preprocessors.append((preprocessor, style_image_key))
111113

0 commit comments

Comments
 (0)