Skip to content

Commit c061ae1

Browse files
disable add on parallel
1 parent b5d3d71 commit c061ae1

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

src/streamdiffusion/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]:
124124
'normalize_prompt_weights': config.get('normalize_prompt_weights', True),
125125
'normalize_seed_weights': config.get('normalize_seed_weights', True),
126126
'enable_pytorch_fallback': config.get('enable_pytorch_fallback', False),
127+
# Concurrency options
128+
'controlnet_max_parallel': config.get('controlnet_max_parallel'),
129+
'controlnet_block_add_when_parallel': config.get('controlnet_block_add_when_parallel', True),
127130
}
128131
if 'controlnets' in config and config['controlnets']:
129132
param_map['use_controlnet'] = True

src/streamdiffusion/stream_parameter_updater.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,22 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non
10021002

10031003
if existing_index is None:
10041004
# Add new controlnet
1005+
# Respect wrapper/init configuration: block adds when parallel enabled
1006+
try:
1007+
block_add = bool(getattr(self.stream, 'controlnet_block_add_when_parallel', True))
1008+
except Exception:
1009+
block_add = True
1010+
concurrency_active = False
1011+
try:
1012+
cn_module = getattr(self.stream, '_controlnet_module', None)
1013+
if cn_module is not None:
1014+
max_par = int(getattr(cn_module, '_max_parallel_controlnets', 0))
1015+
concurrency_active = max_par > 1
1016+
except Exception:
1017+
concurrency_active = False
1018+
if block_add and concurrency_active:
1019+
logger.warning(f"_update_controlnet_config: Add blocked by configuration while parallel ControlNet is active; skipping add for {model_id}")
1020+
continue
10051021
logger.info(f"_update_controlnet_config: Adding ControlNet {model_id}")
10061022
try:
10071023
# Prefer module path: construct ControlNetConfig

src/streamdiffusion/wrapper.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def __init__(
110110
# IPAdapter options
111111
use_ipadapter: bool = False,
112112
ipadapter_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
113+
# Concurrency options
114+
controlnet_max_parallel: Optional[int] = None,
115+
controlnet_block_add_when_parallel: bool = True,
113116
):
114117
"""
115118
Initializes the StreamDiffusionWrapper.
@@ -198,6 +201,9 @@ def __init__(
198201
self.enable_pytorch_fallback = enable_pytorch_fallback
199202
self.use_ipadapter = use_ipadapter
200203
self.ipadapter_config = ipadapter_config
204+
# Concurrency settings
205+
self.controlnet_max_parallel = controlnet_max_parallel
206+
self.controlnet_block_add_when_parallel = controlnet_block_add_when_parallel
201207

202208
if mode == "txt2img":
203209
if cfg_type != "none":
@@ -1482,6 +1488,17 @@ def _load_model(
14821488
from streamdiffusion.modules.controlnet_module import ControlNetModule, ControlNetConfig
14831489
cn_module = ControlNetModule(device=self.device, dtype=self.dtype)
14841490
cn_module.install(stream)
1491+
# Apply configured max parallel if provided
1492+
try:
1493+
if self.controlnet_max_parallel is not None:
1494+
setattr(cn_module, '_max_parallel_controlnets', int(self.controlnet_max_parallel))
1495+
except Exception:
1496+
pass
1497+
# Expose add-blocking policy on stream
1498+
try:
1499+
setattr(stream, 'controlnet_block_add_when_parallel', bool(self.controlnet_block_add_when_parallel))
1500+
except Exception:
1501+
pass
14851502
# Normalize to list of configs
14861503
configs = controlnet_config if isinstance(controlnet_config, list) else [controlnet_config]
14871504
for cfg in configs:

0 commit comments

Comments
 (0)