diff --git a/configs/sd15_multicontrol.yaml.example b/configs/sd15_multicontrol.yaml.example index 7505d157..236375e5 100644 --- a/configs/sd15_multicontrol.yaml.example +++ b/configs/sd15_multicontrol.yaml.example @@ -32,11 +32,19 @@ seed: 789 frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: true +# LoRA configuration - use lora_dict to load LCM LoRA and other LoRAs +lora_dict: + "latent-consistency/lcm-lora-sdv1-5": 1.0 # LCM LoRA for faster inference + # Add other LoRAs here: + # "your_custom_lora": 0.7 + use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + # Engine directory for TensorRT (engines will be built here if not found) engine_dir: "./engines/sd15" diff --git a/configs/sdturbo_multicontrol.yaml.example b/configs/sdturbo_multicontrol.yaml.example index 5f7b8561..a0fe0ce0 100644 --- a/configs/sdturbo_multicontrol.yaml.example +++ b/configs/sdturbo_multicontrol.yaml.example @@ -22,11 +22,19 @@ seed: 789 frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: true # SD-Turbo benefits from LCM LoRA +# LoRA configuration - SD-Turbo can benefit from LCM LoRA +lora_dict: + "latent-consistency/lcm-lora-sdv1-5": 1.0 # LCM LoRA for faster inference + # Add other LoRAs here: + # "your_custom_lora": 0.7 + use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + # Engine directory for TensorRT engine_dir: "./engines/sdturbo" diff --git a/configs/sdxl_multicontrol.yaml.example b/configs/sdxl_multicontrol.yaml.example index f32f4dab..59116ffd 100644 --- a/configs/sdxl_multicontrol.yaml.example +++ b/configs/sdxl_multicontrol.yaml.example @@ -31,11 +31,20 @@ seed: 42 # Base seed (used with seed_blending above) frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: false # SDXL has built-in optimizations +# LoRA configuration - SDXL can use LCM LoRA for faster inference +# lora_dict: +# "latent-consistency/lcm-lora-sdxl": 1.0 # Uncomment to enable LCM LoRA for SDXL +# # Add other LoRAs here: +# # "your_custom_lora": 0.7 + use_taesd: true # Use Tiny AutoEncoder for SDXL use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" + +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + safety_checker: false # Engine directory for TensorRT diff --git a/demo/realtime-img2img/config.py b/demo/realtime-img2img/config.py index 8d74eda4..56d77404 100644 --- a/demo/realtime-img2img/config.py +++ b/demo/realtime-img2img/config.py @@ -20,6 +20,7 @@ class Args(NamedTuple): controlnet_config: str api_only: bool log_level: str + quiet: bool def pretty_print(self): print("\n") @@ -34,6 +35,7 @@ def pretty_print(self): ENGINE_DIR = os.environ.get("ENGINE_DIR", "engines") ACCELERATION = os.environ.get("ACCELERATION", "xformers") LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") +QUIET = os.environ.get("QUIET", "False").lower() in ("true", "1", "yes", "on") default_host = os.getenv("HOST", "0.0.0.0") default_port = int(os.getenv("PORT", "7860")) @@ -129,5 +131,12 @@ def pretty_print(self): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", ) +parser.add_argument( + "--quiet", + dest="quiet", + action="store_true", + default=QUIET, + help="Suppress uvicorn INFO messages (server access logs, etc.)", +) config = Args(**vars(parser.parse_args())) config.pretty_print() diff --git a/demo/realtime-img2img/frontend/package-lock.json b/demo/realtime-img2img/frontend/package-lock.json index aef6d66b..89eb9c49 100644 --- a/demo/realtime-img2img/frontend/package-lock.json +++ b/demo/realtime-img2img/frontend/package-lock.json @@ -3842,9 +3842,9 @@ } }, "node_modules/svelte-check/node_modules/picomatch": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.2.tgz", - "integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==", + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", "optional": true, @@ -4238,20 +4238,6 @@ "node": ">=18" } }, - "node_modules/yaml": { - "version": "2.8.0", - "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.0.tgz", - "integrity": "sha512-4lLa/EcQCB0cJkyts+FpIRx5G/llPxfP6VQU5KByHEhLxY3IJCH0f0Hy1MHI8sClTvsIb8qwRJ6R/ZdlDJ/leQ==", - "license": "ISC", - "optional": true, - "peer": true, - "bin": { - "yaml": "bin.mjs" - }, - "engines": { - "node": ">= 14.6" - } - }, "node_modules/yocto-queue": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", diff --git a/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte b/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte index 830666d7..0c59002b 100644 --- a/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte +++ b/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte @@ -40,7 +40,6 @@ use_denoising_batch: true, delta: 0.7, frame_buffer_size: 1, - use_lcm_lora: true, use_tiny_vae: true, acceleration: "xformers", cfg_type: "self", diff --git a/demo/realtime-img2img/frontend/src/routes/+page.svelte b/demo/realtime-img2img/frontend/src/routes/+page.svelte index 11442df2..4eb886cb 100644 --- a/demo/realtime-img2img/frontend/src/routes/+page.svelte +++ b/demo/realtime-img2img/frontend/src/routes/+page.svelte @@ -1026,7 +1026,7 @@ diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py index 932808d7..bbcac9d6 100644 --- a/demo/realtime-img2img/main.py +++ b/demo/realtime-img2img/main.py @@ -59,6 +59,13 @@ def setup_logging(log_level: str = "INFO"): # Initialize logger logger = setup_logging(config.log_level) +# Suppress uvicorn INFO messages +if config.quiet: + uvicorn_logger = logging.getLogger('uvicorn') + uvicorn_logger.setLevel(logging.WARNING) + uvicorn_access_logger = logging.getLogger('uvicorn.access') + uvicorn_access_logger.setLevel(logging.WARNING) + class AppState: """Centralized application state management - SINGLE SOURCE OF TRUTH""" diff --git a/demo/realtime-img2img/requirements.txt b/demo/realtime-img2img/requirements.txt index a379a58e..dd200a25 100644 --- a/demo/realtime-img2img/requirements.txt +++ b/demo/realtime-img2img/requirements.txt @@ -1,11 +1,11 @@ diffusers==0.35.0 -transformers==4.56.0 -peft==0.18.0 +transformers==4.55.4 +peft==0.17.1 accelerate==1.10.0 -huggingface_hub==0.35.0 +huggingface_hub==0.34.4 fastapi==0.115.0 uvicorn[standard]==0.32.0 -Pillow==10.5.0 +Pillow==10.4.0 compel==2.0.2 controlnet-aux==0.0.7 xformers; sys_platform != 'darwin' or platform_machine != 'arm64' diff --git a/demo/realtime-txt2img/config.py b/demo/realtime-txt2img/config.py index c0a14ba4..35494148 100644 --- a/demo/realtime-txt2img/config.py +++ b/demo/realtime-txt2img/config.py @@ -29,8 +29,7 @@ class Config: model_id_or_path: str = os.environ.get("MODEL", "KBlueLeaf/kohaku-v2.1") # LoRA dictionary write like field(default_factory=lambda: {'E:/stable-diffusion-webui/models/Lora_1.safetensors' : 1.0 , 'E:/stable-diffusion-webui/models/Lora_2.safetensors' : 0.2}) lora_dict: dict = None - # LCM-LORA model - lcm_lora_id: str = os.environ.get("LORA", "latent-consistency/lcm-lora-sdv1-5") + # LCM-LORA model (use lora_dict instead of lcm_lora_id) # TinyVAE model vae_id: str = os.environ.get("VAE", "madebyollin/taesd") # Device to use diff --git a/demo/realtime-txt2img/main.py b/demo/realtime-txt2img/main.py index 88967c4c..18931f9e 100644 --- a/demo/realtime-txt2img/main.py +++ b/demo/realtime-txt2img/main.py @@ -63,7 +63,6 @@ def __init__(self, config: Config) -> None: mode=config.mode, model_id_or_path=config.model_id_or_path, lora_dict=config.lora_dict, - lcm_lora_id=config.lcm_lora_id, vae_id=config.vae_id, device=config.device, dtype=config.dtype, diff --git a/examples/optimal-performance/multi.py b/examples/optimal-performance/multi.py index ac2c2a53..791d88b1 100644 --- a/examples/optimal-performance/multi.py +++ b/examples/optimal-performance/multi.py @@ -74,7 +74,6 @@ def image_generation_process( frame_buffer_size=batch_size, warmup=10, acceleration=acceleration, - use_lcm_lora=False, mode="txt2img", cfg_type="none", use_denoising_batch=True, diff --git a/examples/optimal-performance/single.py b/examples/optimal-performance/single.py index 4bc08b3f..a8020bb8 100644 --- a/examples/optimal-performance/single.py +++ b/examples/optimal-performance/single.py @@ -40,7 +40,6 @@ def image_generation_process( frame_buffer_size=1, warmup=10, acceleration=acceleration, - use_lcm_lora=False, mode="txt2img", cfg_type="none", use_denoising_batch=True, diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index 85123c59..4f49bd43 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -1,3 +1,5 @@ + +import hashlib import logging from enum import Enum from pathlib import Path @@ -75,15 +77,30 @@ def __init__(self, engine_dir: str): 'loader': lambda path, cuda_stream, **kwargs: str(path) } } - + + def _lora_signature(self, lora_dict: Dict[str, float]) -> str: + """Create a short, stable signature for a set of LoRAs. + + Uses sorted basenames and weights, hashed to a short hex to avoid + long/invalid paths while keeping cache keys stable across runs. + """ + # Build canonical string of basename:weight pairs + parts = [] + for path, weight in sorted(lora_dict.items(), key=lambda x: str(x[0])): + base = Path(str(path)).name # basename only + parts.append(f"{base}:{weight}") + canon = "|".join(parts) + h = hashlib.sha1(canon.encode("utf-8")).hexdigest()[:10] + return f"{len(lora_dict)}-{h}" + def get_engine_path(self, engine_type: EngineType, model_id_or_path: str, max_batch_size: int, min_batch_size: int, mode: str, - use_lcm_lora: bool, use_tiny_vae: bool, + lora_dict: Optional[Dict[str, float]] = None, ipadapter_scale: Optional[float] = None, ipadapter_tokens: Optional[int] = None, controlnet_model_id: Optional[str] = None, @@ -114,7 +131,7 @@ def get_engine_path(self, base_name = maybe_path.stem if maybe_path.exists() else model_id_or_path # Create prefix (from wrapper.py lines 1005-1013) - prefix = f"{base_name}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--min_batch-{min_batch_size}--max_batch-{max_batch_size}" + prefix = f"{base_name}--tiny_vae-{use_tiny_vae}--min_batch-{min_batch_size}--max_batch-{max_batch_size}" # IP-Adapter differentiation: add type and (optionally) tokens # Keep scale out of identity for runtime control, but include a type flag to separate caches @@ -122,6 +139,10 @@ def get_engine_path(self, prefix += f"--fid" if ipadapter_tokens is not None: prefix += f"--tokens{ipadapter_tokens}" + + # Fused Loras - use concise hashed signature to avoid long/invalid paths + if lora_dict is not None and len(lora_dict) > 0: + prefix += f"--lora-{self._lora_signature(lora_dict)}" prefix += f"--mode-{mode}" @@ -287,7 +308,6 @@ def get_or_load_controlnet_engine(self, max_batch_size=max_batch_size, min_batch_size=min_batch_size, mode="", # Not used for ControlNet - use_lcm_lora=False, # Not used for ControlNet use_tiny_vae=False, # Not used for ControlNet controlnet_model_id=model_id ) diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py index ce1124df..3f0a3f69 100644 --- a/src/streamdiffusion/acceleration/tensorrt/utilities.py +++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py @@ -360,6 +360,29 @@ def reset_cuda_graph(self): self.graph = None def infer(self, feed_dict, stream, use_cuda_graph=False): + # Filter inputs to only those the engine actually exposes to avoid binding errors + try: + allowed_inputs = set() + for idx in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(idx) + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + allowed_inputs.add(name) + + # Drop any extra keys (e.g., text_embeds/time_ids) that the engine was not built to accept + if allowed_inputs: + filtered_feed_dict = {k: v for k, v in feed_dict.items() if k in allowed_inputs} + if len(filtered_feed_dict) != len(feed_dict): + missing = [k for k in feed_dict.keys() if k not in allowed_inputs] + if missing: + logger.debug( + "TensorRT Engine: filtering unsupported inputs %s (allowed=%s)", + missing, sorted(list(allowed_inputs)) + ) + feed_dict = filtered_feed_dict + except Exception: + # Be permissive if engine query fails; proceed with original dict + pass + for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 3ff5cb90..ac8b6f20 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -100,7 +100,6 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'lora_dict': config.get('lora_dict'), 'mode': config.get('mode', 'img2img'), 'output_type': config.get('output_type', 'pil'), - 'lcm_lora_id': config.get('lcm_lora_id'), 'vae_id': config.get('vae_id'), 'device': config.get('device', 'cuda'), 'dtype': _parse_dtype(config.get('dtype', 'float16')), @@ -111,7 +110,7 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'acceleration': config.get('acceleration', 'tensorrt'), 'do_add_noise': config.get('do_add_noise', True), 'device_ids': config.get('device_ids'), - 'use_lcm_lora': config.get('use_lcm_lora', True), + 'use_lcm_lora': config.get('use_lcm_lora'), # Backwards compatibility 'use_tiny_vae': config.get('use_tiny_vae', True), 'enable_similar_image_filter': config.get('enable_similar_image_filter', False), 'similar_image_filter_threshold': config.get('similar_image_filter_threshold', 0.98), @@ -124,6 +123,8 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'engine_dir': config.get('engine_dir', 'engines'), 'normalize_prompt_weights': config.get('normalize_prompt_weights', True), 'normalize_seed_weights': config.get('normalize_seed_weights', True), + 'scheduler': config.get('scheduler', 'lcm'), + 'sampler': config.get('sampler', 'normal'), 'compile_engines_only': config.get('compile_engines_only', False), } if 'controlnets' in config and config['controlnets']: diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 1bca0bc2..0c372533 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -4,7 +4,7 @@ import numpy as np import PIL.Image import torch -from diffusers import LCMScheduler, StableDiffusionPipeline +from diffusers import LCMScheduler, TCDScheduler, StableDiffusionPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, @@ -36,8 +36,11 @@ def __init__( use_denoising_batch: bool = True, frame_buffer_size: int = 1, cfg_type: Literal["none", "full", "self", "initialize"] = "self", + lora_dict: Optional[Dict[str, float]] = None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + scheduler: Literal["lcm", "tcd"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", ) -> None: self.device = torch.device(device) self.dtype = torch_dtype @@ -53,6 +56,8 @@ def __init__( self.denoising_steps_num = len(t_index_list) self.cfg_type = cfg_type + self.scheduler_type = scheduler + self.sampler_type = sampler # Detect model type detection_result = detect_model(pipe.unet, pipe) @@ -61,7 +66,16 @@ def __init__( self.is_turbo = detection_result['is_turbo'] self.detection_confidence = detection_result['confidence'] - if use_denoising_batch: + # TCD scheduler is incompatible with denoising batch optimization due to Strategic Stochastic Sampling + # Force sequential processing for TCD + if scheduler == "tcd": + logger.info("TCD scheduler detected: Disabling denoising batch optimization for compatibility") + logger.info("TCD now supports ControlNet through proper hook processing") + self.use_denoising_batch = False + self.batch_size = frame_buffer_size + self.trt_unet_batch_size = frame_buffer_size + elif use_denoising_batch: + self.use_denoising_batch = True self.batch_size = self.denoising_steps_num * frame_buffer_size if self.cfg_type == "initialize": self.trt_unet_batch_size = ( @@ -74,13 +88,12 @@ def __init__( else: self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size else: + self.use_denoising_batch = False self.trt_unet_batch_size = self.frame_bff_size self.batch_size = frame_buffer_size self.t_list = t_index_list - self.do_add_noise = do_add_noise - self.use_denoising_batch = use_denoising_batch self.similar_image_filter = False self.similar_filter = SimilarImageFilter() @@ -89,8 +102,8 @@ def __init__( self.pipe = pipe self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) - - self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + self.scheduler = self._initialize_scheduler(scheduler, sampler, pipe.scheduler.config) + self.text_encoder = pipe.text_encoder self.unet = pipe.unet self.vae = pipe.vae @@ -126,7 +139,31 @@ def __init__( self._cached_batch_size: Optional[int] = None self._cached_cfg_type: Optional[str] = None self._cached_guidance_scale: Optional[float] = None + + def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): + """Initialize scheduler based on type and sampler configuration.""" + + # TODO: More testing and validation required on samplers. + # Map sampler types to configuration parameters + sampler_config = { + "simple": {"timestep_spacing": "linspace"}, + "sgm uniform": {"timestep_spacing": "trailing"}, + "normal": {}, # Default configuration + "ddim": {"timestep_spacing": "leading"}, + "beta": {"beta_schedule": "scaled_linear"}, + "karras": {}, # Karras sigmas handled per scheduler + } + + # Get sampler-specific configuration + sampler_params = sampler_config.get(sampler_type, {}) + if scheduler_type == "lcm": + return LCMScheduler.from_config(config, **sampler_params) + elif scheduler_type == "tcd": + return TCDScheduler.from_config(config, **sampler_params) + else: + logger.warning(f"Unknown scheduler type '{scheduler_type}', falling back to LCM") + return LCMScheduler.from_config(config, **sampler_params) def _check_unet_tensorrt(self) -> bool: """Cache TensorRT detection to avoid repeated hasattr calls""" @@ -191,21 +228,8 @@ def _build_sdxl_conditioning(self, batch_size: int) -> Dict[str, torch.Tensor]: 'time_ids': add_time_ids } - def load_lcm_lora( - self, - pretrained_model_name_or_path_or_dict: Union[ - str, Dict[str, torch.Tensor] - ] = "latent-consistency/lcm-lora-sdv1-5", - adapter_name: Optional[Any] = None, - **kwargs, - ) -> None: - # Check for SDXL compatibility - if self.is_sdxl: - return - - self._load_lora_with_offline_fallback( - pretrained_model_name_or_path_or_dict, adapter_name, **kwargs - ) + + def load_lora( self, @@ -445,12 +469,11 @@ def prepare( self.stock_noise = torch.zeros_like(self.init_noise) + # Handle scheduler-specific scaling calculations c_skip_list = [] c_out_list = [] for timestep in self.sub_timesteps: - c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete( - timestep - ) + c_skip, c_out = self._get_scheduler_scalings(timestep) c_skip_list.append(c_skip) c_out_list.append(c_out) @@ -495,7 +518,9 @@ def prepare( #NOTE: this is a hack. Pipeline needs a major refactor along with stream parameter updater. self.update_prompt(prompt) - if not self.use_denoising_batch: + # Only collapse tensors to scalars for LCM non-batched mode + # TCD needs to keep tensor dimensions for iteration + if not self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): self.sub_timesteps_tensor = self.sub_timesteps_tensor[0] self.alpha_prod_t_sqrt = self.alpha_prod_t_sqrt[0] self.beta_prod_t_sqrt = self.beta_prod_t_sqrt[0] @@ -504,6 +529,19 @@ def prepare( self.c_skip = self.c_skip.to(self.device) self.c_out = self.c_out.to(self.device) + def _get_scheduler_scalings(self, timestep): + """Get LCM/TCD-specific scaling factors for boundary conditions.""" + if isinstance(self.scheduler, LCMScheduler): + c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + return c_skip, c_out + else: + # TCD and other schedulers don't use boundary condition scaling like LCM + # They handle scaling internally in their step() method + # Return tensors that are compatible with torch.stack() + c_skip = torch.tensor(1.0, device=self.device, dtype=self.dtype) + c_out = torch.tensor(1.0, device=self.device, dtype=self.dtype) + return c_skip, c_out + @torch.no_grad() def update_prompt(self, prompt: str) -> None: self._param_updater.update_stream_params( @@ -525,6 +563,33 @@ def get_normalize_seed_weights(self) -> bool: + + + def set_scheduler(self, scheduler: Literal["lcm", "tcd"] = None, sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = None) -> None: + """ + Change the scheduler and/or sampler at runtime. + + Parameters + ---------- + scheduler : str, optional + The scheduler type to use ("lcm" or "tcd"). If None, keeps current scheduler. + sampler : str, optional + The sampler type to use. If None, keeps current sampler. + """ + if scheduler is not None: + self.scheduler_type = scheduler + if sampler is not None: + self.sampler_type = sampler + + self.scheduler = self._initialize_scheduler(self.scheduler_type, self.sampler_type, self.pipe.scheduler.config) + logger.info(f"Scheduler changed to {self.scheduler_type} with {self.sampler_type} sampler") + + def _uses_lcm_logic(self) -> bool: + """Return True if scheduler uses LCM-style consistency boundary-condition math.""" + return isinstance(self.scheduler, LCMScheduler) + + + def add_noise( self, original_samples: torch.Tensor, @@ -543,7 +608,6 @@ def scheduler_step_batch( x_t_latent_batch: torch.Tensor, idx: Optional[int] = None, ) -> torch.Tensor: - # TODO: use t_list to select beta_prod_t_sqrt if idx is None: F_theta = ( x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch @@ -556,7 +620,6 @@ def scheduler_step_batch( denoised_batch = ( self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch ) - return denoised_batch def unet_step( @@ -565,7 +628,6 @@ def unet_step( t_list: Union[torch.Tensor, list[int]], idx: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) t_list = torch.concat([t_list[0:1], t_list], dim=0) @@ -783,6 +845,113 @@ def unet_step( return denoised_batch, model_pred + def _call_unet( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + """Call the UNet, handling SDXL kwargs and TensorRT engine calling convention.""" + added_cond_kwargs = added_cond_kwargs or {} + if self.is_sdxl: + try: + # Detect TensorRT engine vs PyTorch UNet + is_tensorrt_engine = hasattr(self.unet, 'engine') and hasattr(self.unet, 'stream') + if is_tensorrt_engine: + out = self.unet( + sample, + timestep, + encoder_hidden_states, + **added_cond_kwargs, + )[0] + else: + out = self.unet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + except Exception as e: + logger.error(f"[PIPELINE] _call_unet: SDXL UNet call failed: {e}") + import traceback + traceback.print_exc() + raise + else: + out = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + )[0] + return out + + def _unet_predict_noise_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + cfg_mode: Literal["none", "full", "self", "initialize"], + ) -> torch.Tensor: + """ + Compute noise prediction from UNet with classifier-free guidance applied. + This function does not apply any scheduler math; it only returns the guided noise. + """ + # Build latent batch for CFG + if self.guidance_scale > 1.0 and cfg_mode == "full": + latent_with_uc = torch.cat([latent_model_input, latent_model_input], dim=0) + elif self.guidance_scale > 1.0 and cfg_mode == "initialize": + latent_with_uc = torch.cat([latent_model_input[0:1], latent_model_input], dim=0) + else: + latent_with_uc = latent_model_input + + # SDXL added conditioning replication to match batch + added_cond_kwargs: Dict[str, torch.Tensor] = {} + if self.is_sdxl and hasattr(self, 'add_text_embeds') and hasattr(self, 'add_time_ids'): + if self.add_text_embeds is not None and self.add_time_ids is not None: + batch_size = latent_with_uc.shape[0] + if self.guidance_scale > 1.0 and cfg_mode == "initialize": + add_text_embeds = torch.cat([ + self.add_text_embeds[0:1], + self.add_text_embeds[1:2].repeat(batch_size - 1, 1), + ], dim=0) + add_time_ids = torch.cat([ + self.add_time_ids[0:1], + self.add_time_ids[1:2].repeat(batch_size - 1, 1), + ], dim=0) + elif self.guidance_scale > 1.0 and cfg_mode == "full": + repeat_factor = batch_size // 2 + add_text_embeds = self.add_text_embeds.repeat(repeat_factor, 1) + add_time_ids = self.add_time_ids.repeat(repeat_factor, 1) + else: + add_text_embeds = ( + self.add_text_embeds[1:2].repeat(batch_size, 1) + if self.add_text_embeds.shape[0] > 1 + else self.add_text_embeds.repeat(batch_size, 1) + ) + add_time_ids = ( + self.add_time_ids[1:2].repeat(batch_size, 1) + if self.add_time_ids.shape[0] > 1 + else self.add_time_ids.repeat(batch_size, 1) + ) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # Call UNet + model_pred = self._call_unet( + sample=latent_with_uc, + timestep=timestep, + encoder_hidden_states=self.prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ) + + # Apply CFG + if self.guidance_scale > 1.0 and cfg_mode == "full": + noise_pred_uncond, noise_pred_text = model_pred.chunk(2) + guided = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + return guided + else: + return model_pred + def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: image_tensors = image_tensors.to( device=self.device, @@ -808,23 +977,19 @@ def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: prev_latent_batch = self.x_t_latent_buffer - - - if self.use_denoising_batch: + # LCM supports our denoising-batch trick. TCD must use standard scheduler.step() sequentially + # but now properly processes ControlNet hooks through unet_step() + if self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): t_list = self.sub_timesteps_tensor - if self.denoising_steps_num > 1: x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) - self.stock_noise = torch.cat( (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 ) - x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) if self.denoising_steps_num > 1: x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) - if self.do_add_noise: self.x_t_latent_buffer = ( self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] @@ -838,25 +1003,43 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: x_0_pred_out = x_0_pred_batch self.x_t_latent_buffer = None else: - self.init_noise = x_t_latent - for idx, t in enumerate(self.sub_timesteps_tensor): - t = t.view(1,).repeat(self.frame_bff_size,) - - x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx) - - if idx < len(self.sub_timesteps_tensor) - 1: - if self.do_add_noise: - x_t_latent = self.alpha_prod_t_sqrt[ - idx + 1 - ] * x_0_pred + self.beta_prod_t_sqrt[ - idx + 1 - ] * torch.randn_like( - x_0_pred, device=self.device, dtype=self.dtype - ) + # Standard scheduler loop for TCD and non-batched LCM + sample = x_t_latent + for idx, timestep in enumerate(self.sub_timesteps_tensor): + # Ensure timestep tensor on device with correct dtype + if not isinstance(timestep, torch.Tensor): + t = torch.tensor(timestep, device=self.device, dtype=torch.long) + else: + t = timestep.to(self.device) + + # For TCD, use the same UNet calling logic as LCM to ensure ControlNet hooks are processed + if isinstance(self.scheduler, TCDScheduler): + # Use unet_step to process ControlNet hooks and get proper noise prediction + t_expanded = t.view(1,).repeat(self.frame_bff_size,) + x_0_pred, model_pred = self.unet_step(sample, t_expanded, idx) + + # Apply TCD scheduler step to the guided noise prediction + step_out = self.scheduler.step(model_pred, t, sample) + sample = getattr(step_out, "prev_sample", step_out[0] if isinstance(step_out, (tuple, list)) else step_out) + else: + # Original LCM logic for non-batched mode + t = t.view(1,).repeat(self.frame_bff_size,) + x_0_pred, model_pred = self.unet_step(sample, t, idx) + if idx < len(self.sub_timesteps_tensor) - 1: + if self.do_add_noise: + sample = self.alpha_prod_t_sqrt[ + idx + 1 + ] * x_0_pred + self.beta_prod_t_sqrt[ + idx + 1 + ] * torch.randn_like( + x_0_pred, device=self.device, dtype=self.dtype + ) + else: + sample = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred else: - x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred - x_0_pred_out = x_0_pred + sample = x_0_pred + x_0_pred_out = sample return x_0_pred_out @torch.no_grad() diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index cc901606..8c8b055e 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -674,6 +674,20 @@ def _update_seed(self, seed: int) -> None: # Reset stock_noise to match the new init_noise self.stream.stock_noise = torch.zeros_like(self.stream.init_noise) + def _get_scheduler_scalings(self, timestep): + """Get LCM/TCD-specific scaling factors for boundary conditions.""" + from diffusers import LCMScheduler + if isinstance(self.stream.scheduler, LCMScheduler): + c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + return c_skip, c_out + else: + # TCD and other schedulers don't use boundary condition scaling like LCM + # They handle scaling internally in their step() method + # Return tensors that are compatible with torch.stack() + c_skip = torch.tensor(1.0, device=self.stream.device, dtype=self.stream.dtype) + c_out = torch.tensor(1.0, device=self.stream.device, dtype=self.stream.dtype) + return c_skip, c_out + def _update_timestep_calculations(self) -> None: """Update timestep-dependent calculations based on current t_list.""" self.stream.sub_timesteps = [] @@ -692,7 +706,7 @@ def _update_timestep_calculations(self) -> None: c_skip_list = [] c_out_list = [] for timestep in self.stream.sub_timesteps: - c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + c_skip, c_out = self._get_scheduler_scalings(timestep) c_skip_list.append(c_skip) c_out_list.append(c_out) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 6a997b12..a0b7c59c 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -75,7 +75,6 @@ def __init__( lora_dict: Optional[Dict[str, float]] = None, mode: Literal["img2img", "txt2img"] = "img2img", output_type: Literal["pil", "pt", "np", "latent"] = "pil", - lcm_lora_id: Optional[str] = None, vae_id: Optional[str] = None, device: Literal["cpu", "cuda"] = "cuda", dtype: torch.dtype = torch.float16, @@ -86,7 +85,7 @@ def __init__( acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", do_add_noise: bool = True, device_ids: Optional[List[int]] = None, - use_lcm_lora: bool = True, + use_lcm_lora: Optional[bool] = None, # DEPRECATED: Backwards compatibility parameter use_tiny_vae: bool = True, enable_similar_image_filter: bool = False, similar_image_filter_threshold: float = 0.98, @@ -101,6 +100,9 @@ def __init__( build_engines_if_missing: bool = True, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + # Scheduler and sampler options + scheduler: Literal["lcm", "tcd"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", # ControlNet options use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, @@ -124,6 +126,10 @@ def __init__( The model id or path to load. t_index_list : List[int] The t_index_list to use for inference. + min_batch_size : int, optional + The minimum batch size for inference, by default 1. + max_batch_size : int, optional + The maximum batch size for inference, by default 4. lora_dict : Optional[Dict[str, float]], optional The lora_dict to load, by default None. Keys are the LoRA names and values are the LoRA scales. @@ -132,16 +138,14 @@ def __init__( txt2img or img2img, by default "img2img". output_type : Literal["pil", "pt", "np", "latent"], optional The output type of image, by default "pil". - lcm_lora_id : Optional[str], optional - The lcm_lora_id to load, by default None. - If None, the default LCM-LoRA - ("latent-consistency/lcm-lora-sdv1-5") will be used. vae_id : Optional[str], optional The vae_id to load, by default None. If None, the default TinyVAE ("madebyollin/taesd") will be used. device : Literal["cpu", "cuda"], optional The device to use for inference, by default "cuda". + device_ids : Optional[List[int]], optional + The device ids to use for DataParallel, by default None. dtype : torch.dtype, optional The dtype for inference, by default torch.float16. frame_buffer_size : int, optional @@ -159,8 +163,11 @@ def __init__( by default True. device_ids : Optional[List[int]], optional The device ids to use for DataParallel, by default None. - use_lcm_lora : bool, optional - Whether to use LCM-LoRA or not, by default True. + use_lcm_lora : Optional[bool], optional + DEPRECATED: Use lora_dict instead. For backwards compatibility only. + If True, automatically adds appropriate LCM LoRA to lora_dict based on model type. + SDXL models get "latent-consistency/lcm-lora-sdxl", others get "latent-consistency/lcm-lora-sdv1-5". + By default None (ignored). use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. enable_similar_image_filter : bool, optional @@ -179,19 +186,42 @@ def __init__( seed : int, optional The seed, by default 2. use_safety_checker : bool, optional - Whether to use safety checker or not, by default False. Only supported for TensorRT acceleration. + Whether to use safety checker or not, by default False. + skip_diffusion : bool, optional + Whether to skip diffusion and apply only preprocessing/postprocessing hooks, by default False. + engine_dir : Optional[Union[str, Path]], optional + Directory path for storing/loading TensorRT engines, by default "engines". + build_engines_if_missing : bool, optional + Whether to build TensorRT engines if they don't exist, by default True. normalize_prompt_weights : bool, optional Whether to normalize prompt weights in blending to sum to 1, by default True. When False, weights > 1 will amplify embeddings. normalize_seed_weights : bool, optional Whether to normalize seed weights in blending to sum to 1, by default True. When False, weights > 1 will amplify noise. + scheduler : Literal["lcm", "tcd"], optional + The scheduler type to use for denoising, by default "lcm". + sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional + The sampler type to use for noise scheduling, by default "normal". use_controlnet : bool, optional Whether to enable ControlNet support, by default False. controlnet_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional ControlNet configuration(s), by default None. Can be a single config dict or list of config dicts for multiple ControlNets. Each config should contain: model_id, preprocessor (optional), conditioning_scale, etc. + use_ipadapter : bool, optional + Whether to enable IPAdapter support, by default False. + ipadapter_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional + IPAdapter configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple IPAdapters. + image_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image preprocessing hooks, by default None. + image_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image postprocessing hooks, by default None. + latent_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent preprocessing hooks, by default None. + latent_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent postprocessing hooks, by default None. safety_checker_fallback_type : Literal["blank", "previous"], optional Whether to use a blank image or the previous image as a fallback, by default "previous". safety_checker_threshold: float, optional @@ -201,7 +231,10 @@ def __init__( """ if compile_engines_only: logger.info("compile_engines_only is True, will only compile engines and not load the model") - + + # Store use_lcm_lora for backwards compatibility processing in _load_model + self.use_lcm_lora = use_lcm_lora + self.sd_turbo = "turbo" in model_id_or_path self.use_controlnet = use_controlnet self.use_ipadapter = use_ipadapter @@ -255,18 +288,19 @@ def __init__( self.stream: StreamDiffusion = self._load_model( model_id_or_path=model_id_or_path, lora_dict=lora_dict, - lcm_lora_id=lcm_lora_id, vae_id=vae_id, t_index_list=t_index_list, acceleration=acceleration, do_add_noise=do_add_noise, - use_lcm_lora=use_lcm_lora, + use_lcm_lora=use_lcm_lora, # Deprecated:Backwards compatibility use_tiny_vae=use_tiny_vae, cfg_type=cfg_type, engine_dir=engine_dir, build_engines_if_missing=build_engines_if_missing, normalize_prompt_weights=normalize_prompt_weights, normalize_seed_weights=normalize_seed_weights, + scheduler=scheduler, + sampler=sampler, use_controlnet=use_controlnet, controlnet_config=controlnet_config, use_ipadapter=use_ipadapter, @@ -802,33 +836,57 @@ def postprocess_image( def _denormalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: """ - Denormalize image tensor on GPU for efficiency - + Denormalize image tensor on GPU for efficiency. - Args: - image_tensor: Input tensor on GPU + Converts image tensor from diffusion range [-1, 1] to standard image range [0, 1]. + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in diffusion range [-1, 1], expected to be on GPU. - Returns: - Denormalized tensor on GPU, clamped to [0,1] + Returns + ------- + torch.Tensor + Denormalized tensor in range [0, 1], clamped and on GPU. """ return (image_tensor / 2 + 0.5).clamp(0, 1) def _normalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: - """Convert tensor from [0,1] (processor range) back to [-1,1] (diffusion range)""" + """ + Normalize tensor from processor range to diffusion range. + + Converts image tensor from standard image range [0, 1] to diffusion range [-1, 1]. + + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in standard image range [0, 1], expected to be on GPU. + + Returns + ------- + torch.Tensor + Normalized tensor in diffusion range [-1, 1], clamped and on GPU. + """ return (image_tensor * 2 - 1).clamp(-1, 1) def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Image]: """ - Optimized tensor to PIL conversion with minimal CPU transfers + Optimized tensor to PIL conversion with minimal CPU transfers. + Efficiently converts a batch of GPU tensors to PIL Images with minimal + CPU-GPU transfers and memory allocations. - Args: - image_tensor: Input tensor on GPU - + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in diffusion range [-1, 1], expected to be on GPU. + Shape should be (batch_size, channels, height, width). - Returns: - List of PIL Images + Returns + ------- + List[Image.Image] + List of PIL RGB images, one for each item in the batch. """ # Denormalize on GPU first denormalized = self._denormalize_on_gpu(image_tensor) @@ -866,6 +924,23 @@ def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Ima return pil_images def set_nsfw_fallback_img(self, height: int, width: int) -> None: + """ + Set the NSFW fallback image used when safety checker blocks content. + + Creates a black RGB image of the specified dimensions that will be returned + when the safety checker determines content should be blocked. + + Parameters + ---------- + height : int + Height of the fallback image in pixels. + width : int + Width of the fallback image in pixels. + + Returns + ------- + None + """ self.nsfw_fallback_img = Image.new("RGB", (height, width), (0, 0, 0)) if self.output_type == "pt": self.nsfw_fallback_img = torch.from_numpy(np.array(self.nsfw_fallback_img)).unsqueeze(0) @@ -877,7 +952,6 @@ def _load_model( model_id_or_path: str, t_index_list: List[int], lora_dict: Optional[Dict[str, float]] = None, - lcm_lora_id: Optional[str] = None, vae_id: Optional[str] = None, acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", do_add_noise: bool = True, @@ -888,6 +962,8 @@ def _load_model( build_engines_if_missing: bool = True, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + scheduler: Literal["lcm", "tcd"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, use_ipadapter: bool = False, @@ -906,7 +982,7 @@ def _load_model( This method does the following: 1. Loads the model from the model_id_or_path. - 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed. + 2. Loads and fuses LoRA models from lora_dict if provided. 3. Loads the VAE model from the vae_id if needed. 4. Enables acceleration if needed. 5. Prepares the model for inference. @@ -916,42 +992,72 @@ def _load_model( Parameters ---------- model_id_or_path : str - The model id or path to load. + The model id or path to load. Can be a Hugging Face model ID, local path to + safetensors/ckpt file, or directory containing model files. t_index_list : List[int] - The t_index_list to use for inference. + The t_index_list to use for inference. Specifies which denoising timesteps + to use from the diffusion schedule. lora_dict : Optional[Dict[str, float]], optional The lora_dict to load, by default None. Keys are the LoRA names and values are the LoRA scales. Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} - lcm_lora_id : Optional[str], optional - The lcm_lora_id to load, by default None. + Use this to load LCM LoRA: {'latent-consistency/lcm-lora-sdv1-5': 1.0} vae_id : Optional[str], optional - The vae_id to load, by default None. - acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional - The acceleration method, by default "tensorrt". - warmup : int, optional - The number of warmup steps to perform, by default 10. + The vae_id to load, by default None. If None, uses default TinyVAE + ("madebyollin/taesd" for SD1.5, "madebyollin/taesdxl" for SDXL). + acceleration : Literal["none", "xformers", "tensorrt"], optional + The acceleration method, by default "tensorrt". Note: docstring shows + "xfomers" and "sfast" but code uses "xformers". do_add_noise : bool, optional Whether to add noise for following denoising steps or not, by default True. use_lcm_lora : bool, optional - Whether to use LCM-LoRA or not, by default True. + DEPRECATED: Use lora_dict instead. For backwards compatibility only. + If True, automatically adds appropriate LCM LoRA to lora_dict based on model type. + SDXL models get "latent-consistency/lcm-lora-sdxl", others get "latent-consistency/lcm-lora-sdv1-5". + By default None (ignored). use_tiny_vae : bool, optional - Whether to use TinyVAE or not, by default True. - cfg_type : Literal["none", "full", "self", "initialize"], - optional + Whether to use TinyVAE or not, by default True. TinyVAE is a distilled, + smaller VAE model that provides faster encoding/decoding with minimal quality loss. + cfg_type : Literal["none", "full", "self", "initialize"], optional The cfg_type for img2img mode, by default "self". You cannot use anything other than "none" for txt2img mode. - seed : int, optional - The seed, by default 2. + engine_dir : Optional[Union[str, Path]], optional + Directory path for storing/loading TensorRT engines, by default "engines". + build_engines_if_missing : bool, optional + Whether to build TensorRT engines if they don't exist, by default True. + normalize_prompt_weights : bool, optional + Whether to normalize prompt weights in blending to sum to 1, by default True. + When False, weights > 1 will amplify embeddings. + normalize_seed_weights : bool, optional + Whether to normalize seed weights in blending to sum to 1, by default True. + When False, weights > 1 will amplify noise. + scheduler : Literal["lcm", "tcd"], optional + The scheduler type to use for denoising, by default "lcm". + sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional + The sampler type to use for noise scheduling, by default "normal". use_controlnet : bool, optional - Whether to apply ControlNet patch, by default False. + Whether to enable ControlNet support, by default False. controlnet_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional - ControlNet configuration(s), by default None. + ControlNet configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple ControlNets. use_ipadapter : bool, optional - Whether to apply IPAdapter patch, by default False. + Whether to enable IPAdapter support, by default False. ipadapter_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional - IPAdapter configuration(s), by default None. + IPAdapter configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple IPAdapters. + image_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image preprocessing hooks, by default None. + image_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image postprocessing hooks, by default None. + latent_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent preprocessing hooks, by default None. + latent_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent postprocessing hooks, by default None. + safety_checker_model_id : Optional[str], optional + Model ID for the safety checker, by default "Freepik/nsfw_image_detector". + compile_engines_only : bool, optional + Whether to only compile engines and not load the model, by default False. Returns ------- @@ -1045,6 +1151,11 @@ def _load_model( pipe.text_encoder = pipe.text_encoder.to(device=self.device) if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None: pipe.text_encoder_2 = pipe.text_encoder_2.to(device=self.device) + # Move main pipeline components to device, but skip UNet for TensorRT + if hasattr(pipe, "unet") and pipe.unet is not None and acceleration != "tensorrt": + pipe.unet = pipe.unet.to(device=self.device) + if hasattr(pipe, "vae") and pipe.vae is not None and acceleration != "tensorrt": + pipe.vae = pipe.vae.to(device=self.device) # If we get here, the model loaded successfully - break out of retry loop logger.info(f"Model loading succeeded") @@ -1063,6 +1174,26 @@ def _load_model( self._is_sdxl = is_sdxl logger.info(f"_load_model: Detected model type: {model_type} (confidence: {confidence:.2f})") + + # DEPRECATED: THIS WILL LOAD LCM_LORA IF USE_LCM_LORA IS TRUE + # Validate backwards compatibility LCM LoRA selection using proper model detection + if hasattr(self, 'use_lcm_lora') and self.use_lcm_lora is not None: + if self.use_lcm_lora and not self.sd_turbo and lora_dict is not None: + # Determine correct LCM LoRA based on actual model detection + lcm_lora = "latent-consistency/lcm-lora-sdxl" if is_sdxl else "latent-consistency/lcm-lora-sdv1-5" + + # Add to lora_dict if not already present + if lcm_lora not in lora_dict: + lora_dict[lcm_lora] = 1.0 + logger.info(f"Added {lcm_lora} with scale 1.0 to lora_dict") + else: + logger.info(f"LCM LoRA {lcm_lora} already present in lora_dict with scale {lora_dict[lcm_lora]}") + else: + logger.info(f"LCM LoRA will not be loaded because use_lcm_lora is {self.use_lcm_lora} and sd_turbo is {self.sd_turbo}") + + # Remove use_lcm_lora from self + self.use_lcm_lora = None + logger.info(f"use_lcm_lora has been removed from self") stream = StreamDiffusion( pipe=pipe, @@ -1075,32 +1206,67 @@ def _load_model( frame_buffer_size=self.frame_buffer_size, use_denoising_batch=self.use_denoising_batch, cfg_type=cfg_type, + lora_dict=lora_dict, # We pass this to include loras in engine path names normalize_prompt_weights=normalize_prompt_weights, normalize_seed_weights=normalize_seed_weights, + scheduler=scheduler, + sampler=sampler, ) - if not self.sd_turbo: - if use_lcm_lora: - if lcm_lora_id is not None: - stream.load_lcm_lora( - pretrained_model_name_or_path_or_dict=lcm_lora_id - ) - else: - stream.load_lcm_lora() - stream.fuse_lora() - if lora_dict is not None: - for lora_name, lora_scale in lora_dict.items(): - stream.load_lora(lora_name) - stream.fuse_lora(lora_scale=lora_scale) + + # Load and properly merge LoRA weights using the standard diffusers approach + lora_adapters_to_merge = [] + lora_scales_to_merge = [] + + # Collect all LoRA adapters and their scales from lora_dict + if lora_dict is not None: + for i, (lora_name, lora_scale) in enumerate(lora_dict.items()): + adapter_name = f"custom_lora_{i}" + logger.info(f"_load_model: Loading LoRA '{lora_name}' with scale {lora_scale}") + + try: + # Load LoRA weights with unique adapter name + stream.pipe.load_lora_weights(lora_name, adapter_name=adapter_name) + lora_adapters_to_merge.append(adapter_name) + lora_scales_to_merge.append(lora_scale) + logger.info(f"Successfully loaded LoRA adapter: {adapter_name}") + except Exception as e: + logger.error(f"Failed to load LoRA {lora_name}: {e}") + # Continue with other LoRAs even if one fails + continue + + # Merge all LoRA adapters using the proper diffusers method + if lora_adapters_to_merge: + try: + for adapter_name, scale in zip(lora_adapters_to_merge, lora_scales_to_merge): + logger.info(f"Merging individual LoRA: {adapter_name} with scale {scale}") + stream.pipe.fuse_lora(lora_scale=scale, adapter_names=[adapter_name]) + + # Clean up after individual merging + stream.pipe.unload_lora_weights() + logger.info("Successfully merged LoRAs individually") + + except Exception as fallback_error: + logger.error(f"LoRA merging fallback also failed: {fallback_error}") + logger.warning("Continuing without LoRA merging - LoRAs may not be applied correctly") + + # Clean up any partial state + try: + stream.pipe.unload_lora_weights() + except: + pass if use_tiny_vae: if vae_id is not None: - stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(dtype=pipe.dtype) + stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(dtype=pipe.dtype, device=self.device) else: # Use TAESD XL for SDXL models, regular TAESD for SD 1.5 taesd_model = "madebyollin/taesdxl" if is_sdxl else "madebyollin/taesd" - stream.vae = AutoencoderTiny.from_pretrained(taesd_model).to(dtype=pipe.dtype) - + stream.vae = AutoencoderTiny.from_pretrained(taesd_model).to(dtype=pipe.dtype, device=self.device) + elif acceleration != "tensorrt": + # For non-TensorRT acceleration, ensure VAE is on device if it wasn't moved earlier + if hasattr(pipe, "vae") and pipe.vae is not None: + pipe.vae = pipe.vae.to(device=self.device) try: if acceleration == "xformers": @@ -1229,8 +1395,8 @@ def _load_model( max_batch_size=self.max_batch_size, min_batch_size=self.min_batch_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None @@ -1241,8 +1407,8 @@ def _load_model( max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None @@ -1253,8 +1419,8 @@ def _load_model( max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None @@ -1306,10 +1472,15 @@ def _load_model( except Exception: pass - # If using TensorRT with IP-Adapter, ensure processors and weights are installed BEFORE export - if use_ipadapter_trt and has_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): + # Note: LoRA weights have already been merged permanently during model loading + + # CRITICAL: Install IPAdapter module BEFORE TensorRT compilation to ensure processors are baked into engines + if use_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): try: from streamdiffusion.modules.ipadapter_module import IPAdapterModule, IPAdapterConfig, IPAdapterType + logger.info("Installing IPAdapter module before TensorRT compilation...") + + # Use first config if list provided cfg = ipadapter_config[0] if isinstance(ipadapter_config, list) else ipadapter_config ip_cfg = IPAdapterConfig( style_image_key=cfg.get('style_image_key') or 'ipadapter_main', @@ -1321,17 +1492,28 @@ def _load_model( type=IPAdapterType(cfg.get('type', "regular")), insightface_model_name=cfg.get('insightface_model_name'), ) - ip_module_for_export = IPAdapterModule(ip_cfg) - ip_module_for_export.install(stream) - setattr(stream, '_ipadapter_module', ip_module_for_export) - try: - logger.info("Installed IP-Adapter processors prior to TensorRT export") - except Exception: - pass + ip_module = IPAdapterModule(ip_cfg) + ip_module.install(stream) + # Expose for later updates + stream._ipadapter_module = ip_module + logger.info("IPAdapter module installed successfully before TensorRT compilation") + + # Cleanup after IPAdapter installation + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + except torch.cuda.OutOfMemoryError as oom_error: + logger.error(f"CUDA Out of Memory during early IPAdapter installation: {oom_error}") + logger.error("Try reducing batch size, using smaller models, or increasing GPU memory") + raise RuntimeError("Insufficient VRAM for IPAdapter installation. Consider using a GPU with more memory or reducing model complexity.") + except Exception: import traceback traceback.print_exc() - logger.error("Failed to pre-install IP-Adapter prior to TensorRT export") + logger.error("Failed to install IPAdapterModule before TensorRT compilation") + raise # NOTE: When IPAdapter is enabled, we must pass num_ip_layers. We cannot know it until after # installing processors in the export wrapper. We construct the wrapper first to discover it, @@ -1555,13 +1737,13 @@ def _load_model( logger.error(f"TensorRT VAE engine loading failed (non-OOM): {e}") raise e + # Safety checker engine (TensorRT-specific) safety_checker_path = engine_manager.get_engine_path( EngineType.SAFETY_CHECKER, model_id_or_path=safety_checker_model_id, max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, ) safety_checker_engine_exists = os.path.exists(safety_checker_path) @@ -1594,7 +1776,7 @@ def _load_model( cuda_stream, use_cuda_graph=True, ) - + if acceleration == "sfast": from streamdiffusion.acceleration.sfast import ( accelerate_with_stable_fast, @@ -1678,6 +1860,8 @@ def _load_model( logger.error("Failed to install ControlNetModule") raise + # IPAdapter module installation has been moved to before TensorRT compilation (see lines 1307-1345) + # This ensures processors are properly baked into the TensorRT engines if use_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): try: from streamdiffusion.modules.ipadapter_module import IPAdapterModule, IPAdapterConfig, IPAdapterType @@ -1707,6 +1891,8 @@ def _load_model( logger.error("Failed to install IPAdapterModule") raise + # Note: LoRA weights have already been merged permanently during model loading + # Install pipeline hook modules (Phase 4: Configuration Integration) if image_preprocessing_config and image_preprocessing_config.get('enabled', True): try: @@ -1778,7 +1964,6 @@ def update_control_image(self, index: int, image: Union[str, Image.Image, torch. else: logger.debug("update_control_image: Skipping ControlNet update in skip diffusion mode") - def update_style_image(self, image: Union[str, Image.Image, torch.Tensor], is_stream: bool = False, style_key = "ipadapter_main") -> None: """Update IPAdapter style image""" if not self.use_ipadapter: @@ -1791,7 +1976,6 @@ def update_style_image(self, image: Union[str, Image.Image, torch.Tensor], is_st - def clear_caches(self) -> None: """Clear all cached prompt embeddings and seed noise tensors.""" self.stream._param_updater.clear_caches() @@ -2114,8 +2298,3 @@ def cleanup_engines_and_rebuild(self, reduce_batch_size: bool = True, reduce_res logger.info(f" Reduced resolution: {old_width}x{old_height} -> {self.width}x{self.height}") logger.info(" Next model load will rebuild engines with these smaller settings") - - - - -