From e98c0996e05954bf9e9f0514ee9d810fe780e8ec Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Thu, 7 Aug 2025 21:12:25 -0400 Subject: [PATCH 01/15] Testing scheduling and sampling. --- src/streamdiffusion/config.py | 2 + src/streamdiffusion/pipeline.py | 438 ++++++++++++++++-- .../stream_parameter_updater.py | 3 +- src/streamdiffusion/wrapper.py | 32 ++ 4 files changed, 430 insertions(+), 45 deletions(-) diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 74f12931..19dfeba0 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -128,6 +128,8 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'normalize_prompt_weights': config.get('normalize_prompt_weights', True), 'normalize_seed_weights': config.get('normalize_seed_weights', True), 'enable_pytorch_fallback': config.get('enable_pytorch_fallback', False), + 'scheduler': config.get('scheduler', 'lcm'), + 'sampler': config.get('sampler', 'normal'), } if 'controlnets' in config and config['controlnets']: param_map['use_controlnet'] = True diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index c605f51f..99899818 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -4,7 +4,15 @@ import numpy as np import PIL.Image import torch -from diffusers import LCMScheduler, StableDiffusionPipeline +from diffusers import ( + LCMScheduler, + StableDiffusionPipeline, + DPMSolverMultistepScheduler, + UniPCMultistepScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + TCDScheduler, +) from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, @@ -33,6 +41,8 @@ def __init__( cfg_type: Literal["none", "full", "self", "initialize"] = "self", normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + scheduler: Literal["lcm", "tcd", "dpm++ 2m", "uni_pc", "ddim", "euler"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", ) -> None: self.device = pipe.device self.dtype = torch_dtype @@ -48,6 +58,8 @@ def __init__( self.denoising_steps_num = len(t_index_list) self.cfg_type = cfg_type + self.scheduler_type = scheduler + self.sampler_type = sampler # Detect model type detection_result = detect_model(pipe.unet, pipe) @@ -84,7 +96,9 @@ def __init__( self.pipe = pipe self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) - self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + # Initialize scheduler based on configuration + self.scheduler = self._initialize_scheduler(scheduler, sampler, pipe.scheduler.config) + self.text_encoder = pipe.text_encoder self.unet = pipe.unet self.vae = pipe.vae @@ -95,10 +109,75 @@ def __init__( if self.is_sdxl: self.add_text_embeds = None self.add_time_ids = None + logger.log(logging.INFO, f"[PIPELINE] SDXL Detected: Using {scheduler} scheduler with {sampler} sampler") # Initialize parameter updater self._param_updater = StreamParameterUpdater(self, normalize_prompt_weights, normalize_seed_weights) + def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): + """Initialize scheduler based on type and sampler configuration.""" + # Map sampler types to configuration parameters + sampler_config = { + "simple": {"timestep_spacing": "linspace"}, + "sgm uniform": {"timestep_spacing": "trailing"}, # SGM Uniform is typically trailing + "normal": {}, # Default configuration + "ddim": {"timestep_spacing": "leading"}, # DDIM default per documentation + "beta": {"beta_schedule": "scaled_linear"}, + "karras": {}, # Karras sigmas will be enabled in scheduler-specific code + } + + # Get sampler-specific configuration + sampler_params = sampler_config.get(sampler_type, {}) + + print(f"Sampler params: {sampler_params}") + print(f"Scheduler type: {scheduler_type}") + + # Create scheduler based on type + if scheduler_type == "lcm": + return LCMScheduler.from_config(config, **sampler_params) + elif scheduler_type == "tcd": + return TCDScheduler.from_config(config, **sampler_params) + elif scheduler_type == "dpm++ 2m": + # DPM++ 2M typically uses solver_order=2 and algorithm_type="dpmsolver++" + return DPMSolverMultistepScheduler.from_config( + config, + solver_order=2, + algorithm_type="dpmsolver++", + use_karras_sigmas=(sampler_type == "karras"), # Enable Karras sigmas if requested + **sampler_params + ) + elif scheduler_type == "uni_pc": + # UniPC: solver_order=2 for guided sampling, solver_type="bh2" by default + return UniPCMultistepScheduler.from_config( + config, + solver_order=2, # Good default for guided sampling + solver_type="bh2", # Default from documentation + disable_corrector=[], # No corrector disabled by default + use_karras_sigmas=(sampler_type == "karras"), # Enable Karras sigmas if requested + **sampler_params + ) + elif scheduler_type == "ddim": + # DDIM defaults to leading timestep spacing, but trailing can be better + return DDIMScheduler.from_config( + config, + set_alpha_to_one=True, # Default per documentation + steps_offset=0, # Default per documentation + prediction_type="epsilon", # Default per documentation + **sampler_params + ) + elif scheduler_type == "euler": + # Euler can use Karras sigmas for improved quality + return EulerDiscreteScheduler.from_config( + config, + use_karras_sigmas=(sampler_type == "karras"), # Enable Karras sigmas if requested + prediction_type="epsilon", # Default per documentation + **sampler_params + ) + else: + # Default to LCM + logger.warning(f"Unknown scheduler type '{scheduler_type}', falling back to LCM") + return LCMScheduler.from_config(config, **sampler_params) + def load_lcm_lora( self, pretrained_model_name_or_path_or_dict: Union[ @@ -273,12 +352,42 @@ def prepare( # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list self.sub_timesteps = [] + max_timestep_index = len(self.timesteps) - 1 + for t in self.t_list: - self.sub_timesteps.append(self.timesteps[t]) - - sub_timesteps_tensor = torch.tensor( - self.sub_timesteps, dtype=torch.long, device=self.device - ) + # Clamp t_index to valid range to prevent index out of bounds + if t > max_timestep_index: + logger.warning(f"t_index {t} is out of bounds for scheduler with {len(self.timesteps)} timesteps. Clamping to {max_timestep_index}") + t = max_timestep_index + elif t < 0: + logger.warning(f"t_index {t} is negative. Clamping to 0") + t = 0 + + timestep_value = self.timesteps[t] + # Convert tensor timesteps to scalar values for indexing operations + if isinstance(timestep_value, torch.Tensor): + timestep_scalar = timestep_value.cpu().item() + else: + timestep_scalar = timestep_value + self.sub_timesteps.append(timestep_scalar) + + # Create tensor version for UNet calls + # Handle both integer and floating-point timesteps from different schedulers + # Some schedulers like Euler may return floating-point timesteps + if len(self.sub_timesteps) > 0: + # Always create the tensor from scalar values to avoid device issues + try: + # Try integer first for compatibility + sub_timesteps_tensor = torch.tensor( + self.sub_timesteps, dtype=torch.long, device=self.device + ) + except (TypeError, ValueError): + # Fallback for floating-point values + sub_timesteps_tensor = torch.tensor( + self.sub_timesteps, dtype=torch.float32, device=self.device + ) + else: + sub_timesteps_tensor = torch.tensor([], dtype=torch.long, device=self.device) self.sub_timesteps_tensor = torch.repeat_interleave( sub_timesteps_tensor, repeats=self.frame_bff_size if self.use_denoising_batch else 1, @@ -292,12 +401,11 @@ def prepare( self.stock_noise = torch.zeros_like(self.init_noise) + # Handle scheduler-specific scaling calculations c_skip_list = [] c_out_list = [] for timestep in self.sub_timesteps: - c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete( - timestep - ) + c_skip, c_out = self._get_scheduler_scalings(timestep) c_skip_list.append(c_skip) c_out_list.append(c_out) @@ -315,8 +423,25 @@ def prepare( alpha_prod_t_sqrt_list = [] beta_prod_t_sqrt_list = [] for timestep in self.sub_timesteps: - alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() - beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() + # Convert floating-point timesteps to integers for tensor indexing + if isinstance(timestep, float): + timestep_idx = int(round(timestep)) + else: + timestep_idx = timestep + + # Ensure timestep_idx is within bounds + max_idx = len(self.scheduler.alphas_cumprod) - 1 + if timestep_idx > max_idx: + logger.warning(f"Timestep index {timestep_idx} out of bounds for alphas_cumprod (max: {max_idx}). Clamping to {max_idx}") + timestep_idx = max_idx + elif timestep_idx < 0: + logger.warning(f"Timestep index {timestep_idx} is negative. Clamping to 0") + timestep_idx = 0 + + # Access scheduler tensors and move to device as needed + alpha_cumprod = self.scheduler.alphas_cumprod[timestep_idx].to(device=self.device, dtype=self.dtype) + alpha_prod_t_sqrt = alpha_cumprod.sqrt() + beta_prod_t_sqrt = (1 - alpha_cumprod).sqrt() alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) alpha_prod_t_sqrt = ( @@ -342,7 +467,8 @@ def prepare( #NOTE: this is a hack. Pipeline needs a major refactor along with stream parameter updater. self.update_prompt(prompt) - if not self.use_denoising_batch: + # Only collapse tensors to a single element for non-batched LCM path. + if (not self.use_denoising_batch) and self._uses_lcm_logic(): self.sub_timesteps_tensor = self.sub_timesteps_tensor[0] self.alpha_prod_t_sqrt = self.alpha_prod_t_sqrt[0] self.beta_prod_t_sqrt = self.beta_prod_t_sqrt[0] @@ -351,6 +477,31 @@ def prepare( self.c_skip = self.c_skip.to(self.device) self.c_out = self.c_out.to(self.device) + def _get_scheduler_scalings(self, timestep): + """ + Get LCM-specific scaling factors for boundary conditions. + Only used for LCMScheduler - other schedulers handle scaling in their step() method. + """ + if isinstance(self.scheduler, LCMScheduler): + c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + # Ensure returned values are tensors on the correct device + if not isinstance(c_skip, torch.Tensor): + c_skip = torch.tensor(c_skip, device=self.device, dtype=self.dtype) + else: + c_skip = c_skip.to(device=self.device, dtype=self.dtype) + if not isinstance(c_out, torch.Tensor): + c_out = torch.tensor(c_out, device=self.device, dtype=self.dtype) + else: + c_out = c_out.to(device=self.device, dtype=self.dtype) + return c_skip, c_out + else: + # For non-LCM schedulers, we don't use boundary condition scaling + # Their step() method handles all the necessary scaling internally + logger.debug(f"Scheduler {type(self.scheduler)} doesn't use boundary condition scaling") + c_skip = torch.tensor(1.0, device=self.device, dtype=self.dtype) + c_out = torch.tensor(1.0, device=self.device, dtype=self.dtype) + return c_skip, c_out + @torch.no_grad() def update_prompt(self, prompt: str) -> None: self._param_updater.update_stream_params( @@ -433,7 +584,55 @@ def get_normalize_seed_weights(self) -> bool: """Get the current seed weight normalization setting.""" return self._param_updater.get_normalize_seed_weights() + def set_scheduler( + self, + scheduler: Literal["lcm", "tcd", "dpm++ 2m", "uni_pc", "ddim", "euler"] = None, + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = None, + ) -> None: + """ + Change the scheduler and/or sampler configuration at runtime. + + Parameters + ---------- + scheduler : str, optional + The scheduler type to use. If None, keeps current scheduler. + sampler : str, optional + The sampler type to use. If None, keeps current sampler. + """ + if scheduler is not None: + self.scheduler_type = scheduler + if sampler is not None: + self.sampler_type = sampler + + # Re-initialize scheduler with new configuration + self.scheduler = self._initialize_scheduler( + self.scheduler_type, + self.sampler_type, + self.pipe.scheduler.config + ) + + logger.info(f"Scheduler changed to {self.scheduler_type} with {self.sampler_type} sampler") + + + def _uses_lcm_logic(self) -> bool: + """Return True if scheduler uses consistency boundary-condition math (LCM/TCD).""" + try: + # Use isinstance checks for more reliable detection + return isinstance(self.scheduler, LCMScheduler) + except Exception: + return False + + def _warned_cfg_mode_fallback(self) -> bool: + return getattr(self, "_cfg_mode_warning_emitted", False) + + def _emit_cfg_mode_warning_once(self) -> None: + if not self._warned_cfg_mode_fallback(): + logger.warning( + "Non-LCM scheduler in use: falling back to standard CFG ('full') semantics. " + "Custom cfg_type values 'self'/'initialize' are ignored for correctness." + ) + setattr(self, "_cfg_mode_warning_emitted", True) def add_noise( self, @@ -453,19 +652,36 @@ def scheduler_step_batch( x_t_latent_batch: torch.Tensor, idx: Optional[int] = None, ) -> torch.Tensor: - # TODO: use t_list to select beta_prod_t_sqrt - if idx is None: - F_theta = ( - x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch - ) / self.alpha_prod_t_sqrt - denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch + """ + Simplified scheduler integration that works with StreamDiffusion's architecture. + For now, we'll use a hybrid approach until we can properly refactor the pipeline. + """ + # For LCM, use boundary condition scaling as before + if self._uses_lcm_logic(): + if idx is None: + F_theta = ( + x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch + ) / self.alpha_prod_t_sqrt + denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch + else: + F_theta = ( + x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch + ) / self.alpha_prod_t_sqrt[idx] + denoised_batch = ( + self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch + ) else: - F_theta = ( - x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch - ) / self.alpha_prod_t_sqrt[idx] - denoised_batch = ( - self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch - ) + # For other schedulers, use simple epsilon denoising + # This is what works reliably with StreamDiffusion's current architecture + if idx is not None and idx < len(self.alpha_prod_t_sqrt): + denoised_batch = ( + x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch + ) / self.alpha_prod_t_sqrt[idx] + else: + # Fallback to first timestep if idx is out of bounds + denoised_batch = ( + x_t_latent_batch - self.beta_prod_t_sqrt[0] * model_pred_batch + ) / self.alpha_prod_t_sqrt[0] return denoised_batch @@ -475,6 +691,11 @@ def unet_step( t_list: Union[torch.Tensor, list[int]], idx: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Original StreamDiffusion UNet call that returns a denoised latent batch using + LCM math or a simplified epsilon inversion. For non-LCM schedulers we will + prefer the scheduler.step() path elsewhere; this function is kept for LCM. + """ if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) t_list = torch.concat([t_list[0:1], t_list], dim=0) @@ -635,6 +856,122 @@ def unet_step( return denoised_batch, model_pred + def _call_unet( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + """Call the UNet, handling SDXL kwargs and TensorRT engine calling convention.""" + added_cond_kwargs = added_cond_kwargs or {} + if self.is_sdxl: + try: + # Detect TensorRT engine vs PyTorch UNet + is_tensorrt_engine = hasattr(self.unet, 'engine') and hasattr(self.unet, 'stream') + if is_tensorrt_engine: + out = self.unet( + sample, + timestep, + encoder_hidden_states, + **added_cond_kwargs, + )[0] + else: + out = self.unet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + except Exception as e: + logger.error(f"[PIPELINE] _call_unet: SDXL UNet call failed: {e}") + import traceback + traceback.print_exc() + raise + else: + out = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + )[0] + return out + + def _unet_predict_noise_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + cfg_mode: Literal["none", "full", "self", "initialize"], + ) -> torch.Tensor: + """ + Compute noise prediction from UNet with classifier-free guidance applied. + This function does not apply any scheduler math; it only returns the guided noise. + + For non-LCM schedulers, custom cfg_mode values 'self'/'initialize' are treated + as 'full' to ensure correctness with scheduler.step(). + """ + effective_cfg = cfg_mode + if not self._uses_lcm_logic() and cfg_mode in ("self", "initialize"): + self._emit_cfg_mode_warning_once() + effective_cfg = "full" + + # Build latent batch for CFG + if self.guidance_scale > 1.0 and effective_cfg == "full": + latent_with_uc = torch.cat([latent_model_input, latent_model_input], dim=0) + elif self.guidance_scale > 1.0 and effective_cfg == "initialize": + # Keep initialize behavior for LCM only; if we reach here, LCM path + latent_with_uc = torch.cat([latent_model_input[0:1], latent_model_input], dim=0) + else: + latent_with_uc = latent_model_input + + # SDXL added conditioning replication to match batch + added_cond_kwargs: Dict[str, torch.Tensor] = {} + if self.is_sdxl and hasattr(self, 'add_text_embeds') and hasattr(self, 'add_time_ids'): + if self.add_text_embeds is not None and self.add_time_ids is not None: + batch_size = latent_with_uc.shape[0] + if self.guidance_scale > 1.0 and effective_cfg == "initialize": + add_text_embeds = torch.cat([ + self.add_text_embeds[0:1], + self.add_text_embeds[1:2].repeat(batch_size - 1, 1), + ], dim=0) + add_time_ids = torch.cat([ + self.add_time_ids[0:1], + self.add_time_ids[1:2].repeat(batch_size - 1, 1), + ], dim=0) + elif self.guidance_scale > 1.0 and effective_cfg == "full": + repeat_factor = batch_size // 2 + add_text_embeds = self.add_text_embeds.repeat(repeat_factor, 1) + add_time_ids = self.add_time_ids.repeat(repeat_factor, 1) + else: + add_text_embeds = ( + self.add_text_embeds[1:2].repeat(batch_size, 1) + if self.add_text_embeds.shape[0] > 1 + else self.add_text_embeds.repeat(batch_size, 1) + ) + add_time_ids = ( + self.add_time_ids[1:2].repeat(batch_size, 1) + if self.add_time_ids.shape[0] > 1 + else self.add_time_ids.repeat(batch_size, 1) + ) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # Call UNet + model_pred = self._call_unet( + sample=latent_with_uc, + timestep=timestep, + encoder_hidden_states=self.prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ) + + # Apply CFG + if self.guidance_scale > 1.0 and effective_cfg == "full": + noise_pred_uncond, noise_pred_text = model_pred.chunk(2) + guided = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + return guided + else: + return model_pred + def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: image_tensors = image_tensors.to( device=self.device, @@ -669,7 +1006,8 @@ def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: prev_latent_batch = self.x_t_latent_buffer - if self.use_denoising_batch: + # LCM supports our denoising-batch trick. Other schedulers should use step() sequentially + if self.use_denoising_batch and self._uses_lcm_logic(): t_list = self.sub_timesteps_tensor if self.denoising_steps_num > 1: @@ -697,24 +1035,36 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: x_0_pred_out = x_0_pred_batch self.x_t_latent_buffer = None else: - self.init_noise = x_t_latent - for idx, t in enumerate(self.sub_timesteps_tensor): - t = t.view(1,).repeat(self.frame_bff_size,) - - x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx) - - if idx < len(self.sub_timesteps_tensor) - 1: - if self.do_add_noise: - x_t_latent = self.alpha_prod_t_sqrt[ - idx + 1 - ] * x_0_pred + self.beta_prod_t_sqrt[ - idx + 1 - ] * torch.randn_like( - x_0_pred, device=self.device, dtype=self.dtype - ) - else: - x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred - x_0_pred_out = x_0_pred + # Standard scheduler loop using scale_model_input + scheduler.step() + sample = x_t_latent + for idx, timestep in enumerate(self.sub_timesteps_tensor): + # Ensure timestep tensor on device with correct dtype + if not isinstance(timestep, torch.Tensor): + t = torch.tensor(timestep, device=self.device, dtype=torch.long) + else: + t = timestep.to(self.device) + + # Scale model input per scheduler requirements + model_input = ( + self.scheduler.scale_model_input(sample, t) + if hasattr(self.scheduler, "scale_model_input") + else sample + ) + + # Predict noise with CFG + noise_pred = self._unet_predict_noise_cfg( + latent_model_input=model_input, + timestep=t, + cfg_mode=self.cfg_type, + ) + + # Advance one step + step_out = self.scheduler.step(noise_pred, t, sample) + # diffusers returns a SchedulerOutput; prefer .prev_sample if present + sample = getattr(step_out, "prev_sample", step_out[0] if isinstance(step_out, (tuple, list)) else step_out) + + # After final step, sample approximates x0 latent + x_0_pred_out = sample return x_0_pred_out diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 1e43fded..a1ab89a8 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -276,6 +276,7 @@ def update_stream_params( seed_list: Optional[List[Tuple[int, float]]] = None, seed_interpolation_method: Literal["linear", "slerp"] = "linear", normalize_seed_weights: Optional[bool] = None, + ipadapter_config: Optional[Dict[str, Any]] = None, ) -> None: """Update streaming parameters efficiently in a single call.""" @@ -676,7 +677,7 @@ def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> Non c_skip_list = [] c_out_list = [] for timestep in self.stream.sub_timesteps: - c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + c_skip, c_out = self.stream._get_scheduler_scalings(timestep) c_skip_list.append(c_skip) c_out_list.append(c_out) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 96f12a0b..747d5d1b 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -103,6 +103,9 @@ def __init__( build_engines_if_missing: bool = True, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + # Scheduler and sampler options + scheduler: Literal["lcm", "dpm++ 2m", "uni_pc", "ddim", "euler"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", # ControlNet options use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, @@ -182,6 +185,10 @@ def __init__( normalize_seed_weights : bool, optional Whether to normalize seed weights in blending to sum to 1, by default True. When False, weights > 1 will amplify noise. + scheduler : Literal["lcm", "dpm++ 2m", "uni_pc", "ddim", "euler"], optional + The scheduler type to use for denoising, by default "lcm". + sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional + The sampler type to use for noise scheduling, by default "normal". use_controlnet : bool, optional Whether to enable ControlNet support, by default False. controlnet_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional @@ -251,6 +258,8 @@ def __init__( build_engines_if_missing=build_engines_if_missing, normalize_prompt_weights=normalize_prompt_weights, normalize_seed_weights=normalize_seed_weights, + scheduler=scheduler, + sampler=sampler, use_controlnet=use_controlnet, controlnet_config=controlnet_config, enable_pytorch_fallback=enable_pytorch_fallback, @@ -499,6 +508,24 @@ def get_normalize_seed_weights(self) -> bool: """Get the current seed weight normalization setting.""" return self.stream.get_normalize_seed_weights() + def set_scheduler( + self, + scheduler: Literal["lcm", "dpm++ 2m", "uni_pc", "ddim", "euler"] = None, + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = None, + ) -> None: + """ + Change the scheduler and/or sampler configuration at runtime. + + Parameters + ---------- + scheduler : str, optional + The scheduler type to use. If None, keeps current scheduler. + sampler : str, optional + The sampler type to use. If None, keeps current sampler. + """ + logger.info(f"Setting scheduler to {scheduler} and sampler to {sampler}") + self.stream.set_scheduler(scheduler=scheduler, sampler=sampler) + def __call__( self, image: Optional[Union[str, Image.Image, torch.Tensor]] = None, @@ -766,6 +793,8 @@ def _load_model( build_engines_if_missing: bool = True, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, + scheduler: Literal["lcm", "dpm++ 2m", "uni_pc", "ddim", "euler"] = "lcm", + sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, enable_pytorch_fallback: bool = False, @@ -935,6 +964,8 @@ def _load_model( cfg_type=cfg_type, normalize_prompt_weights=normalize_prompt_weights, normalize_seed_weights=normalize_seed_weights, + scheduler=scheduler, + sampler=sampler, ) if not self.sd_turbo: if use_lcm_lora: @@ -948,6 +979,7 @@ def _load_model( if lora_dict is not None: for lora_name, lora_scale in lora_dict.items(): + logger.info(f"_load_model: Loading LoRA '{lora_name}' with scale {lora_scale}") stream.load_lora(lora_name) stream.fuse_lora(lora_scale=lora_scale) From 44077245be8aa6954a570e238c40245a486e2281 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 12 Aug 2025 15:32:21 -0400 Subject: [PATCH 02/15] Added lora signature to engine name. --- .../acceleration/tensorrt/engine_manager.py | 24 ++++++++++++++++++- src/streamdiffusion/pipeline.py | 1 + src/streamdiffusion/wrapper.py | 4 +++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index 8649e303..e9d0ed4b 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -1,4 +1,6 @@ import os +import re +import hashlib from enum import Enum from typing import Any, Optional, Dict from pathlib import Path @@ -67,7 +69,22 @@ def __init__(self, engine_dir: str): ) } } - + + def _lora_signature(self, lora_dict: Dict[str, float]) -> str: + """Create a short, stable signature for a set of LoRAs. + + Uses sorted basenames and weights, hashed to a short hex to avoid + long/invalid paths while keeping cache keys stable across runs. + """ + # Build canonical string of basename:weight pairs + parts = [] + for path, weight in sorted(lora_dict.items(), key=lambda x: str(x[0])): + base = Path(str(path)).name # basename only + parts.append(f"{base}:{weight}") + canon = "|".join(parts) + h = hashlib.sha1(canon.encode("utf-8")).hexdigest()[:10] + return f"{len(lora_dict)}-{h}" + def get_engine_path(self, engine_type: EngineType, model_id_or_path: str, @@ -76,6 +93,7 @@ def get_engine_path(self, mode: str, use_lcm_lora: bool, use_tiny_vae: bool, + lora_dict: Optional[Dict[str, float]] = None, ipadapter_scale: Optional[float] = None, ipadapter_tokens: Optional[int] = None, controlnet_model_id: Optional[str] = None) -> Path: @@ -111,6 +129,10 @@ def get_engine_path(self, prefix += f"--ipa{ipadapter_scale}" if ipadapter_tokens is not None: prefix += f"--tokens{ipadapter_tokens}" + + # Fused Loras - use concise hashed signature to avoid long/invalid paths + if lora_dict is not None and len(lora_dict) > 0: + prefix += f"--lora-{self._lora_signature(lora_dict)}" prefix += f"--mode-{mode}" diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 4126784f..918cc6dd 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -39,6 +39,7 @@ def __init__( use_denoising_batch: bool = True, frame_buffer_size: int = 1, cfg_type: Literal["none", "full", "self", "initialize"] = "self", + lora_dict: Optional[Dict[str, float]] = None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, scheduler: Literal["lcm", "tcd", "dpm++ 2m", "uni_pc", "ddim", "euler"] = "lcm", diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 445ad9ae..9f294703 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -979,6 +979,7 @@ def _load_model( frame_buffer_size=self.frame_buffer_size, use_denoising_batch=self.use_denoising_batch, cfg_type=cfg_type, + lora_dict=lora_dict, # We pass this to include loras in engine path names normalize_prompt_weights=normalize_prompt_weights, normalize_seed_weights=normalize_seed_weights, scheduler=scheduler, @@ -1154,7 +1155,8 @@ def _load_model( use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, ipadapter_scale=ipadapter_scale, - ipadapter_tokens=ipadapter_tokens + ipadapter_tokens=ipadapter_tokens, + lora_dict=lora_dict ) vae_encoder_path = engine_manager.get_engine_path( EngineType.VAE_ENCODER, From f79a59cd6b7cefcc356f33df1bc5772d40e381fa Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Sat, 23 Aug 2025 14:51:42 -0400 Subject: [PATCH 03/15] Clean up of scheduler/samplers that weren't working, fix to controlnets and SDXL. --- .../acceleration/tensorrt/utilities.py | 23 ++ src/streamdiffusion/pipeline.py | 365 ++++++------------ 2 files changed, 136 insertions(+), 252 deletions(-) diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py index 2714d2ca..4dd95120 100644 --- a/src/streamdiffusion/acceleration/tensorrt/utilities.py +++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py @@ -331,6 +331,29 @@ def _can_reuse_buffers(self, shape_dict=None, device="cuda"): return True def infer(self, feed_dict, stream, use_cuda_graph=False): + # Filter inputs to only those the engine actually exposes to avoid binding errors + try: + allowed_inputs = set() + for idx in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(idx) + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + allowed_inputs.add(name) + + # Drop any extra keys (e.g., text_embeds/time_ids) that the engine was not built to accept + if allowed_inputs: + filtered_feed_dict = {k: v for k, v in feed_dict.items() if k in allowed_inputs} + if len(filtered_feed_dict) != len(feed_dict): + missing = [k for k in feed_dict.keys() if k not in allowed_inputs] + if missing: + logger.debug( + "TensorRT Engine: filtering unsupported inputs %s (allowed=%s)", + missing, sorted(list(allowed_inputs)) + ) + feed_dict = filtered_feed_dict + except Exception: + # Be permissive if engine query fails; proceed with original dict + pass + for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 6ac81cf1..62c77122 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -4,15 +4,7 @@ import numpy as np import PIL.Image import torch -from diffusers import ( - LCMScheduler, - StableDiffusionPipeline, - DPMSolverMultistepScheduler, - UniPCMultistepScheduler, - DDIMScheduler, - EulerDiscreteScheduler, - TCDScheduler, -) +from diffusers import LCMScheduler, TCDScheduler, StableDiffusionPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, @@ -43,7 +35,7 @@ def __init__( lora_dict: Optional[Dict[str, float]] = None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, - scheduler: Literal["lcm", "tcd", "dpm++ 2m", "uni_pc", "ddim", "euler"] = "lcm", + scheduler: Literal["lcm", "tcd"] = "lcm", sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", ) -> None: self.device = pipe.device @@ -70,7 +62,15 @@ def __init__( self.is_turbo = detection_result['is_turbo'] self.detection_confidence = detection_result['confidence'] - if use_denoising_batch: + # TCD scheduler is incompatible with denoising batch optimization due to Strategic Stochastic Sampling + # Force sequential processing for TCD + if scheduler == "tcd": + logger.info("TCD scheduler detected: Disabling denoising batch optimization for compatibility") + self.use_denoising_batch = False + self.batch_size = frame_buffer_size + self.trt_unet_batch_size = frame_buffer_size + elif use_denoising_batch: + self.use_denoising_batch = True self.batch_size = self.denoising_steps_num * frame_buffer_size if self.cfg_type == "initialize": self.trt_unet_batch_size = ( @@ -83,13 +83,12 @@ def __init__( else: self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size else: + self.use_denoising_batch = False self.trt_unet_batch_size = self.frame_bff_size self.batch_size = frame_buffer_size self.t_list = t_index_list - self.do_add_noise = do_add_noise - self.use_denoising_batch = use_denoising_batch self.similar_image_filter = False self.similar_filter = SimilarImageFilter() @@ -97,8 +96,6 @@ def __init__( self.pipe = pipe self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) - - # Initialize scheduler based on configuration self.scheduler = self._initialize_scheduler(scheduler, sampler, pipe.scheduler.config) self.text_encoder = pipe.text_encoder @@ -111,7 +108,6 @@ def __init__( if self.is_sdxl: self.add_text_embeds = None self.add_time_ids = None - logger.log(logging.INFO, f"[PIPELINE] SDXL Detected: Using {scheduler} scheduler with {sampler} sampler") # Initialize parameter updater self._param_updater = StreamParameterUpdater(self, normalize_prompt_weights, normalize_seed_weights) @@ -131,7 +127,29 @@ def __init__( self._cached_batch_size: Optional[int] = None self._cached_cfg_type: Optional[str] = None self._cached_guidance_scale: Optional[float] = None + + def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): + """Initialize scheduler based on type and sampler configuration.""" + # Map sampler types to configuration parameters + sampler_config = { + "simple": {"timestep_spacing": "linspace"}, + "sgm uniform": {"timestep_spacing": "trailing"}, + "normal": {}, # Default configuration + "ddim": {"timestep_spacing": "leading"}, + "beta": {"beta_schedule": "scaled_linear"}, + "karras": {}, # Karras sigmas handled per scheduler + } + + # Get sampler-specific configuration + sampler_params = sampler_config.get(sampler_type, {}) + if scheduler_type == "lcm": + return LCMScheduler.from_config(config, **sampler_params) + elif scheduler_type == "tcd": + return TCDScheduler.from_config(config, **sampler_params) + else: + logger.warning(f"Unknown scheduler type '{scheduler_type}', falling back to LCM") + return LCMScheduler.from_config(config, **sampler_params) def _check_unet_tensorrt(self) -> bool: """Cache TensorRT detection to avoid repeated hasattr calls""" @@ -196,69 +214,7 @@ def _build_sdxl_conditioning(self, batch_size: int) -> Dict[str, torch.Tensor]: 'time_ids': add_time_ids } - def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): - """Initialize scheduler based on type and sampler configuration.""" - # Map sampler types to configuration parameters - sampler_config = { - "simple": {"timestep_spacing": "linspace"}, - "sgm uniform": {"timestep_spacing": "trailing"}, # SGM Uniform is typically trailing - "normal": {}, # Default configuration - "ddim": {"timestep_spacing": "leading"}, # DDIM default per documentation - "beta": {"beta_schedule": "scaled_linear"}, - "karras": {}, # Karras sigmas will be enabled in scheduler-specific code - } - - # Get sampler-specific configuration - sampler_params = sampler_config.get(sampler_type, {}) - print(f"Sampler params: {sampler_params}") - print(f"Scheduler type: {scheduler_type}") - - # Create scheduler based on type - if scheduler_type == "lcm": - return LCMScheduler.from_config(config, **sampler_params) - elif scheduler_type == "tcd": - return TCDScheduler.from_config(config, **sampler_params) - elif scheduler_type == "dpm++ 2m": - # DPM++ 2M typically uses solver_order=2 and algorithm_type="dpmsolver++" - return DPMSolverMultistepScheduler.from_config( - config, - solver_order=2, - algorithm_type="dpmsolver++", - use_karras_sigmas=(sampler_type == "karras"), # Enable Karras sigmas if requested - **sampler_params - ) - elif scheduler_type == "uni_pc": - # UniPC: solver_order=2 for guided sampling, solver_type="bh2" by default - return UniPCMultistepScheduler.from_config( - config, - solver_order=2, # Good default for guided sampling - solver_type="bh2", # Default from documentation - disable_corrector=[], # No corrector disabled by default - use_karras_sigmas=(sampler_type == "karras"), # Enable Karras sigmas if requested - **sampler_params - ) - elif scheduler_type == "ddim": - # DDIM defaults to leading timestep spacing, but trailing can be better - return DDIMScheduler.from_config( - config, - set_alpha_to_one=True, # Default per documentation - steps_offset=0, # Default per documentation - prediction_type="epsilon", # Default per documentation - **sampler_params - ) - elif scheduler_type == "euler": - # Euler can use Karras sigmas for improved quality - return EulerDiscreteScheduler.from_config( - config, - use_karras_sigmas=(sampler_type == "karras"), # Enable Karras sigmas if requested - prediction_type="epsilon", # Default per documentation - **sampler_params - ) - else: - # Default to LCM - logger.warning(f"Unknown scheduler type '{scheduler_type}', falling back to LCM") - return LCMScheduler.from_config(config, **sampler_params) def load_lcm_lora( self, @@ -495,42 +451,12 @@ def prepare( # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list self.sub_timesteps = [] - max_timestep_index = len(self.timesteps) - 1 - for t in self.t_list: - # Clamp t_index to valid range to prevent index out of bounds - if t > max_timestep_index: - logger.warning(f"t_index {t} is out of bounds for scheduler with {len(self.timesteps)} timesteps. Clamping to {max_timestep_index}") - t = max_timestep_index - elif t < 0: - logger.warning(f"t_index {t} is negative. Clamping to 0") - t = 0 - - timestep_value = self.timesteps[t] - # Convert tensor timesteps to scalar values for indexing operations - if isinstance(timestep_value, torch.Tensor): - timestep_scalar = timestep_value.cpu().item() - else: - timestep_scalar = timestep_value - self.sub_timesteps.append(timestep_scalar) - - # Create tensor version for UNet calls - # Handle both integer and floating-point timesteps from different schedulers - # Some schedulers like Euler may return floating-point timesteps - if len(self.sub_timesteps) > 0: - # Always create the tensor from scalar values to avoid device issues - try: - # Try integer first for compatibility - sub_timesteps_tensor = torch.tensor( - self.sub_timesteps, dtype=torch.long, device=self.device - ) - except (TypeError, ValueError): - # Fallback for floating-point values - sub_timesteps_tensor = torch.tensor( - self.sub_timesteps, dtype=torch.float32, device=self.device - ) - else: - sub_timesteps_tensor = torch.tensor([], dtype=torch.long, device=self.device) + self.sub_timesteps.append(self.timesteps[t]) + + sub_timesteps_tensor = torch.tensor( + self.sub_timesteps, dtype=torch.long, device=self.device + ) self.sub_timesteps_tensor = torch.repeat_interleave( sub_timesteps_tensor, repeats=self.frame_bff_size if self.use_denoising_batch else 1, @@ -566,25 +492,8 @@ def prepare( alpha_prod_t_sqrt_list = [] beta_prod_t_sqrt_list = [] for timestep in self.sub_timesteps: - # Convert floating-point timesteps to integers for tensor indexing - if isinstance(timestep, float): - timestep_idx = int(round(timestep)) - else: - timestep_idx = timestep - - # Ensure timestep_idx is within bounds - max_idx = len(self.scheduler.alphas_cumprod) - 1 - if timestep_idx > max_idx: - logger.warning(f"Timestep index {timestep_idx} out of bounds for alphas_cumprod (max: {max_idx}). Clamping to {max_idx}") - timestep_idx = max_idx - elif timestep_idx < 0: - logger.warning(f"Timestep index {timestep_idx} is negative. Clamping to 0") - timestep_idx = 0 - - # Access scheduler tensors and move to device as needed - alpha_cumprod = self.scheduler.alphas_cumprod[timestep_idx].to(device=self.device, dtype=self.dtype) - alpha_prod_t_sqrt = alpha_cumprod.sqrt() - beta_prod_t_sqrt = (1 - alpha_cumprod).sqrt() + alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() + beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) alpha_prod_t_sqrt = ( @@ -610,8 +519,9 @@ def prepare( #NOTE: this is a hack. Pipeline needs a major refactor along with stream parameter updater. self.update_prompt(prompt) - # Only collapse tensors to a single element for non-batched LCM path. - if (not self.use_denoising_batch) and self._uses_lcm_logic(): + # Only collapse tensors to scalars for LCM non-batched mode + # TCD needs to keep tensor dimensions for iteration + if not self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): self.sub_timesteps_tensor = self.sub_timesteps_tensor[0] self.alpha_prod_t_sqrt = self.alpha_prod_t_sqrt[0] self.beta_prod_t_sqrt = self.beta_prod_t_sqrt[0] @@ -621,26 +531,14 @@ def prepare( self.c_out = self.c_out.to(self.device) def _get_scheduler_scalings(self, timestep): - """ - Get LCM-specific scaling factors for boundary conditions. - Only used for LCMScheduler - other schedulers handle scaling in their step() method. - """ + """Get LCM/TCD-specific scaling factors for boundary conditions.""" if isinstance(self.scheduler, LCMScheduler): c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep) - # Ensure returned values are tensors on the correct device - if not isinstance(c_skip, torch.Tensor): - c_skip = torch.tensor(c_skip, device=self.device, dtype=self.dtype) - else: - c_skip = c_skip.to(device=self.device, dtype=self.dtype) - if not isinstance(c_out, torch.Tensor): - c_out = torch.tensor(c_out, device=self.device, dtype=self.dtype) - else: - c_out = c_out.to(device=self.device, dtype=self.dtype) return c_skip, c_out else: - # For non-LCM schedulers, we don't use boundary condition scaling - # Their step() method handles all the necessary scaling internally - logger.debug(f"Scheduler {type(self.scheduler)} doesn't use boundary condition scaling") + # TCD and other schedulers don't use boundary condition scaling like LCM + # They handle scaling internally in their step() method + # Return tensors that are compatible with torch.stack() c_skip = torch.tensor(1.0, device=self.device, dtype=self.dtype) c_out = torch.tensor(1.0, device=self.device, dtype=self.dtype) return c_skip, c_out @@ -664,18 +562,18 @@ def get_normalize_seed_weights(self) -> bool: """Get the current seed weight normalization setting.""" return self._param_updater.get_normalize_seed_weights() - def set_scheduler( - self, - scheduler: Literal["lcm", "tcd", "dpm++ 2m", "uni_pc", "ddim", "euler"] = None, - sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = None, - ) -> None: + + + + + def set_scheduler(self, scheduler: Literal["lcm", "tcd"] = None, sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = None) -> None: """ - Change the scheduler and/or sampler configuration at runtime. + Change the scheduler and/or sampler at runtime. Parameters ---------- scheduler : str, optional - The scheduler type to use. If None, keeps current scheduler. + The scheduler type to use ("lcm" or "tcd"). If None, keeps current scheduler. sampler : str, optional The sampler type to use. If None, keeps current sampler. """ @@ -684,35 +582,14 @@ def set_scheduler( if sampler is not None: self.sampler_type = sampler - # Re-initialize scheduler with new configuration - self.scheduler = self._initialize_scheduler( - self.scheduler_type, - self.sampler_type, - self.pipe.scheduler.config - ) - + self.scheduler = self._initialize_scheduler(self.scheduler_type, self.sampler_type, self.pipe.scheduler.config) logger.info(f"Scheduler changed to {self.scheduler_type} with {self.sampler_type} sampler") + def _uses_lcm_logic(self) -> bool: + """Return True if scheduler uses LCM-style consistency boundary-condition math.""" + return isinstance(self.scheduler, LCMScheduler) - def _uses_lcm_logic(self) -> bool: - """Return True if scheduler uses consistency boundary-condition math (LCM/TCD).""" - try: - # Use isinstance checks for more reliable detection - return isinstance(self.scheduler, LCMScheduler) - except Exception: - return False - - def _warned_cfg_mode_fallback(self) -> bool: - return getattr(self, "_cfg_mode_warning_emitted", False) - - def _emit_cfg_mode_warning_once(self) -> None: - if not self._warned_cfg_mode_fallback(): - logger.warning( - "Non-LCM scheduler in use: falling back to standard CFG ('full') semantics. " - "Custom cfg_type values 'self'/'initialize' are ignored for correctness." - ) - setattr(self, "_cfg_mode_warning_emitted", True) def add_noise( self, @@ -732,37 +609,18 @@ def scheduler_step_batch( x_t_latent_batch: torch.Tensor, idx: Optional[int] = None, ) -> torch.Tensor: - """ - Simplified scheduler integration that works with StreamDiffusion's architecture. - For now, we'll use a hybrid approach until we can properly refactor the pipeline. - """ - # For LCM, use boundary condition scaling as before - if self._uses_lcm_logic(): - if idx is None: - F_theta = ( - x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch - ) / self.alpha_prod_t_sqrt - denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch - else: - F_theta = ( - x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch - ) / self.alpha_prod_t_sqrt[idx] - denoised_batch = ( - self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch - ) + if idx is None: + F_theta = ( + x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch + ) / self.alpha_prod_t_sqrt + denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch else: - # For other schedulers, use simple epsilon denoising - # This is what works reliably with StreamDiffusion's current architecture - if idx is not None and idx < len(self.alpha_prod_t_sqrt): - denoised_batch = ( - x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch - ) / self.alpha_prod_t_sqrt[idx] - else: - # Fallback to first timestep if idx is out of bounds - denoised_batch = ( - x_t_latent_batch - self.beta_prod_t_sqrt[0] * model_pred_batch - ) / self.alpha_prod_t_sqrt[0] - + F_theta = ( + x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch + ) / self.alpha_prod_t_sqrt[idx] + denoised_batch = ( + self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch + ) return denoised_batch def unet_step( @@ -1039,20 +897,11 @@ def _unet_predict_noise_cfg( """ Compute noise prediction from UNet with classifier-free guidance applied. This function does not apply any scheduler math; it only returns the guided noise. - - For non-LCM schedulers, custom cfg_mode values 'self'/'initialize' are treated - as 'full' to ensure correctness with scheduler.step(). """ - effective_cfg = cfg_mode - if not self._uses_lcm_logic() and cfg_mode in ("self", "initialize"): - self._emit_cfg_mode_warning_once() - effective_cfg = "full" - # Build latent batch for CFG - if self.guidance_scale > 1.0 and effective_cfg == "full": + if self.guidance_scale > 1.0 and cfg_mode == "full": latent_with_uc = torch.cat([latent_model_input, latent_model_input], dim=0) - elif self.guidance_scale > 1.0 and effective_cfg == "initialize": - # Keep initialize behavior for LCM only; if we reach here, LCM path + elif self.guidance_scale > 1.0 and cfg_mode == "initialize": latent_with_uc = torch.cat([latent_model_input[0:1], latent_model_input], dim=0) else: latent_with_uc = latent_model_input @@ -1062,7 +911,7 @@ def _unet_predict_noise_cfg( if self.is_sdxl and hasattr(self, 'add_text_embeds') and hasattr(self, 'add_time_ids'): if self.add_text_embeds is not None and self.add_time_ids is not None: batch_size = latent_with_uc.shape[0] - if self.guidance_scale > 1.0 and effective_cfg == "initialize": + if self.guidance_scale > 1.0 and cfg_mode == "initialize": add_text_embeds = torch.cat([ self.add_text_embeds[0:1], self.add_text_embeds[1:2].repeat(batch_size - 1, 1), @@ -1071,7 +920,7 @@ def _unet_predict_noise_cfg( self.add_time_ids[0:1], self.add_time_ids[1:2].repeat(batch_size - 1, 1), ], dim=0) - elif self.guidance_scale > 1.0 and effective_cfg == "full": + elif self.guidance_scale > 1.0 and cfg_mode == "full": repeat_factor = batch_size // 2 add_text_embeds = self.add_text_embeds.repeat(repeat_factor, 1) add_time_ids = self.add_time_ids.repeat(repeat_factor, 1) @@ -1097,7 +946,7 @@ def _unet_predict_noise_cfg( ) # Apply CFG - if self.guidance_scale > 1.0 and effective_cfg == "full": + if self.guidance_scale > 1.0 and cfg_mode == "full": noise_pred_uncond, noise_pred_text = model_pred.chunk(2) guided = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) return guided @@ -1128,23 +977,19 @@ def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: prev_latent_batch = self.x_t_latent_buffer - - # LCM supports our denoising-batch trick. Other schedulers should use step() sequentially - if self.use_denoising_batch and self._uses_lcm_logic(): + + # LCM supports our denoising-batch trick. TCD must use standard scheduler.step() sequentially + if self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): t_list = self.sub_timesteps_tensor - if self.denoising_steps_num > 1: x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) - self.stock_noise = torch.cat( (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 ) - x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) if self.denoising_steps_num > 1: x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) - if self.do_add_noise: self.x_t_latent_buffer = ( self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] @@ -1158,7 +1003,7 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: x_0_pred_out = x_0_pred_batch self.x_t_latent_buffer = None else: - # Standard scheduler loop using scale_model_input + scheduler.step() + # Standard scheduler loop for TCD and non-batched LCM sample = x_t_latent for idx, timestep in enumerate(self.sub_timesteps_tensor): # Ensure timestep tensor on device with correct dtype @@ -1167,28 +1012,44 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: else: t = timestep.to(self.device) - # Scale model input per scheduler requirements - model_input = ( - self.scheduler.scale_model_input(sample, t) - if hasattr(self.scheduler, "scale_model_input") - else sample - ) + # For TCD, use the scheduler's step method + if isinstance(self.scheduler, TCDScheduler): + # Scale model input per scheduler requirements + model_input = ( + self.scheduler.scale_model_input(sample, t) + if hasattr(self.scheduler, "scale_model_input") + else sample + ) - # Predict noise with CFG - noise_pred = self._unet_predict_noise_cfg( - latent_model_input=model_input, - timestep=t, - cfg_mode=self.cfg_type, - ) + # Predict noise with CFG + noise_pred = self._unet_predict_noise_cfg( + latent_model_input=model_input, + timestep=t, + cfg_mode=self.cfg_type, + ) - # Advance one step - step_out = self.scheduler.step(noise_pred, t, sample) - # diffusers returns a SchedulerOutput; prefer .prev_sample if present - sample = getattr(step_out, "prev_sample", step_out[0] if isinstance(step_out, (tuple, list)) else step_out) + # Advance one step using TCD's step method + step_out = self.scheduler.step(noise_pred, t, sample) + sample = getattr(step_out, "prev_sample", step_out[0] if isinstance(step_out, (tuple, list)) else step_out) + else: + # Original LCM logic for non-batched mode + t = t.view(1,).repeat(self.frame_bff_size,) + x_0_pred, model_pred = self.unet_step(sample, t, idx) + if idx < len(self.sub_timesteps_tensor) - 1: + if self.do_add_noise: + sample = self.alpha_prod_t_sqrt[ + idx + 1 + ] * x_0_pred + self.beta_prod_t_sqrt[ + idx + 1 + ] * torch.randn_like( + x_0_pred, device=self.device, dtype=self.dtype + ) + else: + sample = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + else: + sample = x_0_pred - # After final step, sample approximates x0 latent x_0_pred_out = sample - return x_0_pred_out @torch.no_grad() From 977afb1efdea8373dce7039a74ac9257333bfffc Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 25 Aug 2025 15:20:38 -0400 Subject: [PATCH 04/15] Fix to lora engine setup, changed requirements in realtime-img2img for windows support. --- demo/realtime-img2img/requirements.txt | 8 ++++---- src/streamdiffusion/wrapper.py | 3 +++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/demo/realtime-img2img/requirements.txt b/demo/realtime-img2img/requirements.txt index a379a58e..dd200a25 100644 --- a/demo/realtime-img2img/requirements.txt +++ b/demo/realtime-img2img/requirements.txt @@ -1,11 +1,11 @@ diffusers==0.35.0 -transformers==4.56.0 -peft==0.18.0 +transformers==4.55.4 +peft==0.17.1 accelerate==1.10.0 -huggingface_hub==0.35.0 +huggingface_hub==0.34.4 fastapi==0.115.0 uvicorn[standard]==0.32.0 -Pillow==10.5.0 +Pillow==10.4.0 compel==2.0.2 controlnet-aux==0.0.7 xformers; sys_platform != 'darwin' or platform_machine != 'arm64' diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 5ff8a374..51b7fea1 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -1113,6 +1113,7 @@ def _load_model( mode=self.mode, use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None @@ -1125,6 +1126,7 @@ def _load_model( mode=self.mode, use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None @@ -1137,6 +1139,7 @@ def _load_model( mode=self.mode, use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, + lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None From 0044a9b89a2e2e1663572ea848a438d8a3fd518d Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 25 Aug 2025 19:35:34 -0400 Subject: [PATCH 05/15] ControlNet TCD. --- src/streamdiffusion/pipeline.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 62c77122..d15beb58 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -66,6 +66,7 @@ def __init__( # Force sequential processing for TCD if scheduler == "tcd": logger.info("TCD scheduler detected: Disabling denoising batch optimization for compatibility") + logger.info("TCD now supports ControlNet through proper hook processing") self.use_denoising_batch = False self.batch_size = frame_buffer_size self.trt_unet_batch_size = frame_buffer_size @@ -979,6 +980,7 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: prev_latent_batch = self.x_t_latent_buffer # LCM supports our denoising-batch trick. TCD must use standard scheduler.step() sequentially + # but now properly processes ControlNet hooks through unet_step() if self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): t_list = self.sub_timesteps_tensor if self.denoising_steps_num > 1: @@ -1012,24 +1014,14 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: else: t = timestep.to(self.device) - # For TCD, use the scheduler's step method + # For TCD, use the same UNet calling logic as LCM to ensure ControlNet hooks are processed if isinstance(self.scheduler, TCDScheduler): - # Scale model input per scheduler requirements - model_input = ( - self.scheduler.scale_model_input(sample, t) - if hasattr(self.scheduler, "scale_model_input") - else sample - ) - - # Predict noise with CFG - noise_pred = self._unet_predict_noise_cfg( - latent_model_input=model_input, - timestep=t, - cfg_mode=self.cfg_type, - ) - - # Advance one step using TCD's step method - step_out = self.scheduler.step(noise_pred, t, sample) + # Use unet_step to process ControlNet hooks and get proper noise prediction + t_expanded = t.view(1,).repeat(self.frame_bff_size,) + x_0_pred, model_pred = self.unet_step(sample, t_expanded, idx) + + # Apply TCD scheduler step to the guided noise prediction + step_out = self.scheduler.step(model_pred, t, sample) sample = getattr(step_out, "prev_sample", step_out[0] if isinstance(step_out, (tuple, list)) else step_out) else: # Original LCM logic for non-batched mode From 41e51229279e258077fbb2c1bd98a4d9a4eeef24 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 9 Sep 2025 20:30:56 -0400 Subject: [PATCH 06/15] At uvicorn quiet param to help debug issues without unncessary logging. --- demo/realtime-img2img/config.py | 9 +++++++++ demo/realtime-img2img/main.py | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/demo/realtime-img2img/config.py b/demo/realtime-img2img/config.py index 8d74eda4..56d77404 100644 --- a/demo/realtime-img2img/config.py +++ b/demo/realtime-img2img/config.py @@ -20,6 +20,7 @@ class Args(NamedTuple): controlnet_config: str api_only: bool log_level: str + quiet: bool def pretty_print(self): print("\n") @@ -34,6 +35,7 @@ def pretty_print(self): ENGINE_DIR = os.environ.get("ENGINE_DIR", "engines") ACCELERATION = os.environ.get("ACCELERATION", "xformers") LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") +QUIET = os.environ.get("QUIET", "False").lower() in ("true", "1", "yes", "on") default_host = os.getenv("HOST", "0.0.0.0") default_port = int(os.getenv("PORT", "7860")) @@ -129,5 +131,12 @@ def pretty_print(self): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", ) +parser.add_argument( + "--quiet", + dest="quiet", + action="store_true", + default=QUIET, + help="Suppress uvicorn INFO messages (server access logs, etc.)", +) config = Args(**vars(parser.parse_args())) config.pretty_print() diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py index 4ede7a30..0c2ca2e4 100644 --- a/demo/realtime-img2img/main.py +++ b/demo/realtime-img2img/main.py @@ -97,6 +97,13 @@ def setup_logging(log_level: str = "INFO"): # Initialize logger logger = setup_logging(config.log_level) +# Suppress uvicorn INFO messages +if config.quiet: + uvicorn_logger = logging.getLogger('uvicorn') + uvicorn_logger.setLevel(logging.WARNING) + uvicorn_access_logger = logging.getLogger('uvicorn.access') + uvicorn_access_logger.setLevel(logging.WARNING) + class App: def __init__(self, config: Args): From b04f0e8c7804e9649cf60ab6056fa1f9b7dd65f9 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Thu, 11 Sep 2025 11:16:54 -0400 Subject: [PATCH 07/15] Fix to LoRA and IPAdapter conflict. --- src/streamdiffusion/wrapper.py | 220 +++++++++++++++++++++------------ 1 file changed, 139 insertions(+), 81 deletions(-) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 2612da3f..7dbf923d 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -7,7 +7,8 @@ from PIL import Image import torchvision.transforms as T from torchvision.transforms import InterpolationMode -from diffusers import AutoencoderTiny, StableDiffusionPipeline, StableDiffusionXLPipeline, AutoPipelineForText2Image +from diffusers import AutoencoderTiny, StableDiffusionPipeline, StableDiffusionXLPipeline, AutoPipelineForText2Image, UNet2DConditionModel +from safetensors.torch import load_file from .pipeline import StreamDiffusion from .model_detection import detect_model @@ -1011,21 +1012,78 @@ def _load_model( scheduler=scheduler, sampler=sampler, ) + # Load and properly merge LoRA weights using the standard diffusers approach if not self.sd_turbo: + lora_adapters_to_merge = [] + lora_scales_to_merge = [] + + # Collect all LoRA adapters and their scales if use_lcm_lora: if lcm_lora_id is not None: - stream.load_lcm_lora( - pretrained_model_name_or_path_or_dict=lcm_lora_id - ) + logger.info(f"_load_model: Loading LCM LoRA from {lcm_lora_id}") + stream.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm_lora") else: - stream.load_lcm_lora() - stream.fuse_lora() + logger.info("_load_model: Loading default LCM LoRA") + # Use appropriate default LCM LoRA based on model type + default_lcm_lora = "latent-consistency/lcm-lora-sdxl" if is_sdxl else "latent-consistency/lcm-lora-sdv1-5" + stream.pipe.load_lora_weights(default_lcm_lora, adapter_name="lcm_lora") + + lora_adapters_to_merge.append("lcm_lora") + lora_scales_to_merge.append(1.0) if lora_dict is not None: - for lora_name, lora_scale in lora_dict.items(): + for i, (lora_name, lora_scale) in enumerate(lora_dict.items()): + adapter_name = f"custom_lora_{i}" logger.info(f"_load_model: Loading LoRA '{lora_name}' with scale {lora_scale}") - stream.load_lora(lora_name) - stream.fuse_lora(lora_scale=lora_scale) + + try: + # Load LoRA weights with unique adapter name + stream.pipe.load_lora_weights(lora_name, adapter_name=adapter_name) + lora_adapters_to_merge.append(adapter_name) + lora_scales_to_merge.append(lora_scale) + logger.info(f"Successfully loaded LoRA adapter: {adapter_name}") + except Exception as e: + logger.error(f"Failed to load LoRA {lora_name}: {e}") + # Continue with other LoRAs even if one fails + continue + + # Merge all LoRA adapters using the proper diffusers method + if lora_adapters_to_merge: + try: + logger.info(f"Merging {len(lora_adapters_to_merge)} LoRA adapter(s) with scales: {lora_scales_to_merge}") + + # Use the proper merge_and_unload method from diffusers + # This permanently merges LoRA weights into the base model parameters + stream.pipe.fuse_lora(lora_scale=lora_scales_to_merge, adapter_names=lora_adapters_to_merge) + + # After fusing, unload the LoRA weights to clean up memory and avoid conflicts + stream.pipe.unload_lora_weights() + + logger.info("Successfully merged and unloaded LoRA weights using diffusers merge_and_unload") + + except Exception as e: + logger.error(f"Failed to merge LoRA weights: {e}") + logger.info("Attempting fallback: individual LoRA merging...") + + # Fallback: merge LoRAs individually + try: + for adapter_name, scale in zip(lora_adapters_to_merge, lora_scales_to_merge): + logger.info(f"Merging individual LoRA: {adapter_name} with scale {scale}") + stream.pipe.fuse_lora(lora_scale=scale, adapter_names=[adapter_name]) + + # Clean up after individual merging + stream.pipe.unload_lora_weights() + logger.info("Successfully merged LoRAs individually") + + except Exception as fallback_error: + logger.error(f"LoRA merging fallback also failed: {fallback_error}") + logger.warning("Continuing without LoRA merging - LoRAs may not be applied correctly") + + # Clean up any partial state + try: + stream.pipe.unload_lora_weights() + except: + pass if use_tiny_vae: if vae_id is not None: @@ -1034,7 +1092,6 @@ def _load_model( # Use TAESD XL for SDXL models, regular TAESD for SD 1.5 taesd_model = "madebyollin/taesdxl" if is_sdxl else "madebyollin/taesd" stream.vae = AutoencoderTiny.from_pretrained(taesd_model).to(dtype=pipe.dtype) - try: if acceleration == "xformers": @@ -1243,10 +1300,15 @@ def _load_model( except Exception: pass - # If using TensorRT with IP-Adapter, ensure processors and weights are installed BEFORE export - if use_ipadapter_trt and has_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): + # Note: LoRA weights have already been merged permanently during model loading + + # CRITICAL: Install IPAdapter module BEFORE TensorRT compilation to ensure processors are baked into engines + if use_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): try: from streamdiffusion.modules.ipadapter_module import IPAdapterModule, IPAdapterConfig + logger.info("Installing IPAdapter module before TensorRT compilation...") + + # Use first config if list provided cfg = ipadapter_config[0] if isinstance(ipadapter_config, list) else ipadapter_config ip_cfg = IPAdapterConfig( style_image_key=cfg.get('style_image_key') or 'ipadapter_main', @@ -1258,17 +1320,28 @@ def _load_model( is_faceid=(cfg.get('type') == 'faceid' or bool(cfg.get('is_faceid', False))), insightface_model_name=cfg.get('insightface_model_name'), ) - ip_module_for_export = IPAdapterModule(ip_cfg) - ip_module_for_export.install(stream) - setattr(stream, '_ipadapter_module', ip_module_for_export) - try: - logger.info("Installed IP-Adapter processors prior to TensorRT export") - except Exception: - pass + ip_module = IPAdapterModule(ip_cfg) + ip_module.install(stream) + # Expose for later updates + stream._ipadapter_module = ip_module + logger.info("IPAdapter module installed successfully before TensorRT compilation") + + # Cleanup after IPAdapter installation + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + except torch.cuda.OutOfMemoryError as oom_error: + logger.error(f"CUDA Out of Memory during early IPAdapter installation: {oom_error}") + logger.error("Try reducing batch size, using smaller models, or increasing GPU memory") + raise RuntimeError("Insufficient VRAM for IPAdapter installation. Consider using a GPU with more memory or reducing model complexity.") + except Exception: import traceback traceback.print_exc() - logger.error("Failed to pre-install IP-Adapter prior to TensorRT export") + logger.error("Failed to install IPAdapterModule before TensorRT compilation") + raise # NOTE: When IPAdapter is enabled, we must pass num_ip_layers. We cannot know it until after # installing processors in the export wrapper. We construct the wrapper first to discover it, @@ -1492,46 +1565,47 @@ def _load_model( logger.error(f"TensorRT VAE engine loading failed (non-OOM): {e}") raise e - safety_checker_path = engine_manager.get_engine_path( - EngineType.SAFETY_CHECKER, - model_id_or_path=safety_checker_model_id, - max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, - min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, - mode=self.mode, - use_lcm_lora=use_lcm_lora, - use_tiny_vae=use_tiny_vae, - ) - safety_checker_engine_exists = os.path.exists(safety_checker_path) - - # Always load the safety checker if the engine exists. The model is really small and may be toggled later. - if self.use_safety_checker or safety_checker_engine_exists: - if not safety_checker_engine_exists: - from transformers import AutoModelForImageClassification - self.safety_checker = AutoModelForImageClassification.from_pretrained(safety_checker_model_id).to("cuda") - - safety_checker_model = NSFWDetector( - device=self.device, - max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, - min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, - ) + # Safety checker engine (TensorRT-specific) + safety_checker_path = engine_manager.get_engine_path( + EngineType.SAFETY_CHECKER, + model_id_or_path=safety_checker_model_id, + max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, + min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, + mode=self.mode, + use_lcm_lora=use_lcm_lora, + use_tiny_vae=use_tiny_vae, + ) + safety_checker_engine_exists = os.path.exists(safety_checker_path) + + # Always load the safety checker if the engine exists. The model is really small and may be toggled later. + if self.use_safety_checker or safety_checker_engine_exists: + if not safety_checker_engine_exists: + from transformers import AutoModelForImageClassification + self.safety_checker = AutoModelForImageClassification.from_pretrained(safety_checker_model_id).to("cuda") + + safety_checker_model = NSFWDetector( + device=self.device, + max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, + min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, + ) - engine_manager.compile_and_load_engine( - EngineType.SAFETY_CHECKER, - safety_checker_path, - model=self.safety_checker, - model_config=safety_checker_model, - batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, - cuda_stream=None, - load_engine=load_engine, - ) - - if load_engine: - self.safety_checker = NSFWDetectorEngine( - safety_checker_path, - cuda_stream, - use_cuda_graph=True, - ) + engine_manager.compile_and_load_engine( + EngineType.SAFETY_CHECKER, + safety_checker_path, + model=self.safety_checker, + model_config=safety_checker_model, + batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, + cuda_stream=None, + load_engine=load_engine, + ) + if load_engine: + self.safety_checker = NSFWDetectorEngine( + safety_checker_path, + cuda_stream, + use_cuda_graph=True, + ) + if acceleration == "sfast": from streamdiffusion.acceleration.sfast import ( accelerate_with_stable_fast, @@ -1613,33 +1687,17 @@ def _load_model( logger.error("Failed to install ControlNetModule") raise + # IPAdapter module installation has been moved to before TensorRT compilation (see lines 1307-1345) + # This ensures processors are properly baked into the TensorRT engines if use_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): - try: - from streamdiffusion.modules.ipadapter_module import IPAdapterModule, IPAdapterConfig - # Use first config if list provided - cfg = ipadapter_config[0] if isinstance(ipadapter_config, list) else ipadapter_config - ip_cfg = IPAdapterConfig( - style_image_key=cfg.get('style_image_key') or 'ipadapter_main', - num_image_tokens=cfg.get('num_image_tokens', 4), - ipadapter_model_path=cfg['ipadapter_model_path'], - image_encoder_path=cfg['image_encoder_path'], - style_image=cfg.get('style_image'), - scale=cfg.get('scale', 1.0), - is_faceid=(cfg.get('type') == 'faceid' or bool(cfg.get('is_faceid', False))), - insightface_model_name=cfg.get('insightface_model_name'), - ) - ip_module = IPAdapterModule(ip_cfg) - ip_module.install(stream) - # Expose for later updates - stream._ipadapter_module = ip_module - except Exception: - import traceback - traceback.print_exc() - logger.error("Failed to install IPAdapterModule") - raise + logger.warning("IPAdapter was not installed during TensorRT compilation phase - this may cause runtime issues") + logger.warning("IPAdapter should have been installed before engine compilation for proper TensorRT integration") + + # Note: LoRA weights have already been merged permanently during model loading return stream + def get_last_processed_image(self, index: int) -> Optional[Image.Image]: """Forward get_last_processed_image call to the underlying ControlNet pipeline""" if not self.use_controlnet: From e2778b601ecdd6596c4a596a9842f7f00bc4e9fe Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 15 Sep 2025 19:56:29 -0400 Subject: [PATCH 08/15] Deprecation of use_lcm_lora. --- configs/sd15_multicontrol.yaml.example | 7 +- configs/sdturbo_multicontrol.yaml.example | 7 +- configs/sdxl_multicontrol.yaml.example | 7 +- .../lib/components/PreprocessorDocs.svelte | 1 - demo/realtime-img2img/img2img.py | 1 - demo/realtime-img2img/main.py | 2 - demo/realtime-txt2img/config.py | 3 +- demo/realtime-txt2img/main.py | 1 - examples/benchmark/multi.py | 6 +- examples/benchmark/single.py | 6 +- examples/optimal-performance/multi.py | 1 - examples/optimal-performance/single.py | 1 - .../acceleration/tensorrt/engine_manager.py | 4 +- src/streamdiffusion/config.py | 2 - src/streamdiffusion/pipeline.py | 15 ---- src/streamdiffusion/wrapper.py | 68 +++++++++---------- 16 files changed, 59 insertions(+), 73 deletions(-) diff --git a/configs/sd15_multicontrol.yaml.example b/configs/sd15_multicontrol.yaml.example index e95e55ef..6aa4b93d 100644 --- a/configs/sd15_multicontrol.yaml.example +++ b/configs/sd15_multicontrol.yaml.example @@ -32,7 +32,12 @@ seed: 789 frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: true +# LoRA configuration - use lora_dict to load LCM LoRA and other LoRAs +lora_dict: + "latent-consistency/lcm-lora-sdv1-5": 1.0 # LCM LoRA for faster inference + # Add other LoRAs here: + # "your_custom_lora": 0.7 + use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" diff --git a/configs/sdturbo_multicontrol.yaml.example b/configs/sdturbo_multicontrol.yaml.example index 5f7b8561..54a7b8a6 100644 --- a/configs/sdturbo_multicontrol.yaml.example +++ b/configs/sdturbo_multicontrol.yaml.example @@ -22,7 +22,12 @@ seed: 789 frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: true # SD-Turbo benefits from LCM LoRA +# LoRA configuration - SD-Turbo can benefit from LCM LoRA +lora_dict: + "latent-consistency/lcm-lora-sdv1-5": 1.0 # LCM LoRA for faster inference + # Add other LoRAs here: + # "your_custom_lora": 0.7 + use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" diff --git a/configs/sdxl_multicontrol.yaml.example b/configs/sdxl_multicontrol.yaml.example index 441acce4..61c482de 100644 --- a/configs/sdxl_multicontrol.yaml.example +++ b/configs/sdxl_multicontrol.yaml.example @@ -31,7 +31,12 @@ seed: 42 # Base seed (used with seed_blending above) frame_buffer_size: 1 delta: 0.7 use_denoising_batch: true -use_lcm_lora: false # SDXL has built-in optimizations +# LoRA configuration - SDXL can use LCM LoRA for faster inference +# lora_dict: +# "latent-consistency/lcm-lora-sdxl": 1.0 # Uncomment to enable LCM LoRA for SDXL +# # Add other LoRAs here: +# # "your_custom_lora": 0.7 + use_taesd: true # Use Tiny AutoEncoder for SDXL use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups diff --git a/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte b/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte index 830666d7..0c59002b 100644 --- a/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte +++ b/demo/realtime-img2img/frontend/src/lib/components/PreprocessorDocs.svelte @@ -40,7 +40,6 @@ use_denoising_batch: true, delta: 0.7, frame_buffer_size: 1, - use_lcm_lora: true, use_tiny_vae: true, acceleration: "xformers", cfg_type: "self", diff --git a/demo/realtime-img2img/img2img.py b/demo/realtime-img2img/img2img.py index 54b802f1..8308d6b3 100644 --- a/demo/realtime-img2img/img2img.py +++ b/demo/realtime-img2img/img2img.py @@ -182,7 +182,6 @@ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype, w frame_buffer_size=1, width=params.width, height=params.height, - use_lcm_lora=False, output_type="pt", warmup=10, vae_id=None, diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py index 7ac437d5..a3fb6368 100644 --- a/demo/realtime-img2img/main.py +++ b/demo/realtime-img2img/main.py @@ -625,7 +625,6 @@ async def settings(): 'seed', 'frame_buffer_size', 'use_denoising_batch', - 'use_lcm_lora', 'use_tiny_vae', 'use_taesd', 'cfg_type', @@ -761,7 +760,6 @@ async def upload_controlnet_config(file: UploadFile = File(...)): 'seed', 'frame_buffer_size', 'use_denoising_batch', - 'use_lcm_lora', 'use_tiny_vae', 'use_taesd', 'cfg_type', diff --git a/demo/realtime-txt2img/config.py b/demo/realtime-txt2img/config.py index c0a14ba4..35494148 100644 --- a/demo/realtime-txt2img/config.py +++ b/demo/realtime-txt2img/config.py @@ -29,8 +29,7 @@ class Config: model_id_or_path: str = os.environ.get("MODEL", "KBlueLeaf/kohaku-v2.1") # LoRA dictionary write like field(default_factory=lambda: {'E:/stable-diffusion-webui/models/Lora_1.safetensors' : 1.0 , 'E:/stable-diffusion-webui/models/Lora_2.safetensors' : 0.2}) lora_dict: dict = None - # LCM-LORA model - lcm_lora_id: str = os.environ.get("LORA", "latent-consistency/lcm-lora-sdv1-5") + # LCM-LORA model (use lora_dict instead of lcm_lora_id) # TinyVAE model vae_id: str = os.environ.get("VAE", "madebyollin/taesd") # Device to use diff --git a/demo/realtime-txt2img/main.py b/demo/realtime-txt2img/main.py index 88967c4c..18931f9e 100644 --- a/demo/realtime-txt2img/main.py +++ b/demo/realtime-txt2img/main.py @@ -63,7 +63,6 @@ def __init__(self, config: Config) -> None: mode=config.mode, model_id_or_path=config.model_id_or_path, lora_dict=config.lora_dict, - lcm_lora_id=config.lcm_lora_id, vae_id=config.vae_id, device=config.device, dtype=config.dtype, diff --git a/examples/benchmark/multi.py b/examples/benchmark/multi.py index bfa971cf..f3f879f2 100644 --- a/examples/benchmark/multi.py +++ b/examples/benchmark/multi.py @@ -40,7 +40,7 @@ def run( lora_dict: Optional[Dict[str, float]] = None, prompt: str = "1girl with brown dog hair, thick glasses, smiling", negative_prompt: str = "bad image , bad quality", - use_lcm_lora: bool = True, + lcm_lora: bool = True, use_tiny_vae: bool = True, width: int = 512, height: int = 512, @@ -67,7 +67,7 @@ def run( The prompt to use, by default "1girl with brown dog hair, thick glasses, smiling". negative_prompt : str, optional The negative prompt to use, by default "bad image , bad quality". - use_lcm_lora : bool, optional + lcm_lora : bool, optional Whether to use LCM-LoRA or not, by default True. use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. @@ -97,7 +97,7 @@ def run( warmup=warmup, acceleration=acceleration, device_ids=device_ids, - use_lcm_lora=use_lcm_lora, + lora_dict={"latent-consistency/lcm-lora-sdv1-5": 1.0} if lcm_lora else lora_dict, use_tiny_vae=use_tiny_vae, enable_similar_image_filter=False, similar_image_filter_threshold=0.98, diff --git a/examples/benchmark/single.py b/examples/benchmark/single.py index 5e55fb63..eb868bdd 100644 --- a/examples/benchmark/single.py +++ b/examples/benchmark/single.py @@ -27,7 +27,7 @@ def run( lora_dict: Optional[Dict[str, float]] = None, prompt: str = "1girl with brown dog hair, thick glasses, smiling", negative_prompt: str = "bad image , bad quality", - use_lcm_lora: bool = True, + lcm_lora: bool = True, use_tiny_vae: bool = True, width: int = 512, height: int = 512, @@ -54,7 +54,7 @@ def run( The prompt to use, by default "1girl with brown dog hair, thick glasses, smiling". negative_prompt : str, optional The negative prompt to use, by default "bad image , bad quality". - use_lcm_lora : bool, optional + lcm_lora : bool, optional Whether to use LCM-LoRA or not, by default True. use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. @@ -84,7 +84,7 @@ def run( warmup=warmup, acceleration=acceleration, device_ids=device_ids, - use_lcm_lora=use_lcm_lora, + lora_dict={"latent-consistency/lcm-lora-sdv1-5": 1.0} if lcm_lora else lora_dict, use_tiny_vae=use_tiny_vae, enable_similar_image_filter=False, similar_image_filter_threshold=0.98, diff --git a/examples/optimal-performance/multi.py b/examples/optimal-performance/multi.py index ac2c2a53..791d88b1 100644 --- a/examples/optimal-performance/multi.py +++ b/examples/optimal-performance/multi.py @@ -74,7 +74,6 @@ def image_generation_process( frame_buffer_size=batch_size, warmup=10, acceleration=acceleration, - use_lcm_lora=False, mode="txt2img", cfg_type="none", use_denoising_batch=True, diff --git a/examples/optimal-performance/single.py b/examples/optimal-performance/single.py index 4bc08b3f..a8020bb8 100644 --- a/examples/optimal-performance/single.py +++ b/examples/optimal-performance/single.py @@ -40,7 +40,6 @@ def image_generation_process( frame_buffer_size=1, warmup=10, acceleration=acceleration, - use_lcm_lora=False, mode="txt2img", cfg_type="none", use_denoising_batch=True, diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index f34e6d34..4f49bd43 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -99,7 +99,6 @@ def get_engine_path(self, max_batch_size: int, min_batch_size: int, mode: str, - use_lcm_lora: bool, use_tiny_vae: bool, lora_dict: Optional[Dict[str, float]] = None, ipadapter_scale: Optional[float] = None, @@ -132,7 +131,7 @@ def get_engine_path(self, base_name = maybe_path.stem if maybe_path.exists() else model_id_or_path # Create prefix (from wrapper.py lines 1005-1013) - prefix = f"{base_name}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--min_batch-{min_batch_size}--max_batch-{max_batch_size}" + prefix = f"{base_name}--tiny_vae-{use_tiny_vae}--min_batch-{min_batch_size}--max_batch-{max_batch_size}" # IP-Adapter differentiation: add type and (optionally) tokens # Keep scale out of identity for runtime control, but include a type flag to separate caches @@ -309,7 +308,6 @@ def get_or_load_controlnet_engine(self, max_batch_size=max_batch_size, min_batch_size=min_batch_size, mode="", # Not used for ControlNet - use_lcm_lora=False, # Not used for ControlNet use_tiny_vae=False, # Not used for ControlNet controlnet_model_id=model_id ) diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 979a21cd..5c26a949 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -100,7 +100,6 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'lora_dict': config.get('lora_dict'), 'mode': config.get('mode', 'img2img'), 'output_type': config.get('output_type', 'pil'), - 'lcm_lora_id': config.get('lcm_lora_id'), 'vae_id': config.get('vae_id'), 'device': config.get('device', 'cuda'), 'dtype': _parse_dtype(config.get('dtype', 'float16')), @@ -111,7 +110,6 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'acceleration': config.get('acceleration', 'tensorrt'), 'do_add_noise': config.get('do_add_noise', True), 'device_ids': config.get('device_ids'), - 'use_lcm_lora': config.get('use_lcm_lora', True), 'use_tiny_vae': config.get('use_tiny_vae', True), 'enable_similar_image_filter': config.get('enable_similar_image_filter', False), 'similar_image_filter_threshold': config.get('similar_image_filter_threshold', 0.98), diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index bdde5998..26b12f71 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -231,21 +231,6 @@ def _build_sdxl_conditioning(self, batch_size: int) -> Dict[str, torch.Tensor]: - def load_lcm_lora( - self, - pretrained_model_name_or_path_or_dict: Union[ - str, Dict[str, torch.Tensor] - ] = "latent-consistency/lcm-lora-sdv1-5", - adapter_name: Optional[Any] = None, - **kwargs, - ) -> None: - # Check for SDXL compatibility - if self.is_sdxl: - return - - self._load_lora_with_offline_fallback( - pretrained_model_name_or_path_or_dict, adapter_name, **kwargs - ) def load_lora( self, diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index ae02c90e..d7c553dd 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -75,7 +75,6 @@ def __init__( lora_dict: Optional[Dict[str, float]] = None, mode: Literal["img2img", "txt2img"] = "img2img", output_type: Literal["pil", "pt", "np", "latent"] = "pil", - lcm_lora_id: Optional[str] = None, vae_id: Optional[str] = None, device: Literal["cpu", "cuda"] = "cuda", dtype: torch.dtype = torch.float16, @@ -86,7 +85,7 @@ def __init__( acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", do_add_noise: bool = True, device_ids: Optional[List[int]] = None, - use_lcm_lora: bool = True, + use_lcm_lora: Optional[bool] = None, # Backwards compatibility parameter use_tiny_vae: bool = True, enable_similar_image_filter: bool = False, similar_image_filter_threshold: float = 0.98, @@ -135,10 +134,6 @@ def __init__( txt2img or img2img, by default "img2img". output_type : Literal["pil", "pt", "np", "latent"], optional The output type of image, by default "pil". - lcm_lora_id : Optional[str], optional - The lcm_lora_id to load, by default None. - If None, the default LCM-LoRA - ("latent-consistency/lcm-lora-sdv1-5") will be used. vae_id : Optional[str], optional The vae_id to load, by default None. If None, the default TinyVAE @@ -162,8 +157,6 @@ def __init__( by default True. device_ids : Optional[List[int]], optional The device ids to use for DataParallel, by default None. - use_lcm_lora : bool, optional - Whether to use LCM-LoRA or not, by default True. use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. enable_similar_image_filter : bool, optional @@ -208,6 +201,35 @@ def __init__( """ if compile_engines_only: logger.info("compile_engines_only is True, will only compile engines and not load the model") + + # Handle backwards compatibility for use_lcm_lora parameter + if use_lcm_lora is not None: + logger.warning("use_lcm_lora parameter is deprecated. Use lora_dict instead.") + logger.warning("Automatically converting use_lcm_lora to lora_dict for backwards compatibility.") + + if use_lcm_lora and not self.sd_turbo: + # Initialize lora_dict if it doesn't exist + if lora_dict is None: + lora_dict = {} + else: + # Make a copy to avoid modifying the original + lora_dict = lora_dict.copy() + + # Determine which LCM LoRA to use based on model path + model_path_lower = model_id_or_path.lower() + if any(indicator in model_path_lower for indicator in ['sdxl', 'xl', '1024']): + lcm_lora_id = "latent-consistency/lcm-lora-sdxl" + logger.info(f"Detected SDXL model, adding LCM LoRA: {lcm_lora_id}") + else: + lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" + logger.info(f"Detected SD1.5 model, adding LCM LoRA: {lcm_lora_id}") + + # Add LCM LoRA to lora_dict if not already present + if lcm_lora_id not in lora_dict: + lora_dict[lcm_lora_id] = 1.0 + logger.info(f"Added {lcm_lora_id} with scale 1.0 to lora_dict") + else: + logger.info(f"LCM LoRA {lcm_lora_id} already present in lora_dict with scale {lora_dict[lcm_lora_id]}") self.sd_turbo = "turbo" in model_id_or_path self.use_controlnet = use_controlnet @@ -258,12 +280,10 @@ def __init__( self.stream: StreamDiffusion = self._load_model( model_id_or_path=model_id_or_path, lora_dict=lora_dict, - lcm_lora_id=lcm_lora_id, vae_id=vae_id, t_index_list=t_index_list, acceleration=acceleration, do_add_noise=do_add_noise, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, cfg_type=cfg_type, engine_dir=engine_dir, @@ -884,11 +904,9 @@ def _load_model( model_id_or_path: str, t_index_list: List[int], lora_dict: Optional[Dict[str, float]] = None, - lcm_lora_id: Optional[str] = None, vae_id: Optional[str] = None, acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", do_add_noise: bool = True, - use_lcm_lora: bool = True, use_tiny_vae: bool = True, cfg_type: Literal["none", "full", "self", "initialize"] = "self", engine_dir: Optional[Union[str, Path]] = "engines", @@ -915,7 +933,7 @@ def _load_model( This method does the following: 1. Loads the model from the model_id_or_path. - 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed. + 2. Loads and fuses LoRA models from lora_dict if provided. 3. Loads the VAE model from the vae_id if needed. 4. Enables acceleration if needed. 5. Prepares the model for inference. @@ -932,8 +950,7 @@ def _load_model( The lora_dict to load, by default None. Keys are the LoRA names and values are the LoRA scales. Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} - lcm_lora_id : Optional[str], optional - The lcm_lora_id to load, by default None. + Use this to load LCM LoRA: {'latent-consistency/lcm-lora-sdv1-5': 1.0} vae_id : Optional[str], optional The vae_id to load, by default None. acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional @@ -943,8 +960,6 @@ def _load_model( do_add_noise : bool, optional Whether to add noise for following denoising steps or not, by default True. - use_lcm_lora : bool, optional - Whether to use LCM-LoRA or not, by default True. use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. cfg_type : Literal["none", "full", "self", "initialize"], @@ -1095,20 +1110,7 @@ def _load_model( lora_adapters_to_merge = [] lora_scales_to_merge = [] - # Collect all LoRA adapters and their scales - if use_lcm_lora: - if lcm_lora_id is not None: - logger.info(f"_load_model: Loading LCM LoRA from {lcm_lora_id}") - stream.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm_lora") - else: - logger.info("_load_model: Loading default LCM LoRA") - # Use appropriate default LCM LoRA based on model type - default_lcm_lora = "latent-consistency/lcm-lora-sdxl" if is_sdxl else "latent-consistency/lcm-lora-sdv1-5" - stream.pipe.load_lora_weights(default_lcm_lora, adapter_name="lcm_lora") - - lora_adapters_to_merge.append("lcm_lora") - lora_scales_to_merge.append(1.0) - + # Collect all LoRA adapters and their scales from lora_dict if lora_dict is not None: for i, (lora_name, lora_scale) in enumerate(lora_dict.items()): adapter_name = f"custom_lora_{i}" @@ -1298,7 +1300,6 @@ def _load_model( max_batch_size=self.max_batch_size, min_batch_size=self.min_batch_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, @@ -1311,7 +1312,6 @@ def _load_model( max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, @@ -1324,7 +1324,6 @@ def _load_model( max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, lora_dict=lora_dict, ipadapter_scale=ipadapter_scale, @@ -1650,7 +1649,6 @@ def _load_model( max_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, min_batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size, mode=self.mode, - use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, ) safety_checker_engine_exists = os.path.exists(safety_checker_path) From 53f7d92681097d36ab0c172b4f742375104f25ac Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 15 Sep 2025 20:07:19 -0400 Subject: [PATCH 09/15] Added backwards compatibility for use_lcm_lora. --- src/streamdiffusion/config.py | 1 + src/streamdiffusion/wrapper.py | 61 ++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 5c26a949..ac8b6f20 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -110,6 +110,7 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'acceleration': config.get('acceleration', 'tensorrt'), 'do_add_noise': config.get('do_add_noise', True), 'device_ids': config.get('device_ids'), + 'use_lcm_lora': config.get('use_lcm_lora'), # Backwards compatibility 'use_tiny_vae': config.get('use_tiny_vae', True), 'enable_similar_image_filter': config.get('enable_similar_image_filter', False), 'similar_image_filter_threshold': config.get('similar_image_filter_threshold', 0.98), diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index d7c553dd..3d403097 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -157,6 +157,11 @@ def __init__( by default True. device_ids : Optional[List[int]], optional The device ids to use for DataParallel, by default None. + use_lcm_lora : Optional[bool], optional + DEPRECATED: Use lora_dict instead. For backwards compatibility only. + If True, automatically adds appropriate LCM LoRA to lora_dict based on model type. + SDXL models get "latent-consistency/lcm-lora-sdxl", others get "latent-consistency/lcm-lora-sdv1-5". + By default None (ignored). use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. enable_similar_image_filter : bool, optional @@ -202,35 +207,9 @@ def __init__( if compile_engines_only: logger.info("compile_engines_only is True, will only compile engines and not load the model") - # Handle backwards compatibility for use_lcm_lora parameter - if use_lcm_lora is not None: - logger.warning("use_lcm_lora parameter is deprecated. Use lora_dict instead.") - logger.warning("Automatically converting use_lcm_lora to lora_dict for backwards compatibility.") - - if use_lcm_lora and not self.sd_turbo: - # Initialize lora_dict if it doesn't exist - if lora_dict is None: - lora_dict = {} - else: - # Make a copy to avoid modifying the original - lora_dict = lora_dict.copy() - - # Determine which LCM LoRA to use based on model path - model_path_lower = model_id_or_path.lower() - if any(indicator in model_path_lower for indicator in ['sdxl', 'xl', '1024']): - lcm_lora_id = "latent-consistency/lcm-lora-sdxl" - logger.info(f"Detected SDXL model, adding LCM LoRA: {lcm_lora_id}") - else: - lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" - logger.info(f"Detected SD1.5 model, adding LCM LoRA: {lcm_lora_id}") - - # Add LCM LoRA to lora_dict if not already present - if lcm_lora_id not in lora_dict: - lora_dict[lcm_lora_id] = 1.0 - logger.info(f"Added {lcm_lora_id} with scale 1.0 to lora_dict") - else: - logger.info(f"LCM LoRA {lcm_lora_id} already present in lora_dict with scale {lora_dict[lcm_lora_id]}") - + # Store use_lcm_lora for backwards compatibility processing in _load_model + self.use_lcm_lora = use_lcm_lora + self.sd_turbo = "turbo" in model_id_or_path self.use_controlnet = use_controlnet self.use_ipadapter = use_ipadapter @@ -284,6 +263,7 @@ def __init__( t_index_list=t_index_list, acceleration=acceleration, do_add_noise=do_add_noise, + use_lcm_lora=use_lcm_lora, # Deprecated:Backwards compatibility use_tiny_vae=use_tiny_vae, cfg_type=cfg_type, engine_dir=engine_dir, @@ -907,6 +887,7 @@ def _load_model( vae_id: Optional[str] = None, acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", do_add_noise: bool = True, + use_lcm_lora: bool = True, use_tiny_vae: bool = True, cfg_type: Literal["none", "full", "self", "initialize"] = "self", engine_dir: Optional[Union[str, Path]] = "engines", @@ -960,6 +941,8 @@ def _load_model( do_add_noise : bool, optional Whether to add noise for following denoising steps or not, by default True. + use_lcm_lora : bool, optional + Whether to use LCM-LoRA or not, by default True. # DEPRECATED: Backwards compatibility use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. cfg_type : Literal["none", "full", "self", "initialize"], @@ -1087,6 +1070,26 @@ def _load_model( self._is_sdxl = is_sdxl logger.info(f"_load_model: Detected model type: {model_type} (confidence: {confidence:.2f})") + + # DEPRECATED: THIS WILL LOAD LCM_LORA IF USE_LCM_LORA IS TRUE + # Validate backwards compatibility LCM LoRA selection using proper model detection + if hasattr(self, 'use_lcm_lora') and self.use_lcm_lora is not None: + if self.use_lcm_lora and not self.sd_turbo and lora_dict is not None: + # Determine correct LCM LoRA based on actual model detection + lcm_lora = "latent-consistency/lcm-lora-sdxl" if is_sdxl else "latent-consistency/lcm-lora-sdv1-5" + + # Add to lora_dict if not already present + if lcm_lora not in lora_dict: + lora_dict[lcm_lora] = 1.0 + logger.info(f"Added {lcm_lora} with scale 1.0 to lora_dict") + else: + logger.info(f"LCM LoRA {lcm_lora} already present in lora_dict with scale {lora_dict[lcm_lora]}") + else: + logger.info(f"LCM LoRA will not be loaded because use_lcm_lora is {self.use_lcm_lora} and sd_turbo is {self.sd_turbo}") + + # Remove use_lcm_lora from self + self.use_lcm_lora = None + logger.info(f"use_lcm_lora has been removed from self") stream = StreamDiffusion( pipe=pipe, From 55a20c924e8f5f1776d84b112e119d0f331d5da7 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 15 Sep 2025 20:14:39 -0400 Subject: [PATCH 10/15] Reverted single/multi scripts for simplicity. --- examples/benchmark/multi.py | 6 +++--- examples/benchmark/single.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/benchmark/multi.py b/examples/benchmark/multi.py index f3f879f2..bfa971cf 100644 --- a/examples/benchmark/multi.py +++ b/examples/benchmark/multi.py @@ -40,7 +40,7 @@ def run( lora_dict: Optional[Dict[str, float]] = None, prompt: str = "1girl with brown dog hair, thick glasses, smiling", negative_prompt: str = "bad image , bad quality", - lcm_lora: bool = True, + use_lcm_lora: bool = True, use_tiny_vae: bool = True, width: int = 512, height: int = 512, @@ -67,7 +67,7 @@ def run( The prompt to use, by default "1girl with brown dog hair, thick glasses, smiling". negative_prompt : str, optional The negative prompt to use, by default "bad image , bad quality". - lcm_lora : bool, optional + use_lcm_lora : bool, optional Whether to use LCM-LoRA or not, by default True. use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. @@ -97,7 +97,7 @@ def run( warmup=warmup, acceleration=acceleration, device_ids=device_ids, - lora_dict={"latent-consistency/lcm-lora-sdv1-5": 1.0} if lcm_lora else lora_dict, + use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, enable_similar_image_filter=False, similar_image_filter_threshold=0.98, diff --git a/examples/benchmark/single.py b/examples/benchmark/single.py index eb868bdd..5e55fb63 100644 --- a/examples/benchmark/single.py +++ b/examples/benchmark/single.py @@ -27,7 +27,7 @@ def run( lora_dict: Optional[Dict[str, float]] = None, prompt: str = "1girl with brown dog hair, thick glasses, smiling", negative_prompt: str = "bad image , bad quality", - lcm_lora: bool = True, + use_lcm_lora: bool = True, use_tiny_vae: bool = True, width: int = 512, height: int = 512, @@ -54,7 +54,7 @@ def run( The prompt to use, by default "1girl with brown dog hair, thick glasses, smiling". negative_prompt : str, optional The negative prompt to use, by default "bad image , bad quality". - lcm_lora : bool, optional + use_lcm_lora : bool, optional Whether to use LCM-LoRA or not, by default True. use_tiny_vae : bool, optional Whether to use TinyVAE or not, by default True. @@ -84,7 +84,7 @@ def run( warmup=warmup, acceleration=acceleration, device_ids=device_ids, - lora_dict={"latent-consistency/lcm-lora-sdv1-5": 1.0} if lcm_lora else lora_dict, + use_lcm_lora=use_lcm_lora, use_tiny_vae=use_tiny_vae, enable_similar_image_filter=False, similar_image_filter_threshold=0.98, From 123ba695e24e60c00cf244d9b051800c778e3660 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 15 Sep 2025 20:35:32 -0400 Subject: [PATCH 11/15] Updated descriptive comments, added tcd support, small cleanup/fixes. --- configs/sd15_multicontrol.yaml.example | 4 + configs/sdturbo_multicontrol.yaml.example | 3 + configs/sdxl_multicontrol.yaml.example | 4 + src/streamdiffusion/pipeline.py | 2 + src/streamdiffusion/wrapper.py | 165 +++++++++++++++++----- 5 files changed, 144 insertions(+), 34 deletions(-) diff --git a/configs/sd15_multicontrol.yaml.example b/configs/sd15_multicontrol.yaml.example index 6aa4b93d..a5e865e1 100644 --- a/configs/sd15_multicontrol.yaml.example +++ b/configs/sd15_multicontrol.yaml.example @@ -42,12 +42,16 @@ use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + # Engine directory for TensorRT (engines will be built here if not found) engine_dir: "./engines/sd15" # Enable multi-modal conditioning use_controlnet: true use_ipadapter: true +use_ipadapter: false # IPAdapter configuration for style conditioning ipadapters: diff --git a/configs/sdturbo_multicontrol.yaml.example b/configs/sdturbo_multicontrol.yaml.example index 54a7b8a6..a0fe0ce0 100644 --- a/configs/sdturbo_multicontrol.yaml.example +++ b/configs/sdturbo_multicontrol.yaml.example @@ -32,6 +32,9 @@ use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + # Engine directory for TensorRT engine_dir: "./engines/sdturbo" diff --git a/configs/sdxl_multicontrol.yaml.example b/configs/sdxl_multicontrol.yaml.example index 61c482de..0a39c415 100644 --- a/configs/sdxl_multicontrol.yaml.example +++ b/configs/sdxl_multicontrol.yaml.example @@ -41,6 +41,10 @@ use_taesd: true # Use Tiny AutoEncoder for SDXL use_tiny_vae: true acceleration: "tensorrt" # "xformers" for non-TensorRT setups cfg_type: "self" + +scheduler: "lcm" # Supports "lcm" or "tcd" +sampler: "normal" + safety_checker: false # Engine directory for TensorRT diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 26b12f71..ed38f9be 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -145,6 +145,8 @@ def __init__( def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): """Initialize scheduler based on type and sampler configuration.""" + + # TODO: More testing and validation required on samplers. # Map sampler types to configuration parameters sampler_config = { "simple": {"timestep_spacing": "linspace"}, diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 3d403097..1c915825 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -85,7 +85,7 @@ def __init__( acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", do_add_noise: bool = True, device_ids: Optional[List[int]] = None, - use_lcm_lora: Optional[bool] = None, # Backwards compatibility parameter + use_lcm_lora: Optional[bool] = None, # DEPRECATED: Backwards compatibility parameter use_tiny_vae: bool = True, enable_similar_image_filter: bool = False, similar_image_filter_threshold: float = 0.98, @@ -101,7 +101,7 @@ def __init__( normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, # Scheduler and sampler options - scheduler: Literal["lcm", "dpm++ 2m", "uni_pc", "ddim", "euler"] = "lcm", + scheduler: Literal["lcm", "tcd"] = "lcm", sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", # ControlNet options use_controlnet: bool = False, @@ -126,6 +126,10 @@ def __init__( The model id or path to load. t_index_list : List[int] The t_index_list to use for inference. + min_batch_size : int, optional + The minimum batch size for inference, by default 1. + max_batch_size : int, optional + The maximum batch size for inference, by default 4. lora_dict : Optional[Dict[str, float]], optional The lora_dict to load, by default None. Keys are the LoRA names and values are the LoRA scales. @@ -140,6 +144,8 @@ def __init__( ("madebyollin/taesd") will be used. device : Literal["cpu", "cuda"], optional The device to use for inference, by default "cuda". + device_ids : Optional[List[int]], optional + The device ids to use for DataParallel, by default None. dtype : torch.dtype, optional The dtype for inference, by default torch.float16. frame_buffer_size : int, optional @@ -181,13 +187,19 @@ def __init__( The seed, by default 2. use_safety_checker : bool, optional Whether to use safety checker or not, by default False. + skip_diffusion : bool, optional + Whether to skip diffusion and apply only preprocessing/postprocessing hooks, by default False. + engine_dir : Optional[Union[str, Path]], optional + Directory path for storing/loading TensorRT engines, by default "engines". + build_engines_if_missing : bool, optional + Whether to build TensorRT engines if they don't exist, by default True. normalize_prompt_weights : bool, optional Whether to normalize prompt weights in blending to sum to 1, by default True. When False, weights > 1 will amplify embeddings. normalize_seed_weights : bool, optional Whether to normalize seed weights in blending to sum to 1, by default True. When False, weights > 1 will amplify noise. - scheduler : Literal["lcm", "dpm++ 2m", "uni_pc", "ddim", "euler"], optional + scheduler : Literal["lcm", "tcd"], optional The scheduler type to use for denoising, by default "lcm". sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional The sampler type to use for noise scheduling, by default "normal". @@ -197,6 +209,19 @@ def __init__( ControlNet configuration(s), by default None. Can be a single config dict or list of config dicts for multiple ControlNets. Each config should contain: model_id, preprocessor (optional), conditioning_scale, etc. + use_ipadapter : bool, optional + Whether to enable IPAdapter support, by default False. + ipadapter_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional + IPAdapter configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple IPAdapters. + image_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image preprocessing hooks, by default None. + image_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image postprocessing hooks, by default None. + latent_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent preprocessing hooks, by default None. + latent_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent postprocessing hooks, by default None. safety_checker_fallback_type : Literal["blank", "previous"], optional Whether to use a blank image or the previous image as a fallback, by default "previous". safety_checker_threshold: float, optional @@ -809,33 +834,57 @@ def postprocess_image( def _denormalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: """ - Denormalize image tensor on GPU for efficiency + Denormalize image tensor on GPU for efficiency. + Converts image tensor from diffusion range [-1, 1] to standard image range [0, 1]. - Args: - image_tensor: Input tensor on GPU - + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in diffusion range [-1, 1], expected to be on GPU. - Returns: - Denormalized tensor on GPU, clamped to [0,1] + Returns + ------- + torch.Tensor + Denormalized tensor in range [0, 1], clamped and on GPU. """ return (image_tensor / 2 + 0.5).clamp(0, 1) def _normalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: - """Convert tensor from [0,1] (processor range) back to [-1,1] (diffusion range)""" + """ + Normalize tensor from processor range to diffusion range. + + Converts image tensor from standard image range [0, 1] to diffusion range [-1, 1]. + + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in standard image range [0, 1], expected to be on GPU. + + Returns + ------- + torch.Tensor + Normalized tensor in diffusion range [-1, 1], clamped and on GPU. + """ return (image_tensor * 2 - 1).clamp(-1, 1) def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Image]: """ - Optimized tensor to PIL conversion with minimal CPU transfers + Optimized tensor to PIL conversion with minimal CPU transfers. + Efficiently converts a batch of GPU tensors to PIL Images with minimal + CPU-GPU transfers and memory allocations. - Args: - image_tensor: Input tensor on GPU - + Parameters + ---------- + image_tensor : torch.Tensor + Input tensor in diffusion range [-1, 1], expected to be on GPU. + Shape should be (batch_size, channels, height, width). - Returns: - List of PIL Images + Returns + ------- + List[Image.Image] + List of PIL RGB images, one for each item in the batch. """ # Denormalize on GPU first denormalized = self._denormalize_on_gpu(image_tensor) @@ -873,6 +922,23 @@ def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Ima return pil_images def set_nsfw_fallback_img(self, height: int, width: int) -> None: + """ + Set the NSFW fallback image used when safety checker blocks content. + + Creates a black RGB image of the specified dimensions that will be returned + when the safety checker determines content should be blocked. + + Parameters + ---------- + height : int + Height of the fallback image in pixels. + width : int + Width of the fallback image in pixels. + + Returns + ------- + None + """ self.nsfw_fallback_img = Image.new("RGB", (height, width), (0, 0, 0)) if self.output_type == "pt": self.nsfw_fallback_img = torch.from_numpy(np.array(self.nsfw_fallback_img)).unsqueeze(0) @@ -894,7 +960,7 @@ def _load_model( build_engines_if_missing: bool = True, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True, - scheduler: Literal["lcm", "dpm++ 2m", "uni_pc", "ddim", "euler"] = "lcm", + scheduler: Literal["lcm", "tcd"] = "lcm", sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", use_controlnet: bool = False, controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, @@ -924,41 +990,72 @@ def _load_model( Parameters ---------- model_id_or_path : str - The model id or path to load. + The model id or path to load. Can be a Hugging Face model ID, local path to + safetensors/ckpt file, or directory containing model files. t_index_list : List[int] - The t_index_list to use for inference. + The t_index_list to use for inference. Specifies which denoising timesteps + to use from the diffusion schedule. lora_dict : Optional[Dict[str, float]], optional The lora_dict to load, by default None. Keys are the LoRA names and values are the LoRA scales. Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} Use this to load LCM LoRA: {'latent-consistency/lcm-lora-sdv1-5': 1.0} vae_id : Optional[str], optional - The vae_id to load, by default None. - acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional - The acceleration method, by default "tensorrt". - warmup : int, optional - The number of warmup steps to perform, by default 10. + The vae_id to load, by default None. If None, uses default TinyVAE + ("madebyollin/taesd" for SD1.5, "madebyollin/taesdxl" for SDXL). + acceleration : Literal["none", "xformers", "tensorrt"], optional + The acceleration method, by default "tensorrt". Note: docstring shows + "xfomers" and "sfast" but code uses "xformers". do_add_noise : bool, optional Whether to add noise for following denoising steps or not, by default True. use_lcm_lora : bool, optional - Whether to use LCM-LoRA or not, by default True. # DEPRECATED: Backwards compatibility + DEPRECATED: Use lora_dict instead. For backwards compatibility only. + If True, automatically adds appropriate LCM LoRA to lora_dict based on model type. + SDXL models get "latent-consistency/lcm-lora-sdxl", others get "latent-consistency/lcm-lora-sdv1-5". + By default None (ignored). use_tiny_vae : bool, optional - Whether to use TinyVAE or not, by default True. - cfg_type : Literal["none", "full", "self", "initialize"], - optional + Whether to use TinyVAE or not, by default True. TinyVAE is a distilled, + smaller VAE model that provides faster encoding/decoding with minimal quality loss. + cfg_type : Literal["none", "full", "self", "initialize"], optional The cfg_type for img2img mode, by default "self". You cannot use anything other than "none" for txt2img mode. - seed : int, optional - The seed, by default 2. + engine_dir : Optional[Union[str, Path]], optional + Directory path for storing/loading TensorRT engines, by default "engines". + build_engines_if_missing : bool, optional + Whether to build TensorRT engines if they don't exist, by default True. + normalize_prompt_weights : bool, optional + Whether to normalize prompt weights in blending to sum to 1, by default True. + When False, weights > 1 will amplify embeddings. + normalize_seed_weights : bool, optional + Whether to normalize seed weights in blending to sum to 1, by default True. + When False, weights > 1 will amplify noise. + scheduler : Literal["lcm", "tcd"], optional + The scheduler type to use for denoising, by default "lcm". + sampler : Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"], optional + The sampler type to use for noise scheduling, by default "normal". use_controlnet : bool, optional - Whether to apply ControlNet patch, by default False. + Whether to enable ControlNet support, by default False. controlnet_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional - ControlNet configuration(s), by default None. + ControlNet configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple ControlNets. use_ipadapter : bool, optional - Whether to apply IPAdapter patch, by default False. + Whether to enable IPAdapter support, by default False. ipadapter_config : Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional - IPAdapter configuration(s), by default None. + IPAdapter configuration(s), by default None. Can be a single config dict + or list of config dicts for multiple IPAdapters. + image_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image preprocessing hooks, by default None. + image_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for image postprocessing hooks, by default None. + latent_preprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent preprocessing hooks, by default None. + latent_postprocessing_config : Optional[Dict[str, Any]], optional + Configuration for latent postprocessing hooks, by default None. + safety_checker_model_id : Optional[str], optional + Model ID for the safety checker, by default "Freepik/nsfw_image_detector". + compile_engines_only : bool, optional + Whether to only compile engines and not load the model, by default False. Returns ------- From 312811c5aead0ccddadb796f7d3091f334096183 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 15 Sep 2025 20:42:41 -0400 Subject: [PATCH 12/15] Oops. --- configs/sd15_multicontrol.yaml.example | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/sd15_multicontrol.yaml.example b/configs/sd15_multicontrol.yaml.example index a5e865e1..65f36748 100644 --- a/configs/sd15_multicontrol.yaml.example +++ b/configs/sd15_multicontrol.yaml.example @@ -51,7 +51,6 @@ engine_dir: "./engines/sd15" # Enable multi-modal conditioning use_controlnet: true use_ipadapter: true -use_ipadapter: false # IPAdapter configuration for style conditioning ipadapters: From 54f054642a76942010c467520d58d5e9757fc2cf Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Mon, 15 Sep 2025 21:20:03 -0400 Subject: [PATCH 13/15] Fix for potential xformers issue. --- src/streamdiffusion/wrapper.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 1c915825..973a5854 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -1149,6 +1149,11 @@ def _load_model( pipe.text_encoder = pipe.text_encoder.to(device=self.device) if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None: pipe.text_encoder_2 = pipe.text_encoder_2.to(device=self.device) + # Move main pipeline components to device, but skip UNet for TensorRT + if hasattr(pipe, "unet") and pipe.unet is not None and acceleration != "tensorrt": + pipe.unet = pipe.unet.to(device=self.device) + if hasattr(pipe, "vae") and pipe.vae is not None and acceleration != "tensorrt": + pipe.vae = pipe.vae.to(device=self.device) # If we get here, the model loaded successfully - break out of retry loop logger.info(f"Model loading succeeded") @@ -1267,11 +1272,15 @@ def _load_model( if use_tiny_vae: if vae_id is not None: - stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(dtype=pipe.dtype) + stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(dtype=pipe.dtype, device=self.device) else: # Use TAESD XL for SDXL models, regular TAESD for SD 1.5 taesd_model = "madebyollin/taesdxl" if is_sdxl else "madebyollin/taesd" - stream.vae = AutoencoderTiny.from_pretrained(taesd_model).to(dtype=pipe.dtype) + stream.vae = AutoencoderTiny.from_pretrained(taesd_model).to(dtype=pipe.dtype, device=self.device) + elif acceleration != "tensorrt": + # For non-TensorRT acceleration, ensure VAE is on device if it wasn't moved earlier + if hasattr(pipe, "vae") and pipe.vae is not None: + pipe.vae = pipe.vae.to(device=self.device) try: if acceleration == "xformers": @@ -1920,7 +1929,6 @@ def _load_model( return stream - def get_last_processed_image(self, index: int) -> Optional[Image.Image]: """Forward get_last_processed_image call to the underlying ControlNet pipeline""" if not self.use_controlnet: @@ -1945,14 +1953,12 @@ def update_control_image(self, index: int, image: Union[str, Image.Image, torch. else: logger.debug("update_control_image: Skipping ControlNet update in skip diffusion mode") - def update_style_image(self, image: Union[str, Image.Image, torch.Tensor]) -> None: """Update IPAdapter style image""" if not self.use_ipadapter: raise RuntimeError("update_style_image: IPAdapter support not enabled. Set use_ipadapter=True in constructor.") self.stream.update_style_image(image) - def clear_caches(self) -> None: """Clear all cached prompt embeddings and seed noise tensors.""" self.stream._param_updater.clear_caches() From 7e210eadd91143bcc14891909c802707eda99fbd Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Tue, 16 Sep 2025 16:31:58 -0400 Subject: [PATCH 14/15] Fix to TCD update params. --- src/streamdiffusion/stream_parameter_updater.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 81acc89f..0a0bade3 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -674,6 +674,20 @@ def _update_seed(self, seed: int) -> None: # Reset stock_noise to match the new init_noise self.stream.stock_noise = torch.zeros_like(self.stream.init_noise) + def _get_scheduler_scalings(self, timestep): + """Get LCM/TCD-specific scaling factors for boundary conditions.""" + from diffusers import LCMScheduler + if isinstance(self.stream.scheduler, LCMScheduler): + c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + return c_skip, c_out + else: + # TCD and other schedulers don't use boundary condition scaling like LCM + # They handle scaling internally in their step() method + # Return tensors that are compatible with torch.stack() + c_skip = torch.tensor(1.0, device=self.stream.device, dtype=self.stream.dtype) + c_out = torch.tensor(1.0, device=self.stream.device, dtype=self.stream.dtype) + return c_skip, c_out + def _update_timestep_calculations(self) -> None: """Update timestep-dependent calculations based on current t_list.""" self.stream.sub_timesteps = [] @@ -692,7 +706,7 @@ def _update_timestep_calculations(self) -> None: c_skip_list = [] c_out_list = [] for timestep in self.stream.sub_timesteps: - c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + c_skip, c_out = self._get_scheduler_scalings(timestep) c_skip_list.append(c_skip) c_out_list.append(c_out) From a0779f4a5ce18ada8373364c396f00ddf7c9eb29 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 17 Sep 2025 13:45:05 -0400 Subject: [PATCH 15/15] Removal of old fuse method. --- src/streamdiffusion/wrapper.py | 96 ++++++++++++++-------------------- 1 file changed, 40 insertions(+), 56 deletions(-) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 973a5854..2eb83a9f 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -1210,65 +1210,49 @@ def _load_model( scheduler=scheduler, sampler=sampler, ) + + # Load and properly merge LoRA weights using the standard diffusers approach - if not self.sd_turbo: - lora_adapters_to_merge = [] - lora_scales_to_merge = [] - - # Collect all LoRA adapters and their scales from lora_dict - if lora_dict is not None: - for i, (lora_name, lora_scale) in enumerate(lora_dict.items()): - adapter_name = f"custom_lora_{i}" - logger.info(f"_load_model: Loading LoRA '{lora_name}' with scale {lora_scale}") - - try: - # Load LoRA weights with unique adapter name - stream.pipe.load_lora_weights(lora_name, adapter_name=adapter_name) - lora_adapters_to_merge.append(adapter_name) - lora_scales_to_merge.append(lora_scale) - logger.info(f"Successfully loaded LoRA adapter: {adapter_name}") - except Exception as e: - logger.error(f"Failed to load LoRA {lora_name}: {e}") - # Continue with other LoRAs even if one fails - continue - - # Merge all LoRA adapters using the proper diffusers method - if lora_adapters_to_merge: + lora_adapters_to_merge = [] + lora_scales_to_merge = [] + + # Collect all LoRA adapters and their scales from lora_dict + if lora_dict is not None: + for i, (lora_name, lora_scale) in enumerate(lora_dict.items()): + adapter_name = f"custom_lora_{i}" + logger.info(f"_load_model: Loading LoRA '{lora_name}' with scale {lora_scale}") + try: - logger.info(f"Merging {len(lora_adapters_to_merge)} LoRA adapter(s) with scales: {lora_scales_to_merge}") - - # Use the proper merge_and_unload method from diffusers - # This permanently merges LoRA weights into the base model parameters - stream.pipe.fuse_lora(lora_scale=lora_scales_to_merge, adapter_names=lora_adapters_to_merge) - - # After fusing, unload the LoRA weights to clean up memory and avoid conflicts - stream.pipe.unload_lora_weights() - - logger.info("Successfully merged and unloaded LoRA weights using diffusers merge_and_unload") - + # Load LoRA weights with unique adapter name + stream.pipe.load_lora_weights(lora_name, adapter_name=adapter_name) + lora_adapters_to_merge.append(adapter_name) + lora_scales_to_merge.append(lora_scale) + logger.info(f"Successfully loaded LoRA adapter: {adapter_name}") except Exception as e: - logger.error(f"Failed to merge LoRA weights: {e}") - logger.info("Attempting fallback: individual LoRA merging...") - - # Fallback: merge LoRAs individually - try: - for adapter_name, scale in zip(lora_adapters_to_merge, lora_scales_to_merge): - logger.info(f"Merging individual LoRA: {adapter_name} with scale {scale}") - stream.pipe.fuse_lora(lora_scale=scale, adapter_names=[adapter_name]) - - # Clean up after individual merging - stream.pipe.unload_lora_weights() - logger.info("Successfully merged LoRAs individually") - - except Exception as fallback_error: - logger.error(f"LoRA merging fallback also failed: {fallback_error}") - logger.warning("Continuing without LoRA merging - LoRAs may not be applied correctly") - - # Clean up any partial state - try: - stream.pipe.unload_lora_weights() - except: - pass + logger.error(f"Failed to load LoRA {lora_name}: {e}") + # Continue with other LoRAs even if one fails + continue + + # Merge all LoRA adapters using the proper diffusers method + if lora_adapters_to_merge: + try: + for adapter_name, scale in zip(lora_adapters_to_merge, lora_scales_to_merge): + logger.info(f"Merging individual LoRA: {adapter_name} with scale {scale}") + stream.pipe.fuse_lora(lora_scale=scale, adapter_names=[adapter_name]) + + # Clean up after individual merging + stream.pipe.unload_lora_weights() + logger.info("Successfully merged LoRAs individually") + + except Exception as fallback_error: + logger.error(f"LoRA merging fallback also failed: {fallback_error}") + logger.warning("Continuing without LoRA merging - LoRAs may not be applied correctly") + + # Clean up any partial state + try: + stream.pipe.unload_lora_weights() + except: + pass if use_tiny_vae: if vae_id is not None: