From e0bf41123e20350e75871057027cb6d48b4e3f3b Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Sat, 13 Sep 2025 14:23:01 -0400 Subject: [PATCH 1/3] lora module --- src/streamdiffusion/config.py | 44 ++ src/streamdiffusion/modules/README_LoRA.md | 291 ++++++++++++ src/streamdiffusion/modules/__init__.py | 2 + src/streamdiffusion/modules/lora_module.py | 438 ++++++++++++++++++ .../stream_parameter_updater.py | 129 ++++++ src/streamdiffusion/wrapper.py | 66 ++- 6 files changed, 945 insertions(+), 25 deletions(-) create mode 100644 src/streamdiffusion/modules/README_LoRA.md create mode 100644 src/streamdiffusion/modules/lora_module.py diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 3ff5cb90..ce9d5d16 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -141,6 +141,14 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: param_map['use_ipadapter'] = config.get('use_ipadapter', False) param_map['ipadapter_config'] = config.get('ipadapter_config') + # Set LoRA usage if LoRAs are configured + if 'loras' in config and config['loras']: + param_map['use_lora'] = True + param_map['lora_config'] = _prepare_lora_configs(config) + else: + param_map['use_lora'] = config.get('use_lora', False) + param_map['lora_config'] = config.get('lora_config') + # Pipeline hook configurations (Phase 4: Configuration Integration) hook_configs = _prepare_pipeline_hook_configs(config) param_map.update(hook_configs) @@ -219,6 +227,24 @@ def _prepare_ipadapter_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]: return ipadapter_configs +def _prepare_lora_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]: + """Prepare LoRA configurations for wrapper""" + lora_configs = [] + + for lora_config in config['loras']: + lora_config_prepared = { + 'lora_path': lora_config['lora_path'], + 'scale': lora_config.get('scale', 1.0), + 'enabled': lora_config.get('enabled', True), + 'lora_type': lora_config.get('lora_type'), + 'display_name': lora_config.get('display_name'), + 'description': lora_config.get('description'), + } + lora_configs.append(lora_config_prepared) + + return lora_configs + + def _prepare_pipeline_hook_configs(config: Dict[str, Any]) -> Dict[str, Any]: """Prepare pipeline hook configurations for wrapper following ControlNet/IPAdapter pattern""" hook_configs = {} @@ -418,6 +444,24 @@ def _validate_config(config: Dict[str, Any]) -> None: if 'image_encoder_path' not in ipadapter: raise ValueError(f"_validate_config: IPAdapter {i} missing required 'image_encoder_path'") + # Validate loras if present + if 'loras' in config: + if not isinstance(config['loras'], list): + raise ValueError("_validate_config: 'loras' must be a list") + + for i, lora in enumerate(config['loras']): + if not isinstance(lora, dict): + raise ValueError(f"_validate_config: LoRA {i} must be a dictionary") + + if 'lora_path' not in lora: + raise ValueError(f"_validate_config: LoRA {i} missing required 'lora_path'") + + # Validate scale if provided + if 'scale' in lora: + scale = lora['scale'] + if not isinstance(scale, (int, float)) or scale < 0: + raise ValueError(f"_validate_config: LoRA {i} 'scale' must be a non-negative number, got {scale}") + # Validate prompt blending configuration if present if 'prompt_blending' in config: blend_config = config['prompt_blending'] diff --git a/src/streamdiffusion/modules/README_LoRA.md b/src/streamdiffusion/modules/README_LoRA.md new file mode 100644 index 00000000..a85a51ff --- /dev/null +++ b/src/streamdiffusion/modules/README_LoRA.md @@ -0,0 +1,291 @@ +# LoRA Module for StreamDiffusion + +This document describes the new LoRA module system implemented for StreamDiffusion, providing comprehensive LoRA management and hotswapping capabilities. + +## Overview + +The LoRA module replaces the legacy LoRA handling system with a full module-based approach that supports: + +- **Hotswapping**: Add, remove, and modify LoRAs at runtime without restarting the pipeline +- **State Management**: Track loaded LoRAs, their configurations, and current states +- **Type Detection**: Automatically detect LoRA types (standard, LCM, text encoder, etc.) +- **Runtime Updates**: Modify LoRA scales and enabled states in real-time +- **Thread Safety**: Safe concurrent access to LoRA collections +- **Configuration Integration**: Full YAML configuration support + +## Architecture + +### Core Components + +1. **LoRAConfig**: Configuration dataclass for individual LoRAs +2. **LoRAModuleConfig**: Configuration for the LoRA module itself +3. **LoRAModule**: Main module class implementing the OrchestratorUser pattern + +### Integration Points + +- **StreamDiffusionWrapper**: Integrated into the wrapper constructor and _load_model method +- **Pipeline**: Accessible via `stream.lora_module` after installation +- **Demo Interface**: Full UI controls in the realtime-img2img demo + +## Usage + +### Basic Configuration + +```yaml +# Enable LoRA module in your config +use_lora_module: true + +lora_config: + module_config: + device: "cuda" + dtype: "float16" + auto_detect_type: true + enable_offline_fallback: true + default_scale: 1.0 + + loras: + - lora_path: "path/to/lora.safetensors" + scale: 0.8 + target: "both" + enabled: true + display_name: "My LoRA" +``` + +### Programmatic Usage + +```python +from streamdiffusion import StreamDiffusionWrapper +from streamdiffusion.modules import LoRAConfig + +# Initialize with LoRA module +wrapper = StreamDiffusionWrapper( + model_id_or_path="runwayml/stable-diffusion-v1-5", + t_index_list=[35, 45], + use_lora_module=True, + lora_config={ + "loras": [ + { + "lora_path": "path/to/lora.safetensors", + "scale": 0.8, + "target": "both" + } + ] + } +) + +# Runtime LoRA management +wrapper.add_lora("new/lora/path.safetensors", scale=0.6) +wrapper.update_lora_scale(0, 1.0) +wrapper.remove_lora(1) + +# Get LoRA information +loras = wrapper.get_loaded_loras() +print(f"Loaded {len(loras)} LoRAs") +``` + +### Demo Interface + +The realtime-img2img demo includes a full LoRA configuration panel with: + +- **Add LoRAs**: By path/URL or file upload +- **Remove LoRAs**: One-click removal +- **Scale Control**: Real-time scale adjustment sliders +- **Enable/Disable**: Toggle LoRAs on/off +- **Status Display**: View loaded LoRAs and their states + +## API Reference + +### LoRAConfig + +```python +@dataclass +class LoRAConfig: + lora_path: str # Path to LoRA file or HuggingFace model ID + adapter_name: Optional[str] # Custom adapter name (auto-generated if None) + scale: float = 1.0 # LoRA strength/scale + target: Literal["unet", "text_encoder", "both"] = "both" # Application target + enabled: bool = True # Whether LoRA is active + lora_type: Optional[str] = None # LoRA type (auto-detected if None) + display_name: Optional[str] # Human-readable name + description: Optional[str] # Description +``` + +### LoRAModule Methods + +#### Core Management +- `add_lora(config: LoRAConfig) -> bool`: Add a new LoRA +- `remove_lora(index: int) -> bool`: Remove LoRA by index +- `update_lora_scale(index: int, scale: float) -> bool`: Update LoRA scale +- `update_lora_enabled(index: int, enabled: bool) -> bool`: Enable/disable LoRA + +#### State Access +- `get_loaded_loras_info() -> List[Dict]`: Get detailed LoRA information +- `get_lora_state() -> Dict`: Get complete module state +- `update_config(config: List[Dict]) -> None`: Update from configuration + +### Wrapper Integration + +#### Constructor Parameters +- `use_lora_module: bool = False`: Enable LoRA module system +- `lora_config: Optional[Dict] = None`: LoRA configuration + +#### Runtime Methods +- `add_lora(lora_path: str, scale: float = 1.0, target: str = "both") -> bool` +- `remove_lora(index: int) -> bool` +- `update_lora_scale(index: int, scale: float) -> bool` +- `update_lora_enabled(index: int, enabled: bool) -> bool` +- `get_loaded_loras() -> List[Dict]` + +## Features + +### Automatic Type Detection + +The module automatically detects LoRA types based on: +- File content analysis for local files +- Naming patterns for HuggingFace model IDs +- Weight key patterns (LCM, text encoder, UNet-specific) + +### Offline Fallback Support + +When HuggingFace is offline, the module tries common weight filenames: +- `pytorch_lora_weights.safetensors` +- `pytorch_lora_weights.bin` +- `diffusion_pytorch_model.safetensors` +- `adapter_model.safetensors` +- `lora.safetensors` + +### Thread Safety + +All LoRA operations are protected by threading locks to ensure safe concurrent access from multiple threads. + +### Error Handling + +Comprehensive error handling with detailed logging: +- Invalid file paths +- Loading failures +- Runtime update errors +- State inconsistencies + +## Migration from Legacy System + +### Old System (Deprecated) +```python +wrapper = StreamDiffusionWrapper( + model_id_or_path="model", + lora_dict={"lora1": 0.8, "lora2": 0.6}, + lcm_lora_id="latent-consistency/lcm-lora-sdv1-5", + use_lcm_lora=True +) +``` + +### New System +```python +wrapper = StreamDiffusionWrapper( + model_id_or_path="model", + use_lora_module=True, + lora_config={ + "loras": [ + {"lora_path": "lora1", "scale": 0.8}, + {"lora_path": "lora2", "scale": 0.6}, + {"lora_path": "latent-consistency/lcm-lora-sdv1-5", "scale": 1.0, "lora_type": "lcm"} + ] + } +) +``` + +## Demo API Endpoints + +The realtime-img2img demo exposes these LoRA endpoints: + +- `GET /api/lora/list`: Get loaded LoRAs +- `POST /api/lora/add`: Add new LoRA +- `POST /api/lora/remove`: Remove LoRA +- `POST /api/lora/update-scale`: Update LoRA scale +- `POST /api/lora/update-enabled`: Enable/disable LoRA +- `POST /api/lora/upload`: Upload LoRA file + +## Configuration Examples + +### Multiple LoRAs with Different Targets +```yaml +lora_config: + loras: + - lora_path: "style_lora.safetensors" + scale: 0.8 + target: "unet" + display_name: "Style Enhancement" + + - lora_path: "text_lora.safetensors" + scale: 0.6 + target: "text_encoder" + display_name: "Text Understanding" + + - lora_path: "combined_lora.safetensors" + scale: 1.0 + target: "both" + display_name: "Combined Enhancement" +``` + +### LCM LoRA Configuration +```yaml +lora_config: + loras: + - lora_path: "latent-consistency/lcm-lora-sdv1-5" + scale: 1.0 + target: "both" + lora_type: "lcm" + display_name: "LCM Acceleration" +``` + +### HuggingFace Model Integration +```yaml +lora_config: + loras: + - lora_path: "username/awesome-style-lora" + scale: 0.7 + target: "both" + display_name: "Awesome Style" + description: "Downloaded from HuggingFace Hub" +``` + +## Troubleshooting + +### Common Issues + +1. **LoRA not loading**: Check file path and permissions +2. **Scale not updating**: Ensure LoRA is enabled and pipeline supports set_adapters +3. **Type detection failing**: Manually specify lora_type in configuration +4. **Memory issues**: Reduce number of loaded LoRAs or lower scales + +### Debug Information + +Enable debug logging to see detailed LoRA operations: +```python +import logging +logging.getLogger('streamdiffusion').setLevel(logging.DEBUG) +``` + +### State Inspection + +Check LoRA module state: +```python +if hasattr(wrapper.stream, 'lora_module'): + state = wrapper.stream.lora_module.get_lora_state() + print(f"LoRA State: {state}") +``` + +## Performance Considerations + +- **Memory Usage**: Each LoRA consumes additional GPU memory +- **Loading Time**: Initial LoRA loading may take several seconds +- **Runtime Updates**: Scale and enabled state updates are fast +- **Hotswapping**: Adding/removing LoRAs may cause brief processing delays + +## Future Enhancements + +Planned improvements include: +- LoRA blending and mixing capabilities +- Advanced scheduling and automation +- Performance optimizations +- Enhanced type detection +- Batch operations support diff --git a/src/streamdiffusion/modules/__init__.py b/src/streamdiffusion/modules/__init__.py index 54954961..7fcb998e 100644 --- a/src/streamdiffusion/modules/__init__.py +++ b/src/streamdiffusion/modules/__init__.py @@ -4,11 +4,13 @@ from .ipadapter_module import IPAdapterModule from .image_processing_module import ImageProcessingModule, ImagePreprocessingModule, ImagePostprocessingModule from .latent_processing_module import LatentProcessingModule, LatentPreprocessingModule, LatentPostprocessingModule +from .lora_module import LoRAModule __all__ = [ # Existing modules 'ControlNetModule', 'IPAdapterModule', + 'LoRAModule', # Pipeline processing base classes 'ImageProcessingModule', diff --git a/src/streamdiffusion/modules/lora_module.py b/src/streamdiffusion/modules/lora_module.py new file mode 100644 index 00000000..811d869f --- /dev/null +++ b/src/streamdiffusion/modules/lora_module.py @@ -0,0 +1,438 @@ +from dataclasses import dataclass +from typing import Optional, Dict, Any, List, Union, Literal +import torch +from pathlib import Path +import threading +import logging + +from ..preprocessing.orchestrator_user import OrchestratorUser + +logger = logging.getLogger(__name__) + + +@dataclass +class LoRAConfig: + """Configuration for a single LoRA.""" + lora_path: str + adapter_name: Optional[str] = None + scale: float = 1.0 + enabled: bool = True + lora_type: Optional[Literal["standard", "lcm"]] = None + # Additional metadata + display_name: Optional[str] = None + description: Optional[str] = None + + +class LoRAModule(OrchestratorUser): + """LoRA module providing comprehensive LoRA management and hotswapping.""" + + def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16): + self.device = device + self.dtype = dtype + + # State management + self.loras: List[LoRAConfig] = [] + self.loaded_adapters: Dict[str, str] = {} # adapter_name -> lora_path + self._collections_lock = threading.RLock() + + # Pipeline reference (set during install) + self._stream = None + self._pipe = None + + # LoRA type detection + self._lora_type_cache: Dict[str, str] = {} + + # Offline fallback support + self._candidate_weight_names = ( + "pytorch_lora_weights.safetensors", + "pytorch_lora_weights.bin", + "diffusion_pytorch_model.safetensors", + "adapter_model.safetensors", + "lora.safetensors", + ) + + def install(self, stream) -> None: + """Install LoRA module into the pipeline.""" + self._stream = stream + self._pipe = stream.pipe + + # Attach orchestrator for consistency + self.attach_orchestrator(stream) + + # No hooks needed - LoRAs work transparently after loading + # State management is handled at the module level + logger.info("install: LoRA module installed successfully") + + def _detect_lora_type(self, lora_path: str) -> str: + """Detect LoRA type from file content.""" + if lora_path in self._lora_type_cache: + return self._lora_type_cache[lora_path] + + try: + # Handle both file paths and HuggingFace model IDs + if Path(lora_path).exists(): + # Local file - try to load and inspect + try: + import safetensors.torch + if lora_path.endswith('.safetensors'): + lora_weights = safetensors.torch.load_file(lora_path, device='cpu') + else: + lora_weights = torch.load(lora_path, map_location='cpu') + except Exception: + # If we can't load the weights, assume standard + lora_type = 'standard' + self._lora_type_cache[lora_path] = lora_type + return lora_type + else: + # HuggingFace model ID - use heuristics based on name + if 'lcm' in lora_path.lower(): + lora_type = 'lcm' + self._lora_type_cache[lora_path] = lora_type + return lora_type + else: + lora_type = 'standard' + self._lora_type_cache[lora_path] = lora_type + return lora_type + + # Check for LCM patterns in weights + if any('lcm' in key.lower() for key in lora_weights.keys()): + lora_type = 'lcm' + else: + # Check for text encoder vs unet patterns + text_encoder_keys = [k for k in lora_weights.keys() if 'text_model' in k or 'text_encoder' in k] + unet_keys = [k for k in lora_weights.keys() if 'unet' in k or 'diffusion_model' in k] + + if text_encoder_keys and not unet_keys: + lora_type = 'text_encoder' + elif unet_keys and not text_encoder_keys: + lora_type = 'unet' + else: + lora_type = 'standard' + + except Exception as e: + logger.warning(f"_detect_lora_type: Failed to detect LoRA type for {lora_path}: {e}") + lora_type = 'standard' + + self._lora_type_cache[lora_path] = lora_type + return lora_type + + def _load_lora_with_offline_fallback(self, lora_path: str, adapter_name: Optional[str] = None, **kwargs) -> bool: + """Load LoRA weights with offline fallback support.""" + try: + logger.debug(f"_load_lora_with_offline_fallback: Trying to load {lora_path} with adapter_name={adapter_name}") + self._pipe.load_lora_weights(lora_path, adapter_name=adapter_name, **kwargs) + logger.info(f"_load_lora_with_offline_fallback: Successfully loaded {lora_path}") + return True + except Exception as e: + message = str(e) + logger.debug(f"_load_lora_with_offline_fallback: Initial load failed: {e}") + is_offline_weight_error = isinstance(e, ValueError) and "must specify a `weight_name`" in message + if not is_offline_weight_error: + logger.error(f"_load_lora_with_offline_fallback: Failed to load LoRA {lora_path}: {e}") + return False + + # Try offline fallback with common weight names + logger.debug(f"_load_lora_with_offline_fallback: Trying offline fallback for {lora_path}") + last_err: Optional[Exception] = None + for weight_name in self._candidate_weight_names: + try: + logger.debug(f"_load_lora_with_offline_fallback: Trying weight_name={weight_name}") + self._pipe.load_lora_weights( + lora_path, + adapter_name=adapter_name, + weight_name=weight_name, + **kwargs + ) + logger.info(f"_load_lora_with_offline_fallback: Successfully loaded LoRA {lora_path} with weight_name={weight_name}") + return True + except Exception as e: + logger.debug(f"_load_lora_with_offline_fallback: Failed with weight_name={weight_name}: {e}") + last_err = e + continue + + if last_err is not None: + logger.error(f"_load_lora_with_offline_fallback: All fallback attempts failed for {lora_path}: {last_err}") + return False + + def add_lora(self, config: LoRAConfig) -> bool: + """Add a new LoRA to the pipeline.""" + with self._collections_lock: + try: + # 1. Validate LoRA file exists or is valid HF model ID + if not (Path(config.lora_path).exists() or '/' in config.lora_path): + logger.error(f"add_lora: LoRA path does not exist and is not a valid HF model ID: {config.lora_path}") + return False + + # 2. Detect LoRA type if not specified + if config.lora_type is None: + config.lora_type = self._detect_lora_type(config.lora_path) + logger.info(f"add_lora: Detected LoRA type: {config.lora_type} for {config.lora_path}") + + # 3. Generate adapter name if not provided + if config.adapter_name is None: + # Get existing adapter names from pipeline if possible + existing_adapters = set() + if hasattr(self._pipe, 'get_list_adapters'): + try: + existing_adapters.update(self._pipe.get_list_adapters()) + logger.debug(f"add_lora: Found existing adapters in pipeline: {existing_adapters}") + except Exception as e: + logger.debug(f"add_lora: Could not get existing adapters: {e}") + + # Also check our internal tracking + existing_adapters.update(self.loaded_adapters.keys()) + + # Generate unique adapter name + import time + timestamp = int(time.time() * 1000) % 10000 # Last 4 digits of timestamp + base_name = f"lora_{timestamp}" + config.adapter_name = base_name + + # Ensure uniqueness + counter = 0 + while config.adapter_name in existing_adapters: + counter += 1 + config.adapter_name = f"{base_name}_{counter}" + + logger.info(f"add_lora: Generated unique adapter name: {config.adapter_name}") + + # 4. Load LoRA weights using pipe.load_lora_weights + try: + logger.debug(f"add_lora: Loading LoRA weights: {config.lora_path} with adapter_name: {config.adapter_name}") + self._pipe.load_lora_weights(config.lora_path, adapter_name=config.adapter_name) + logger.info(f"add_lora: Successfully loaded LoRA weights") + except Exception as e: + logger.debug(f"add_lora: Failed to load LoRA weights: {e}") + # Try offline fallback + if not self._load_lora_with_offline_fallback(config.lora_path, config.adapter_name): + return False + + # 5. Set adapter scale if supported and enabled + if config.enabled and hasattr(self._pipe, 'set_adapters'): + try: + # Get current adapters + current_adapters = [] + current_scales = [] + + for lora in self.loras: + if lora.enabled: + current_adapters.append(lora.adapter_name) + current_scales.append(lora.scale) + + # Add new adapter + current_adapters.append(config.adapter_name) + current_scales.append(config.scale) + + logger.debug(f"add_lora: Calling set_adapters with adapters={current_adapters}, adapter_weights={current_scales}") + self._pipe.set_adapters(current_adapters, adapter_weights=current_scales) + logger.debug(f"add_lora: Successfully set adapter weights") + logger.info(f"add_lora: Set adapter scales: {dict(zip(current_adapters, current_scales))}") + except Exception as e: + logger.warning(f"add_lora: Failed to set adapter scale: {e}") + + # 6. Add to internal state + self.loras.append(config) + self.loaded_adapters[config.adapter_name] = config.lora_path + + logger.info(f"add_lora: Successfully added LoRA {config.lora_path} as {config.adapter_name}") + return True + + except Exception as e: + logger.error(f"add_lora: Failed to add LoRA {config.lora_path}: {e}") + return False + + def remove_lora(self, index: int) -> bool: + """Remove a LoRA from the pipeline.""" + with self._collections_lock: + try: + # 1. Validate index + if index < 0 or index >= len(self.loras): + logger.error(f"remove_lora: Invalid index {index}, valid range: 0-{len(self.loras)-1}") + return False + + lora_config = self.loras[index] + + # 2. Unload LoRA via pipe.unload_lora_weights() + if hasattr(self._pipe, 'unload_lora_weights'): + try: + self._pipe.unload_lora_weights(lora_config.adapter_name) + logger.info(f"remove_lora: Unloaded LoRA adapter {lora_config.adapter_name}") + except Exception as e: + logger.warning(f"remove_lora: Failed to unload LoRA adapter {lora_config.adapter_name}: {e}") + + # 3. Remove from internal state + removed_lora = self.loras.pop(index) + if removed_lora.adapter_name in self.loaded_adapters: + del self.loaded_adapters[removed_lora.adapter_name] + + # 4. Update remaining adapter scales + if hasattr(self._pipe, 'set_adapters'): + try: + current_adapters = [] + current_scales = [] + + for lora in self.loras: + if lora.enabled: + current_adapters.append(lora.adapter_name) + current_scales.append(lora.scale) + + if current_adapters: + logger.debug(f"remove_lora: Calling set_adapters with adapters={current_adapters}, adapter_weights={current_scales}") + self._pipe.set_adapters(current_adapters, adapter_weights=current_scales) + else: + # Disable all adapters if none remain + logger.debug(f"remove_lora: Disabling all adapters") + self._pipe.set_adapters([], adapter_weights=[]) + + logger.info(f"remove_lora: Updated adapter scales after removal") + except Exception as e: + logger.warning(f"remove_lora: Failed to update adapter scales: {e}") + + logger.info(f"remove_lora: Successfully removed LoRA at index {index}") + return True + + except Exception as e: + logger.error(f"remove_lora: Failed to remove LoRA at index {index}: {e}") + return False + + def update_lora_scale(self, index: int, scale: float) -> bool: + """Update LoRA scale at runtime.""" + logger.debug(f"update_lora_scale: Called with index={index}, scale={scale}") + with self._collections_lock: + try: + # 1. Validate index and scale + if index < 0 or index >= len(self.loras): + logger.error(f"update_lora_scale: Invalid index {index}, valid range: 0-{len(self.loras)-1}") + return False + + if scale < 0.0: + logger.error(f"update_lora_scale: Invalid scale {scale}, must be >= 0.0") + return False + + logger.debug(f"update_lora_scale: Before update - LoRA {index} scale was {self.loras[index].scale}") + + # 2. Update internal scale + old_scale = self.loras[index].scale + self.loras[index].scale = scale + + logger.debug(f"update_lora_scale: After internal update - LoRA {index} scale is now {self.loras[index].scale}") + + # 3. Apply new scale via pipe.set_adapters() + logger.debug(f"update_lora_scale: Pipeline type: {type(self._pipe)}") + logger.debug(f"update_lora_scale: Pipeline has set_adapters: {hasattr(self._pipe, 'set_adapters')}") + if hasattr(self._pipe, 'set_adapters'): + try: + current_adapters = [] + current_scales = [] + + for i, lora in enumerate(self.loras): + if lora.enabled: + current_adapters.append(lora.adapter_name) + current_scales.append(lora.scale) + logger.debug(f"update_lora_scale: Including LoRA {i} ({lora.adapter_name}) with scale {lora.scale}") + + logger.debug(f"update_lora_scale: Calling set_adapters with adapters={current_adapters}, adapter_weights={current_scales}") + self._pipe.set_adapters(current_adapters, adapter_weights=current_scales) + logger.debug(f"update_lora_scale: set_adapters call completed successfully") + logger.info(f"update_lora_scale: Updated scale for index {index} to {scale}") + except Exception as e: + logger.error(f"update_lora_scale: Exception in set_adapters: {e}") + logger.warning(f"update_lora_scale: Failed to apply scale update: {e}") + return False + else: + logger.warning(f"update_lora_scale: Pipeline does not have set_adapters method") + + return True + + except Exception as e: + logger.error(f"update_lora_scale: Exception in main try block: {e}") + return False + + def update_lora_enabled(self, index: int, enabled: bool) -> bool: + """Enable/disable LoRA at runtime.""" + with self._collections_lock: + try: + # 1. Validate index + if index < 0 or index >= len(self.loras): + logger.error(f"update_lora_enabled: Invalid index {index}, valid range: 0-{len(self.loras)-1}") + return False + + # 2. Update enabled state + self.loras[index].enabled = enabled + + # 3. Apply changes via pipe.set_adapters() + if hasattr(self._pipe, 'set_adapters'): + try: + current_adapters = [] + current_scales = [] + + for lora in self.loras: + if lora.enabled: + current_adapters.append(lora.adapter_name) + current_scales.append(lora.scale) + + logger.debug(f"update_lora_enabled: Calling set_adapters with adapters={current_adapters}, adapter_weights={current_scales}") + self._pipe.set_adapters(current_adapters, adapter_weights=current_scales) + logger.info(f"update_lora_enabled: Updated enabled state for index {index} to {enabled}") + except Exception as e: + logger.warning(f"update_lora_enabled: Failed to apply enabled state update: {e}") + return False + + return True + + except Exception as e: + logger.error(f"update_lora_enabled: Failed to update enabled state for index {index}: {e}") + return False + + def update_config(self, config: List[Dict[str, Any]]) -> None: + """Update LoRA configuration from wrapper.""" + with self._collections_lock: + try: + # Convert dict configs to LoRAConfig objects + desired_configs = [] + for cfg_dict in config: + lora_config = LoRAConfig(**cfg_dict) + desired_configs.append(lora_config) + + # Simple approach: clear all and reload + # More sophisticated diffing could be implemented later + + # Remove all current LoRAs + while self.loras: + self.remove_lora(0) + + # Add all desired LoRAs + for lora_config in desired_configs: + self.add_lora(lora_config) + + logger.info(f"update_config: Updated configuration with {len(desired_configs)} LoRAs") + + except Exception as e: + logger.error(f"update_config: Failed to update configuration: {e}") + + def get_loaded_loras_info(self) -> List[Dict[str, Any]]: + """Get detailed information about loaded LoRAs.""" + with self._collections_lock: + return [ + { + 'index': i, + 'lora_path': lora.lora_path, + 'adapter_name': lora.adapter_name, + 'scale': lora.scale, + 'enabled': lora.enabled, + 'lora_type': lora.lora_type, + 'display_name': lora.display_name, + } + for i, lora in enumerate(self.loras) + ] + + def get_lora_state(self) -> Dict[str, Any]: + """Get complete LoRA module state for debugging.""" + with self._collections_lock: + return { + 'loaded_loras': len(self.loras), + 'enabled_loras': sum(1 for lora in self.loras if lora.enabled), + 'total_scales': [lora.scale for lora in self.loras], + 'lora_types': [lora.lora_type for lora in self.loras], + 'loaded_adapters': dict(self.loaded_adapters), + } diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 81acc89f..d7a51331 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -250,6 +250,7 @@ def update_stream_params( normalize_seed_weights: Optional[bool] = None, controlnet_config: Optional[List[Dict[str, Any]]] = None, ipadapter_config: Optional[Dict[str, Any]] = None, + lora_config: Optional[List[Dict[str, Any]]] = None, image_preprocessing_config: Optional[List[Dict[str, Any]]] = None, image_postprocessing_config: Optional[List[Dict[str, Any]]] = None, latent_preprocessing_config: Optional[List[Dict[str, Any]]] = None, @@ -314,6 +315,11 @@ def update_stream_params( logger.info(f"update_stream_params: Updating IPAdapter configuration") self._update_ipadapter_config(ipadapter_config) + # Handle LoRA configuration updates + if lora_config is not None: + logger.info(f"update_stream_params: Updating LoRA configuration") + self._update_lora_config(lora_config) + # Handle Hook configuration updates if image_preprocessing_config is not None: logger.info(f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors") @@ -1458,3 +1464,126 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any logger.info(f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors") + def _update_lora_config(self, desired_config: List[Dict[str, Any]]) -> None: + """ + Update LoRA configuration by diffing current vs desired state. + + Args: + desired_config: Complete LoRA configuration list defining the desired state. + Each dict contains: lora_path, scale, enabled, etc. + """ + logger.debug(f"_update_lora_config: Called with desired_config={desired_config}") + + # Find the LoRA module + lora_module = self._get_lora_module() + if not lora_module: + logger.warning(f"_update_lora_config: No LoRA module found") + return + + logger.debug(f"_update_lora_config: Found LoRA module: {lora_module}") + + current_config = self._get_current_lora_config() + logger.debug(f"_update_lora_config: Current config: {current_config}") + + # Simple approach: detect what changed and apply minimal updates + current_loras = {i: lora.get('lora_path', f'lora_{i}') for i, lora in enumerate(current_config)} + desired_loras = {cfg['lora_path']: cfg for cfg in desired_config} + + # Remove LoRAs not in desired config + for i in reversed(range(len(current_config))): + lora_path = current_loras.get(i, f'lora_{i}') + if lora_path not in desired_loras: + logger.info(f"_update_lora_config: Removing LoRA {lora_path}") + try: + lora_module.remove_lora(i) + except Exception as e: + logger.error(f"_update_lora_config: Failed to remove LoRA at index {i}: {e}") + + # Add new LoRAs and update existing ones + for desired_cfg in desired_config: + lora_path = desired_cfg['lora_path'] + existing_index = next((i for i, path in current_loras.items() if path == lora_path), None) + + if existing_index is None: + # Add new LoRA + logger.info(f"_update_lora_config: Adding LoRA {lora_path}") + try: + from .modules.lora_module import LoRAConfig + lora_config_obj = LoRAConfig(**desired_cfg) + lora_module.add_lora(lora_config_obj) + except Exception as e: + logger.error(f"_update_lora_config: Failed to add LoRA {lora_path}: {e}") + else: + # Update existing LoRA + if 'scale' in desired_cfg: + current_scale = current_config[existing_index].get('scale', 1.0) + desired_scale = desired_cfg['scale'] + + logger.debug(f"_update_lora_config: Comparing scales for {lora_path}: current={current_scale}, desired={desired_scale}") + + if current_scale != desired_scale: + logger.debug(f"_update_lora_config: Scale change detected, updating {lora_path} scale: {current_scale} -> {desired_scale}") + logger.info(f"_update_lora_config: Updating {lora_path} scale: {current_scale} → {desired_scale}") + try: + logger.debug(f"_update_lora_config: Calling lora_module.update_lora_scale({existing_index}, {desired_scale})") + result = lora_module.update_lora_scale(existing_index, desired_scale) + logger.debug(f"_update_lora_config: update_lora_scale returned: {result}") + except Exception as e: + logger.error(f"_update_lora_config: Exception during scale update: {e}") + logger.error(f"_update_lora_config: Failed to update scale for LoRA at index {existing_index}: {e}") + else: + logger.debug(f"_update_lora_config: No scale change needed for {lora_path}") + + # Enable/disable toggle + if 'enabled' in desired_cfg: + current_enabled = current_config[existing_index].get('enabled', True) + desired_enabled = desired_cfg['enabled'] + + if current_enabled != desired_enabled: + logger.info(f"_update_lora_config: {'Enabling' if desired_enabled else 'Disabling'} LoRA {lora_path}") + try: + lora_module.update_lora_enabled(existing_index, desired_enabled) + except Exception as e: + logger.error(f"_update_lora_config: Failed to update enabled state for LoRA at index {existing_index}: {e}") + + def _get_lora_module(self): + """ + Get the LoRA module from the pipeline structure. + + Returns: + LoRA module object or None if not found + """ + # Check if stream has LoRA module + if hasattr(self.stream, 'lora_module'): + return self.stream.lora_module + + # Check if stream has nested stream + if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'lora_module'): + return self.stream.stream.lora_module + + # Check if we have a wrapper reference and can access through it + if self.wrapper and hasattr(self.wrapper, 'stream'): + if hasattr(self.wrapper.stream, 'lora_module'): + return self.wrapper.stream.lora_module + elif hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'lora_module'): + return self.wrapper.stream.stream.lora_module + + return None + + def _get_current_lora_config(self) -> List[Dict[str, Any]]: + """ + Get current LoRA configuration state. + + Returns: + List of current LoRA configurations + """ + lora_module = self._get_lora_module() + if not lora_module: + return [] + + try: + return lora_module.get_loaded_loras_info() + except Exception as e: + logger.error(f"_get_current_lora_config: Failed to get LoRA info: {e}") + return [] + diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 89dd6d6e..f8b3f3de 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -10,6 +10,7 @@ from .pipeline import StreamDiffusion from .model_detection import detect_model from .image_utils import postprocess_image +from .modules import LoRAModule import logging logger = logging.getLogger(__name__) @@ -72,7 +73,6 @@ def __init__( t_index_list: List[int], min_batch_size: int = 1, max_batch_size: int = 4, - lora_dict: Optional[Dict[str, float]] = None, mode: Literal["img2img", "txt2img"] = "img2img", output_type: Literal["pil", "pt", "np", "latent"] = "pil", lcm_lora_id: Optional[str] = None, @@ -107,6 +107,9 @@ def __init__( # IPAdapter options use_ipadapter: bool = False, ipadapter_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + # LoRA options + use_lora: bool = False, + lora_config: Optional[List[Dict[str, Any]]] = None, # Pipeline hook configurations image_preprocessing_config: Optional[Dict[str, Any]] = None, image_postprocessing_config: Optional[Dict[str, Any]] = None, @@ -124,10 +127,6 @@ def __init__( The model id or path to load. t_index_list : List[int] The t_index_list to use for inference. - lora_dict : Optional[Dict[str, float]], optional - The lora_dict to load, by default None. - Keys are the LoRA names and values are the LoRA scales. - Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} mode : Literal["img2img", "txt2img"], optional txt2img or img2img, by default "img2img". output_type : Literal["pil", "pt", "np", "latent"], optional @@ -206,6 +205,8 @@ def __init__( self.use_controlnet = use_controlnet self.use_ipadapter = use_ipadapter self.ipadapter_config = ipadapter_config + self.use_lora = use_lora + self.lora_config = lora_config # Store pipeline hook configurations self.image_preprocessing_config = image_preprocessing_config @@ -250,7 +251,6 @@ def __init__( self.stream: StreamDiffusion = self._load_model( model_id_or_path=model_id_or_path, - lora_dict=lora_dict, lcm_lora_id=lcm_lora_id, vae_id=vae_id, t_index_list=t_index_list, @@ -267,6 +267,8 @@ def __init__( controlnet_config=controlnet_config, use_ipadapter=use_ipadapter, ipadapter_config=ipadapter_config, + use_lora=use_lora, + lora_config=lora_config, # Pipeline hook configurations image_preprocessing_config=image_preprocessing_config, image_postprocessing_config=image_postprocessing_config, @@ -491,6 +493,8 @@ def update_stream_params( controlnet_config: Optional[List[Dict[str, Any]]] = None, # IPAdapter configuration ipadapter_config: Optional[Dict[str, Any]] = None, + # LoRA configuration + lora_config: Optional[List[Dict[str, Any]]] = None, # Hook configurations image_preprocessing_config: Optional[List[Dict[str, Any]]] = None, image_postprocessing_config: Optional[List[Dict[str, Any]]] = None, @@ -558,6 +562,7 @@ def update_stream_params( normalize_seed_weights=normalize_seed_weights, controlnet_config=controlnet_config, ipadapter_config=ipadapter_config, + lora_config=lora_config, image_preprocessing_config=image_preprocessing_config, image_postprocessing_config=image_postprocessing_config, latent_preprocessing_config=latent_preprocessing_config, @@ -874,7 +879,6 @@ def _load_model( self, model_id_or_path: str, t_index_list: List[int], - lora_dict: Optional[Dict[str, float]] = None, lcm_lora_id: Optional[str] = None, vae_id: Optional[str] = None, acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", @@ -890,6 +894,8 @@ def _load_model( controlnet_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, use_ipadapter: bool = False, ipadapter_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + use_lora: bool = False, + lora_config: Optional[List[Dict[str, Any]]] = None, # Pipeline hook configurations (Phase 4: Configuration Integration) image_preprocessing_config: Optional[Dict[str, Any]] = None, image_postprocessing_config: Optional[Dict[str, Any]] = None, @@ -917,10 +923,6 @@ def _load_model( The model id or path to load. t_index_list : List[int] The t_index_list to use for inference. - lora_dict : Optional[Dict[str, float]], optional - The lora_dict to load, by default None. - Keys are the LoRA names and values are the LoRA scales. - Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} lcm_lora_id : Optional[str], optional The lcm_lora_id to load, by default None. vae_id : Optional[str], optional @@ -1076,20 +1078,29 @@ def _load_model( normalize_prompt_weights=normalize_prompt_weights, normalize_seed_weights=normalize_seed_weights, ) - if not self.sd_turbo: - if use_lcm_lora: - if lcm_lora_id is not None: - stream.load_lcm_lora( - pretrained_model_name_or_path_or_dict=lcm_lora_id - ) - else: - stream.load_lcm_lora() - stream.fuse_lora() + # Initialize LoRA module if configured + if use_lora and lora_config: + lora_module = LoRAModule(device=self.device, dtype=self.dtype) + lora_module.install(stream) + stream.lora_module = lora_module + + # Load initial LoRAs from config + from .modules.lora_module import LoRAConfig + for lora_cfg in lora_config: + lora_config_obj = LoRAConfig(**lora_cfg) + lora_module.add_lora(lora_config_obj) + + logger.info(f"_load_model: Initialized LoRA module with {len(lora_config)} LoRAs") - if lora_dict is not None: - for lora_name, lora_scale in lora_dict.items(): - stream.load_lora(lora_name) - stream.fuse_lora(lora_scale=lora_scale) + # Handle LCM LoRA loading (performance LoRA) + if not self.sd_turbo and use_lcm_lora: + if lcm_lora_id is not None: + stream.load_lcm_lora( + pretrained_model_name_or_path_or_dict=lcm_lora_id + ) + else: + stream.load_lcm_lora() + stream.fuse_lora() if use_tiny_vae: if vae_id is not None: @@ -1853,7 +1864,7 @@ def get_stream_state(self, include_caches: bool = False) -> Dict[str, Any]: 'num_inference_steps': num_inference_steps, }) - # Module configs (ControlNet, IP-Adapter) + # Module configs (ControlNet, IP-Adapter, LoRA) try: controlnet_config = updater._get_current_controlnet_config() except Exception: @@ -1862,6 +1873,10 @@ def get_stream_state(self, include_caches: bool = False) -> Dict[str, Any]: ipadapter_config = updater._get_current_ipadapter_config() except Exception: ipadapter_config = None + try: + lora_config = updater._get_current_lora_config() + except Exception: + lora_config = [] # Hook configs try: image_preprocessing_config = updater._get_current_hook_config('image_preprocessing') @@ -1883,6 +1898,7 @@ def get_stream_state(self, include_caches: bool = False) -> Dict[str, Any]: state.update({ 'controlnet_config': controlnet_config, 'ipadapter_config': ipadapter_config, + 'lora_config': lora_config, 'image_preprocessing_config': image_preprocessing_config, 'image_postprocessing_config': image_postprocessing_config, 'latent_preprocessing_config': latent_preprocessing_config, From 7f51acbec9875a05666eb42356d6b6ff85bf38a6 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:06:07 -0400 Subject: [PATCH 2/3] detect lora type with handling --- src/streamdiffusion/modules/lora_module.py | 144 ++++++++++++++++----- 1 file changed, 113 insertions(+), 31 deletions(-) diff --git a/src/streamdiffusion/modules/lora_module.py b/src/streamdiffusion/modules/lora_module.py index 811d869f..3acb8ade 100644 --- a/src/streamdiffusion/modules/lora_module.py +++ b/src/streamdiffusion/modules/lora_module.py @@ -17,7 +17,7 @@ class LoRAConfig: adapter_name: Optional[str] = None scale: float = 1.0 enabled: bool = True - lora_type: Optional[Literal["standard", "lcm"]] = None + lora_type: Optional[Literal["text_encoder", "unet", "both"]] = None # Additional metadata display_name: Optional[str] = None description: Optional[str] = None @@ -64,7 +64,7 @@ def install(self, stream) -> None: logger.info("install: LoRA module installed successfully") def _detect_lora_type(self, lora_path: str) -> str: - """Detect LoRA type from file content.""" + """Detect LoRA type from file content - text_encoder, unet, or both.""" if lora_path in self._lora_type_cache: return self._lora_type_cache[lora_path] @@ -78,40 +78,38 @@ def _detect_lora_type(self, lora_path: str) -> str: lora_weights = safetensors.torch.load_file(lora_path, device='cpu') else: lora_weights = torch.load(lora_path, map_location='cpu') - except Exception: - # If we can't load the weights, assume standard - lora_type = 'standard' + except Exception as e: + # If we can't load the weights, assume both + logger.warning(f"_detect_lora_type: Could not load weights from {lora_path}: {e}. Assuming 'both' type.") + lora_type = 'both' self._lora_type_cache[lora_path] = lora_type return lora_type else: - # HuggingFace model ID - use heuristics based on name - if 'lcm' in lora_path.lower(): - lora_type = 'lcm' - self._lora_type_cache[lora_path] = lora_type - return lora_type - else: - lora_type = 'standard' - self._lora_type_cache[lora_path] = lora_type - return lora_type + # HuggingFace model ID - assume both for unknown models + lora_type = 'both' + logger.info(f"_detect_lora_type: Assuming 'both' type for HuggingFace model: {lora_path}") + self._lora_type_cache[lora_path] = lora_type + return lora_type + + # Check for text encoder vs unet patterns + text_encoder_keys = [k for k in lora_weights.keys() if 'text_model' in k or 'text_encoder' in k or 'lora_te' in k] + unet_keys = [k for k in lora_weights.keys() if 'unet' in k or 'diffusion_model' in k or 'lora_unet' in k] - # Check for LCM patterns in weights - if any('lcm' in key.lower() for key in lora_weights.keys()): - lora_type = 'lcm' + if text_encoder_keys and not unet_keys: + lora_type = 'text_encoder' + logger.info(f"_detect_lora_type: Detected text encoder LoRA from weight patterns in {lora_path}") + elif unet_keys and not text_encoder_keys: + lora_type = 'unet' + logger.info(f"_detect_lora_type: Detected UNet LoRA from weight patterns in {lora_path}") + elif unet_keys and text_encoder_keys: + lora_type = 'both' + logger.info(f"_detect_lora_type: Detected both text encoder and UNet LoRA from weight patterns in {lora_path}") else: - # Check for text encoder vs unet patterns - text_encoder_keys = [k for k in lora_weights.keys() if 'text_model' in k or 'text_encoder' in k] - unet_keys = [k for k in lora_weights.keys() if 'unet' in k or 'diffusion_model' in k] - - if text_encoder_keys and not unet_keys: - lora_type = 'text_encoder' - elif unet_keys and not text_encoder_keys: - lora_type = 'unet' - else: - lora_type = 'standard' - + lora_type = 'unknown' + logger.info(f"_detect_lora_type: Detected unknown LoRA from weight patterns in {lora_path}") except Exception as e: - logger.warning(f"_detect_lora_type: Failed to detect LoRA type for {lora_path}: {e}") - lora_type = 'standard' + logger.warning(f"_detect_lora_type: Failed to detect LoRA type for {lora_path}: {e}. Assuming 'both' type.") + lora_type = 'both' self._lora_type_cache[lora_path] = lora_type return lora_type @@ -166,7 +164,26 @@ def add_lora(self, config: LoRAConfig) -> bool: # 2. Detect LoRA type if not specified if config.lora_type is None: config.lora_type = self._detect_lora_type(config.lora_path) - logger.info(f"add_lora: Detected LoRA type: {config.lora_type} for {config.lora_path}") + type_description = self._get_type_description(config.lora_type) + logger.info(f"add_lora: Detected LoRA type: {config.lora_type} ({type_description}) for {config.lora_path}") + else: + logger.info(f"add_lora: Using specified LoRA type: {config.lora_type} for {config.lora_path}") + + # 2.5. Check for TensorRT compatibility + is_tensorrt = self._is_tensorrt_acceleration() + logger.info(f"add_lora: TensorRT detection: {is_tensorrt}, LoRA type: {config.lora_type}") + if is_tensorrt and config.lora_type == 'unet': + print("=" * 80) + print("TENSORRT COMPATIBILITY WARNING") + print("=" * 80) + print(f"Pure UNet LoRAs are NOT supported with TensorRT acceleration!") + print(f"LoRA: {config.lora_path}") + print(f"Detected Type: {config.lora_type}") + print(f"Only text_encoder and 'both' type LoRAs are supported with TensorRT pipelines.") + print("=" * 80) + print("This LoRA will NOT be loaded to prevent pipeline errors.") + print("=" * 80) + return False # 3. Generate adapter name if not provided if config.adapter_name is None: @@ -436,3 +453,68 @@ def get_lora_state(self) -> Dict[str, Any]: 'lora_types': [lora.lora_type for lora in self.loras], 'loaded_adapters': dict(self.loaded_adapters), } + + def get_lora_type_info(self, lora_path: str) -> Dict[str, Any]: + """ + Get detailed LoRA type information for a specific LoRA. + + Args: + lora_path: Path to the LoRA file or HuggingFace model ID + + Returns: + Dictionary containing type information and detection details + """ + lora_type = self._detect_lora_type(lora_path) + + # Check if this LoRA is currently loaded + loaded_info = None + with self._collections_lock: + for lora in self.loras: + if lora.lora_path == lora_path: + loaded_info = { + 'is_loaded': True, + 'adapter_name': lora.adapter_name, + 'scale': lora.scale, + 'enabled': lora.enabled, + 'display_name': lora.display_name, + 'description': lora.description + } + break + + if loaded_info is None: + loaded_info = {'is_loaded': False} + + return { + 'lora_path': lora_path, + 'detected_type': lora_type, + 'type_description': self._get_type_description(lora_type), + 'is_cached': lora_path in self._lora_type_cache, + 'loaded_info': loaded_info + } + + def _get_type_description(self, lora_type: str) -> str: + """Get human-readable description of LoRA type.""" + descriptions = { + 'text_encoder': 'Text Encoder LoRA - affects only text processing', + 'unet': 'UNet LoRA - affects only the diffusion model', + 'both': 'Both Text Encoder and UNet LoRA - affects text processing and diffusion model', + 'unknown': 'Unknown LoRA type' + } + return descriptions.get(lora_type, f'Unknown LoRA type: {lora_type}') + + def _is_tensorrt_acceleration(self) -> bool: + """Check if the pipeline is using TensorRT acceleration.""" + if not self._stream: + logger.info("_is_tensorrt_acceleration: No stream available") + return False + + # Check wrapper's acceleration setting + if hasattr(self._stream, '_param_updater') and self._stream._param_updater.wrapper: + wrapper = self._stream._param_updater.wrapper + acceleration = getattr(wrapper, '_acceleration', None) + logger.info(f"_is_tensorrt_acceleration: Wrapper acceleration: {acceleration}") + return acceleration == 'tensorrt' + + logger.info("_is_tensorrt_acceleration: No wrapper available") + return False + \ No newline at end of file From 0f90b98cf56b2bba634547c7e341f8de5fdaf423 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:06:31 -0400 Subject: [PATCH 3/3] ui support --- .../src/lib/components/LoRAConfig.svelte | 459 ++++++++++++++++++ .../frontend/src/routes/+page.svelte | 8 + demo/realtime-img2img/img2img.py | 2 + demo/realtime-img2img/main.py | 322 ++++++++++++ 4 files changed, 791 insertions(+) create mode 100644 demo/realtime-img2img/frontend/src/lib/components/LoRAConfig.svelte diff --git a/demo/realtime-img2img/frontend/src/lib/components/LoRAConfig.svelte b/demo/realtime-img2img/frontend/src/lib/components/LoRAConfig.svelte new file mode 100644 index 00000000..e3ef969b --- /dev/null +++ b/demo/realtime-img2img/frontend/src/lib/components/LoRAConfig.svelte @@ -0,0 +1,459 @@ + + +
+ +
+ + {#if showLoRA} +
+ +
+ {#if loraInfo?.enabled} +
+ LoRA Enabled + {:else} +
+ Standard Mode + {/if} +
+ + {#if loraInfo?.enabled && loraInfo?.loras?.length > 0} + +
+
+
LoRA Configuration
+
+ + +
+
+ + {#each loraInfo.loras as lora, index} +
+
+
+ + {lora.display_name || lora.lora_path.split('/').pop() || `LoRA ${index}`} + + {#if lora.lora_type} + + {lora.lora_type} + + {/if} +
+
+ + Index: {index} + + +
+
+ + +
+ handleEnabledChange(index, e)} + class="w-4 h-4 text-blue-600 bg-gray-100 border-gray-300 rounded focus:ring-blue-500 dark:focus:ring-blue-600 dark:ring-offset-gray-800 focus:ring-2 dark:bg-gray-700 dark:border-gray-600" + /> + +
+ + +
+
+ Scale + + {lora.scale.toFixed(2)} + +
+ handleScaleChange(index, e)} + class="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer dark:bg-gray-600 disabled:opacity-50" + /> +

+ Controls LoRA strength. Higher values = stronger effect. +

+
+ + + {#if lora.description} +
+ {lora.description} +
+ {/if} +
+ {lora.lora_path} +
+
+ {/each} +
+ {:else if loraInfo?.enabled} +
+

+ No LoRAs active. Add one to get started: +

+
+ + +
+
+ {:else} +
+

+ Load a configuration with LoRA settings to enable LoRA support. +

+
+ + +
+
+ {/if} + + + + + + {#if uploadStatus} +

+ {uploadStatus} +

+ {/if} +
+ {/if} +
+
+ + + diff --git a/demo/realtime-img2img/frontend/src/routes/+page.svelte b/demo/realtime-img2img/frontend/src/routes/+page.svelte index 33d451f9..7f3f0597 100644 --- a/demo/realtime-img2img/frontend/src/routes/+page.svelte +++ b/demo/realtime-img2img/frontend/src/routes/+page.svelte @@ -8,6 +8,7 @@ import PipelineOptions from '$lib/components/PipelineOptions.svelte'; import ControlNetConfig from '$lib/components/ControlNetConfig.svelte'; import IPAdapterConfig from '$lib/components/IPAdapterConfig.svelte'; + import LoRAConfig from '$lib/components/LoRAConfig.svelte'; import BlendingControl from '$lib/components/BlendingControl.svelte'; import PipelineHooksConfig from '$lib/components/PipelineHooksConfig.svelte'; import ResolutionPicker from '$lib/components/ResolutionPicker.svelte'; @@ -25,6 +26,7 @@ let pipelineInfo: PipelineInfo; let controlnetInfo: any = null; let ipadapterInfo: any = null; + let loraInfo: any = null; let imagePreprocessingInfo: any = null; let imagePostprocessingInfo: any = null; let latentPreprocessingInfo: any = null; @@ -131,6 +133,7 @@ controlnetInfo = settings.controlnet || null; ipadapterInfo = settings.ipadapter || null; + loraInfo = settings.lora || null; // Load pipeline hooks info try { @@ -1025,6 +1028,11 @@ currentWeightType={ipadapterWeightType} > + + 0 self.has_ipadapter = 'ipadapters' in self.config and len(self.config['ipadapters']) > 0 + self.has_lora = 'loras' in self.config and len(self.config['loras']) > 0 diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py index d67c352b..9d1f615e 100644 --- a/demo/realtime-img2img/main.py +++ b/demo/realtime-img2img/main.py @@ -502,6 +502,9 @@ async def settings(): # Add IPAdapter information ipadapter_info = self._get_ipadapter_info() + # Add LoRA information + lora_info = self._get_lora_info() + # Include config prompt if available, otherwise use default config_prompt = None if self.uploaded_controlnet_config and 'prompt' in self.uploaded_controlnet_config: @@ -636,6 +639,7 @@ async def settings(): "pipeline_active": bool(self.pipeline) and hasattr(self.pipeline, 'stream'), "controlnet": controlnet_info, "ipadapter": ipadapter_info, + "lora": lora_info, "config_prompt": config_prompt, "t_index_list": current_t_index_list, "acceleration": current_acceleration, @@ -1272,6 +1276,252 @@ async def update_ipadapter_weight_type(request: Request): logging.error(f"update_ipadapter_weight_type: Failed to update weight type: {e}") raise HTTPException(status_code=500, detail=f"Failed to update weight type: {str(e)}") + # LoRA API endpoints + @self.app.get("/api/lora/list") + async def get_lora_list(): + """Get list of loaded LoRAs""" + try: + return JSONResponse({"lora": self._get_lora_info()}) + except Exception as e: + logging.error(f"get_lora_list: Failed to get LoRA list: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get LoRA list: {str(e)}") + + @self.app.post("/api/lora/add") + async def add_lora(request: Request): + """Add a LoRA by path or HuggingFace model ID""" + try: + data = await request.json() + lora_path = data.get("lora_path") + scale = data.get("scale", 1.0) + + if not lora_path: + raise HTTPException(status_code=400, detail="Missing lora_path parameter") + + if not self.pipeline: + raise HTTPException(status_code=400, detail="Pipeline is not initialized") + + # Add LoRA using update_stream_params + lora_config = [{ + "lora_path": lora_path, + "scale": scale, + "enabled": True + }] + + # Get current LoRA config and append new one + try: + current_state = self.pipeline.stream.get_stream_state() + current_lora_config = current_state.get('lora_config', []) + updated_lora_config = current_lora_config + lora_config + + self.pipeline.update_stream_params(lora_config=updated_lora_config) + success = True + except Exception as e: + logger.error(f"add_lora: Failed to update stream params: {e}") + success = False + + if success: + return JSONResponse({ + "status": "success", + "message": f"Added LoRA: {lora_path}", + "lora_info": self._get_lora_info() + }) + else: + raise HTTPException(status_code=500, detail="Failed to add LoRA") + + except Exception as e: + logging.error(f"add_lora: Failed to add LoRA: {e}") + raise HTTPException(status_code=500, detail=f"Failed to add LoRA: {str(e)}") + + @self.app.post("/api/lora/remove") + async def remove_lora(request: Request): + """Remove a LoRA by index""" + try: + data = await request.json() + index = data.get("index") + + if index is None: + raise HTTPException(status_code=400, detail="Missing index parameter") + + if not self.pipeline: + raise HTTPException(status_code=400, detail="Pipeline is not initialized") + + # Remove LoRA using update_stream_params + try: + current_state = self.pipeline.stream.get_stream_state() + current_lora_config = current_state.get('lora_config', []) + + if index < 0 or index >= len(current_lora_config): + raise HTTPException(status_code=400, detail=f"Invalid LoRA index: {index}") + + # Remove LoRA at index + updated_lora_config = current_lora_config[:index] + current_lora_config[index+1:] + + self.pipeline.update_stream_params(lora_config=updated_lora_config) + success = True + except Exception as e: + logger.error(f"remove_lora: Failed to update stream params: {e}") + success = False + + if success: + return JSONResponse({ + "status": "success", + "message": f"Removed LoRA at index {index}", + "lora_info": self._get_lora_info() + }) + else: + raise HTTPException(status_code=500, detail="Failed to remove LoRA") + + except Exception as e: + logging.error(f"remove_lora: Failed to remove LoRA: {e}") + raise HTTPException(status_code=500, detail=f"Failed to remove LoRA: {str(e)}") + + @self.app.post("/api/lora/update-scale") + async def update_lora_scale(request: Request): + """Update LoRA scale in real-time""" + try: + data = await request.json() + index = data.get("index") + scale = data.get("scale") + + if index is None or scale is None: + raise HTTPException(status_code=400, detail="Missing index or scale parameter") + + if not self.pipeline: + raise HTTPException(status_code=400, detail="Pipeline is not initialized") + + # Update LoRA scale using update_stream_params + try: + current_state = self.pipeline.stream.get_stream_state() + current_lora_config = current_state.get('lora_config', []) + + if index < 0 or index >= len(current_lora_config): + raise HTTPException(status_code=400, detail=f"Invalid LoRA index: {index}") + + # Update scale at index + updated_lora_config = current_lora_config.copy() + updated_lora_config[index] = {**updated_lora_config[index], "scale": scale} + + self.pipeline.update_stream_params(lora_config=updated_lora_config) + success = True + except Exception as e: + logger.error(f"update_lora_scale: Failed to update stream params: {e}") + success = False + + if success: + return JSONResponse({ + "status": "success", + "message": f"Updated LoRA scale at index {index} to {scale}" + }) + else: + raise HTTPException(status_code=500, detail="Failed to update LoRA scale") + + except Exception as e: + logging.error(f"update_lora_scale: Failed to update LoRA scale: {e}") + raise HTTPException(status_code=500, detail=f"Failed to update LoRA scale: {str(e)}") + + @self.app.post("/api/lora/update-enabled") + async def update_lora_enabled(request: Request): + """Update LoRA enabled state in real-time""" + try: + data = await request.json() + index = data.get("index") + enabled = data.get("enabled") + + if index is None or enabled is None: + raise HTTPException(status_code=400, detail="Missing index or enabled parameter") + + if not self.pipeline: + raise HTTPException(status_code=400, detail="Pipeline is not initialized") + + # Update LoRA enabled state using update_stream_params + try: + current_state = self.pipeline.stream.get_stream_state() + current_lora_config = current_state.get('lora_config', []) + + if index < 0 or index >= len(current_lora_config): + raise HTTPException(status_code=400, detail=f"Invalid LoRA index: {index}") + + # Update enabled state at index + updated_lora_config = current_lora_config.copy() + updated_lora_config[index] = {**updated_lora_config[index], "enabled": enabled} + + self.pipeline.update_stream_params(lora_config=updated_lora_config) + success = True + except Exception as e: + logger.error(f"update_lora_enabled: Failed to update stream params: {e}") + success = False + + if success: + return JSONResponse({ + "status": "success", + "message": f"{'Enabled' if enabled else 'Disabled'} LoRA at index {index}" + }) + else: + raise HTTPException(status_code=500, detail="Failed to update LoRA enabled state") + + except Exception as e: + logging.error(f"update_lora_enabled: Failed to update LoRA enabled state: {e}") + raise HTTPException(status_code=500, detail=f"Failed to update LoRA enabled state: {str(e)}") + + @self.app.post("/api/lora/upload") + async def upload_lora(file: UploadFile = File(...)): + """Upload a LoRA file""" + try: + # Validate file type + if not file.filename or not (file.filename.endswith('.safetensors') or file.filename.endswith('.bin')): + raise HTTPException(status_code=400, detail="File must be a .safetensors or .bin file") + + if not self.pipeline: + raise HTTPException(status_code=400, detail="Pipeline is not initialized") + + # Create uploads directory if it doesn't exist + uploads_dir = Path(__file__).parent / "uploads" / "loras" + uploads_dir.mkdir(parents=True, exist_ok=True) + + # Save uploaded file + file_path = uploads_dir / file.filename + content = await file.read() + + with open(file_path, 'wb') as f: + f.write(content) + + # Add LoRA using update_stream_params + lora_config = [{ + "lora_path": str(file_path), + "scale": 1.0, + "enabled": True, + "display_name": file.filename + }] + + try: + current_state = self.pipeline.stream.get_stream_state() + current_lora_config = current_state.get('lora_config', []) + updated_lora_config = current_lora_config + lora_config + + self.pipeline.update_stream_params(lora_config=updated_lora_config) + success = True + except Exception as e: + logger.error(f"upload_lora: Failed to update stream params: {e}") + success = False + + if success: + return JSONResponse({ + "status": "success", + "message": f"Uploaded and added LoRA: {file.filename}", + "lora_info": self._get_lora_info() + }) + else: + # Clean up file if adding failed + try: + file_path.unlink() + except: + pass + raise HTTPException(status_code=500, detail="Failed to add uploaded LoRA") + + except Exception as e: + logging.error(f"upload_lora: Failed to upload LoRA: {e}") + raise HTTPException(status_code=500, detail=f"Failed to upload LoRA: {str(e)}") + @self.app.post("/api/params") async def update_params(request: Request): """Update multiple streaming parameters in a single unified call""" @@ -2707,6 +2957,78 @@ def _get_ipadapter_info(self): return ipadapter_info + def _get_lora_info(self): + """Get LoRA information from uploaded config or active pipeline""" + lora_info = { + "enabled": False, + "config_loaded": False, + "loras": [] + } + + # Check uploaded config first + if self.uploaded_controlnet_config: + if 'loras' in self.uploaded_controlnet_config and len(self.uploaded_controlnet_config['loras']) > 0: + lora_info["enabled"] = True + lora_info["config_loaded"] = True + + # Convert config LoRAs to display format + for i, lora_config in enumerate(self.uploaded_controlnet_config['loras']): + lora_info["loras"].append({ + "index": i, + "lora_path": lora_config.get('lora_path', ''), + "scale": lora_config.get('scale', 1.0), + "enabled": lora_config.get('enabled', True), + "lora_type": lora_config.get('lora_type'), + "display_name": lora_config.get('display_name'), + "description": lora_config.get('description'), + "adapter_name": lora_config.get('adapter_name') + }) + + # Otherwise check active pipeline + elif self.pipeline and self.pipeline.use_config and self.pipeline.config and 'loras' in self.pipeline.config: + if len(self.pipeline.config['loras']) > 0: + lora_info["enabled"] = True + lora_info["config_loaded"] = True + + # Convert config LoRAs to display format + for i, lora_config in enumerate(self.pipeline.config['loras']): + lora_info["loras"].append({ + "index": i, + "lora_path": lora_config.get('lora_path', ''), + "scale": lora_config.get('scale', 1.0), + "enabled": lora_config.get('enabled', True), + "lora_type": lora_config.get('lora_type'), + "display_name": lora_config.get('display_name'), + "description": lora_config.get('description'), + "adapter_name": lora_config.get('adapter_name') + }) + + # Try to get current LoRA state from active pipeline if available + if self.pipeline and hasattr(self.pipeline.stream, 'get_stream_state'): + try: + # Get current LoRA info from stream state + current_state = self.pipeline.stream.get_stream_state() + current_lora_config = current_state.get('lora_config', []) + if current_lora_config: + lora_info["enabled"] = True + lora_info["loras"] = [ + { + "index": i, + "lora_path": lora.get('lora_path', ''), + "scale": lora.get('scale', 1.0), + "enabled": lora.get('enabled', True), + "lora_type": lora.get('lora_type'), + "display_name": lora.get('display_name'), + "description": lora.get('description'), + "adapter_name": lora.get('adapter_name') + } + for i, lora in enumerate(current_lora_config) + ] + except Exception as e: + logger.warning(f"_get_lora_info: Failed to get current LoRA state: {e}") + + return lora_info + def _get_hook_module(self, hook_type: str): """Get the hook module for a specific hook type using the proper pattern""" if not self.pipeline: