diff --git a/src/streamdiffusion/modules/controlnet_module.py b/src/streamdiffusion/modules/controlnet_module.py index d0a5fe58..ce74fd74 100644 --- a/src/streamdiffusion/modules/controlnet_module.py +++ b/src/streamdiffusion/modules/controlnet_module.py @@ -98,20 +98,9 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st preproc = None if cfg.preprocessor: from streamdiffusion.preprocessing.processors import get_preprocessor - preproc = get_preprocessor(cfg.preprocessor, pipeline_ref=self._stream) - # Apply provided parameters to the preprocessor instance - if cfg.preprocessor_params: - params = cfg.preprocessor_params or {} - # If the preprocessor exposes a 'params' dict, update it - if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict): - preproc.params.update(params) - # Also set attributes directly when they exist - for name, value in params.items(): - try: - if hasattr(preproc, name): - setattr(preproc, name, value) - except Exception: - pass + # Pass all preprocessor params as constructor kwargs + preprocessor_params = cfg.preprocessor_params or {} + preproc = get_preprocessor(cfg.preprocessor, pipeline_ref=self._stream, **preprocessor_params) # Align preprocessor target size with stream resolution once (avoid double-resize later) diff --git a/src/streamdiffusion/modules/image_processing_module.py b/src/streamdiffusion/modules/image_processing_module.py index e3d7be66..ff21e24f 100644 --- a/src/streamdiffusion/modules/image_processing_module.py +++ b/src/streamdiffusion/modules/image_processing_module.py @@ -42,20 +42,9 @@ def add_processor(self, proc_config: Dict[str, Any]) -> None: # Check if processor is enabled (default to True, same as ControlNet) enabled = proc_config.get('enabled', True) - # Create processor using existing registry (same as ControlNet) - processor = get_preprocessor(processor_type, pipeline_ref=getattr(self, '_stream', None)) - - # Apply parameters (same pattern as ControlNet) + # Pass all processor params as constructor kwargs processor_params = proc_config.get('params', {}) - if processor_params: - if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict): - processor.params.update(processor_params) - for name, value in processor_params.items(): - try: - if hasattr(processor, name): - setattr(processor, name, value) - except Exception: - pass + processor = get_preprocessor(processor_type, pipeline_ref=getattr(self, '_stream', None), **processor_params) # Set order for sequential execution order = proc_config.get('order', len(self.processors)) diff --git a/src/streamdiffusion/modules/latent_processing_module.py b/src/streamdiffusion/modules/latent_processing_module.py index f68cd512..7a4eebbc 100644 --- a/src/streamdiffusion/modules/latent_processing_module.py +++ b/src/streamdiffusion/modules/latent_processing_module.py @@ -41,20 +41,9 @@ def add_processor(self, proc_config: Dict[str, Any]) -> None: # Check if processor is enabled (default to True, same as ControlNet) enabled = proc_config.get('enabled', True) - # Create processor using existing registry (same as ControlNet) - processor = get_preprocessor(processor_type, pipeline_ref=self._stream) - - # Apply parameters (same pattern as ControlNet) + # Pass all processor params as constructor kwargs processor_params = proc_config.get('params', {}) - if processor_params: - if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict): - processor.params.update(processor_params) - for name, value in processor_params.items(): - try: - if hasattr(processor, name): - setattr(processor, name, value) - except Exception: - pass + processor = get_preprocessor(processor_type, pipeline_ref=self._stream, **processor_params) # Set order for sequential execution order = proc_config.get('order', len(self.processors)) diff --git a/src/streamdiffusion/preprocessing/processors/__init__.py b/src/streamdiffusion/preprocessing/processors/__init__.py index 2df6508e..91b8cb0b 100644 --- a/src/streamdiffusion/preprocessing/processors/__init__.py +++ b/src/streamdiffusion/preprocessing/processors/__init__.py @@ -111,13 +111,14 @@ def get_preprocessor_class(name: str) -> type: return _preprocessor_registry[name] -def get_preprocessor(name: str, pipeline_ref: Any = None) -> BasePreprocessor: +def get_preprocessor(name: str, pipeline_ref: Any = None, **constructor_kwargs) -> BasePreprocessor: """ Get a preprocessor by name Args: name: Name of the preprocessor pipeline_ref: Pipeline reference for pipeline-aware processors (required for some processors) + **constructor_kwargs: Additional keyword arguments to pass to the processor constructor Returns: Preprocessor instance @@ -131,9 +132,9 @@ def get_preprocessor(name: str, pipeline_ref: Any = None) -> BasePreprocessor: if hasattr(processor_class, 'requires_sync_processing') and processor_class.requires_sync_processing: if pipeline_ref is None: raise ValueError(f"Processor '{name}' requires a pipeline_ref") - return processor_class(pipeline_ref=pipeline_ref, _registry_name=name) + return processor_class(pipeline_ref=pipeline_ref, _registry_name=name, **constructor_kwargs) else: - return processor_class(_registry_name=name) + return processor_class(_registry_name=name, **constructor_kwargs) def register_preprocessor(name: str, preprocessor_class): diff --git a/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py b/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py index e7009274..003cf382 100644 --- a/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py +++ b/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py @@ -126,17 +126,27 @@ def infer(self, feed_dict, stream=None): class RealESRGANProcessor(BasePreprocessor): """ - RealESRGAN 2x upscaling processor with automatic model download, ONNX export, and TensorRT acceleration. + RealESRGAN upscaling processor with automatic model download, ONNX export, and TensorRT acceleration. + Supports both 2x and 4x upscaling models. """ - MODEL_URL = "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth?download=true" + MODEL_URLS = { + 2: "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth?download=true", + 4: "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth?download=true" + } @classmethod def get_preprocessor_metadata(cls): return { - "display_name": "RealESRGAN 2x", - "description": "High-quality 2x image upscaling using RealESRGAN with TensorRT acceleration", + "display_name": "RealESRGAN", + "description": "High-quality image upscaling using RealESRGAN with TensorRT acceleration", "parameters": { + "scale_factor": { + "type": "int", + "default": 2, + "options": [2, 4], + "description": "Upscaling factor (2x or 4x)" + }, "enable_tensorrt": { "type": "bool", "default": True, @@ -148,21 +158,27 @@ def get_preprocessor_metadata(cls): "description": "Force rebuild TensorRT engine even if it exists" } }, - "use_cases": ["High-quality upscaling", "Real-time 2x enlargement", "Image enhancement"] + "use_cases": ["High-quality upscaling", "Real-time enlargement", "Image enhancement"] } - def __init__(self, enable_tensorrt: bool = True, force_rebuild: bool = False, **kwargs): - super().__init__(enable_tensorrt=enable_tensorrt, force_rebuild=force_rebuild, **kwargs) + def __init__(self, scale_factor: int = 2, enable_tensorrt: bool = True, force_rebuild: bool = False, **kwargs): + super().__init__(scale_factor=scale_factor, enable_tensorrt=enable_tensorrt, force_rebuild=force_rebuild, **kwargs) self.enable_tensorrt = enable_tensorrt and TRT_AVAILABLE self.force_rebuild = force_rebuild - self.scale_factor = 2 # RealESRGAN 2x model - # Model paths + # Validate scale factor + if scale_factor not in self.MODEL_URLS: + available_scales = list(self.MODEL_URLS.keys()) + raise ValueError(f"__init__: Unsupported scale_factor {scale_factor}. Available: {available_scales}") + + self.scale_factor = scale_factor + + # Model paths (scale-factor specific) self.models_dir = Path("models") / "realesrgan" self.models_dir.mkdir(parents=True, exist_ok=True) - self.model_path = self.models_dir / "RealESRGAN_x2.pth" - self.onnx_path = self.models_dir / "RealESRGAN_x2.onnx" - self.engine_path = self.models_dir / f"RealESRGAN_x2_{trt.__version__ if TRT_AVAILABLE else 'notrt'}.trt" + self.model_path = self.models_dir / f"RealESRGAN_x{scale_factor}.pth" + self.onnx_path = self.models_dir / f"RealESRGAN_x{scale_factor}.onnx" + self.engine_path = self.models_dir / f"RealESRGAN_x{scale_factor}_{trt.__version__ if TRT_AVAILABLE else 'notrt'}.trt" # Model state self.pytorch_model = None @@ -218,7 +234,8 @@ def _ensure_model_ready(self): """Ensure PyTorch model is downloaded and loaded""" # Download model if needed if not self.model_path.exists(): - self._download_file(self.MODEL_URL, self.model_path) + model_url = self.MODEL_URLS[self.scale_factor] + self._download_file(model_url, self.model_path) # Load PyTorch model if self.pytorch_model is None: @@ -470,7 +487,7 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: return output_tensor def get_target_dimensions(self) -> Tuple[int, int]: - """Get target output dimensions (width, height) - 2x upscaled""" + """Get target output dimensions (width, height) - upscaled by scale_factor""" width = self.params.get('image_width') height = self.params.get('image_height')