diff --git a/configs/prompt_blending_demo.yaml.example b/configs/prompt_blending_demo.yaml.example deleted file mode 100644 index 1b076ffc..00000000 --- a/configs/prompt_blending_demo.yaml.example +++ /dev/null @@ -1,54 +0,0 @@ -# StreamDiffusion Configuration for Prompt & Seed Blending Demo -# Simple img2img setup without ControlNets for demonstrating prompt and seed blending - -model_id: "KBlueLeaf/kohaku-v2.1" -t_index_list: [16, 32] -width: 512 -height: 512 -device: "cuda" -dtype: "float16" -mode: img2img - -# Generation parameters (base - will be overridden by blending) -prompt: "a waifu girl cute" -negative_prompt: "blurry, low quality, ugly" -guidance_scale: 1.2 -num_inference_steps: 50 - -# StreamDiffusion parameters -use_denoising_batch: true -delta: 0.7 -frame_buffer_size: 1 - -# Pipeline configuration -pipeline_type: "sd1.5" -use_lcm_lora: true -use_tiny_vae: true -acceleration: "xformers" # Use xformers instead of tensorrt for easier setup -cfg_type: "self" -seed: 42 - -# Warmup iterations for performance -warmup: 5 - -# Prompt blending configuration -# This will override the single 'prompt' above -# prompt_blending: -# prompt_list: -# - ["a waifu girl cute", 1.0] -# - ["a demon from hell", 0.0] -# interpolation_method: "slerp" # or "linear" -# enable_caching: true - -# Seed blending configuration -# This enables blending between different noise patterns -# for added visual variety alongside prompt blending -seed_blending: - seed_list: - - [42, 1.0] # Stable, controlled generation - - [999, 0.0] # More chaotic, varied generation - interpolation_method: "linear" # or "slerp" - enable_caching: true - -# No ControlNets for this demo -# controlnets: [] \ No newline at end of file diff --git a/configs/sd15_canny_depth.yaml.example b/configs/sd15_canny_depth.yaml.example deleted file mode 100644 index 300bfb85..00000000 --- a/configs/sd15_canny_depth.yaml.example +++ /dev/null @@ -1,45 +0,0 @@ -model_id: "stabilityai/sd-turbo" -t_index_list: [0,16] -width: 512 -height: 512 -device: "cuda" -dtype: "float16" - -# Generation parameters -prompt: "an anime render of a girl with purple hair, masterpiece" -negative_prompt: "blurry, low quality, flat, 2d" -guidance_scale: 1.1 -num_inference_steps: 50 - -# Temporal consistency parameters -frame_buffer_size: 1 -delta: 0.7 - -# Advanced parameters -use_lcm_lora: false -use_tiny_vae: true -acceleration: "tensorrt" -cfg_type: "self" -seed: 789 - -# ControlNet configuration with TensorRT Depth Anything -controlnets: - - model_id: "thibaud/controlnet-sd21-depth-diffusers" - conditioning_scale: 0.5 - preprocessor: "depth_tensorrt" - preprocessor_params: - engine_path: "C:\\_dev\\comfy\\ComfyUI\\models\\tensorrt\\depth-anything\\depth_anything_vits14-fp16.engine" - detect_resolution: 518 - image_resolution: 512 - enabled: true - - - model_id: "thibaud/controlnet-sd21-canny-diffusers" - conditioning_scale: 0.5 - preprocessor: "canny" - preprocessor_params: - low_threshold: 50 - high_threshold: 100 - control_image_path: null - enabled: true - control_guidance_start: 0.0 - control_guidance_end: 1.0 \ No newline at end of file diff --git a/configs/sd15_depth_trt_example.yaml.example b/configs/sd15_depth_trt_example.yaml.example deleted file mode 100644 index 7ebcdd7b..00000000 --- a/configs/sd15_depth_trt_example.yaml.example +++ /dev/null @@ -1,42 +0,0 @@ -model_id: "KBlueLeaf/kohaku-v2.1" -t_index_list: [16, 32] -width: 512 -height: 512 -device: "cuda" -dtype: "float16" - -# Generation parameters -prompt: "an anime render of a girl with purple hair, masterpiece" -negative_prompt: "blurry, low quality, flat, 2d" -guidance_scale: 1.1 -num_inference_steps: 50 - -use_denoising_batch: true -delta: 0.7 -frame_buffer_size: 1 - -# Advanced parameters -use_lcm_lora: true -use_tiny_vae: true -acceleration: "xformers" -cfg_type: "self" -seed: 789 - -# ControlNet configuration with TensorRT Depth Anything -controlnets: - - model_id: "lllyasviel/control_v11f1p_sd15_depth" - conditioning_scale: 0.28 - preprocessor: "depth_tensorrt" - preprocessor_params: - engine_path: "C:\\_dev\\comfy\\ComfyUI\\models\\tensorrt\\depth-anything\\v2_depth_anything_v2_vits-fp16.engine" - detect_resolution: 518 - image_resolution: 512 - enabled: true - - - model_id: "lllyasviel/control_v11p_sd15_canny" - conditioning_scale: 0.29 - preprocessor: "canny" - preprocessor_params: - low_threshold: 100 - high_threshold: 200 - enabled: true \ No newline at end of file diff --git a/configs/sd15_multicontrol.yaml.example b/configs/sd15_multicontrol.yaml.example new file mode 100644 index 00000000..e95e55ef --- /dev/null +++ b/configs/sd15_multicontrol.yaml.example @@ -0,0 +1,74 @@ +# StreamDiffusion SD1.5 Multi-ControlNet + IPAdapter Configuration +# Demonstrates: TensorRT depth processing, tile with feedback, and IPAdapter integration + +# Base model configuration (use HuggingFace model or local path) +model_id: "KBlueLeaf/kohaku-v2.1" +# model_id: "C:\\_dev\\models\\your_sd15_model.safetensors" + +# StreamDiffusion core parameters +t_index_list: [16, 32] # Denoising timesteps - lower values = less denoising +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +# Generation parameters +# prompt: "masterpiece, high quality, detailed, cinematic lighting" # Overridden by prompt_blending below + +# Prompt blending configuration - interpolates between multiple prompts +prompt_blending: + prompt_list: + - ["masterpiece, studio ghibli style, detailed anime artwork", 1.0] + - ["cyberpunk aesthetic, neon lights, futuristic", 0.3] + interpolation_method: "slerp" # or "linear" + enable_caching: true + +negative_prompt: "blurry, low quality, distorted, 3d render" +guidance_scale: 1.1 +num_inference_steps: 50 +seed: 789 + +# Temporal consistency and optimization +frame_buffer_size: 1 +delta: 0.7 +use_denoising_batch: true +use_lcm_lora: true +use_tiny_vae: true +acceleration: "tensorrt" # "xformers" for non-TensorRT setups +cfg_type: "self" + +# Engine directory for TensorRT (engines will be built here if not found) +engine_dir: "./engines/sd15" + +# Enable multi-modal conditioning +use_controlnet: true +use_ipadapter: true + +# IPAdapter configuration for style conditioning +ipadapters: + - ipadapter_model_path: "h94/IP-Adapter/models/ip-adapter_sd15.safetensors" + image_encoder_path: "h94/IP-Adapter/models/image_encoder" + # style_image: "path/to/your/style/image.jpg" # Optional: specify style image + scale: 0.7 + enabled: true + +# ControlNet configurations +controlnets: + # TensorRT Depth ControlNet (requires TensorRT engine) + - model_id: "lllyasviel/control_v11f1p_sd15_depth" + conditioning_scale: 0.3 + preprocessor: "depth_tensorrt" + preprocessor_params: + engine_path: "C:\\_dev\\models\\tensorrt\\depth_anything_v2_vits-fp16.engine" # REQUIRED: Path to TensorRT engine + detect_resolution: 518 # Must match engine input size + image_resolution: 512 + enabled: true + + # Tile ControlNet with feedback processor for temporal consistency + - model_id: "lllyasviel/control_v11f1e_sd15_tile" + conditioning_scale: 0.2 + preprocessor: "feedback" + preprocessor_params: + image_resolution: 512 + feedback_strength: 0.15 # Controls temporal feedback intensity + enabled: true \ No newline at end of file diff --git a/configs/sd15_tile.yaml.example b/configs/sd15_tile.yaml.example deleted file mode 100644 index 50a8e002..00000000 --- a/configs/sd15_tile.yaml.example +++ /dev/null @@ -1,28 +0,0 @@ -model_id: "KBlueLeaf/kohaku-v2.1" -t_index_list: [0,16] -width: 512 -height: 512 -device: "cuda" -dtype: "float16" - -# Generation parameters -prompt: "masterpiece, high quality, detailed" -negative_prompt: "blurry, low quality, distorted" -guidance_scale: 1.1 -num_inference_steps: 50 - -# Advanced parameters -use_lcm_lora: true -use_tiny_vae: true -acceleration: "xformers" -cfg_type: "self" -seed: 456 - -# ControlNet configuration with Tile ControlNet -controlnets: - - model_id: "lllyasviel/control_v11f1e_sd15_tile" - conditioning_scale: 0.2 - preprocessor: "passthrough" - preprocessor_params: - image_resolution: 512 - enabled: true \ No newline at end of file diff --git a/configs/sdturbo_color.yaml.example b/configs/sdturbo_color.yaml.example deleted file mode 100644 index 10873743..00000000 --- a/configs/sdturbo_color.yaml.example +++ /dev/null @@ -1,33 +0,0 @@ -model_id: "stabilityai/sd-turbo" -t_index_list: [0,16] -width: 512 -height: 512 -device: "cuda" -dtype: "float16" - -# Generation parameters -prompt: "an anime render of a girl with purple hair, masterpiece" -negative_prompt: "blurry, low quality, flat, 2d" -guidance_scale: 1.1 -num_inference_steps: 50 - -# Temporal consistency parameters -frame_buffer_size: 1 -use_denoising_batch: true -delta: 0.7 - -# Advanced parameters -use_lcm_lora: true -use_tiny_vae: true -acceleration: "tensorrt" -cfg_type: "self" -seed: 789 - -# ControlNet configuration with TensorRT Depth Anything -controlnets: - - model_id: "thibaud/controlnet-sd21-color-diffusers" - conditioning_scale: 0.2 - preprocessor: "passthrough" - preprocessor_params: - image_resolution: 512 - enabled: true \ No newline at end of file diff --git a/configs/sdturbo_mediapipe_pose_depth_trt.yaml.example b/configs/sdturbo_mediapipe_pose_depth_trt.yaml.example deleted file mode 100644 index dd8167c1..00000000 --- a/configs/sdturbo_mediapipe_pose_depth_trt.yaml.example +++ /dev/null @@ -1,51 +0,0 @@ -model_id: "stabilityai/sd-turbo" -t_index_list: [16,32] -width: 512 -height: 512 -device: "cuda" -dtype: "float16" - -# Generation parameters -prompt: "an anime render of a girl with purple hair, masterpiece" -negative_prompt: "blurry, low quality, flat, 2d" -guidance_scale: 1.041 -num_inference_steps: 50 - -# Temporal consistency parameters -frame_buffer_size: 1 -use_denoising_batch: true -delta: 0.7 - -# Advanced parameters -use_lcm_lora: true -use_tiny_vae: true -acceleration: "tensorrt" -cfg_type: "self" -seed: 789 - -# ControlNet configuration with TensorRT Depth Anything -controlnets: - - model_id: "thibaud/controlnet-sd21-openpose-diffusers" - conditioning_scale: 0.711 - preprocessor: "mediapipe_pose" - preprocessor_params: - detect_resolution: 512 - image_resolution: 512 - min_detection_confidence: 0.5 - min_tracking_confidence: 0.5 - model_complexity: 1 - static_image_mode: false - draw_hands: true - draw_face: false - line_thickness: 2 - circle_radius: 3 - enabled: true - - - model_id: "thibaud/controlnet-sd21-depth-diffusers" - conditioning_scale: 0.15 - preprocessor: "depth_tensorrt" - preprocessor_params: - engine_path: "C:\\_dev\\comfy\\ComfyUI\\models\\tensorrt\\depth-anything\\v2_depth_anything_v2_vits-fp16.engine" - detect_resolution: 518 - image_resolution: 512 - enabled: true \ No newline at end of file diff --git a/configs/sdturbo_multicontrol.yaml.example b/configs/sdturbo_multicontrol.yaml.example new file mode 100644 index 00000000..5f7b8561 --- /dev/null +++ b/configs/sdturbo_multicontrol.yaml.example @@ -0,0 +1,57 @@ +# StreamDiffusion SD-Turbo Multi-ControlNet Configuration +# Demonstrates: Fast inference with multiple ControlNet guidance (no IPAdapter for speed) + +# Base model configuration +model_id: "stabilityai/sd-turbo" + +# StreamDiffusion core parameters +t_index_list: [0, 16] # SD-Turbo optimized timesteps +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +# Generation parameters +prompt: "masterpiece, high quality, detailed anime character" +negative_prompt: "blurry, low quality, distorted, 3d render" +guidance_scale: 1.0 # SD-Turbo typically uses lower guidance +num_inference_steps: 4 # SD-Turbo optimized for few steps +seed: 789 + +# Temporal consistency and optimization +frame_buffer_size: 1 +delta: 0.7 +use_denoising_batch: true +use_lcm_lora: true # SD-Turbo benefits from LCM LoRA +use_tiny_vae: true +acceleration: "tensorrt" # "xformers" for non-TensorRT setups +cfg_type: "self" + +# Engine directory for TensorRT +engine_dir: "./engines/sdturbo" + +# Enable ControlNet (no IPAdapter for maximum speed) +use_controlnet: true + +# ControlNet configurations +controlnets: + # Canny edge detection for structural guidance + - model_id: "thibaud/controlnet-sd21-canny-diffusers" + conditioning_scale: 0.5 + preprocessor: "canny" + preprocessor_params: + low_threshold: 100 + high_threshold: 200 + enabled: true + + # Soft edge detection for artistic guidance + - model_id: "thibaud/controlnet-sd21-hed-diffusers" + conditioning_scale: 0.3 + preprocessor: "soft_edge" + preprocessor_params: + image_resolution: 512 + strength: 1.0 + soft_threshold: 0.5 + multi_scale: true + gaussian_sigma: 1.0 + enabled: true \ No newline at end of file diff --git a/configs/sdturbo_yolonas_pose_depth_trt.yaml.example b/configs/sdturbo_yolonas_pose_depth_trt.yaml.example deleted file mode 100644 index f818b09d..00000000 --- a/configs/sdturbo_yolonas_pose_depth_trt.yaml.example +++ /dev/null @@ -1,44 +0,0 @@ -model_id: "stabilityai/sd-turbo" -t_index_list: [16,32] -width: 512 -height: 512 -device: "cuda" -dtype: "float16" - -# Generation parameters -prompt: "an anime render of a girl with purple hair, masterpiece" -negative_prompt: "blurry, low quality, flat, 2d" -guidance_scale: 1.1 -num_inference_steps: 50 - -# Temporal consistency parameters -frame_buffer_size: 1 -use_denoising_batch: true -delta: 0.7 - -# Advanced parameters -use_lcm_lora: true -use_tiny_vae: true -acceleration: "tensorrt" -cfg_type: "self" -seed: 789 - -# ControlNet configuration with TensorRT Depth Anything -controlnets: - - model_id: "thibaud/controlnet-sd21-openpose-diffusers" - conditioning_scale: 0.5 - preprocessor: "pose_tensorrt" - preprocessor_params: - engine_path: "C:\\_dev\\comfy\\ComfyUI\\models\\tensorrt\\yolo-nas-pose\\yolo_nas_pose_l_0.8-fp16.engine" - detect_resolution: 640 - image_resolution: 512 - enabled: true - - - model_id: "thibaud/controlnet-sd21-depth-diffusers" - conditioning_scale: 0.5 - preprocessor: "depth_tensorrt" - preprocessor_params: - engine_path: "C:\\_dev\\comfy\\ComfyUI\\models\\tensorrt\\depth-anything\\v2_depth_anything_v2_vits-fp16.engine" - detect_resolution: 518 - image_resolution: 512 - enabled: true \ No newline at end of file diff --git a/configs/sdxl_multicontrol.yaml.example b/configs/sdxl_multicontrol.yaml.example new file mode 100644 index 00000000..441acce4 --- /dev/null +++ b/configs/sdxl_multicontrol.yaml.example @@ -0,0 +1,75 @@ +# StreamDiffusion SDXL Multi-ControlNet + IPAdapter Configuration +# Demonstrates: TensorRT pose processing, canny edge detection, and SDXL IPAdapter integration + +# Base model configuration +model_id: "stabilityai/stable-diffusion-xl-base-1.0" +# model_id: "C:\\_dev\\models\\your_sdxl_model.safetensors" + +# StreamDiffusion core parameters +t_index_list: [20, 32] # SDXL optimized timesteps +width: 1024 # SDXL native resolution +height: 1024 +device: "cuda" +dtype: "float16" + +# Generation parameters +prompt: "masterpiece, highest quality, cinematic lighting, detailed artwork" + +# Seed blending configuration - interpolates between multiple seeds for variation +seed_blending: + seed_list: + - [42, 1.0] # Primary seed with full weight + - [789, 0.4] # Secondary seed with partial weight + - [1337, 0.2] # Tertiary seed with low weight + +negative_prompt: "blurry, low quality, distorted, 3d render, oversaturated" +guidance_scale: 1.0 # SDXL typically uses moderate guidance +num_inference_steps: 25 # SDXL optimized step count +seed: 42 # Base seed (used with seed_blending above) + +# Temporal consistency and optimization +frame_buffer_size: 1 +delta: 0.7 +use_denoising_batch: true +use_lcm_lora: false # SDXL has built-in optimizations +use_taesd: true # Use Tiny AutoEncoder for SDXL +use_tiny_vae: true +acceleration: "tensorrt" # "xformers" for non-TensorRT setups +cfg_type: "self" +safety_checker: false + +# Engine directory for TensorRT +engine_dir: "./engines/sdxl" + +# Enable multi-modal conditioning +use_controlnet: true +use_ipadapter: true + +# IPAdapter configuration for SDXL style conditioning +ipadapters: + - ipadapter_model_path: "h94/IP-Adapter/sdxl_models/ip-adapter_sdxl.safetensors" + image_encoder_path: "h94/IP-Adapter/sdxl_models/image_encoder" + # style_image: "path/to/your/style/image.jpg" # Optional: specify style image + scale: 0.7 + enabled: true + +# ControlNet configurations +controlnets: + # TensorRT Pose ControlNet (requires TensorRT engine) + - model_id: "thibaud/controlnet-openpose-sdxl-1.0" + conditioning_scale: 0.4 + preprocessor: "pose_tensorrt" + preprocessor_params: + engine_path: "C:\\_dev\\models\\tensorrt\\yolo_nas_pose_l_0.8-fp16.engine" # REQUIRED: Path to TensorRT engine + detect_resolution: 640 # Must match engine input size + image_resolution: 512 + enabled: true + + # Canny ControlNet for structural guidance + - model_id: "diffusers/controlnet-canny-sdxl-1.0" + conditioning_scale: 0.3 + preprocessor: "canny" + preprocessor_params: + low_threshold: 100 + high_threshold: 200 + enabled: true \ No newline at end of file diff --git a/configs/sdxlturbo_canny.yaml.example b/configs/sdxlturbo_canny.yaml.example deleted file mode 100644 index afba6d53..00000000 --- a/configs/sdxlturbo_canny.yaml.example +++ /dev/null @@ -1,37 +0,0 @@ -model_id: "stabilityai/sdxl-turbo" -t_index_list: [32] # Controls denoising strength - lower values = less denoising, higher = more denoising -width: 1024 -height: 1024 -device: "cuda" -dtype: "float16" - -# Generation parameters -prompt: "a beautiful artwork, highly detailed, masterpiece, cinematic lighting" -negative_prompt: "blurry, low quality, distorted, render, 3D, oversaturated" -guidance_scale: 0.0 # SD-XL Turbo typically uses no guidance -num_inference_steps: 2 # SD-XL Turbo typically uses 2-4 steps -# Note: strength parameter removed - use t_index_list above to control denoising amount - -# SD-XL Turbo optimizations -use_taesd: true # Use Tiny AutoEncoder XL for faster decoding -safety_checker: false - -# Advanced parameters -use_lcm_lora: false # SD-XL Turbo already has optimizations built-in -use_tiny_vae: true -acceleration: "none" # Can be "tensorrt" for additional speed -cfg_type: "none" -seed: 42 - -# ControlNet configurations -controlnets: - - model_id: "diffusers/controlnet-canny-sdxl-1.0" # SD-XL ControlNet - conditioning_scale: 0.5 - preprocessor: "canny" - preprocessor_params: - low_threshold: 100 - high_threshold: 200 - control_image_path: null - enabled: true - control_guidance_start: 0.0 - control_guidance_end: 1.0 \ No newline at end of file diff --git a/demo/realtime-img2img/frontend/src/lib/components/IPAdapterConfig.svelte b/demo/realtime-img2img/frontend/src/lib/components/IPAdapterConfig.svelte index 47e6a8fb..0c7de05a 100644 --- a/demo/realtime-img2img/frontend/src/lib/components/IPAdapterConfig.svelte +++ b/demo/realtime-img2img/frontend/src/lib/components/IPAdapterConfig.svelte @@ -4,6 +4,7 @@ export let ipadapterInfo: any = null; export let currentScale: number = 1.0; + export let currentWeightType: string = "linear"; const dispatch = createEventDispatcher(); @@ -16,6 +17,14 @@ // Collapsible section state let showIPAdapter: boolean = true; + // Available weight types + const weightTypes = [ + "linear", "ease in", "ease out", "ease in-out", "reverse in-out", + "weak input", "weak output", "weak middle", "strong middle", + "style transfer", "composition", "strong style transfer", + "style and composition", "style transfer precise", "composition precise" + ]; + async function updateIPAdapterScale(scale: number) { try { const response = await fetch('/api/ipadapter/update-scale', { @@ -47,6 +56,37 @@ updateIPAdapterScale(scale); } + async function updateIPAdapterWeightType(weightType: string) { + try { + const response = await fetch('/api/ipadapter/update-weight-type', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + weight_type: weightType, + }), + }); + + if (!response.ok) { + const result = await response.json(); + console.error('updateIPAdapterWeightType: Failed to update weight type:', result.detail); + } + } catch (error) { + console.error('updateIPAdapterWeightType: Update failed:', error); + } + } + + function handleWeightTypeChange(event: Event) { + const target = event.target as HTMLSelectElement; + const weightType = target.value; + + // Update local state immediately for responsiveness + currentWeightType = weightType; + + updateIPAdapterWeightType(weightType); + } + async function uploadStyleImage() { if (!styleImageFile.files || styleImageFile.files.length === 0) { uploadStatus = 'Please select an image file'; @@ -104,10 +144,13 @@ styleImageFile.click(); } - // Update current scale when prop changes + // Update current scale and weight type when prop changes $: if (ipadapterInfo?.scale !== undefined) { currentScale = ipadapterInfo.scale; } + $: if (ipadapterInfo?.weight_type !== undefined) { + currentWeightType = ipadapterInfo.weight_type; + }
@@ -224,6 +267,25 @@
+ +
+
Weight Type
+
+ +

+ Controls how the IPAdapter influence is distributed across different layers of the model. +

+
+
+ {#if ipadapterInfo?.model_path}
diff --git a/demo/realtime-img2img/frontend/src/lib/components/InputControl.svelte b/demo/realtime-img2img/frontend/src/lib/components/InputControl.svelte index 5a043307..b0718c8d 100644 --- a/demo/realtime-img2img/frontend/src/lib/components/InputControl.svelte +++ b/demo/realtime-img2img/frontend/src/lib/components/InputControl.svelte @@ -159,6 +159,14 @@ max: 2.0, category: 'ipadapter' }); + + parameters.push({ + value: 'ipadapter_weight_type', + label: 'IPAdapter Weight Type', + min: 0, + max: 14, // 15 weight types (0-14) + category: 'ipadapter' + }); } // ControlNet strength parameters @@ -435,6 +443,23 @@ await updatePromptWeightParameter(control, scaledValue); } else if (control.parameter_name.startsWith('seed_weight_')) { await updateSeedWeightParameter(control, scaledValue); + } else if (control.parameter_name === 'ipadapter_weight_type') { + // Convert numeric value to weight type string + const weightTypes = ["linear", "ease in", "ease out", "ease in-out", "reverse in-out", + "weak input", "weak output", "weak middle", "strong middle", + "style transfer", "composition", "strong style transfer", + "style and composition", "style transfer precise", "composition precise"]; + const index = Math.round(scaledValue) % weightTypes.length; + const weightType = weightTypes[index]; + + const endpoint = getParameterUpdateEndpoint(control.parameter_name); + if (endpoint) { + await fetch(endpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ weight_type: weightType }) + }); + } } else { const endpoint = getParameterUpdateEndpoint(control.parameter_name); if (endpoint) { @@ -574,7 +599,8 @@ 'delta': '/api/update-delta', 'num_inference_steps': '/api/update-num-inference-steps', 'seed': '/api/update-seed', - 'ipadapter_scale': '/api/ipadapter/update-scale' + 'ipadapter_scale': '/api/ipadapter/update-scale', + 'ipadapter_weight_type': '/api/ipadapter/update-weight-type' }; return endpoints[parameterName] || null; } @@ -603,7 +629,8 @@ 'delta': 'delta', 'num_inference_steps': 'num_inference_steps', 'seed': 'seed', - 'ipadapter_scale': 'scale' + 'ipadapter_scale': 'scale', + 'ipadapter_weight_type': 'weight_type' }; return keys[parameterName] || parameterName; } @@ -702,6 +729,23 @@ await updatePromptWeightParameter(control, control.pendingValue); } else if (control.parameter_name.startsWith('seed_weight_')) { await updateSeedWeightParameter(control, control.pendingValue); + } else if (control.parameter_name === 'ipadapter_weight_type') { + // Convert numeric value to weight type string + const weightTypes = ["linear", "ease in", "ease out", "ease in-out", "reverse in-out", + "weak input", "weak output", "weak middle", "strong middle", + "style transfer", "composition", "strong style transfer", + "style and composition", "style transfer precise", "composition precise"]; + const index = Math.round(control.pendingValue) % weightTypes.length; + const weightType = weightTypes[index]; + + const endpoint = getParameterUpdateEndpoint(control.parameter_name); + if (endpoint) { + await fetch(endpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ weight_type: weightType }) + }); + } } else { const endpoint = getParameterUpdateEndpoint(control.parameter_name); if (endpoint) { @@ -744,6 +788,23 @@ await updatePromptWeightParameter(control, control.pendingValue); } else if (control.parameter_name.startsWith('seed_weight_')) { await updateSeedWeightParameter(control, control.pendingValue); + } else if (control.parameter_name === 'ipadapter_weight_type') { + // Convert numeric value to weight type string + const weightTypes = ["linear", "ease in", "ease out", "ease in-out", "reverse in-out", + "weak input", "weak output", "weak middle", "strong middle", + "style transfer", "composition", "strong style transfer", + "style and composition", "style transfer precise", "composition precise"]; + const index = Math.round(control.pendingValue) % weightTypes.length; + const weightType = weightTypes[index]; + + const endpoint = getParameterUpdateEndpoint(control.parameter_name); + if (endpoint) { + await fetch(endpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ weight_type: weightType }) + }); + } } else { const endpoint = getParameterUpdateEndpoint(control.parameter_name); if (endpoint) { diff --git a/demo/realtime-img2img/frontend/src/routes/+page.svelte b/demo/realtime-img2img/frontend/src/routes/+page.svelte index fa1945c8..e5f8c520 100644 --- a/demo/realtime-img2img/frontend/src/routes/+page.svelte +++ b/demo/realtime-img2img/frontend/src/routes/+page.svelte @@ -24,6 +24,7 @@ let controlnetInfo: any = null; let ipadapterInfo: any = null; let ipadapterScale: number = 1.0; + let ipadapterWeightType: string = "linear"; let tIndexList: number[] = [35, 45]; let guidanceScale: number = 1.1; let delta: number = 0.7; @@ -107,6 +108,7 @@ controlnetInfo = settings.controlnet || null; ipadapterInfo = settings.ipadapter || null; ipadapterScale = settings.ipadapter?.scale || 1.0; + ipadapterWeightType = settings.ipadapter?.weight_type || "linear"; tIndexList = settings.t_index_list || [35, 45]; guidanceScale = settings.guidance_scale || 1.1; delta = settings.delta || 0.7; @@ -202,7 +204,7 @@ async function handleResolutionUpdate(resolution: string) { try { - const response = await fetch('/api/update-resolution', { + const response = await fetch('/api/params', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -212,11 +214,11 @@ if (response.ok) { const result = await response.json(); - console.log('Resolution updated successfully:', result.detail); + console.log('handleResolutionUpdate: Resolution updated successfully:', result.message); // Show success message - no restart needed for real-time updates - if (result.detail) { - warningMessage = result.detail; + if (result.message) { + warningMessage = result.message; // Clear message after a few seconds setTimeout(() => { warningMessage = ''; @@ -224,11 +226,11 @@ } } else { const result = await response.json(); - console.error('Failed to update resolution:', result.detail); + console.error('handleResolutionUpdate: Failed to update resolution:', result.detail); warningMessage = 'Failed to update resolution: ' + result.detail; } } catch (error: unknown) { - console.error('Failed to update resolution:', error); + console.error('handleResolutionUpdate: Failed to update resolution:', error); warningMessage = 'Failed to update resolution: ' + (error instanceof Error ? error.message : String(error)); } } @@ -669,6 +671,7 @@ {/if}
diff --git a/demo/realtime-img2img/img2img.py b/demo/realtime-img2img/img2img.py index 8d772b74..6281e119 100644 --- a/demo/realtime-img2img/img2img.py +++ b/demo/realtime-img2img/img2img.py @@ -216,8 +216,16 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image: if self.pipeline_mode == "txt2img": # Text-to-image mode if self.has_controlnet: - # txt2img with ControlNets: need image for control - self.stream.update_control_image_efficient(params.image) + # txt2img with ControlNets: push control image via consolidated API + try: + current_cfg = self.stream.stream._param_updater._get_current_controlnet_config() if hasattr(self.stream, 'stream') else [] + except Exception: + current_cfg = [] + if current_cfg: + # update just the control image for all configured CNs + for i in range(len(current_cfg)): + current_cfg[i]['control_image'] = params.image + self.stream.update_stream_params(controlnet_config=current_cfg) output_image = self.stream(params.image) elif self.has_ipadapter: # txt2img with IPAdapter: no input image needed (style image handled separately) @@ -228,8 +236,15 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image: else: # Image-to-image mode: use original logic if self.has_controlnet: - # ControlNet mode: update control image and use PIL image - self.stream.update_control_image_efficient(params.image) + # ControlNet mode: push control image via consolidated API and use PIL image + try: + current_cfg = self.stream.stream._param_updater._get_current_controlnet_config() if hasattr(self.stream, 'stream') else [] + except Exception: + current_cfg = [] + if current_cfg: + for i in range(len(current_cfg)): + current_cfg[i]['control_image'] = params.image + self.stream.update_stream_params(controlnet_config=current_cfg) output_image = self.stream(params.image) elif self.has_ipadapter: # IPAdapter mode: use PIL image for img2img @@ -285,6 +300,24 @@ def update_ipadapter_style_image(self, style_image: Image.Image) -> bool: """Legacy method - use update_ipadapter_config instead""" return self.update_ipadapter_config(style_image=style_image) + def update_ipadapter_weight_type(self, weight_type: str) -> bool: + """Update IPAdapter weight type in real-time""" + if not self.has_ipadapter: + return False + + try: + # Use unified updater on wrapper + if hasattr(self.stream, 'update_stream_params'): + self.stream.update_stream_params(ipadapter_config={ 'weight_type': weight_type }) + return True + # Direct attribute set as last resort + if hasattr(self.stream, 'ipadapter_weight_type'): + self.stream.ipadapter_weight_type = weight_type + return True + return False + except Exception as e: + return False + def get_ipadapter_info(self) -> dict: """ Get current IPAdapter information @@ -295,6 +328,7 @@ def get_ipadapter_info(self) -> dict: info = { "enabled": self.has_ipadapter, "scale": 1.0, + "weight_type": "linear", "model_path": None, "style_image_set": False } @@ -304,15 +338,19 @@ def get_ipadapter_info(self) -> dict: if len(self.config['ipadapters']) > 0: ipadapter_config = self.config['ipadapters'][0] info["scale"] = ipadapter_config.get('scale', 1.0) + info["weight_type"] = ipadapter_config.get('weight_type', 'linear') info["model_path"] = ipadapter_config.get('ipadapter_model_path') info["style_image_set"] = 'style_image' in ipadapter_config - # Try to get current scale from stream if available + # Try to get current scale and weight type from stream if available if hasattr(self.stream, 'scale'): info["scale"] = self.stream.scale elif hasattr(self.stream, 'ipadapter') and hasattr(self.stream.ipadapter, 'scale'): info["scale"] = self.stream.ipadapter.scale + if hasattr(self.stream, 'ipadapter_weight_type'): + info["weight_type"] = self.stream.ipadapter_weight_type + return info def update_stream_params(self, **kwargs): diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py index df7a8465..c0cc9843 100644 --- a/demo/realtime-img2img/main.py +++ b/demo/realtime-img2img/main.py @@ -136,15 +136,23 @@ def _handle_input_parameter_update(self, parameter_name: str, value: float) -> N # Map parameter names to pipeline update methods if parameter_name == 'guidance_scale': - self.pipeline.stream.update_stream_params(guidance_scale=value) + self.pipeline.update_stream_params(guidance_scale=value) elif parameter_name == 'delta': - self.pipeline.stream.update_stream_params(delta=value) + self.pipeline.update_stream_params(delta=value) elif parameter_name == 'num_inference_steps': - self.pipeline.stream.update_stream_params(num_inference_steps=int(value)) + self.pipeline.update_stream_params(num_inference_steps=int(value)) elif parameter_name == 'seed': - self.pipeline.stream.update_stream_params(seed=int(value)) + self.pipeline.update_stream_params(seed=int(value)) elif parameter_name == 'ipadapter_scale': - self.pipeline.stream.update_stream_params(ipadapter_scale=value) + self.pipeline.update_stream_params(ipadapter_config={'scale': value}) + elif parameter_name == 'ipadapter_weight_type': + # For weight type, we need to convert the numeric value to a string + weight_types = ["linear", "ease in", "ease out", "ease in-out", "reverse in-out", + "weak input", "weak output", "weak middle", "strong middle", + "style transfer", "composition", "strong style transfer", + "style and composition", "style transfer precise", "composition precise"] + index = int(value) % len(weight_types) + self.pipeline.update_ipadapter_weight_type(weight_types[index]) elif parameter_name.startswith('controlnet_') and parameter_name.endswith('_strength'): # Handle ControlNet strength parameters import re @@ -155,8 +163,8 @@ def _handle_input_parameter_update(self, parameter_name: str, value: float) -> N current_config = self._get_current_controlnet_config() if current_config and index < len(current_config): current_config[index]['conditioning_scale'] = float(value) - # Apply the updated config - self.pipeline.stream.apply_controlnet_config(current_config) + # Apply the updated config via unified API + self.pipeline.update_stream_params(controlnet_config=current_config) elif parameter_name.startswith('controlnet_') and '_preprocessor_' in parameter_name: # Handle ControlNet preprocessor parameters match = re.match(r'controlnet_(\d+)_preprocessor_(.+)', parameter_name) @@ -176,27 +184,25 @@ def _handle_input_parameter_update(self, parameter_name: str, value: float) -> N match = re.match(r'prompt_weight_(\d+)', parameter_name) if match: index = int(match.group(1)) - # Get current prompt list and update specific weight - current_prompts = self.pipeline.get_current_prompts() + # Get current prompt list from unified state and update specific weight + state = self.pipeline.stream.get_stream_state() + current_prompts = state.get('prompt_list', []) if current_prompts and index < len(current_prompts): - # Create updated prompt list with new weight - updated_prompts = current_prompts.copy() - updated_prompts[index] = (updated_prompts[index][0], value) - # Update prompt list with new weights - self.pipeline.update_prompt_weights([weight for _, weight in updated_prompts]) + updated_prompts = list(current_prompts) + updated_prompts[index] = (updated_prompts[index][0], float(value)) + self.pipeline.update_stream_params(prompt_list=updated_prompts) elif parameter_name.startswith('seed_weight_'): # Handle seed blending weights match = re.match(r'seed_weight_(\d+)', parameter_name) if match: index = int(match.group(1)) - # Get current seed list and update specific weight - current_seeds = self.pipeline.get_current_seeds() + # Get current seed list from unified state and update specific weight + state = self.pipeline.stream.get_stream_state() + current_seeds = state.get('seed_list', []) if current_seeds and index < len(current_seeds): - # Create updated seed list with new weight - updated_seeds = current_seeds.copy() - updated_seeds[index] = (updated_seeds[index][0], value) - # Update seed list with new weights - self.pipeline.update_seed_weights([weight for _, weight in updated_seeds]) + updated_seeds = list(current_seeds) + updated_seeds[index] = (updated_seeds[index][0], float(value)) + self.pipeline.update_stream_params(seed_list=updated_seeds) else: logger.warning(f"_handle_input_parameter_update: Unknown parameter {parameter_name}") @@ -215,7 +221,7 @@ def _get_controlnet_pipeline(self): stream = self.pipeline.stream - # Check if stream is ControlNet pipeline directly + # Module-aware: module installs expose preprocessors on stream if hasattr(stream, 'preprocessors'): return stream @@ -223,6 +229,9 @@ def _get_controlnet_pipeline(self): if hasattr(stream, 'stream') and hasattr(stream.stream, 'preprocessors'): return stream.stream + # New module path on stream + if hasattr(stream, '_controlnet_module'): + return stream._controlnet_module return None def _get_current_controlnet_config(self): @@ -358,6 +367,27 @@ async def stream(user_id: uuid.UUID, request: Request): logger.info("stream: Creating default pipeline...") self.pipeline = self._create_default_pipeline() logger.info("stream: Pipeline created successfully") + try: + acc = getattr(self.args, 'acceleration', None) + logger.debug(f"stream: acceleration={acc}, use_config={getattr(self.pipeline, 'use_config', False)}") + stream_obj = getattr(self.pipeline, 'stream', None) + unet_obj = getattr(stream_obj, 'unet', None) + is_trt = unet_obj is not None and hasattr(unet_obj, 'engine') and hasattr(unet_obj, 'stream') + logger.debug(f"stream: unet_is_trt={is_trt}, has_ipadapter={getattr(self.pipeline, 'has_ipadapter', False)}") + if is_trt: + logger.debug(f"stream: unet.use_ipadapter={getattr(unet_obj, 'use_ipadapter', None)}, num_ip_layers={getattr(unet_obj, 'num_ip_layers', None)}") + if hasattr(stream_obj, 'ipadapter_scale'): + try: + scale_val = getattr(stream_obj, 'ipadapter_scale') + if hasattr(scale_val, 'shape'): + logger.debug(f"stream: ipadapter_scale tensor shape={tuple(scale_val.shape)}") + else: + logger.debug(f"stream: ipadapter_scale scalar={scale_val}") + except Exception: + pass + logger.debug(f"stream: ipadapter_weight_type={getattr(stream_obj, 'ipadapter_weight_type', None)}") + except Exception: + logger.exception("stream: failed to log pipeline state after creation") # Recreate pipeline if config changed (but not resolution - that's handled separately) elif self.config_needs_reload or (self.uploaded_controlnet_config and not (self.pipeline.use_config and self.pipeline.config and 'controlnets' in self.pipeline.config)) or (self.uploaded_controlnet_config and not self.pipeline.use_config): @@ -394,8 +424,34 @@ async def generate(): continue try: + try: + stream_obj = getattr(self.pipeline, 'stream', None) + unet_obj = getattr(stream_obj, 'unet', None) + is_trt = unet_obj is not None and hasattr(unet_obj, 'engine') and hasattr(unet_obj, 'stream') + logger.debug(f"generate: calling predict; acceleration={getattr(self.args, 'acceleration', None)}, is_trt={is_trt}, mode={getattr(self.pipeline, 'pipeline_mode', None)}, has_ipadapter={getattr(self.pipeline, 'has_ipadapter', False)}, has_controlnet={(self.pipeline.use_config and self.pipeline.config and 'controlnets' in self.pipeline.config) if getattr(self.pipeline, 'use_config', False) else False}") + img = getattr(params, 'image', None) + if isinstance(img, torch.Tensor): + logger.debug(f"generate: params.image tensor shape={tuple(img.shape)}, dtype={img.dtype}") + else: + logger.debug(f"generate: params.image type={type(img).__name__}") + if is_trt: + logger.debug(f"generate: unet.use_ipadapter={getattr(unet_obj, 'use_ipadapter', None)}, num_ip_layers={getattr(unet_obj, 'num_ip_layers', None)}") + try: + base_scale = getattr(stream_obj, 'ipadapter_scale', None) + if base_scale is not None: + if hasattr(base_scale, 'shape'): + logger.debug(f"generate: base ipadapter_scale shape={tuple(base_scale.shape)}") + else: + logger.debug(f"generate: base ipadapter_scale scalar={base_scale}") + logger.debug(f"generate: ipadapter_weight_type={getattr(stream_obj, 'ipadapter_weight_type', None)}") + except Exception: + pass + except Exception: + logger.exception("generate: pre-predict logging failed") + image = self.pipeline.predict(params) if image is None: + logger.error("generate: predict returned None image; skipping frame") continue # Use appropriate frame conversion based on output type @@ -404,6 +460,7 @@ async def generate(): else: frame = pil_to_frame(image) except Exception as e: + logger.exception(f"generate: predict failed with exception: {e}") continue # Update FPS counter @@ -480,14 +537,12 @@ async def settings(): current_num_inference_steps = DEFAULT_SETTINGS.get('num_inference_steps', 50) current_seed = DEFAULT_SETTINGS.get('seed', 2) - if self.pipeline: - current_guidance_scale = getattr(self.pipeline.stream, 'guidance_scale', DEFAULT_SETTINGS.get('guidance_scale', 1.1)) - current_delta = getattr(self.pipeline.stream, 'delta', DEFAULT_SETTINGS.get('delta', 0.7)) - current_num_inference_steps = getattr(self.pipeline.stream, 'num_inference_steps', DEFAULT_SETTINGS.get('num_inference_steps', 50)) - # Get seed from generator if available - if hasattr(self.pipeline.stream, 'generator') and self.pipeline.stream.generator is not None: - # We can't directly get seed from generator, but we'll use the configured value - current_seed = getattr(self.pipeline.stream, 'current_seed', DEFAULT_SETTINGS.get('seed', 2)) + if self.pipeline and hasattr(self.pipeline.stream, 'get_stream_state'): + state = self.pipeline.stream.get_stream_state() + current_guidance_scale = state.get('guidance_scale', DEFAULT_SETTINGS.get('guidance_scale', 1.1)) + current_delta = state.get('delta', DEFAULT_SETTINGS.get('delta', 0.7)) + current_num_inference_steps = state.get('num_inference_steps', DEFAULT_SETTINGS.get('num_inference_steps', 50)) + current_seed = state.get('current_seed', DEFAULT_SETTINGS.get('seed', 2)) elif self.uploaded_controlnet_config: current_guidance_scale = self.uploaded_controlnet_config.get('guidance_scale', DEFAULT_SETTINGS.get('guidance_scale', 1.1)) current_delta = self.uploaded_controlnet_config.get('delta', DEFAULT_SETTINGS.get('delta', 0.7)) @@ -499,20 +554,14 @@ async def settings(): seed_blending_config = None # First try to get from current pipeline if available - if self.pipeline: - try: - current_prompts = self.pipeline.stream.get_current_prompts() - if current_prompts and len(current_prompts) > 0: - prompt_blending_config = current_prompts - except: - pass - - try: - current_seeds = self.pipeline.stream.get_current_seeds() - if current_seeds and len(current_seeds) > 0: - seed_blending_config = current_seeds - except: - pass + if self.pipeline and hasattr(self.pipeline.stream, 'get_stream_state'): + state = self.pipeline.stream.get_stream_state() + current_prompts = state.get('prompt_list', []) + current_seeds = state.get('seed_list', []) + if current_prompts: + prompt_blending_config = current_prompts + if current_seeds: + seed_blending_config = current_seeds # If not available from pipeline, get from uploaded config and normalize if not prompt_blending_config: @@ -525,9 +574,10 @@ async def settings(): normalize_prompt_weights = True # default normalize_seed_weights = True # default - if self.pipeline: - normalize_prompt_weights = self.pipeline.stream.get_normalize_prompt_weights() - normalize_seed_weights = self.pipeline.stream.get_normalize_seed_weights() + if self.pipeline and hasattr(self.pipeline.stream, 'get_stream_state'): + state = self.pipeline.stream.get_stream_state() + normalize_prompt_weights = state.get('normalize_prompt_weights', True) + normalize_seed_weights = state.get('normalize_seed_weights', True) elif self.uploaded_controlnet_config: normalize_prompt_weights = self.uploaded_controlnet_config.get('normalize_weights', True) normalize_seed_weights = self.uploaded_controlnet_config.get('normalize_weights', True) @@ -674,51 +724,28 @@ async def get_controlnet_info(): async def get_current_blending_config(): """Get current prompt and seed blending configurations""" try: - # Get normalized configurations (same logic as settings endpoint) - prompt_blending_config = None - seed_blending_config = None - - # First try to get from current pipeline if available - if self.pipeline: - try: - current_prompts = self.pipeline.stream.get_current_prompts() - if current_prompts and len(current_prompts) > 0: - prompt_blending_config = current_prompts - except Exception: - pass - - try: - current_seeds = self.pipeline.stream.get_current_seeds() - if current_seeds and len(current_seeds) > 0: - seed_blending_config = current_seeds - except: - pass - - # If not available from pipeline, get from uploaded config and normalize - if not prompt_blending_config: - prompt_blending_config = self._normalize_prompt_config(self.uploaded_controlnet_config) - - if not seed_blending_config: - seed_blending_config = self._normalize_seed_config(self.uploaded_controlnet_config) - - # Get normalization settings - normalize_prompt_weights = True - normalize_seed_weights = True - - if self.pipeline: - normalize_prompt_weights = self.pipeline.stream.get_normalize_prompt_weights() - normalize_seed_weights = self.pipeline.stream.get_normalize_seed_weights() - elif self.uploaded_controlnet_config: - normalize_prompt_weights = self.uploaded_controlnet_config.get('normalize_weights', True) - normalize_seed_weights = self.uploaded_controlnet_config.get('normalize_weights', True) - + if self.pipeline and hasattr(self.pipeline, 'stream') and hasattr(self.pipeline.stream, 'get_stream_state'): + state = self.pipeline.stream.get_stream_state(include_caches=False) + return JSONResponse({ + "prompt_blending": state.get("prompt_list", []), + "seed_blending": state.get("seed_list", []), + "normalize_prompt_weights": state.get("normalize_prompt_weights", True), + "normalize_seed_weights": state.get("normalize_seed_weights", True), + "has_config": self.uploaded_controlnet_config is not None, + "pipeline_active": True + }) + + # Fallback to uploaded config normalization when pipeline not initialized + prompt_blending_config = self._normalize_prompt_config(self.uploaded_controlnet_config) + seed_blending_config = self._normalize_seed_config(self.uploaded_controlnet_config) + normalize_weights = self.uploaded_controlnet_config.get('normalize_weights', True) if self.uploaded_controlnet_config else True return JSONResponse({ "prompt_blending": prompt_blending_config, "seed_blending": seed_blending_config, - "normalize_prompt_weights": normalize_prompt_weights, - "normalize_seed_weights": normalize_seed_weights, + "normalize_prompt_weights": normalize_weights, + "normalize_seed_weights": normalize_weights, "has_config": self.uploaded_controlnet_config is not None, - "pipeline_active": self.pipeline is not None + "pipeline_active": False }) except Exception as e: @@ -988,20 +1015,20 @@ async def upload_style_image(file: UploadFile = File(...)): # Read file content content = await file.read() - # Save temporarily and load as PIL Image - with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp: - tmp.write(content) - tmp_path = tmp.name - + tmp_path = None try: + with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp: + tmp.write(content) + tmp_path = tmp.name + # Load and validate image from PIL import Image style_image = Image.open(tmp_path).convert("RGB") - + # Store the uploaded style image persistently FIRST self.uploaded_style_image = style_image print(f"upload_style_image: Stored style image with size: {style_image.size}") - + # If pipeline exists and has IPAdapter, update it immediately pipeline_updated = False if self.pipeline and getattr(self.pipeline, 'has_ipadapter', False): @@ -1010,10 +1037,11 @@ async def upload_style_image(file: UploadFile = File(...)): if success: pipeline_updated = True print("upload_style_image: Successfully applied to existing pipeline") - + # Force prompt re-encoding to apply new style image embeddings try: - current_prompts = self.pipeline.stream.get_current_prompts() + state = self.pipeline.stream.get_stream_state() + current_prompts = state.get('prompt_list', []) if current_prompts: print("upload_style_image: Forcing prompt re-encoding to apply new style image") self.pipeline.stream.update_prompt(current_prompts, prompt_interpolation_method="slerp") @@ -1026,7 +1054,7 @@ async def upload_style_image(file: UploadFile = File(...)): print(f"upload_style_image: Pipeline exists but has_ipadapter={getattr(self.pipeline, 'has_ipadapter', False)}") else: print("upload_style_image: No pipeline exists yet") - + # Return success message = "Style image uploaded successfully" if pipeline_updated: @@ -1038,13 +1066,12 @@ async def upload_style_image(file: UploadFile = File(...)): "status": "success", "message": message }) - finally: - # Clean up temp file - try: - os.unlink(tmp_path) - except: - pass + if tmp_path: + try: + os.unlink(tmp_path) + except: + pass except HTTPException: raise @@ -1129,6 +1156,42 @@ async def update_ipadapter_scale(request: Request): logging.error(f"update_ipadapter_scale: Failed to update scale: {e}") raise HTTPException(status_code=500, detail=f"Failed to update scale: {str(e)}") + @self.app.post("/api/ipadapter/update-weight-type") + async def update_ipadapter_weight_type(request: Request): + """Update IPAdapter weight type in real-time""" + try: + data = await request.json() + weight_type = data.get("weight_type") + + if weight_type is None: + raise HTTPException(status_code=400, detail="Missing weight_type parameter") + + if not self.pipeline: + raise HTTPException(status_code=400, detail="Pipeline is not initialized") + + # Check if we're using config mode and have ipadapters configured + ipadapter_enabled = (self.pipeline.use_config and + self.pipeline.config and + 'ipadapters' in self.pipeline.config) + + if not ipadapter_enabled: + raise HTTPException(status_code=400, detail="IPAdapter is not enabled") + + # Update IPAdapter weight type in the pipeline + success = self.pipeline.update_ipadapter_weight_type(weight_type) + + if success: + return JSONResponse({ + "status": "success", + "message": f"Updated IPAdapter weight type to {weight_type}" + }) + else: + raise HTTPException(status_code=500, detail="Failed to update weight type in pipeline") + + except Exception as e: + logging.error(f"update_ipadapter_weight_type: Failed to update weight type: {e}") + raise HTTPException(status_code=500, detail=f"Failed to update weight type: {str(e)}") + @self.app.post("/api/params") async def update_params(request: Request): """Update multiple streaming parameters in a single unified call""" @@ -1228,7 +1291,7 @@ async def update_guidance_scale(request: Request): if not self.pipeline: raise HTTPException(status_code=400, detail="Pipeline is not initialized") - self.pipeline.stream.update_stream_params(guidance_scale=guidance_scale) + self.pipeline.update_stream_params(guidance_scale=guidance_scale) return JSONResponse({ "status": "success", @@ -1250,7 +1313,7 @@ async def update_delta(request: Request): if not self.pipeline: raise HTTPException(status_code=400, detail="Pipeline is not initialized") - self.pipeline.stream.update_stream_params(delta=delta) + self.pipeline.update_stream_params(delta=delta) return JSONResponse({ "status": "success", @@ -1272,7 +1335,7 @@ async def update_num_inference_steps(request: Request): if not self.pipeline: raise HTTPException(status_code=400, detail="Pipeline is not initialized") - self.pipeline.stream.update_stream_params(num_inference_steps=num_inference_steps) + self.pipeline.update_stream_params(num_inference_steps=num_inference_steps) return JSONResponse({ "status": "success", @@ -1294,7 +1357,7 @@ async def update_seed(request: Request): if not self.pipeline: raise HTTPException(status_code=400, detail="Pipeline is not initialized") - self.pipeline.stream.update_stream_params(seed=seed) + self.pipeline.update_stream_params(seed=seed) return JSONResponse({ "status": "success", @@ -1459,24 +1522,37 @@ async def switch_preprocessor(request: Request): # Create new preprocessor instance from src.streamdiffusion.preprocessing.processors import get_preprocessor new_preprocessor_instance = get_preprocessor(new_preprocessor) - + + # Resolve stream object and preprocessor list regardless of module or stream facade + stream_obj = getattr(cn_pipeline, '_stream', None) + if stream_obj is None: + stream_obj = getattr(self.pipeline, 'stream', None) + if stream_obj is None: + raise HTTPException(status_code=500, detail="Pipeline stream not available") + + preproc_list = getattr(cn_pipeline, 'preprocessors', None) + if preproc_list is None: + preproc_list = getattr(stream_obj, 'preprocessors', None) + if preproc_list is None: + raise HTTPException(status_code=500, detail="ControlNet preprocessors not available") + # Set system parameters system_params = { - 'device': cn_pipeline.device, - 'dtype': cn_pipeline.dtype, - 'image_width': cn_pipeline.stream.width, - 'image_height': cn_pipeline.stream.height, + 'device': stream_obj.device, + 'dtype': stream_obj.dtype, + 'image_width': stream_obj.width, + 'image_height': stream_obj.height, } system_params.update(preprocessor_params) new_preprocessor_instance.params.update(system_params) - + # Set pipeline reference for feedback preprocessor if hasattr(new_preprocessor_instance, 'set_pipeline_ref'): - new_preprocessor_instance.set_pipeline_ref(cn_pipeline.stream) - + new_preprocessor_instance.set_pipeline_ref(stream_obj) + # Replace the preprocessor - old_preprocessor = cn_pipeline.preprocessors[controlnet_index] - cn_pipeline.preprocessors[controlnet_index] = new_preprocessor_instance + old_preprocessor = preproc_list[controlnet_index] + preproc_list[controlnet_index] = new_preprocessor_instance logger.info(f"switch_preprocessor: Successfully switched ControlNet {controlnet_index} from {type(old_preprocessor).__name__ if old_preprocessor else 'None'} to {type(new_preprocessor_instance).__name__}") @@ -1508,22 +1584,25 @@ async def update_preprocessor_params(request: Request): if not self.pipeline: raise HTTPException(status_code=400, detail="Pipeline is not initialized") - # Update preprocessor parameters using consolidated API - current_config = self._get_current_controlnet_config() - - if not current_config: - raise HTTPException(status_code=400, detail="No ControlNet configuration available") - - if controlnet_index >= len(current_config): - raise HTTPException(status_code=400, detail=f"ControlNet index {controlnet_index} out of range (max: {len(current_config)-1})") - - # Update preprocessor_params for the specified controlnet - if 'preprocessor_params' not in current_config[controlnet_index]: - current_config[controlnet_index]['preprocessor_params'] = {} - current_config[controlnet_index]['preprocessor_params'].update(preprocessor_params) - - # Apply the updated configuration - self.pipeline.update_stream_params(controlnet_config=current_config) + # Fast path: update module preprocessor directly when available + cn_pipeline = self._get_controlnet_pipeline() + preproc_list = getattr(cn_pipeline, 'preprocessors', None) + if preproc_list is None: + raise HTTPException(status_code=400, detail="ControlNet preprocessors not available") + + if controlnet_index >= len(preproc_list): + raise HTTPException(status_code=400, detail=f"ControlNet index {controlnet_index} out of range (max: {len(preproc_list)-1})") + + target_preproc = preproc_list[controlnet_index] + if target_preproc is None: + raise HTTPException(status_code=400, detail="ControlNet preprocessor is not set") + + # Merge params: update both the params map and setattr when attribute exists + if hasattr(target_preproc, 'params') and isinstance(target_preproc.params, dict): + target_preproc.params.update(preprocessor_params) + for name, value in preprocessor_params.items(): + if hasattr(target_preproc, name): + setattr(target_preproc, name, value) return JSONResponse({ "status": "success", @@ -1550,16 +1629,9 @@ async def update_prompt_weight(request: Request): if not self.pipeline: raise HTTPException(status_code=400, detail="Pipeline is not initialized") - # Get current prompt blending configuration using the same logic as the blending/current endpoint - current_prompts = None - try: - current_prompts = self.pipeline.stream.get_current_prompts() - except Exception: - pass - - # If not available from pipeline, get from uploaded config and normalize - if not current_prompts: - current_prompts = self._normalize_prompt_config(self.uploaded_controlnet_config) + # Get current prompt blending configuration via unified getter, fallback to uploaded config + state = self.pipeline.stream.get_stream_state() + current_prompts = state.get('prompt_list') or self._normalize_prompt_config(self.uploaded_controlnet_config) if current_prompts and index < len(current_prompts): # Create updated prompt list with new weight @@ -1603,16 +1675,9 @@ async def update_seed_weight(request: Request): if not self.pipeline: raise HTTPException(status_code=400, detail="Pipeline is not initialized") - # Get current seed blending configuration using the same logic as the blending/current endpoint - current_seeds = None - try: - current_seeds = self.pipeline.stream.get_current_seeds() - except Exception: - pass - - # If not available from pipeline, get from uploaded config and normalize - if not current_seeds: - current_seeds = self._normalize_seed_config(self.uploaded_controlnet_config) + # Get current seed blending configuration via unified getter, fallback to uploaded config + state = self.pipeline.stream.get_stream_state() + current_seeds = state.get('seed_list') or self._normalize_seed_config(self.uploaded_controlnet_config) if current_seeds and index < len(current_seeds): # Create updated seed list with new weight @@ -1651,10 +1716,14 @@ async def get_current_preprocessor_params(controlnet_index: int): if not cn_pipeline: raise HTTPException(status_code=400, detail="ControlNet pipeline not found") - if controlnet_index >= len(cn_pipeline.preprocessors): + # Module-aware: allow accessing module's preprocessors list + preprocessors = getattr(cn_pipeline, 'preprocessors', None) + if preprocessors is None: + raise HTTPException(status_code=400, detail="ControlNet preprocessors not available") + if controlnet_index >= len(preprocessors): raise HTTPException(status_code=400, detail=f"ControlNet index {controlnet_index} out of range") - current_preprocessor = cn_pipeline.preprocessors[controlnet_index] + current_preprocessor = preprocessors[controlnet_index] if not current_preprocessor: return JSONResponse({ "preprocessor": None, @@ -1668,7 +1737,7 @@ async def get_current_preprocessor_params(controlnet_index: int): # Extract current values, using defaults if not set current_values = {} for param_name, param_meta in user_param_meta.items(): - if param_name in current_preprocessor.params: + if hasattr(current_preprocessor, 'params') and param_name in current_preprocessor.params: current_values[param_name] = current_preprocessor.params[param_name] else: current_values[param_name] = param_meta.get("default") @@ -2007,7 +2076,8 @@ def _create_pipeline_with_config(self, controlnet_config_path=None): # Force prompt re-encoding to apply style image embeddings try: - current_prompts = new_pipeline.stream.get_current_prompts() + state = new_pipeline.stream.get_stream_state() + current_prompts = state.get('prompt_list', []) if current_prompts: print("_create_pipeline_with_config: Forcing prompt re-encoding to apply style image") new_pipeline.stream.update_prompt(current_prompts, prompt_interpolation_method="slerp") @@ -2275,7 +2345,8 @@ def _update_resolution(self, width: int, height: int) -> None: # Force prompt re-encoding to apply style image embeddings try: - current_prompts = new_pipeline.stream.get_current_prompts() + state = new_pipeline.stream.get_stream_state() + current_prompts = state.get('prompt_list', []) if current_prompts: print("_update_resolution: Forcing prompt re-encoding to apply style image") new_pipeline.stream.update_prompt(current_prompts, prompt_interpolation_method="slerp") diff --git a/examples/controlnet/controlnet_video_test.py b/examples/controlnet/controlnet_video_test.py index 31448fe0..22ea2607 100644 --- a/examples/controlnet/controlnet_video_test.py +++ b/examples/controlnet/controlnet_video_test.py @@ -95,8 +95,15 @@ def process_video(config_path, input_video, output_dir, engine_only=False): frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB) frame_pil = Image.fromarray(frame_rgb) - # Update control image and generate - wrapper.update_control_image_efficient(frame_pil) + # Update control image and generate via consolidated API + try: + current_cfg = wrapper.stream._param_updater._get_current_controlnet_config() if hasattr(wrapper, 'stream') else [] + except Exception: + current_cfg = [] + if current_cfg: + for i in range(len(current_cfg)): + current_cfg[i]['control_image'] = frame_pil + wrapper.update_stream_params(controlnet_config=current_cfg) output_image = wrapper(frame_pil) # Convert output to display format @@ -222,7 +229,9 @@ def main(): print("main: Video processing completed successfully!") return 0 except Exception as e: + import traceback print(f"main: Error during processing: {e}") + print(f"main: Traceback:\n{''.join(traceback.format_tb(e.__traceback__))}") return 1 diff --git a/examples/controlnet/standalone_controlnet_pipeline.py b/examples/controlnet/standalone_controlnet_pipeline.py deleted file mode 100644 index dccf1164..00000000 --- a/examples/controlnet/standalone_controlnet_pipeline.py +++ /dev/null @@ -1,231 +0,0 @@ -#!/usr/bin/env python3 -""" -Standalone Multi-ControlNet StreamDiffusion Pipeline - -Self-contained script demonstrating multiple ControlNets + StreamDiffusion integration. -Shows depth + canny edge conditioning. - -Designed for reference for porting into other production systems. -No GUI, no webcam complexity - just core pipeline logic with hardcoded configs. -""" - -import sys -import os -import time -from pathlib import Path -from typing import List, Optional -import argparse - -# Add StreamDiffusion to path -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) - -import torch -from PIL import Image -from streamdiffusion import load_config, create_wrapper_from_config -# ============================================================================ -# PIPELINE IMPLEMENTATION -# ============================================================================ - - -class MultiControlNetStreamDiffusionPipeline: - """ - Multi-ControlNet StreamDiffusion pipeline. - """ - - def __init__(self, config_file: str): - self.config_file = config_file - self.wrapper = None - self._setup_pipeline() - - def _setup_pipeline(self): - """Initialize the StreamDiffusion pipeline from config file""" - print("Initializing Multi-ControlNet StreamDiffusion pipeline...") - print(f"Using config file: {self.config_file}") - # Load configuration and create wrapper - config_data = load_config(self.config_file) - self.wrapper = create_wrapper_from_config(config_data) - self.warmup_steps = config_data.get('warmup', 10) - self._warmed_up = False - - print("Pipeline created - warmup will occur with first input image") - - # Check TensorRT status - if hasattr(self.wrapper.stream, 'unet') and hasattr(self.wrapper.stream.unet, 'engine'): - print("TensorRT acceleration active") - else: - print("Running in PyTorch mode") - - - - def process_image(self, image: Image.Image) -> Image.Image: - """ - Process a single image through the multi-ControlNet pipeline. - - This is the core controlnet inference method that would be used in production. - The conditioning for all of the controlnets that are defined in the config will be applied automatically. - """ - # Run warmup with actual input image on first call - if not self._warmed_up and self.warmup_steps > 0: - print(f"Running {self.warmup_steps} warmup iterations with input image...") - for i in range(self.warmup_steps): - if i % 3 == 0: # Print progress every 3 steps - print(f" Warmup step {i+1}/{self.warmup_steps}") - self.wrapper.update_control_image_efficient(image) - _ = self.wrapper(image) - self._warmed_up = True - print("Warmup completed!") - - # Update control image for all ControlNets - self.wrapper.update_control_image_efficient(image) - - # Generate output with multi-ControlNet conditioning - return self.wrapper(image) - - def update_controlnet_strength(self, index: int, strength: float): - """Dynamically update ControlNet strength. This will be required for Product.""" - if hasattr(self.wrapper, 'update_controlnet_scale'): - self.wrapper.update_controlnet_scale(index, strength) - print(f"update_controlnet_strength: Updated ControlNet {index+1} strength to {strength}") - else: - print("update_controlnet_strength: Not supported for this pipeline") - - def update_stream_params(self, guidance_scale: float = None, delta: float = None, num_inference_steps: int = None): - """Dynamically update StreamDiffusion parameters during inference""" - if hasattr(self.wrapper.stream, 'update_stream_params'): - self.wrapper.stream.update_stream_params( - guidance_scale=guidance_scale, - delta=delta, - num_inference_steps=num_inference_steps - ) - else: - print("update_stream_params: Not supported for this pipeline") - - -def load_input_image(image_path: str, target_width: int, target_height: int) -> Image.Image: - """Load and prepare input image""" - print(f"Loading input image: {image_path}") - image = Image.open(image_path).convert("RGB") - - # Resize to target resolution while maintaining aspect ratio - original_size = image.size - print(f"Original size: {original_size}") - - # Resize to target width and height - # TODO: This is a hack to get the image to the correct size for tensorrt development. The resolution of the input image must match the resolution specified for the pipeline for tenosrrt acceleration to work. - image = image.resize((target_width, target_height), Image.Resampling.LANCZOS) - print(f"Resized to: {image.size}") - - return image - - -def setup_output_directory(): - """Create output directory next to the script""" - script_dir = Path(__file__).parent - output_dir = script_dir / "standalone_controlnet_pipeline_output" - output_dir.mkdir(exist_ok=True) - return output_dir - - -def run_demo(config_file: str, input_image_path: str, engine_only: bool = False): - """ - Demonstration of the multi-ControlNet pipeline. - Shows how depth + canny ControlNets work together. - """ - # Setup output directory - output_dir = setup_output_directory() - print(f"Output directory: {output_dir}") - - # Validate paths - if not os.path.exists(config_file): - print(f"ERROR: Config file not found at {config_file}") - return False - - if not os.path.exists(input_image_path): - print(f"ERROR: Input image not found at {input_image_path}") - return False - - try: - # Load configuration to get actual dimensions - - config_data = load_config(config_file) - target_width = config_data.get('width', 512) - target_height = config_data.get('height', 512) - - # Initialize pipeline (this will trigger engine building if needed) - pipeline = MultiControlNetStreamDiffusionPipeline(config_file) - - if engine_only: - print("Engine-only mode: TensorRT engines have been built (if needed). Exiting.") - return True - - # Load input image - input_image = load_input_image(input_image_path, target_width, target_height) - - - print("Running multi-ControlNet inference...") - start_time = time.time() - - output_image = pipeline.process_image(input_image) - - inference_time = time.time() - start_time - print(f"Inference completed in {inference_time:.2f}s") - - # Save results to output directory - timestamp = int(time.time()) - input_path = output_dir / f"input_{timestamp}.png" - output_path = output_dir / f"output_{timestamp}.png" - - input_image.save(input_path) - output_image.save(output_path) - - print(f"Results saved:") - print(f" Input: {input_path}") - print(f" Output: {output_path}") - - # Demonstrate combined parameter updates - pipeline.update_controlnet_strength(0, 0.2) # Reduce depth influence - pipeline.update_stream_params(guidance_scale=1.2, delta=1.0) # Adjust guidance and noise - - adjusted_output = pipeline.process_image(input_image) - adjusted_path = output_dir / f"output_adjusted_{timestamp}.png" - adjusted_output.save(adjusted_path) - print(f" Adjusted output: {adjusted_path}") - - return True - - except Exception as e: - print(f"ERROR: {e}") - import traceback - traceback.print_exc() - return False - - -def main(): - """Main entry point""" - parser = argparse.ArgumentParser(description="Standalone Multi-ControlNet StreamDiffusion Pipeline") - parser.add_argument("--config", type=str, required=True, help="Path to ControlNet configuration file") - parser.add_argument("--input-image", type=str, required=True, help="Path to input image file") - parser.add_argument("--engine-only", action="store_true", help="Only build TensorRT engines and exit (no inference)") - args = parser.parse_args() - - print("=" * 70) - print("Standalone Multi-ControlNet StreamDiffusion Pipeline") - print("=" * 70) - print(f"Config: {args.config}") - print(f"Input Image: {args.input_image}") - print("=" * 70) - - success = run_demo( - config_file=args.config, - input_image_path=args.input_image, - engine_only=args.engine_only - ) - - if success: - print("\n✓ Multi-ControlNet demo completed successfully!") - else: - print("\n✗ Demo failed - check configuration and dependencies") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/ipadapter/ipadapter_img2img_config_example.py b/examples/ipadapter/ipadapter_img2img_config_example.py deleted file mode 100644 index e0e84a06..00000000 --- a/examples/ipadapter/ipadapter_img2img_config_example.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import sys -import torch -from pathlib import Path -from PIL import Image - -# Add paths to import from parent directories -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) - -from src.streamdiffusion.config import create_wrapper_from_config, load_config - -CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -CONFIG_PATH = os.path.join(CURRENT_DIR, "ipadapter_img2img_config_example.yaml") -OUTPUT_DIR = os.path.join(CURRENT_DIR, "..", "..", "output") - -def main(): - """IPAdapter img2img example using configuration system with multiple strength values.""" - - device = "cuda" if torch.cuda.is_available() else "cpu" - os.makedirs(OUTPUT_DIR, exist_ok=True) - - print(f"main: Loading img2img configuration from {CONFIG_PATH}") - - # Load configuration - config = load_config(CONFIG_PATH) - print(f"main: Device: {device}") - print(f"main: Mode: {config.get('mode', 'img2img')}") - - # Get the original scale from config - original_scale = config.get('ipadapters', [{}])[0].get('scale', 1.0) - - # Define strength values to test - strength_values = [original_scale] - - # Load input image for img2img transformation - input_image_path = os.path.join(CURRENT_DIR, "..", "..", "images", "inputs", "hand_up512.png") - - # Check if input image exists, if not use alternative paths - if not os.path.exists(input_image_path): - print(f"main: Input image not found at {input_image_path}") - print("main: Please place an input image at the specified path or update the path") - # For demonstration, try alternative paths - alt_paths = [ - os.path.join(CURRENT_DIR, "..", "..", "images", "inputs", "input.png"), - os.path.join(CURRENT_DIR, "..", "..", "images", "inputs", "style.webp"), - ] - - for alt_path in alt_paths: - if os.path.exists(alt_path): - input_image_path = alt_path - print(f"main: Using alternative input image: {input_image_path}") - break - else: - print("main: No suitable input image found. Please provide an input image.") - return - - try: - print(f"main: Testing IPAdapter img2img with strength values: {strength_values}") - - for i, strength in enumerate(strength_values): - print(f"\nmain: Creating pipeline {i+1}/4 with strength {strength}") - - # Create a copy of config for this strength - current_config = config.copy() - if 'ipadapters' in current_config and len(current_config['ipadapters']) > 0: - current_config['ipadapters'][0]['scale'] = strength - - # Create fresh wrapper for this strength (clean slate) - wrapper = create_wrapper_from_config(current_config, device=device) - - # Preprocess the input image - print(f"main: Loading and preprocessing input image from {input_image_path}") - input_image = wrapper.preprocess_image(input_image_path) - - print(f"main: Generating img2img with IPAdapter strength {strength}") - - # Warm up the pipeline - for _ in range(wrapper.batch_size - 1): - wrapper(image=input_image) - - # Generate final image - output_image = wrapper(image=input_image) - - # Save result with strength in filename - if strength == original_scale: - output_path = os.path.join(OUTPUT_DIR, f"ipadapter_img2img_strength_{strength:.2f}_config.png") - else: - output_path = os.path.join(OUTPUT_DIR, f"ipadapter_img2img_strength_{strength:.2f}.png") - - output_image.save(output_path) - print(f"main: IPAdapter img2img image (strength {strength}) saved to: {output_path}") - - # Clean up wrapper to ensure no state interference - del wrapper - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - print(f"\nmain: IPAdapter img2img demonstration completed successfully!") - print(f"main: Generated 4 images with different IPAdapter strengths:") - print(f"main: - 0.00: No IPAdapter influence (pure img2img)") - print(f"main: - 0.50: Balanced IPAdapter and input image") - print(f"main: - 1.00: Strong IPAdapter influence") - print(f"main: - {original_scale:.2f}: Original config value") - - except Exception as e: - print(f"main: Error - {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/ipadapter/ipadapter_img2img_config_example.yaml b/examples/ipadapter/ipadapter_img2img_config_example.yaml deleted file mode 100644 index 3fe0d80b..00000000 --- a/examples/ipadapter/ipadapter_img2img_config_example.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# StreamDiffusion IPAdapter img2img Configuration Example -# This demonstrates how to configure IPAdapter for image-to-image mode - -# Base model configuration -model_id: "C:\\_dev\\comfy\\ComfyUI\\models\\checkpoints\\perfectPhotonPerfect_perfectPhotonV21.safetensors" -device: "cuda" -dtype: "float16" -width: 512 -height: 512 -mode: "img2img" # Changed from txt2img to img2img - -# StreamDiffusion parameters -t_index_list: [25, 32, 45] # Different t_index_list for img2img (closer to img2img examples) -frame_buffer_size: 1 -warmup: 10 -acceleration: "tensorrt" -use_denoising_batch: true # img2img typically uses denoising batch -cfg_type: "self" # img2img uses "self" instead of "none" -seed: 42 - -engine_dir: "C:\\_dev\\comfy\\ComfyUI\\StreamDiffusion\\engines\\ipa" - -# Text prompts (these will be combined with IPAdapter style conditioning) -prompt: "masterpiece, high detail, 8k" -negative_prompt: "blurry, horror, worst quality, low quality" -num_inference_steps: 50 -guidance_scale: 1.0 -delta: 1 # Delta parameter for img2img - -use_controlnet: false - -# IPAdapter configuration -ipadapters: - # Explicit model paths (required) - - ipadapter_model_path: "h94/IP-Adapter/models/ip-adapter-plus_sd15.safetensors" - image_encoder_path: "h94/IP-Adapter/models/image_encoder" - style_image: "../../images/inputs/gold.jpg" - scale: 0.70 - enabled: true - - # For SDXL models: - # - ipadapter_model_path: "h94/IP-Adapter/sdxl_models/ip-adapter-plus-face_sdxl_vit-h.safetensors" - # image_encoder_path: "h94/IP-Adapter/sdxl_models/image_encoder" - # style_image: "../../images/inputs/gold.jpg" - # scale: 0.70 - # enabled: true - -# python ipadapter_img2img_config_example.py --config .\ipadapter_img2img_config_example.yaml --input-image ..\..\..\input\hand_up512.png \ No newline at end of file diff --git a/setup.py b/setup.py index eb36f225..9b2a154c 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,9 @@ def deps_list(*pkgs): deps["diffusers"], deps["transformers"], deps["accelerate"], - "diffusers-ipadapter @ git+https://github.com/livepeer/Diffusers_IPAdapter.git@4a61dbb452c024e2df161128595b8af88c662940", + "diffusers-ipadapter @ git+https://github.com/livepeer/Diffusers_IPAdapter.git@aabdc79e8298e7f66700e6fd15923aa9efc21cb1", + # "diffusers-ipadapter @ file:///C:/_dev/projects/StreamDiffusion/Diffusers_IPAdapter", + ] setup( diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index 8649e303..2dd23054 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -107,8 +107,8 @@ def get_engine_path(self, # Create prefix (from wrapper.py lines 1005-1013) prefix = f"{base_name}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch}--min_batch-{min_batch_size}" - if ipadapter_scale is not None: - prefix += f"--ipa{ipadapter_scale}" + # Do not bake IP-Adapter scale into engine name; strength is now a runtime input + # (ipadapter_scale remains a parameter for backward-compatibility but is ignored here) if ipadapter_tokens is not None: prefix += f"--tokens{ipadapter_tokens}" @@ -238,6 +238,10 @@ def _set_unet_metadata(self, loaded_engine, kwargs: Dict) -> None: if kwargs.get('use_ipadapter_trt', False): setattr(loaded_engine, 'ipadapter_arch', kwargs.get('unet_arch', {})) + # number of IP-attention layers for runtime vector sizing + if 'num_ip_layers' in kwargs and kwargs['num_ip_layers'] is not None: + setattr(loaded_engine, 'num_ip_layers', kwargs['num_ip_layers']) + def get_or_load_controlnet_engine(self, model_id: str, diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py index 65a62ede..74e5e8fa 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py @@ -3,6 +3,8 @@ from typing import Optional, Dict, Any, List from ....model_detection import detect_model, detect_model_from_diffusers_unet +from diffusers_ipadapter.ip_adapter.attention_processor import TRTIPAttnProcessor, TRTIPAttnProcessor2_0 + class IPAdapterUNetExportWrapper(torch.nn.Module): """ @@ -23,6 +25,10 @@ def __init__(self, unet: UNet2DConditionModel, cross_attention_dim: int, num_tok # Convert to float32 BEFORE installing processors (to avoid resetting them) self.unet = self.unet.to(dtype=torch.float32) + # Track installed TRT processors + self._ip_trt_processors: List[torch.nn.Module] = [] + self.num_ip_layers: int = 0 + # Check if IPAdapter processors are already installed (from pre-loading) if self._has_ipadapter_processors(): self._ensure_processor_dtype_consistency() @@ -51,19 +57,47 @@ def _ensure_processor_dtype_consistency(self): try: processors = self.unet.attn_processors updated_processors = {} + self._ip_trt_processors = [] + ip_layer_index = 0 for name, processor in processors.items(): processor_class = processor.__class__.__name__ - if 'IPAttn' in processor_class or 'IPAttnProcessor' in processor_class: - # Convert IPAdapter processors to float32 for ONNX consistency - # This preserves the weights while updating dtype - updated_processors[name] = processor.to(dtype=torch.float32) + if 'TRTIPAttn' in processor_class: + # Already TRT processors: ensure dtype and record + proc = processor.to(dtype=torch.float32) + proc._scale_index = ip_layer_index + self._ip_trt_processors.append(proc) + ip_layer_index += 1 + updated_processors[name] = proc + elif 'IPAttn' in processor_class or 'IPAttnProcessor' in processor_class: + # Replace standard processors with TRT variants, preserving weights where applicable + hidden_size = getattr(processor, 'hidden_size', None) + cross_attention_dim = getattr(processor, 'cross_attention_dim', None) + num_tokens = getattr(processor, 'num_tokens', self.num_image_tokens) + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + from diffusers_ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor + IPProcClass = TRTIPAttnProcessor2_0 + else: + from diffusers_ipadapter.ip_adapter.attention_processor import AttnProcessor + IPProcClass = TRTIPAttnProcessor + proc = IPProcClass(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens) + # Copy IP projection weights if present + if hasattr(processor, 'to_k_ip') and hasattr(processor, 'to_v_ip') and hasattr(proc, 'to_k_ip'): + with torch.no_grad(): + proc.to_k_ip.weight.copy_(processor.to_k_ip.weight.to(dtype=torch.float32)) + proc.to_v_ip.weight.copy_(processor.to_v_ip.weight.to(dtype=torch.float32)) + proc = proc.to(self.unet.device, dtype=torch.float32) + proc._scale_index = ip_layer_index + self._ip_trt_processors.append(proc) + ip_layer_index += 1 + updated_processors[name] = proc else: # Keep standard processors as-is updated_processors[name] = processor # Update all processors to ensure consistency self.unet.set_attn_processor(updated_processors) + self.num_ip_layers = len(self._ip_trt_processors) except Exception as e: print(f"IPAdapterUNetExportWrapper: Error updating processor dtypes: {e}") @@ -78,14 +112,17 @@ def _install_ipadapter_processors(self): # Import IPAdapter attention processors from installed package try: if hasattr(torch.nn.functional, "scaled_dot_product_attention"): - from diffusers_ipadapter.ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor + from diffusers_ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor + IPProcClass = TRTIPAttnProcessor2_0 else: - from diffusers_ipadapter.ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor + from diffusers_ipadapter.ip_adapter.attention_processor import AttnProcessor + IPProcClass = TRTIPAttnProcessor # Install attention processors with proper configuration processor_names = list(self.unet.attn_processors.keys()) attn_procs = {} + ip_layer_index = 0 for name in processor_names: cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim @@ -107,14 +144,20 @@ def _install_ipadapter_processors(self): # Self-attention layers use standard processors attn_procs[name] = AttnProcessor() else: - # Cross-attention layers use IPAdapter processors - attn_procs[name] = IPAttnProcessor( - hidden_size=hidden_size, + # Cross-attention layers use TRTIPAttn processors (runtime scale tensor) + proc = IPProcClass( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, - num_tokens=self.num_image_tokens - ).to(self.unet.device, dtype=torch.float32) # Force float32 for ONNX + num_tokens=self.num_image_tokens, + ).to(self.unet.device, dtype=torch.float32) + # record mapping index + proc._scale_index = ip_layer_index + self._ip_trt_processors.append(proc) + ip_layer_index += 1 + attn_procs[name] = proc self.unet.set_attn_processor(attn_procs) + self.num_ip_layers = len(self._ip_trt_processors) @@ -126,7 +169,30 @@ def _install_ipadapter_processors(self): traceback.print_exc() raise e - def forward(self, sample, timestep, encoder_hidden_states): + def set_ipadapter_scale(self, ipadapter_scale: torch.Tensor) -> None: + """Assign per-layer scale tensor to installed TRTIPAttn processors.""" + if not isinstance(ipadapter_scale, torch.Tensor): + import logging + logging.getLogger(__name__).error(f"IPAdapterUNetExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}") + raise TypeError("ipadapter_scale must be a torch.Tensor") + if self.num_ip_layers <= 0 or not self._ip_trt_processors: + raise RuntimeError("No TRTIPAttn processors installed") + if ipadapter_scale.ndim != 1 or ipadapter_scale.shape[0] != self.num_ip_layers: + import logging + logging.getLogger(__name__).error(f"IPAdapterUNetExportWrapper: ipadapter_scale has wrong shape {tuple(ipadapter_scale.shape)}, expected=({self.num_ip_layers},)") + raise ValueError(f"ipadapter_scale must have shape [{self.num_ip_layers}]") + + # Ensure float32 for ONNX export stability + scale_vec = ipadapter_scale.to(dtype=torch.float32) + try: + import logging + logging.getLogger(__name__).debug(f"IPAdapterUNetExportWrapper: scale_vec min={float(scale_vec.min().item())}, max={float(scale_vec.max().item())}") + except Exception: + pass + for proc in self._ip_trt_processors: + proc._scale_tensor = scale_vec[proc._scale_index] + + def forward(self, sample, timestep, encoder_hidden_states, ipadapter_scale: torch.Tensor = None): """ Forward pass with concatenated embeddings (text + image). @@ -153,6 +219,11 @@ def forward(self, sample, timestep, encoder_hidden_states): # Ensure dtype consistency for ONNX export if encoder_hidden_states.dtype != torch.float32: encoder_hidden_states = encoder_hidden_states.to(torch.float32) + + # Set per-layer scale tensor + if ipadapter_scale is None: + raise RuntimeError("IPAdapterUNetExportWrapper.forward requires ipadapter_scale tensor") + self.set_ipadapter_scale(ipadapter_scale) # Pass concatenated embeddings to UNet with baked-in IPAdapter processors return self.unet( diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py index a4b4247a..a6971ba1 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py @@ -63,6 +63,28 @@ def forward(self, *control_args, **kwargs) -> torch.Tensor: """Forward pass that handles any UNet parameters via **kwargs passthrough""" + # Handle IP-Adapter runtime scale vector as a positional argument placed before control tensors + if self.use_ipadapter and self.ipadapter_wrapper is not None: + # ipadapter_scale is appended as the first extra positional input after the 3 base inputs + if len(control_args) == 0: + import logging + logging.getLogger(__name__).error("UnifiedExportWrapper: ipadapter_scale missing; required when use_ipadapter=True") + raise RuntimeError("UnifiedExportWrapper: ipadapter_scale tensor is required when use_ipadapter=True") + ipadapter_scale = control_args[0] + if not isinstance(ipadapter_scale, torch.Tensor): + import logging + logging.getLogger(__name__).error(f"UnifiedExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}") + raise TypeError("ipadapter_scale must be a torch.Tensor") + try: + import logging + logging.getLogger(__name__).debug(f"UnifiedExportWrapper: ipadapter_scale shape={tuple(ipadapter_scale.shape)}, dtype={ipadapter_scale.dtype}") + except Exception: + pass + # assign per-layer scale tensors into processors + self.ipadapter_wrapper.set_ipadapter_scale(ipadapter_scale) + # remove it from control args before passing to controlnet wrapper + control_args = control_args[1:] + if self.controlnet_wrapper: # ControlNet wrapper handles the UNet call with all parameters return self.controlnet_wrapper(sample, timestep, encoder_hidden_states, *control_args, **kwargs) diff --git a/src/streamdiffusion/acceleration/tensorrt/models/models.py b/src/streamdiffusion/acceleration/tensorrt/models/models.py index 39a7fe16..6c94caee 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/models.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/models.py @@ -261,6 +261,7 @@ def __init__( image_width=512, use_ipadapter=False, num_image_tokens=4, + num_ip_layers: int = None, ): super(UNet, self).__init__( fp16=fp16, @@ -279,6 +280,7 @@ def __init__( self.unet_arch = unet_arch or {} self.use_ipadapter = use_ipadapter self.num_image_tokens = num_image_tokens + self.num_ip_layers = num_ip_layers # Baked-in IPAdapter configuration if self.use_ipadapter: @@ -287,6 +289,8 @@ def __init__( # Could use dynamic shapes: min=77 (text only), max=93 (text + 16 tokens) # This would allow a single engine to handle all IPAdapter types instead of separate engines self.text_maxlen = text_maxlen + self.num_image_tokens + if self.num_ip_layers is None: + raise ValueError("UNet model requires num_ip_layers when use_ipadapter=True") if self.use_control and self.unet_arch: @@ -389,6 +393,13 @@ def _add_control_inputs(self): def get_input_names(self): """Get input names including ControlNet inputs""" base_names = ["sample", "timestep", "encoder_hidden_states"] + if self.use_ipadapter: + base_names.append("ipadapter_scale") + try: + import logging + logging.getLogger(__name__).debug(f"TRT Models: get_input_names with ipadapter -> {base_names}") + except Exception: + pass if self.use_control and self.control_inputs: control_names = sorted(self.control_inputs.keys()) return base_names + control_names @@ -404,6 +415,13 @@ def get_dynamic_axes(self): "encoder_hidden_states": {0: "2B"}, "latent": {0: "2B", 2: "H", 3: "W"}, } + if self.use_ipadapter: + base_axes["ipadapter_scale"] = {0: "L_ip"} + try: + import logging + logging.getLogger(__name__).debug(f"TRT Models: dynamic axes include ipadapter_scale with L_ip={getattr(self, 'num_ip_layers', None)}") + except Exception: + pass if self.use_control and self.control_inputs: for name, shape_spec in self.control_inputs.items(): @@ -469,6 +487,18 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, (max_batch, self.text_maxlen, self.embedding_dim), ], } + if self.use_ipadapter: + # scalar per-layer vector, length fixed to num_ip_layers + profile["ipadapter_scale"] = [ + (1,), + (self.num_ip_layers,), + (self.num_ip_layers,), + ] + try: + import logging + logging.getLogger(__name__).debug(f"TRT Models: profile ipadapter_scale min/opt/max={(1,),(self.num_ip_layers,),(self.num_ip_layers,)}") + except Exception: + pass if self.use_control and self.control_inputs: # Use the actual calculated spatial dimensions for each ControlNet input @@ -507,6 +537,13 @@ def get_shape_dict(self, batch_size, image_height, image_width): "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), "latent": (2 * batch_size, 4, latent_height, latent_width), } + if self.use_ipadapter: + shape_dict["ipadapter_scale"] = (self.num_ip_layers,) + try: + import logging + logging.getLogger(__name__).debug(f"TRT Models: shape_dict ipadapter_scale={(self.num_ip_layers,)}") + except Exception: + pass if self.use_control and self.control_inputs: # Use the actual calculated spatial dimensions for each ControlNet input @@ -543,8 +580,7 @@ def get_sample_input(self, batch_size, image_height, image_width): torch.randn(2 * export_batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), ] - print(f"🔧 UNet ONNX export inputs: sample={base_inputs[0].shape}, timestep={base_inputs[1].shape}, encoder_hidden_states={base_inputs[2].shape}") - print(f" embedding_dim={self.embedding_dim}, expected for model type: {'SDXL=2048' if self.embedding_dim >= 2048 else 'SD1.5=768'}") + if self.use_control and self.control_inputs: control_inputs = [] @@ -570,8 +606,13 @@ def get_sample_input(self, batch_size, image_height, image_width): if len(control_inputs) % 4 == 0: torch.cuda.empty_cache() + # Append ipadapter_scale if needed + if self.use_ipadapter: + base_inputs.append(torch.ones(self.num_ip_layers, dtype=torch.float32, device=self.device)) return tuple(base_inputs + control_inputs) + if self.use_ipadapter: + base_inputs.append(torch.ones(self.num_ip_layers, dtype=torch.float32, device=self.device)) return tuple(base_inputs) diff --git a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/controlnet_engine.py b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/controlnet_engine.py index 82d01027..18f6e2b3 100644 --- a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/controlnet_engine.py +++ b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/controlnet_engine.py @@ -5,6 +5,7 @@ import traceback import logging from typing import List, Optional, Tuple, Dict, Any +import threading from polygraphy import cuda from ..utilities import Engine @@ -22,6 +23,8 @@ def __init__(self, engine_path: str, stream: 'cuda.Stream', use_cuda_graph: bool self.stream = stream self.use_cuda_graph = use_cuda_graph self.model_type = model_type.lower() + # Serialize infer calls per engine/context for safety + self._infer_lock = threading.RLock() self.engine.load() self.engine.activate() @@ -108,15 +111,16 @@ def __call__(self, output_shapes = self._resolve_output_shapes(batch_size, latent_height, latent_width) shape_dict.update(output_shapes) - self.engine.allocate_buffers(shape_dict=shape_dict, device=sample.device) - - outputs = self.engine.infer( - input_dict, - self.stream, - use_cuda_graph=self.use_cuda_graph, - ) - - self.stream.synchronize() + with self._infer_lock: + self.engine.allocate_buffers(shape_dict=shape_dict, device=sample.device) + outputs = self.engine.infer( + input_dict, + self.stream, + use_cuda_graph=self.use_cuda_graph, + ) + # Synchronize to ensure outputs are ready before consumption by other streams + # This preserves correctness when UNet runs on a different CUDA stream. + self.stream.synchronize() down_blocks, mid_block = self._extract_controlnet_outputs(outputs) return down_blocks, mid_block diff --git a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/unet_engine.py b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/unet_engine.py index 86d1963f..1dafd766 100644 --- a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/unet_engine.py +++ b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/unet_engine.py @@ -39,11 +39,11 @@ def __call__( controlnet_conditioning: Optional[Dict[str, List[torch.Tensor]]] = None, **kwargs, ) -> Any: - - - - - + logger.debug("UNet2DConditionModelEngine.__call__: enter") + logger.debug(f"UNet2DConditionModelEngine.__call__: use_ipadapter={getattr(self, 'use_ipadapter', False)}, use_control={self.use_control}") + logger.debug(f"UNet2DConditionModelEngine.__call__: latent_model_input shape={tuple(latent_model_input.shape)}, dtype={latent_model_input.dtype}, device={latent_model_input.device}") + logger.debug(f"UNet2DConditionModelEngine.__call__: timestep shape={tuple(timestep.shape)}, dtype={timestep.dtype}") + logger.debug(f"UNet2DConditionModelEngine.__call__: encoder_hidden_states shape={tuple(encoder_hidden_states.shape)}, dtype={encoder_hidden_states.dtype}") if timestep.dtype != torch.float32: @@ -63,6 +63,22 @@ def __call__( "encoder_hidden_states": encoder_hidden_states, } + + # Handle IP-Adapter runtime scale vector if engine was built with it + if getattr(self, 'use_ipadapter', False): + if 'ipadapter_scale' not in kwargs: + logger.error("UNet2DConditionModelEngine: ipadapter_scale missing but required (use_ipadapter=True)") + raise RuntimeError("UNet2DConditionModelEngine: ipadapter_scale is required for IP-Adapter engines") + ip_scale = kwargs['ipadapter_scale'] + if not isinstance(ip_scale, torch.Tensor): + logger.error(f"UNet2DConditionModelEngine: ipadapter_scale has wrong type: {type(ip_scale)}") + raise TypeError("ipadapter_scale must be a torch.Tensor") + logger.debug(f"UNet2DConditionModelEngine.__call__: ipadapter_scale shape={tuple(ip_scale.shape)}, dtype={ip_scale.dtype}, device={ip_scale.device}, min={float(ip_scale.min().item()) if ip_scale.numel()>0 else 'n/a'}, max={float(ip_scale.max().item()) if ip_scale.numel()>0 else 'n/a'}") + shape_dict["ipadapter_scale"] = ip_scale.shape + input_dict["ipadapter_scale"] = ip_scale + + + # Handle ControlNet inputs if provided if controlnet_conditioning is not None: # Option 1: Direct ControlNet conditioning dict (organized by type) @@ -102,17 +118,24 @@ def __call__( allocated_before = torch.cuda.memory_allocated() / 1024**3 logger.debug(f"VRAM before allocation: {allocated_before:.2f}GB") + logger.debug(f"UNet2DConditionModelEngine.__call__: shape_dict={ {k: tuple(v) if hasattr(v,'__iter__') else v for k,v in shape_dict.items()} }") self.engine.allocate_buffers(shape_dict=shape_dict, device=latent_model_input.device) if self.debug_vram: allocated_after = torch.cuda.memory_allocated() / 1024**3 logger.debug(f"VRAM after allocation: {allocated_after:.2f}GB") - outputs = self.engine.infer( - input_dict, - self.stream, - use_cuda_graph=self.use_cuda_graph, - ) + logger.debug(f"UNet2DConditionModelEngine.__call__: input_dict keys={list(input_dict.keys())}") + try: + outputs = self.engine.infer( + input_dict, + self.stream, + use_cuda_graph=self.use_cuda_graph, + ) + except Exception as e: + logger.exception(f"UNet2DConditionModelEngine.__call__: Engine.infer failed: {e}") + raise + if self.debug_vram: allocated_final = torch.cuda.memory_allocated() / 1024**3 @@ -123,6 +146,9 @@ def __call__( raise ValueError("TensorRT engine did not produce expected 'latent' output") noise_pred = outputs["latent"] + logger.debug(f"UNet2DConditionModelEngine.__call__: output shape={tuple(noise_pred.shape)}, dtype={noise_pred.dtype}, device={noise_pred.device}") + if torch.isnan(noise_pred).any() or torch.isinf(noise_pred).any(): + logger.error("UNet2DConditionModelEngine.__call__: output contains NaN/Inf") diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 74f12931..da2def14 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -54,10 +54,6 @@ def create_wrapper_from_config(config: Dict[str, Any], **overrides) -> Any: wrapper_params = _extract_wrapper_params(final_config) wrapper = StreamDiffusionWrapper(**wrapper_params) - # Setup IPAdapter if configured - if 'ipadapters' in final_config and final_config['ipadapters']: - wrapper = _setup_ipadapter_from_config(wrapper, final_config) - prepare_params = _extract_prepare_params(final_config) # Handle prompt configuration with clear precedence @@ -87,7 +83,7 @@ def create_wrapper_from_config(config: Dict[str, Any], **overrides) -> Any: # Apply seed blending if configured and not already handled in prepare if 'seed_blending' in final_config and 'prompt_blending' not in final_config: seed_blend_config = final_config['seed_blending'] - wrapper.update_seed_blending( + wrapper.update_stream_params( seed_list=seed_blend_config.get('seed_list', []), interpolation_method=seed_blend_config.get('interpolation_method', 'linear') ) @@ -128,6 +124,9 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: 'normalize_prompt_weights': config.get('normalize_prompt_weights', True), 'normalize_seed_weights': config.get('normalize_seed_weights', True), 'enable_pytorch_fallback': config.get('enable_pytorch_fallback', False), + # Concurrency options + 'controlnet_max_parallel': config.get('controlnet_max_parallel'), + 'controlnet_block_add_when_parallel': config.get('controlnet_block_add_when_parallel', True), } if 'controlnets' in config and config['controlnets']: param_map['use_controlnet'] = True @@ -214,98 +213,6 @@ def _prepare_ipadapter_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]: return ipadapter_configs -def _setup_ipadapter_from_config(wrapper, config: Dict[str, Any]): - """Setup IPAdapter pipeline from configuration""" - try: - from .ipadapter import BaseIPAdapterPipeline - - # Create pipeline - device = config.get('device', 'cuda') - dtype = _parse_dtype(config.get('dtype', 'float16')) - pipeline = BaseIPAdapterPipeline(wrapper.stream, device, dtype) - - # Handle preloaded models vs fresh setup - if _has_preloaded_models(wrapper): - _configure_preloaded_pipeline(pipeline, config) - else: - _configure_fresh_pipeline(pipeline, config) - - # Setup pipeline attributes - pipeline.batch_size = getattr(wrapper, 'batch_size', 1) - pipeline._original_wrapper = wrapper - - return pipeline - - except ImportError as e: - raise ImportError(f"_setup_ipadapter_from_config: IPAdapter module not found: {e}") from e - except Exception as e: - print(f"_setup_ipadapter_from_config: Failed to setup IPAdapter: {e}") - raise - - - - - -def _has_preloaded_models(wrapper) -> bool: - """Check if wrapper has preloaded IPAdapter models""" - return (hasattr(wrapper, 'stream') and - hasattr(wrapper.stream, '_preloaded_with_weights') and - wrapper.stream._preloaded_with_weights and - hasattr(wrapper.stream, '_preloaded_ipadapters') and - wrapper.stream._preloaded_ipadapters) - - -def _configure_preloaded_pipeline(pipeline, config: Dict[str, Any]): - """Configure pipeline using preloaded models""" - pipeline.ipadapter = pipeline.stream._preloaded_ipadapters[0] - - ipadapter_configs = _prepare_ipadapter_configs(config) - if ipadapter_configs: - ip_config = ipadapter_configs[0] - if ip_config.get('enabled', True): - _apply_ipadapter_config(pipeline, ip_config) - - # Register enhancer for TensorRT compatibility - pipeline.stream._param_updater.register_embedding_enhancer( - pipeline._enhance_embeddings_with_ipadapter, name="IPAdapter" - ) - - if len(ipadapter_configs) > 1: - print("_setup_ipadapter_from_config: WARNING - Multiple IPAdapters configured but only first one will be used") - - -def _configure_fresh_pipeline(pipeline, config: Dict[str, Any]): - """Configure pipeline with fresh IPAdapter setup""" - ipadapter_configs = _prepare_ipadapter_configs(config) - if ipadapter_configs: - ip_config = ipadapter_configs[0] - if ip_config.get('enabled', True): - pipeline.set_ipadapter( - ipadapter_model_path=ip_config['ipadapter_model_path'], - image_encoder_path=ip_config['image_encoder_path'], - style_image=ip_config.get('style_image'), - scale=ip_config.get('scale', 1.0) - ) - - if len(ipadapter_configs) > 1: - print("_setup_ipadapter_from_config: WARNING - Multiple IPAdapters configured but only first one will be used") - - -def _apply_ipadapter_config(pipeline, ip_config: Dict[str, Any]): - """Apply configuration to existing IPAdapter""" - # Set style image - style_image_path = ip_config.get('style_image') - if style_image_path: - from PIL import Image - pipeline.style_image = Image.open(style_image_path).convert("RGB") - - # Set scale - scale = ip_config.get('scale', 1.0) - pipeline.scale = scale - if pipeline.ipadapter: - pipeline.ipadapter.set_scale(scale) - - def create_prompt_blending_config( base_config: Dict[str, Any], prompt_list: List[Tuple[str, float]], diff --git a/src/streamdiffusion/controlnet/__init__.py b/src/streamdiffusion/controlnet/__init__.py deleted file mode 100644 index dbf1d606..00000000 --- a/src/streamdiffusion/controlnet/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .base_controlnet_pipeline import BaseControlNetPipeline -from .controlnet_pipeline import ControlNetPipeline -from .controlnet_sdxlturbo_pipeline import SDXLTurboControlNetPipeline - - -__all__ = [ - "BaseControlNetPipeline", - "ControlNetPipeline", - "SDXLTurboControlNetPipeline", -] \ No newline at end of file diff --git a/src/streamdiffusion/controlnet/base_controlnet_pipeline.py b/src/streamdiffusion/controlnet/base_controlnet_pipeline.py deleted file mode 100644 index bd802577..00000000 --- a/src/streamdiffusion/controlnet/base_controlnet_pipeline.py +++ /dev/null @@ -1,741 +0,0 @@ -import torch -import traceback -import threading -import queue -from enum import Enum -from typing import List, Optional, Union, Dict, Any, Tuple -from PIL import Image -import numpy as np -from pathlib import Path -import logging - -from diffusers.models import ControlNetModel -from diffusers.utils import load_image - -from ..pipeline import StreamDiffusion -from ..preprocessing.processors import get_preprocessor -from ..preprocessing.preprocessing_orchestrator import PreprocessingOrchestrator - -# Setup logger for parallel processing -logger = logging.getLogger(__name__) - - -class ControlNetOperation(Enum): - """Types of ControlNet operations for deferred execution""" - ADD = "add" - REMOVE = "remove" - -class BaseControlNetPipeline: - """ - ControlNet-enabled StreamDiffusion pipeline with optional inter-frame pipelining. - - Supports both synchronous and pipelined preprocessing modes: - - Sync mode: Processes each frame completely before moving to the next - - Pipelined mode: Overlaps preprocessing of next frame with current frame processing - """ - - def __init__(self, - stream_diffusion: StreamDiffusion, - device: str = "cuda", - dtype: torch.dtype = torch.float16, - use_pipelined_processing: bool = True): - """ - Initialize ControlNet pipeline. - - Args: - stream_diffusion: StreamDiffusion instance to wrap - device: Device to run on ("cuda" or "cpu") - dtype: Tensor dtype for processing - use_pipelined_processing: Enable inter-frame pipelining for better performance - """ - self.stream = stream_diffusion - self.device = device - self.dtype = dtype - self.model_type = getattr(self, 'model_type', 'ControlNet') # Default fallback - self.use_pipelined_processing = use_pipelined_processing - - self.controlnets: List[ControlNetModel] = [] - self.controlnet_images: List[Optional[torch.Tensor]] = [] - self.controlnet_scales: List[float] = [] - self.preprocessors: List[Optional[Any]] = [] - - self._original_unet_step = None - self._is_patched = False - self._has_feedback_preprocessor_cached = False - - # Initialize preprocessing orchestrator - self._preprocessing_orchestrator = PreprocessingOrchestrator( - device=self.device, - dtype=self.dtype, - max_workers=4 - ) - - # Non-blocking operation support - self._operation_queue = queue.Queue() - self._collections_lock = threading.RLock() - self._enable_persistent_patching = True - - # Background thread for processing operations - self._shutdown_event = threading.Event() - self._background_thread = threading.Thread(target=self._process_operations_worker, daemon=True) - self._background_thread.start() - logger.info(f"BaseControlNetPipeline: Started background operations thread") - - def add_controlnet(self, - controlnet_config: Dict[str, Any], - control_image: Optional[Union[str, Image.Image, np.ndarray, torch.Tensor]] = None, - immediate: bool = False) -> Optional[int]: - """Add a ControlNet to the pipeline - - Args: - controlnet_config: ControlNet configuration dict - control_image: Optional control image - immediate: If True, add immediately (blocking). If False, queue for deferred execution. - - Returns: - ControlNet index if immediate=True, None if deferred - """ - if immediate: - return self._add_controlnet_immediate(controlnet_config, control_image) - else: - operation = { - 'type': ControlNetOperation.ADD, - 'config': controlnet_config, - 'control_image': control_image - } - self._operation_queue.put(operation) - logger.info(f"add_controlnet: Queued addition of ControlNet {controlnet_config.get('model_id', 'unknown')}") - return None - - def _add_controlnet_immediate(self, - controlnet_config: Dict[str, Any], - control_image: Optional[Union[str, Image.Image, np.ndarray, torch.Tensor]] = None) -> int: - """Add a ControlNet to the pipeline immediately (for internal use during initialization)""" - if not controlnet_config.get('enabled', True): - return -1 - - # Load ControlNet model - controlnet = self._load_controlnet_model(controlnet_config['model_id']) - - # Load preprocessor if specified - preprocessor = None - if controlnet_config.get('preprocessor'): - preprocessor = get_preprocessor(controlnet_config['preprocessor']) - # Set preprocessor parameters including device, dtype, and resolution - preprocessor_params = { - 'device': self.device, - 'dtype': self.dtype, - 'image_width': self.stream.width, # Pass actual width - 'image_height': self.stream.height, # Pass actual height - } - if controlnet_config.get('preprocessor_params'): - preprocessor_params.update(controlnet_config['preprocessor_params']) - preprocessor.params.update(preprocessor_params) - # Update device and dtype directly - if hasattr(preprocessor, 'device'): - preprocessor.device = self.device - if hasattr(preprocessor, 'dtype'): - preprocessor.dtype = self.dtype - - # Set pipeline reference for feedback preprocessor - if hasattr(preprocessor, 'set_pipeline_ref'): - preprocessor.set_pipeline_ref(self.stream) - - # Process control image if provided - processed_image = None - if control_image is not None: - processed_image = self._prepare_control_image(control_image, preprocessor) - elif controlnet_config.get('control_image_path'): - # Load from configured path - control_image = load_image(controlnet_config['control_image_path']) - processed_image = self._prepare_control_image(control_image, preprocessor) - - # Thread-safe addition to collections - with self._collections_lock: - self.controlnets.append(controlnet) - self.controlnet_images.append(processed_image) - self.controlnet_scales.append(controlnet_config.get('conditioning_scale', 1.0)) - self.preprocessors.append(preprocessor) - - # Update feedback preprocessor cache - self._update_feedback_preprocessor_cache() - - # Use persistent patching - patch once and handle empty lists gracefully - if len(self.controlnets) == 1 and self._enable_persistent_patching and not self._is_patched: - self._patch_stream_diffusion() - - return len(self.controlnets) - 1 - - def remove_controlnet(self, index: int, immediate: bool = False) -> None: - """Remove a ControlNet by index - - Args: - index: Index of ControlNet to remove - immediate: If True, remove immediately (blocking). If False, queue for deferred execution. - """ - if immediate: - self._remove_controlnet_immediate(index) - else: - operation = { - 'type': ControlNetOperation.REMOVE, - 'index': index - } - self._operation_queue.put(operation) - logger.info(f"remove_controlnet: Queued removal of ControlNet at index {index}") - - def _remove_controlnet_immediate(self, index: int) -> None: - """Remove a ControlNet by index immediately (for internal use)""" - with self._collections_lock: - if 0 <= index < len(self.controlnets): - # Get controlnet reference for cleanup - controlnet_to_remove = self.controlnets[index] - - # Aggressive cleanup - move model to CPU and clear GPU memory - if controlnet_to_remove is not None: - try: - # Move model to CPU first to free GPU memory - if hasattr(controlnet_to_remove, 'to'): - controlnet_to_remove.to('cpu') - - # Clear module parameters - if hasattr(controlnet_to_remove, '_modules'): - for module_name, module in controlnet_to_remove._modules.items(): - if hasattr(module, 'to'): - module.to('cpu') - if hasattr(module, '_parameters'): - for param_name, param in list(module._parameters.items()): - if param is not None: - param.data = param.data.cpu() - del param - module._parameters[param_name] = None - - # Clear TensorRT engine if present - if hasattr(controlnet_to_remove, 'engine') and controlnet_to_remove.engine is not None: - del controlnet_to_remove.engine - controlnet_to_remove.engine = None - - except Exception: - pass # Continue cleanup even if some parts fail - - # Remove from collections - self.controlnets.pop(index) - self.controlnet_images.pop(index) - self.controlnet_scales.pop(index) - self.preprocessors.pop(index) - - # Update feedback preprocessor cache - self._update_feedback_preprocessor_cache() - - # With persistent patching, we don't unpatch when empty - just handle gracefully - if len(self.controlnets) == 0 and not self._enable_persistent_patching: - self._unpatch_stream_diffusion() - - # Cleanup reference and force garbage collection - del controlnet_to_remove - import gc - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - torch.cuda.empty_cache() - - else: - raise IndexError(f"{self.model_type} ControlNet index {index} out of range") - - def _process_operations_worker(self) -> None: - """Background thread worker that processes ControlNet operations without blocking inference""" - logger.info("BaseControlNetPipeline: Background operations worker started") - - while not self._shutdown_event.is_set(): - try: - # Wait for operation with timeout to allow checking shutdown event - operation = self._operation_queue.get(timeout=1.0) - - if operation['type'] == ControlNetOperation.ADD: - self._add_controlnet_immediate(operation['config'], operation.get('control_image')) - logger.info(f"Background worker: Applied ADD operation for {operation['config'].get('model_id', 'unknown')}") - elif operation['type'] == ControlNetOperation.REMOVE: - self._remove_controlnet_immediate(operation['index']) - logger.info(f"Background worker: Applied REMOVE operation for index {operation['index']}") - - # Mark task as done - self._operation_queue.task_done() - - except queue.Empty: - # Timeout occurred, check shutdown event and continue - continue - except Exception as e: - logger.error(f"Background worker: Failed to process operation: {e}") - # Mark task as done even if it failed - try: - self._operation_queue.task_done() - except: - pass - - logger.info("BaseControlNetPipeline: Background operations worker stopped") - - - - def clear_controlnets(self) -> None: - """Remove all ControlNets""" - self.controlnets.clear() - self.controlnet_images.clear() - self.controlnet_scales.clear() - self.preprocessors.clear() - - # Update feedback preprocessor cache - self._update_feedback_preprocessor_cache() - - self._unpatch_stream_diffusion() - - def _update_feedback_preprocessor_cache(self) -> None: - """Update the cached feedback preprocessor detection result""" - self._has_feedback_preprocessor_cached = any( - preprocessor is not None and hasattr(preprocessor, 'set_pipeline_ref') - for preprocessor in self.preprocessors - ) - - def _has_feedback_preprocessor(self) -> bool: - """Check if any preprocessor is a feedback preprocessor (cached result)""" - return self._has_feedback_preprocessor_cached - - def cleanup_controlnets(self) -> None: - """Cleanup resources including thread pool and background worker""" - # Stop background operations worker - if hasattr(self, '_shutdown_event'): - logger.info("BaseControlNetPipeline: Stopping background operations worker...") - self._shutdown_event.set() - - if hasattr(self, '_background_thread') and self._background_thread.is_alive(): - self._background_thread.join(timeout=5.0) - if self._background_thread.is_alive(): - logger.warning("BaseControlNetPipeline: Background thread did not stop within timeout") - else: - logger.info("BaseControlNetPipeline: Background operations worker stopped") - - # Cleanup preprocessing orchestrator - if hasattr(self, '_preprocessing_orchestrator'): - self._preprocessing_orchestrator.cleanup() - - def __del__(self): - """Cleanup on object destruction""" - try: - self.cleanup_controlnets() - except: - pass # Ignore errors during cleanup - - - - - def update_control_image_efficient(self, control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], index: Optional[int] = None) -> None: - """Efficiently update ControlNet(s) with cache-aware preprocessing""" - - # Single ControlNet case - always use sync processing - if index is not None: - processed_images = self._preprocessing_orchestrator.process_control_images_sync( - control_image=control_image, - preprocessors=self.preprocessors, - scales=self.controlnet_scales, - stream_width=self.stream.width, - stream_height=self.stream.height, - index=index - ) - # Multi-ControlNet case - use pipelined or sync based on configuration - elif self.use_pipelined_processing and not self._has_feedback_preprocessor(): - processed_images = self._preprocessing_orchestrator.process_control_images_pipelined( - control_image=control_image, - preprocessors=self.preprocessors, - scales=self.controlnet_scales, - stream_width=self.stream.width, - stream_height=self.stream.height - ) - else: - # Use synchronous processing (required for feedback preprocessors or when pipelining disabled) - processed_images = self._preprocessing_orchestrator.process_control_images_sync( - control_image=control_image, - preprocessors=self.preprocessors, - scales=self.controlnet_scales, - stream_width=self.stream.width, - stream_height=self.stream.height, - index=None - ) - - # If empty list returned, no update needed (same frame detected) - if not processed_images: - return - - # Update controlnet_images with results - for i, processed_image in enumerate(processed_images): - if processed_image is not None: - self.controlnet_images[i] = processed_image - - def update_controlnet_scale(self, index: int, scale: float) -> None: - """Update the conditioning scale for a specific ControlNet""" - if 0 <= index < len(self.controlnets): - self.controlnet_scales[index] = scale - else: - raise IndexError(f"{self.model_type} ControlNet index {index} out of range") - - def _load_controlnet_model(self, model_id: str): - """Load a ControlNet model with TensorRT acceleration support""" - # Check if TensorRT engine pool is available - if hasattr(self.stream, 'controlnet_engine_pool'): - model_type = self._detected_model_type - is_sdxl = self._is_sdxl - - logger.info(f"Loading ControlNet {model_id} with TensorRT acceleration") - logger.info(f" Model type: {model_type}, is_sdxl: {is_sdxl}") - - # Get batch size for engine compilation - detected_batch_size = getattr(self.stream, 'trt_unet_batch_size', 1) - - try: - return self.stream.controlnet_engine_pool.load_engine( - model_id=model_id, - model_type=model_type, - batch_size=detected_batch_size - ) - except Exception as e: - pytorch_controlnet = self._load_pytorch_controlnet_model(model_id) - logger.error(f"Failed to load {self.model_type} ControlNet model '{model_id}': {e}") - return self.stream.controlnet_engine_pool.get_or_load_engine( - model_id=model_id, - pytorch_model=pytorch_controlnet, - model_type=model_type, - batch_size=detected_batch_size - ) - else: - # Fallback to PyTorch only - logger.info(f"Loading ControlNet {model_id} (PyTorch only - no TensorRT acceleration)") - return self._load_pytorch_controlnet_model(model_id) - - def _load_pytorch_controlnet_model(self, model_id: str): - """Load a ControlNet model from HuggingFace or local path""" - try: - # Check if it's a local path - if Path(model_id).exists(): - # Local directory - controlnet = ControlNetModel.from_pretrained( - model_id, - torch_dtype=self.dtype, - local_files_only=True - ) - else: - # Try as HuggingFace model ID - if "/" in model_id and model_id.count("/") > 1: - # Handle subfolder case (e.g., "repo/model/subfolder") - parts = model_id.split("/") - repo_id = "/".join(parts[:2]) - subfolder = "/".join(parts[2:]) - controlnet = ControlNetModel.from_pretrained( - repo_id, - subfolder=subfolder, - torch_dtype=self.dtype - ) - else: - controlnet = ControlNetModel.from_pretrained( - model_id, - torch_dtype=self.dtype - ) - - # Move to device - controlnet = controlnet.to(device=self.device, dtype=self.dtype) - return controlnet - - except Exception as e: - logger.error(f"Failed to load {self.model_type} ControlNet model '{model_id}': {e}") - logger.error(f"Full stack trace for model loading failure:") - logger.error(traceback.format_exc()) - raise ValueError(f"Failed to load {self.model_type} ControlNet model '{model_id}': {e}") - - - - def _prepare_control_image(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - preprocessor: Optional[Any] = None) -> torch.Tensor: - """Prepare a control image for ControlNet input""" - # Delegate to preprocessing orchestrator - return self._preprocessing_orchestrator.prepare_control_image( - control_image=control_image, - preprocessor=preprocessor, - target_width=self.stream.width, - target_height=self.stream.height - ) - - def _process_cfg_and_predict(self, model_pred: torch.Tensor, x_t_latent: torch.Tensor, idx=None) -> Tuple[torch.Tensor, torch.Tensor]: - """Process CFG logic and scheduler step (shared between TensorRT and PyTorch modes)""" - # CFG processing - if self.stream.guidance_scale > 1.0 and (self.stream.cfg_type == "initialize"): - noise_pred_text = model_pred[1:] - self.stream.stock_noise = torch.concat( - [model_pred[0:1], self.stream.stock_noise[1:]], dim=0 - ) - elif self.stream.guidance_scale > 1.0 and (self.stream.cfg_type == "full"): - noise_pred_uncond, noise_pred_text = model_pred.chunk(2) - else: - noise_pred_text = model_pred - - if self.stream.guidance_scale > 1.0 and ( - self.stream.cfg_type == "self" or self.stream.cfg_type == "initialize" - ): - noise_pred_uncond = self.stream.stock_noise * self.stream.delta - - if self.stream.guidance_scale > 1.0 and self.stream.cfg_type != "none": - model_pred = noise_pred_uncond + self.stream.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - else: - model_pred = noise_pred_text - - # Scheduler step - if self.stream.use_denoising_batch: - denoised_batch = self.stream.scheduler_step_batch(model_pred, x_t_latent, idx) - if self.stream.cfg_type == "self" or self.stream.cfg_type == "initialize": - scaled_noise = self.stream.beta_prod_t_sqrt * self.stream.stock_noise - delta_x = self.stream.scheduler_step_batch(model_pred, scaled_noise, idx) - alpha_next = torch.concat( - [ - self.stream.alpha_prod_t_sqrt[1:], - torch.ones_like(self.stream.alpha_prod_t_sqrt[0:1]), - ], - dim=0, - ) - delta_x = alpha_next * delta_x - beta_next = torch.concat( - [ - self.stream.beta_prod_t_sqrt[1:], - torch.ones_like(self.stream.beta_prod_t_sqrt[0:1]), - ], - dim=0, - ) - delta_x = delta_x / beta_next - init_noise = torch.concat( - [self.stream.init_noise[1:], self.stream.init_noise[0:1]], dim=0 - ) - self.stream.stock_noise = init_noise + delta_x - else: - denoised_batch = self.stream.scheduler_step_batch(model_pred, x_t_latent, idx) - - return denoised_batch, model_pred - - def _get_controlnet_conditioning(self, - x_t_latent: torch.Tensor, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - **kwargs) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: - """Get ControlNet conditioning with thread safety""" - # Operations are now processed in background thread - no blocking here - - # Thread-safe access to ControlNet collections - with self._collections_lock: - if not self.controlnets: - return None, None - - # Get active ControlNet indices (ControlNets with scale > 0 and valid images) - active_indices = [ - i for i, (controlnet, control_image, scale) in enumerate( - zip(self.controlnets, self.controlnet_images, self.controlnet_scales) - ) if controlnet is not None and control_image is not None and scale > 0 - ] - - if not active_indices: - return None, None - - # Create working copies to avoid holding lock during inference - active_controlnets = [self.controlnets[i] for i in active_indices] - active_images = [self.controlnet_images[i] for i in active_indices] - active_scales = [self.controlnet_scales[i] for i in active_indices] - - # Prepare base kwargs for ControlNet calls - main_batch_size = x_t_latent.shape[0] - base_kwargs = { - 'sample': x_t_latent, - 'timestep': timestep, - 'encoder_hidden_states': encoder_hidden_states, - 'return_dict': False, - } - base_kwargs.update(self._get_additional_controlnet_kwargs(**kwargs)) - - down_samples_list = [] - mid_samples_list = [] - - for i, (controlnet, control_image, scale) in enumerate(zip(active_controlnets, active_images, active_scales)): - # Optimize batch expansion - do once per ControlNet - current_control_image = control_image - if (hasattr(controlnet, 'engine') and controlnet.engine is not None and - control_image.shape[0] != main_batch_size): - # Only expand if needed for TensorRT and batch sizes don't match - if control_image.dim() == 4: - current_control_image = control_image.repeat(main_batch_size // control_image.shape[0], 1, 1, 1) - else: - current_control_image = control_image.unsqueeze(0).repeat(main_batch_size, 1, 1, 1) - - # Optimized kwargs - copy base dict and update specific values - controlnet_kwargs = base_kwargs.copy() - controlnet_kwargs['controlnet_cond'] = current_control_image - controlnet_kwargs['conditioning_scale'] = scale - - # Forward pass through ControlNet - try: - down_samples, mid_sample = controlnet(**controlnet_kwargs) - - down_samples_list.append(down_samples) - mid_samples_list.append(mid_sample) - except Exception as e: - logger.error(f"_get_controlnet_conditioning: ControlNet {i} failed: {e}") - logger.error(f"_get_controlnet_conditioning: Full stack trace for ControlNet {i}:") - logger.error(traceback.format_exc()) - continue - - # Early exit if no outputs - if not down_samples_list: - return None, None - - # Optimized combination - single pass for single ControlNet - if len(down_samples_list) == 1: - return down_samples_list[0], mid_samples_list[0] - - # Vectorized combination for multiple ControlNets - down_block_res_samples = down_samples_list[0] - mid_block_res_sample = mid_samples_list[0] - - # In-place addition for remaining ControlNets - for down_samples, mid_sample in zip(down_samples_list[1:], mid_samples_list[1:]): - for j in range(len(down_block_res_samples)): - down_block_res_samples[j] += down_samples[j] - mid_block_res_sample += mid_sample - - return down_block_res_samples, mid_block_res_sample - - def _patch_stream_diffusion(self) -> None: - """Patch StreamDiffusion's unet_step method to include ControlNet conditioning""" - if self._is_patched: - return - - # Store original method - self._original_unet_step = self.stream.unet_step - - # Detect if TensorRT acceleration is being used - is_tensorrt = hasattr(self.stream.unet, 'engine') or hasattr(self.stream.unet, 'use_control') - - if is_tensorrt: - self._patch_tensorrt_mode() - else: - self._patch_pytorch_mode() - - self._is_patched = True - - def _patch_tensorrt_mode(self): - """Patch for TensorRT mode with ControlNet support""" - - def patched_unet_step_tensorrt(x_t_latent, t_list, idx=None): - # Handle CFG expansion (same as original) - if self.stream.guidance_scale > 1.0 and (self.stream.cfg_type == "initialize"): - x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) - t_list_expanded = torch.concat([t_list[0:1], t_list], dim=0) - elif self.stream.guidance_scale > 1.0 and (self.stream.cfg_type == "full"): - x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) - t_list_expanded = torch.concat([t_list, t_list], dim=0) - else: - x_t_latent_plus_uc = x_t_latent - t_list_expanded = t_list - - # Get pipeline-specific conditioning context - conditioning_context = self._get_conditioning_context(x_t_latent_plus_uc, t_list_expanded) - - # Get ControlNet conditioning - down_block_res_samples, mid_block_res_sample = self._get_controlnet_conditioning( - x_t_latent_plus_uc, t_list_expanded, self.stream.prompt_embeds[:, :77, :], **conditioning_context - ) - - # Call TensorRT engine with ControlNet inputs - model_pred = self.stream.unet( - x_t_latent_plus_uc, - t_list_expanded, - self.stream.prompt_embeds, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ).sample - - # Use shared CFG processing - return self._process_cfg_and_predict(model_pred, x_t_latent, idx) - - # Replace the method - self.stream.unet_step = patched_unet_step_tensorrt - logger.info("ControlNetPipeline: Successfully patched TensorRT unet_step method") - - def _patch_pytorch_mode(self): - """Patch for PyTorch mode with ControlNet support (original implementation)""" - - def patched_unet_step_pytorch(x_t_latent, t_list, idx=None): - # Handle CFG expansion - if self.stream.guidance_scale > 1.0 and (self.stream.cfg_type == "initialize"): - x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) - t_list_expanded = torch.concat([t_list[0:1], t_list], dim=0) - elif self.stream.guidance_scale > 1.0 and (self.stream.cfg_type == "full"): - x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) - t_list_expanded = torch.concat([t_list, t_list], dim=0) - else: - x_t_latent_plus_uc = x_t_latent - t_list_expanded = t_list - - # Get pipeline-specific conditioning context - conditioning_context = self._get_conditioning_context(x_t_latent_plus_uc, t_list_expanded) - - # Get ControlNet conditioning (extract original text embeddings only - no IPAdapter tokens) - down_block_res_samples, mid_block_res_sample = self._get_controlnet_conditioning( - x_t_latent_plus_uc, t_list_expanded, self.stream.prompt_embeds[:, :77, :], **conditioning_context - ) - - # Prepare UNet kwargs - unet_kwargs = { - 'sample': x_t_latent_plus_uc, - 'timestep': t_list_expanded, - 'encoder_hidden_states': self.stream.prompt_embeds, - 'return_dict': False, - } - - # Add ControlNet conditioning - if down_block_res_samples is not None: - unet_kwargs['down_block_additional_residuals'] = down_block_res_samples - if mid_block_res_sample is not None: - unet_kwargs['mid_block_additional_residual'] = mid_block_res_sample - - # Allow subclasses to add additional UNet kwargs (e.g., SDXL added_cond_kwargs) - unet_kwargs.update(self._get_additional_unet_kwargs(**conditioning_context)) - - # Call UNet with ControlNet conditioning - model_pred = self.stream.unet(**unet_kwargs)[0] - - # Use shared CFG processing - return self._process_cfg_and_predict(model_pred, x_t_latent, idx) - - # Replace the method - self.stream.unet_step = patched_unet_step_pytorch - - def _unpatch_stream_diffusion(self) -> None: - """Restore original StreamDiffusion unet_step method""" - if self._is_patched and self._original_unet_step is not None: - self.stream.unet_step = self._original_unet_step - self._is_patched = False - - def __call__(self, *args, **kwargs): - """Forward calls to the underlying StreamDiffusion instance""" - return self.stream(*args, **kwargs) - - def __getattr__(self, name): - """Forward attribute access to the underlying StreamDiffusion instance""" - return getattr(self.stream, name) - - # Hook methods for subclasses to override - def _get_conditioning_context(self, x_t_latent: torch.Tensor, t_list: torch.Tensor) -> Dict[str, Any]: - """Get conditioning context for this pipeline type (hook for subclasses)""" - return {} - - def _get_additional_controlnet_kwargs(self, **kwargs) -> Dict[str, Any]: - """Get additional kwargs for ControlNet calls (hook for subclasses)""" - return {} - - def _get_additional_unet_kwargs(self, **kwargs) -> Dict[str, Any]: - """Get additional kwargs for UNet calls (hook for subclasses)""" - return {} \ No newline at end of file diff --git a/src/streamdiffusion/controlnet/controlnet_pipeline.py b/src/streamdiffusion/controlnet/controlnet_pipeline.py deleted file mode 100644 index 1b771acf..00000000 --- a/src/streamdiffusion/controlnet/controlnet_pipeline.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from typing import List, Optional, Union, Dict, Any, Tuple -from PIL import Image -import numpy as np - -from ..pipeline import StreamDiffusion -from .base_controlnet_pipeline import BaseControlNetPipeline - - -class ControlNetPipeline(BaseControlNetPipeline): - """ - ControlNet-enabled StreamDiffusion pipeline for SD1.5 and SD Turbo with inter-frame parallelism - - This class extends StreamDiffusion with ControlNet support, allowing for - conditioning the generation process with multiple ControlNet models. - Supports both SD1.5 and SD Turbo models with pipelined preprocessing. - """ - - def __init__(self, - stream_diffusion: StreamDiffusion, - device: str = "cuda", - dtype: torch.dtype = torch.float16, - model_type: str = "SD1.5"): - """ - Initialize ControlNet pipeline - - Args: - stream_diffusion: Base StreamDiffusion instance - device: Device to run ControlNets on - dtype: Data type for ControlNet models - model_type: Type of model being used (e.g., "SD1.5", "SD Turbo") - """ - super().__init__(stream_diffusion, device, dtype, use_pipelined_processing=True) - self.model_type = model_type \ No newline at end of file diff --git a/src/streamdiffusion/controlnet/controlnet_sdxlturbo_pipeline.py b/src/streamdiffusion/controlnet/controlnet_sdxlturbo_pipeline.py deleted file mode 100644 index 34a3ddfe..00000000 --- a/src/streamdiffusion/controlnet/controlnet_sdxlturbo_pipeline.py +++ /dev/null @@ -1,107 +0,0 @@ -import torch -from typing import List, Optional, Union, Dict, Any, Tuple -from PIL import Image -import numpy as np - -from ..pipeline import StreamDiffusion -from .base_controlnet_pipeline import BaseControlNetPipeline - -class SDXLTurboControlNetPipeline(BaseControlNetPipeline): - """SDXL Turbo ControlNet pipeline using StreamDiffusion with inter-frame parallelism""" - - def __init__(self, - stream_diffusion: StreamDiffusion, - device: str = "cuda", - dtype: torch.dtype = torch.float16, - model_type: str = "SDXL Turbo"): - """Initialize SDXL Turbo ControlNet pipeline""" - super().__init__(stream_diffusion, device, dtype, use_pipelined_processing=True) - self.model_type = model_type - - @property - def controlnet_configs(self) -> List[Dict[str, Any]]: - """Get ControlNet configurations for compatibility with demos""" - configs = [] - for i in range(len(self.controlnets)): - configs.append({ - 'conditioning_scale': self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0, - 'enabled': self.controlnet_scales[i] > 0 if i < len(self.controlnet_scales) else False, - 'preprocessor': type(self.preprocessors[i]).__name__ if i < len(self.preprocessors) and self.preprocessors[i] is not None else None - }) - return configs - - def _post_process_control_image(self, control_tensor: torch.Tensor) -> torch.Tensor: - """Post-process control image tensor with SDXL-specific handling""" - target_size = (self.stream.height, self.stream.width) - current_size = control_tensor.shape[-2:] - if current_size != target_size: - import torch.nn.functional as F - control_tensor = F.interpolate( - control_tensor, - size=target_size, - mode='bilinear', - align_corners=False - ) - - return control_tensor - - def _get_conditioning_context(self, x_t_latent: torch.Tensor, t_list: torch.Tensor) -> Dict[str, Any]: - """Get SDXL-specific conditioning context""" - conditioning_context = {} - - # Use the conditioning that was set up in StreamDiffusion.prepare() - if hasattr(self.stream, 'add_text_embeds') and hasattr(self.stream, 'add_time_ids'): - if self.stream.add_text_embeds is not None and self.stream.add_time_ids is not None: - # Handle batching for CFG - replicate conditioning to match batch size - batch_size = x_t_latent.shape[0] - - # Replicate add_text_embeds and add_time_ids to match the batch size - if self.stream.guidance_scale > 1.0 and (self.stream.cfg_type == "initialize"): - # For initialize mode: [uncond, cond, cond, ...] - add_text_embeds = torch.cat([ - self.stream.add_text_embeds[0:1], # uncond - self.stream.add_text_embeds[1:2].repeat(batch_size - 1, 1) # repeat cond - ], dim=0) - add_time_ids = torch.cat([ - self.stream.add_time_ids[0:1], # uncond - self.stream.add_time_ids[1:2].repeat(batch_size - 1, 1) # repeat cond - ], dim=0) - elif self.stream.guidance_scale > 1.0 and (self.stream.cfg_type == "full"): - # For full mode: repeat both uncond and cond for each latent - repeat_factor = batch_size // 2 - add_text_embeds = self.stream.add_text_embeds.repeat(repeat_factor, 1) - add_time_ids = self.stream.add_time_ids.repeat(repeat_factor, 1) - else: - # No CFG: just repeat the conditioning - add_text_embeds = self.stream.add_text_embeds[1:2].repeat(batch_size, 1) if self.stream.add_text_embeds.shape[0] > 1 else self.stream.add_text_embeds.repeat(batch_size, 1) - add_time_ids = self.stream.add_time_ids[1:2].repeat(batch_size, 1) if self.stream.add_time_ids.shape[0] > 1 else self.stream.add_time_ids.repeat(batch_size, 1) - - conditioning_context['text_embeds'] = add_text_embeds - conditioning_context['time_ids'] = add_time_ids - - return conditioning_context - - def _get_additional_controlnet_kwargs(self, **kwargs) -> Dict[str, Any]: - """Get SDXL-specific additional kwargs for ControlNet calls""" - if 'text_embeds' in kwargs or 'time_ids' in kwargs: - return {'added_cond_kwargs': kwargs} - return {} - - def _get_additional_unet_kwargs(self, **kwargs) -> Dict[str, Any]: - """Get SDXL-specific additional kwargs for UNet calls""" - if kwargs: - return {'added_cond_kwargs': kwargs} - return {} - - def __call__(self, - image: Union[str, Image.Image, np.ndarray, torch.Tensor] = None, - num_inference_steps: int = None, - guidance_scale: float = None, - **kwargs) -> torch.Tensor: - """Generate image using SDXL with ControlNet""" - - if image is not None: - self.update_control_image_efficient(image) - return self.stream(image) - else: - return self.stream() \ No newline at end of file diff --git a/src/streamdiffusion/hooks.py b/src/streamdiffusion/hooks.py new file mode 100644 index 00000000..c71e3d60 --- /dev/null +++ b/src/streamdiffusion/hooks.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional +import torch + + +@dataclass +class EmbedsCtx: + """Context passed to embedding hooks. + + Fields: + - prompt_embeds: [batch, seq_len, dim] + - negative_prompt_embeds: optional [batch, seq_len, dim] + """ + prompt_embeds: torch.Tensor + negative_prompt_embeds: Optional[torch.Tensor] = None + + +@dataclass +class StepCtx: + """Context passed to UNet hooks for each denoising step. + + Fields: + - x_t_latent: latent tensor (possibly CFG-expanded) + - t_list: timesteps tensor (possibly CFG-expanded) + - step_index: optional int step index within total steps + - guidance_mode: one of {"none","full","self","initialize"} + - sdxl_cond: optional dict with SDXL micro-cond tensors + """ + x_t_latent: torch.Tensor + t_list: torch.Tensor + step_index: Optional[int] + guidance_mode: str + sdxl_cond: Optional[Dict[str, torch.Tensor]] = None + + +@dataclass +class UnetKwargsDelta: + """Delta produced by UNet hooks to augment UNet call kwargs.""" + down_block_additional_residuals: Optional[List[torch.Tensor]] = None + mid_block_additional_residual: Optional[torch.Tensor] = None + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None + # Additional kwargs to pass directly to the UNet call (e.g., ipadapter_scale) + extra_unet_kwargs: Optional[Dict[str, Any]] = None + + +# Type aliases for clarity +EmbeddingHook = Callable[[EmbedsCtx], EmbedsCtx] +UnetHook = Callable[[StepCtx], UnetKwargsDelta] + diff --git a/src/streamdiffusion/ipadapter/__init__.py b/src/streamdiffusion/ipadapter/__init__.py deleted file mode 100644 index 7068cdbe..00000000 --- a/src/streamdiffusion/ipadapter/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -try: - from .base_ipadapter_pipeline import BaseIPAdapterPipeline -except Exception as e: - print(f"ipadapter.__init__: Failed to import BaseIPAdapterPipeline: {e}") - raise - -__all__ = [ - "BaseIPAdapterPipeline", -] \ No newline at end of file diff --git a/src/streamdiffusion/ipadapter/base_ipadapter_pipeline.py b/src/streamdiffusion/ipadapter/base_ipadapter_pipeline.py deleted file mode 100644 index e1ea91b5..00000000 --- a/src/streamdiffusion/ipadapter/base_ipadapter_pipeline.py +++ /dev/null @@ -1,420 +0,0 @@ -import torch -import sys -import os -from typing import List, Optional, Union, Dict, Any, Tuple -from PIL import Image -import numpy as np -from pathlib import Path - -# Using relative import - no sys.path modification needed - -try: - from diffusers_ipadapter import IPAdapter -except Exception as e: - print(f"base_ipadapter_pipeline: Failed to import IPAdapter: {e}") - raise - -try: - from ..pipeline import StreamDiffusion -except Exception as e: - print(f"base_ipadapter_pipeline: Failed to import StreamDiffusion: {e}") - raise - -try: - from ..preprocessing.processors.ipadapter_embedding import IPAdapterEmbeddingPreprocessor -except Exception as e: - print(f"base_ipadapter_pipeline: Failed to import IPAdapterEmbeddingPreprocessor: {e}") - raise - -class BaseIPAdapterPipeline: - """ - Base IPAdapter-enabled StreamDiffusion pipeline - - This class integrates the existing Diffusers_IPAdapter implementation - with StreamDiffusion following the same pattern as ControlNet. - """ - - def __init__(self, - stream_diffusion: StreamDiffusion, - device: str = "cuda", - dtype: torch.dtype = torch.float16): - """ - Initialize base IPAdapter pipeline - - Args: - stream_diffusion: Base StreamDiffusion instance - device: Device to run IPAdapter on - dtype: Data type for IPAdapter models - """ - self.stream = stream_diffusion - self.device = device - self.dtype = dtype - - # IPAdapter storage (single IPAdapter for now) - # TODO: Add support for multiple IPAdapters and multiple style images in future phase - self.ipadapter: Optional[IPAdapter] = None - self.style_image: Optional[Image.Image] = None - self.scale: float = 1.0 - - # Style image key for embedding preprocessing - self._style_image_key = "ipadapter_main" - - # No caching needed - StreamParameterUpdater handles that - - # No patching needed - we use direct embedding assignment like the working script - - def set_ipadapter(self, - ipadapter_model_path: str, - image_encoder_path: str, - style_image: Optional[Union[str, Image.Image]] = None, - scale: float = 1.0) -> None: - """ - Set the IPAdapter for the pipeline (replaces any existing IPAdapter) - - Args: - ipadapter_model_path: Full path to IPAdapter weights file (local path or HuggingFace repo/file path) - image_encoder_path: Full path to CLIP image encoder (local path or HuggingFace repo/file path) - style_image: Style image for conditioning (optional) - scale: Conditioning scale - """ - # Clear any existing IPAdapter first - self.clear_ipadapter() - - # Resolve model paths (download if HuggingFace paths) - resolved_ipadapter_path = self._resolve_model_path(ipadapter_model_path) - resolved_encoder_path = self._resolve_model_path(image_encoder_path) - - # Create IPAdapter instance using existing code - self.ipadapter = IPAdapter( - pipe=self.stream.pipe, - ipadapter_ckpt_path=resolved_ipadapter_path, - image_encoder_path=resolved_encoder_path, - device=self.device, - dtype=self.dtype - ) - - # Create embedding preprocessor for parallel processing (if not already registered) - if not self._has_registered_preprocessor(): - embedding_preprocessor = IPAdapterEmbeddingPreprocessor( - ipadapter=self.ipadapter, - device=self.device, - dtype=self.dtype - ) - - # Register with StreamParameterUpdater for integrated processing - self.stream._param_updater.register_embedding_preprocessor( - embedding_preprocessor, - self._style_image_key - ) - - # Process style image if provided - if style_image is not None: - if isinstance(style_image, str): - self.style_image = Image.open(style_image).convert("RGB") - else: - self.style_image = style_image - - # Immediately process embeddings synchronously to ensure they're cached - self.stream._param_updater.update_style_image( - self._style_image_key, - self.style_image - ) - else: - self.style_image = None - - # Set scale - self.scale = scale - self.ipadapter.set_scale(scale) - - # Register IPAdapter enhancer with StreamParameterUpdater - self.stream._param_updater.register_embedding_enhancer( - self._enhance_embeddings_with_ipadapter, - name="IPAdapter" - ) - - def _has_registered_preprocessor(self) -> bool: - """Check if an embedding preprocessor is already registered for our style image key""" - if not hasattr(self.stream._param_updater, '_embedding_preprocessors'): - return False - - for preprocessor, key in self.stream._param_updater._embedding_preprocessors: - if key == self._style_image_key: - return True - return False - - def clear_ipadapter(self) -> None: - """Remove the IPAdapter""" - # Unregister enhancer from StreamParameterUpdater - if hasattr(self, '_enhance_embeddings_with_ipadapter'): - self.stream._param_updater.unregister_embedding_enhancer( - self._enhance_embeddings_with_ipadapter - ) - - # Unregister embedding preprocessor from StreamParameterUpdater - self.stream._param_updater.unregister_embedding_preprocessor(self._style_image_key) - - self.ipadapter = None - self.style_image = None - self.scale = 1.0 - - def update_style_image(self, style_image: Union[str, Image.Image]) -> None: - """ - Update style image for the IPAdapter - - Args: - style_image: New style image - """ - if isinstance(style_image, str): - self.style_image = Image.open(style_image).convert("RGB") - else: - self.style_image = style_image - - # Trigger parallel embedding preprocessing via StreamParameterUpdater - if self.style_image is not None: - self.stream._param_updater.update_style_image( - self._style_image_key, - self.style_image - ) - - def update_scale(self, scale: float) -> None: - """ - Update the conditioning scale for the IPAdapter - - Args: - scale: New conditioning scale - """ - if self.ipadapter is not None: - self.scale = scale - self.ipadapter.set_scale(scale) - - def _resolve_model_path(self, model_path: str) -> str: - """ - Resolve model path - download from HuggingFace if it's a repo/file path, or use local path - - Args: - model_path: Either a local file path or HuggingFace repo/file path (e.g. "h94/IP-Adapter/models/ip-adapter-plus_sd15.safetensors") - - Returns: - Resolved local path to the model - """ - from huggingface_hub import hf_hub_download, snapshot_download - - print(f"_resolve_model_path: Resolving path: {model_path}") - - # Check if it's a local path that exists - if os.path.exists(model_path): - print(f"_resolve_model_path: Using local path: {model_path}") - return model_path - - # Check if it looks like a HuggingFace repo/file path - if "/" in model_path and not os.path.isabs(model_path): - parts = model_path.split("/") - if len(parts) >= 3: - # Format: "repo/owner/path/to/file.bin" or "repo/owner/directory" - repo_id = "/".join(parts[:2]) # "h94/IP-Adapter" - file_path = "/".join(parts[2:]) # "models/ip-adapter-plus_sd15.bin" or "models/image_encoder" - - # Check if it's a file (has extension) or directory - if "." in parts[-1]: - # It's a file - print(f"_resolve_model_path: Downloading file {file_path} from {repo_id}") - try: - downloaded_path = hf_hub_download(repo_id=repo_id, filename=file_path) - print(f"_resolve_model_path: Downloaded to: {downloaded_path}") - return downloaded_path - except Exception as e: - raise ValueError(f"_resolve_model_path: Could not download {file_path} from {repo_id}: {e}") - else: - # It's a directory - print(f"_resolve_model_path: Downloading directory {file_path} from {repo_id}") - try: - repo_path = snapshot_download( - repo_id=repo_id, - allow_patterns=[f"{file_path}/*"] - ) - full_path = os.path.join(repo_path, file_path) - print(f"_resolve_model_path: Downloaded directory to: {full_path}") - return full_path - except Exception as e: - raise ValueError(f"_resolve_model_path: Could not download directory {file_path} from {repo_id}: {e}") - - # If we get here, it's neither a valid local path nor a valid HuggingFace path - raise ValueError(f"_resolve_model_path: Invalid model path: {model_path}. Must be either a local path or HuggingFace repo/file path (e.g. 'h94/IP-Adapter/models/ip-adapter-plus_sd15.safetensors').") - - def preload_models_for_tensorrt(self, ipadapter_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None) -> None: - """ - Pre-load IPAdapter models and install processors with weights before TensorRT compilation. - - This ensures that when TensorRT compilation occurs, the UNet already has IPAdapter - processors with actual model weights installed. - - Args: - ipadapter_config: Optional IPAdapter configuration - """ - - - try: - # Use the config if provided, otherwise use default h94/IP-Adapter - if ipadapter_config: - if isinstance(ipadapter_config, list): - config = ipadapter_config[0] # Use first IPAdapter config - else: - config = ipadapter_config - - model_path = config.get('ipadapter_model_path', 'h94/IP-Adapter/models/ip-adapter-plus_sd15.bin') - encoder_path = config.get('image_encoder_path', 'h94/IP-Adapter/models/image_encoder') - scale = config.get('scale', 1.0) - else: - # Default configuration - model_path = 'h94/IP-Adapter/models/ip-adapter-plus_sd15.safetensors' - encoder_path = 'h94/IP-Adapter/models/image_encoder' - scale = 1.0 - - # Resolve model paths using existing resolution logic - resolved_ipadapter_path = self._resolve_model_path(model_path) - resolved_encoder_path = self._resolve_model_path(encoder_path) - - - - # Create IPAdapter instance - this will install processors with weights - self.ipadapter = IPAdapter( - pipe=self.stream.pipe, - ipadapter_ckpt_path=resolved_ipadapter_path, - image_encoder_path=resolved_encoder_path, - device=self.device, - dtype=self.dtype, - ) - - # Set the correct scale from config BEFORE TensorRT compilation - self.ipadapter.set_scale(scale) - - # Create and register embedding preprocessor for parallel processing (if not already registered) - if not self._has_registered_preprocessor(): - embedding_preprocessor = IPAdapterEmbeddingPreprocessor( - ipadapter=self.ipadapter, - device=self.device, - dtype=self.dtype - ) - - # Register with StreamParameterUpdater for integrated processing - self.stream._param_updater.register_embedding_preprocessor( - embedding_preprocessor, - self._style_image_key - ) - - # Store reference to pre-loaded IPAdapter for later use - if not hasattr(self.stream, '_preloaded_ipadapters'): - self.stream._preloaded_ipadapters = [] - self.stream._preloaded_ipadapters.append(self.ipadapter) - - # Set our own properties - self.style_image = None # No style image during preload - self.scale = scale - - # Mark that stream was pre-loaded with weights - self.stream._preloaded_with_weights = True - - - - except Exception as e: - raise RuntimeError(f"Failed to load IPAdapter models: {e}. Check model paths and file formats.") - - def get_tensorrt_info(self) -> Dict[str, Any]: - """ - Get information needed for TensorRT compilation. - - Returns: - Dictionary with TensorRT-relevant IPAdapter information - """ - tensorrt_info = { - 'has_preloaded_models': getattr(self.stream, '_preloaded_with_weights', False), - 'num_image_tokens': 4, # Default - 'scale': 1.0, # Default - 'cross_attention_dim': None - } - - if self.ipadapter is not None: - tensorrt_info['num_image_tokens'] = getattr(self.ipadapter, 'num_tokens', 4) - tensorrt_info['scale'] = self.scale - - # Get cross attention dimension - if hasattr(self.stream, 'unet') and hasattr(self.stream.unet, 'config'): - tensorrt_info['cross_attention_dim'] = self.stream.unet.config.cross_attention_dim - - return tensorrt_info - - def _enhance_embeddings_with_ipadapter(self, prompt_embeds: torch.Tensor, negative_prompt_embeds: Optional[torch.Tensor]) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Enhance embeddings with IPAdapter conditioning using the hook system. - - This method integrates IPAdapter image conditioning with text embeddings, - maintaining compatibility with both single prompts and prompt blending. - Now uses cached parallel-processed embeddings when available. - - Args: - prompt_embeds: Text prompt embeddings from StreamParameterUpdater - negative_prompt_embeds: Negative prompt embeddings (may be None) - - Returns: - Tuple of (enhanced_prompt_embeds, enhanced_negative_prompt_embeds) - """ - # If no IPAdapter or style image, return original embeddings - if self.ipadapter is None or self.style_image is None: - return prompt_embeds, negative_prompt_embeds - - # Get cached embeddings from StreamParameterUpdater (must be available) - cached_embeddings = self.stream._param_updater.get_cached_embeddings(self._style_image_key) - if cached_embeddings is None: - raise RuntimeError(f"_enhance_embeddings_with_ipadapter: No cached embeddings found for key '{self._style_image_key}'. Embedding preprocessing must complete before enhancement.") - - image_prompt_embeds, negative_image_prompt_embeds = cached_embeddings - - # Ensure image embeddings have the same batch size as text embeddings - batch_size = prompt_embeds.shape[0] - if image_prompt_embeds.shape[0] == 1 and batch_size > 1: - image_prompt_embeds = image_prompt_embeds.repeat(batch_size, 1, 1) - negative_image_prompt_embeds = negative_image_prompt_embeds.repeat(batch_size, 1, 1) - - # Concatenate text and image embeddings along sequence dimension (dim=1) - # This is how IPAdapter works - text tokens + image tokens - enhanced_prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) - - if negative_prompt_embeds is not None: - enhanced_negative_prompt_embeds = torch.cat([negative_prompt_embeds, negative_image_prompt_embeds], dim=1) - else: - # Create negative embeddings if none provided - enhanced_negative_prompt_embeds = torch.cat([prompt_embeds, negative_image_prompt_embeds], dim=1) - - # Update token count for attention processors - old_tokens = getattr(self.ipadapter, '_current_tokens', None) - new_tokens = image_prompt_embeds.shape[0] * self.ipadapter.num_tokens - - if old_tokens != new_tokens: - self.ipadapter.set_tokens(new_tokens) - self.ipadapter._current_tokens = new_tokens - - return enhanced_prompt_embeds, enhanced_negative_prompt_embeds - - def prepare(self, *args, **kwargs): - """Forward prepare calls to the underlying StreamDiffusion""" - return self._original_wrapper.prepare(*args, **kwargs) - - - - def __call__(self, *args, **kwargs): - """Forward calls to the original wrapper, IPAdapter enhancement happens automatically via hook system""" - # If we have the original wrapper, use its __call__ method (handles image= parameter correctly) - if hasattr(self, '_original_wrapper'): - return self._original_wrapper(*args, **kwargs) - - # Fallback to underlying stream - return self.stream(*args, **kwargs) - - def __getattr__(self, name): - """Forward attribute access to the original wrapper first, then to the underlying StreamDiffusion""" - # Try original wrapper first (for methods like preprocess_image) - if hasattr(self, '_original_wrapper') and hasattr(self._original_wrapper, name): - return getattr(self._original_wrapper, name) - - # Fallback to underlying stream - return getattr(self.stream, name) \ No newline at end of file diff --git a/src/streamdiffusion/modules/controlnet_module.py b/src/streamdiffusion/modules/controlnet_module.py new file mode 100644 index 00000000..34577ae7 --- /dev/null +++ b/src/streamdiffusion/modules/controlnet_module.py @@ -0,0 +1,750 @@ +from __future__ import annotations + +import threading +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from diffusers.models import ControlNetModel +import logging + +from streamdiffusion.hooks import StepCtx, UnetKwargsDelta, UnetHook +from streamdiffusion.preprocessing.preprocessing_orchestrator import ( + PreprocessingOrchestrator, +) +from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser + + +@dataclass +class ControlNetConfig: + model_id: str + preprocessor: Optional[str] = None + conditioning_scale: float = 1.0 + enabled: bool = True + preprocessor_params: Optional[Dict[str, Any]] = None + + +class ControlNetModule(OrchestratorUser): + """ControlNet module that provides a UNet hook for residual conditioning. + + Responsibilities in this step (3): + - Manage a collection of ControlNet models, their scales, and current images + - Provide a UNet hook that computes down/mid residuals for active ControlNets + - Reuse the existing preprocessing orchestrator for control images + - Do not alter the wrapper or pipeline call sites (registration happens via install()) + """ + + def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) -> None: + self.device = device + self.dtype = dtype + + self.controlnets: List[Optional[ControlNetModel]] = [] + self.controlnet_images: List[Optional[torch.Tensor]] = [] + self.controlnet_scales: List[float] = [] + self.preprocessors: List[Optional[Any]] = [] + self.enabled_list: List[bool] = [] + + self._collections_lock = threading.RLock() + self._preprocessing_orchestrator: Optional[PreprocessingOrchestrator] = None + + self._stream = None # set in install + # Per-frame prepared tensor cache to avoid per-step device/dtype alignment and batch repeats + self._prepared_tensors: List[Optional[torch.Tensor]] = [] + self._prepared_device: Optional[torch.device] = None + self._prepared_dtype: Optional[torch.dtype] = None + self._prepared_batch: Optional[int] = None + self._images_version: int = 0 + + # Pre-allocated CUDA streams for PyTorch ControlNets (indexed to controlnets) + self._pt_cn_streams: List[Optional[torch.cuda.Stream]] = [] + # Cache max parallel setting once + try: + import os + self._max_parallel_controlnets = int(os.getenv('STREAMDIFFUSION_CN_MAX_PAR', '0')) + except Exception: + self._max_parallel_controlnets = 0 + # Persistent thread pool to avoid per-step creation cost + self._executor = None + self._executor_workers = 0 + + # ---------- Public API (used by wrapper in a later step) ---------- + def install(self, stream) -> None: + self._stream = stream + self.device = stream.device + self.dtype = stream.dtype + if self._preprocessing_orchestrator is None: + # Enforce shared orchestrator via base helper (raises if missing) + self.attach_orchestrator(stream) + # Register UNet hook + stream.unet_hooks.append(self.build_unet_hook()) + # Expose controlnet collections so existing updater can find them + setattr(stream, 'controlnets', self.controlnets) + setattr(stream, 'controlnet_scales', self.controlnet_scales) + setattr(stream, 'preprocessors', self.preprocessors) + # Reset prepared tensors on install + self._prepared_tensors = [] + self._prepared_device = None + self._prepared_dtype = None + self._prepared_batch = None + # Reset PT CN streams on install + self._pt_cn_streams = [] + + def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None: + model = self._load_pytorch_controlnet_model(cfg.model_id) + model = model.to(device=self.device, dtype=self.dtype) + + preproc = None + if cfg.preprocessor: + from streamdiffusion.preprocessing.processors import get_preprocessor + preproc = get_preprocessor(cfg.preprocessor) + # Apply provided parameters to the preprocessor instance + if cfg.preprocessor_params: + params = cfg.preprocessor_params or {} + # If the preprocessor exposes a 'params' dict, update it + if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict): + preproc.params.update(params) + # Also set attributes directly when they exist + for name, value in params.items(): + try: + if hasattr(preproc, name): + setattr(preproc, name, value) + except Exception: + pass + + # Provide pipeline reference for preprocessors that need it (e.g., FeedbackPreprocessor) + try: + if hasattr(preproc, 'set_pipeline_ref'): + preproc.set_pipeline_ref(self._stream) + except Exception: + pass + + # Align preprocessor target size with stream resolution once (avoid double-resize later) + try: + if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict): + preproc.params['image_width'] = int(self._stream.width) + preproc.params['image_height'] = int(self._stream.height) + if hasattr(preproc, 'image_width'): + setattr(preproc, 'image_width', int(self._stream.width)) + if hasattr(preproc, 'image_height'): + setattr(preproc, 'image_height', int(self._stream.height)) + except Exception: + pass + + image_tensor: Optional[torch.Tensor] = None + if control_image is not None and self._preprocessing_orchestrator is not None: + image_tensor = self._prepare_control_image(control_image, preproc) + + with self._collections_lock: + self.controlnets.append(model) + self.controlnet_images.append(image_tensor) + self.controlnet_scales.append(float(cfg.conditioning_scale)) + self.preprocessors.append(preproc) + self.enabled_list.append(bool(cfg.enabled)) + # Invalidate prepared tensors and bump version when graph changes + self._prepared_tensors = [] + self._images_version += 1 + # Maintain stream slots for PyTorch ControlNets + self._pt_cn_streams.append(None) + # Initialize/update engine map if present + try: + if hasattr(self._stream, 'controlnet_engines'): + for eng in list(getattr(self._stream, 'controlnet_engines') or []): + if not hasattr(eng, 'model_id'): + try: + setattr(eng, 'model_id', cfg.model_id) + except Exception: + pass + except Exception: + pass + + def update_control_image_efficient(self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None) -> None: + if self._preprocessing_orchestrator is None: + return + with self._collections_lock: + if not self.controlnets: + return + total = len(self.controlnets) + # Build active scales, respecting enabled_list if present + scales = [ + (self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0) + for i in range(total) + ] + if hasattr(self, 'enabled_list') and self.enabled_list and len(self.enabled_list) == total: + scales = [sc if bool(self.enabled_list[i]) else 0.0 for i, sc in enumerate(scales)] + preprocessors = [self.preprocessors[i] if i < len(self.preprocessors) else None for i in range(total)] + + # Single-index fast path + if index is not None: + results = self._preprocessing_orchestrator.process_control_images_sync( + control_image=control_image, + preprocessors=preprocessors, + scales=scales, + stream_width=self._stream.width, + stream_height=self._stream.height, + index=index, + ) + processed = results[index] if results and len(results) > index else None + with self._collections_lock: + if processed is not None and index < len(self.controlnet_images): + self.controlnet_images[index] = processed + # Invalidate prepared tensors and bump version for per-frame reuse + self._prepared_tensors = [] + self._images_version += 1 + # Pre-prepare tensors if we know the target specs + if self._stream and hasattr(self._stream, 'device') and hasattr(self._stream, 'dtype'): + # Use default batch size of 1 for now, will be adjusted on first use + self.prepare_frame_tensors(self._stream.device, self._stream.dtype, 1) + return + + # Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync) + processed_images = self._preprocessing_orchestrator.process_control_images_pipelined( + control_image=control_image, + preprocessors=preprocessors, + scales=scales, + stream_width=self._stream.width, + stream_height=self._stream.height, + ) + + # If orchestrator returns empty list, it indicates no update needed for this frame + if processed_images is None or (isinstance(processed_images, list) and len(processed_images) == 0): + return + + # Assign results + with self._collections_lock: + for i, img in enumerate(processed_images): + if img is not None and i < len(self.controlnet_images): + self.controlnet_images[i] = img + # Invalidate prepared cache and bump version after bulk update + self._prepared_tensors = [] + self._images_version += 1 + # Pre-prepare tensors if we know the target specs + if self._stream and hasattr(self._stream, 'device') and hasattr(self._stream, 'dtype'): + # Use default batch size of 1 for now, will be adjusted on first use + self.prepare_frame_tensors(self._stream.device, self._stream.dtype, 1) + + def update_controlnet_scale(self, index: int, scale: float) -> None: + with self._collections_lock: + if 0 <= index < len(self.controlnet_scales): + self.controlnet_scales[index] = float(scale) + + def update_controlnet_enabled(self, index: int, enabled: bool) -> None: + with self._collections_lock: + if 0 <= index < len(self.enabled_list): + self.enabled_list[index] = bool(enabled) + + def remove_controlnet(self, index: int) -> None: + with self._collections_lock: + if 0 <= index < len(self.controlnets): + del self.controlnets[index] + if index < len(self.controlnet_images): + del self.controlnet_images[index] + if index < len(self.controlnet_scales): + del self.controlnet_scales[index] + if index < len(self.preprocessors): + del self.preprocessors[index] + if index < len(self.enabled_list): + del self.enabled_list[index] + # Invalidate prepared tensors and bump version + self._prepared_tensors = [] + self._images_version += 1 + if index < len(self._pt_cn_streams): + del self._pt_cn_streams[index] + + def reorder_controlnets_by_model_ids(self, desired_model_ids: List[str]) -> None: + """Reorder internal collections to match the desired model_id order. + + Any controlnet whose model_id is not present in desired_model_ids retains its + relative order after those that are specified. + """ + with self._collections_lock: + # Build current mapping from model_id to index + current_ids: List[str] = [] + for i, cn in enumerate(self.controlnets): + model_id = getattr(cn, 'model_id', f'controlnet_{i}') + current_ids.append(model_id) + + # Compute new index order + picked = set() + new_order: List[int] = [] + for mid in desired_model_ids: + if mid in current_ids: + idx = current_ids.index(mid) + new_order.append(idx) + picked.add(idx) + # Append remaining indices (not specified) preserving order + for i in range(len(self.controlnets)): + if i not in picked: + new_order.append(i) + + if new_order == list(range(len(self.controlnets))): + return # Already in desired order + + def reindex(lst: List[Any]) -> List[Any]: + return [lst[i] for i in new_order] + + self.controlnets = reindex(self.controlnets) + self.controlnet_images = reindex(self.controlnet_images) + self.controlnet_scales = reindex(self.controlnet_scales) + self.preprocessors = reindex(self.preprocessors) + self.enabled_list = reindex(self.enabled_list) + + def get_current_config(self) -> List[Dict[str, Any]]: + cfg: List[Dict[str, Any]] = [] + with self._collections_lock: + for i, cn in enumerate(self.controlnets): + model_id = getattr(cn, 'model_id', f'controlnet_{i}') + scale = self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0 + preproc_params = getattr(self.preprocessors[i], 'params', {}) if i < len(self.preprocessors) and self.preprocessors[i] else {} + cfg.append({ + 'model_id': model_id, + 'conditioning_scale': scale, + 'preprocessor_params': preproc_params, + 'enabled': (self.enabled_list[i] if i < len(self.enabled_list) else True), + }) + return cfg + + def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_size: int) -> None: + """Prepare control image tensors for the current frame. + + This method is called once per frame to prepare all control images with the correct + device, dtype, and batch size. This avoids redundant operations during each denoising step. + + Args: + device: Target device for tensors + dtype: Target dtype for tensors + batch_size: Target batch size + """ + with self._collections_lock: + # Check if we need to re-prepare tensors + cache_valid = ( + self._prepared_device == device and + self._prepared_dtype == dtype and + self._prepared_batch == batch_size and + len(self._prepared_tensors) == len(self.controlnet_images) + ) + + if cache_valid: + return + + # Prepare tensors for current frame + self._prepared_tensors = [] + for img in self.controlnet_images: + if img is None: + self._prepared_tensors.append(None) + continue + + # Prepare tensor with correct batch size + prepared = img + if prepared.dim() == 4 and prepared.shape[0] != batch_size: + if prepared.shape[0] == 1: + prepared = prepared.repeat(batch_size, 1, 1, 1) + else: + repeat_factor = max(1, batch_size // prepared.shape[0]) + prepared = prepared.repeat(repeat_factor, 1, 1, 1)[:batch_size] + + # Move to correct device and dtype + prepared = prepared.to(device=device, dtype=dtype) + self._prepared_tensors.append(prepared) + + # Update cache state + self._prepared_device = device + self._prepared_dtype = dtype + self._prepared_batch = batch_size + + # ---------- Internal helpers ---------- + def build_unet_hook(self) -> UnetHook: + def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: + # Compute residuals under lock, using only original text tokens for ControlNet encoding + x_t = ctx.x_t_latent + t_list = ctx.t_list + + with self._collections_lock: + if not self.controlnets: + return UnetKwargsDelta() + + active_indices = [ + i + for i, (cn, img, scale, enabled) in enumerate( + zip( + self.controlnets, + self.controlnet_images, + self.controlnet_scales, + self.enabled_list if len(self.enabled_list) == len(self.controlnets) else [True] * len(self.controlnets), + ) + ) + if cn is not None and img is not None and scale > 0 and bool(enabled) + ] + + if not active_indices: + return UnetKwargsDelta() + + active_controlnets = [self.controlnets[i] for i in active_indices] + active_images = [self.controlnet_images[i] for i in active_indices] + active_scales = [self.controlnet_scales[i] for i in active_indices] + + # Prefer TRT engines when available by model_id + engines_by_id: Dict[str, Any] = {} + try: + if hasattr(self._stream, 'controlnet_engines') and isinstance(self._stream.controlnet_engines, list): + for eng in self._stream.controlnet_engines: + mid = getattr(eng, 'model_id', None) + if mid: + engines_by_id[mid] = eng + except Exception: + pass + + # Use original text token window only for ControlNet encoding + # Detect expected text length from UNet config if available; fallback to 77 + expected_text_len = 77 + try: + if hasattr(self._stream.unet, 'config') and hasattr(self._stream.unet.config, 'cross_attention_dim'): + # For SDXL TRT with IPAdapter baked, engine may expect 77+num_image_tokens for encoder_hidden_states + # However, ControlNet expects just the text portion. Slice accordingly. + expected_text_len = 77 + except Exception: + pass + encoder_hidden_states = self._stream.prompt_embeds[:, :expected_text_len, :] + + base_kwargs: Dict[str, Any] = { + 'sample': x_t, + 'timestep': t_list, + 'encoder_hidden_states': encoder_hidden_states, + 'return_dict': False, + } + + down_samples_list: List[List[torch.Tensor]] = [] + mid_samples_list: List[torch.Tensor] = [] + + # Optionally prepare tensors for this frame (used by other code paths) + try: + if (self._prepared_device != x_t.device or + self._prepared_dtype != x_t.dtype or + self._prepared_batch != x_t.shape[0]): + self.prepare_frame_tensors(x_t.device, x_t.dtype, x_t.shape[0]) + except Exception: + pass + + # Helper: run sequentially (baseline, safest for PyTorch/xformers) + def run_sequential(): + local_down: List[List[torch.Tensor]] = [] + local_mid: List[torch.Tensor] = [] + for cn, img, scale in zip(active_controlnets, active_images, active_scales): + # Swap to TRT engine if compiled and available for this model_id + try: + model_id = getattr(cn, 'model_id', None) + if model_id and model_id in engines_by_id: + cn = engines_by_id[model_id] + except Exception: + pass + current_img = img + if current_img is None: + continue + try: + main_batch = x_t.shape[0] + if current_img.dim() == 4 and current_img.shape[0] != main_batch: + if current_img.shape[0] == 1: + current_img = current_img.repeat(main_batch, 1, 1, 1) + else: + repeat_factor = max(1, main_batch // current_img.shape[0]) + current_img = current_img.repeat(repeat_factor, 1, 1, 1) + current_img = current_img.to(device=x_t.device, dtype=x_t.dtype) + except Exception: + pass + kwargs = base_kwargs.copy() + kwargs['controlnet_cond'] = current_img + kwargs['conditioning_scale'] = float(scale) + try: + if getattr(self._stream, 'is_sdxl', False) and ctx.sdxl_cond is not None: + kwargs['added_cond_kwargs'] = ctx.sdxl_cond + except Exception: + pass + try: + if hasattr(cn, 'engine') and hasattr(cn, 'stream'): + ds, ms = cn( + sample=kwargs['sample'], + timestep=kwargs['timestep'], + encoder_hidden_states=kwargs['encoder_hidden_states'], + controlnet_cond=kwargs['controlnet_cond'], + conditioning_scale=float(scale), + **({} if 'added_cond_kwargs' not in kwargs else kwargs['added_cond_kwargs']) + ) + else: + ds, ms = cn(**kwargs) + local_down.append(ds) + local_mid.append(ms) + except Exception: + continue + return local_down, local_mid + + # Bounded parallelism for ControlNet forwards + from concurrent.futures import ThreadPoolExecutor, as_completed + + # Cap parallel CN based on cached env override or count + max_par = self._max_parallel_controlnets if self._max_parallel_controlnets > 0 else len(active_controlnets) + + # Fast path: single active ControlNet → run inline, no thread pool or extra CUDA stream creation + if len(active_controlnets) == 1: + cn = active_controlnets[0] + current_img = active_images[0] + scale = active_scales[0] + try: + model_id = getattr(cn, 'model_id', None) + if model_id and model_id in engines_by_id: + cn = engines_by_id[model_id] + except Exception: + pass + if current_img is not None: + try: + main_batch = x_t.shape[0] + if current_img.dim() == 4 and current_img.shape[0] != main_batch: + if current_img.shape[0] == 1: + current_img = current_img.repeat(main_batch, 1, 1, 1) + else: + repeat_factor = max(1, main_batch // current_img.shape[0]) + current_img = current_img.repeat(repeat_factor, 1, 1, 1) + current_img = current_img.to(device=x_t.device, dtype=x_t.dtype) + except Exception: + pass + kwargs = base_kwargs.copy() + kwargs['controlnet_cond'] = current_img + kwargs['conditioning_scale'] = float(scale) + try: + if getattr(self._stream, 'is_sdxl', False) and ctx.sdxl_cond is not None: + kwargs['added_cond_kwargs'] = ctx.sdxl_cond + except Exception: + pass + try: + if hasattr(cn, 'engine') and hasattr(cn, 'stream'): + ds, ms = cn( + sample=kwargs['sample'], + timestep=kwargs['timestep'], + encoder_hidden_states=kwargs['encoder_hidden_states'], + controlnet_cond=kwargs['controlnet_cond'], + conditioning_scale=float(scale), + **({} if 'added_cond_kwargs' not in kwargs else kwargs['added_cond_kwargs']) + ) + else: + ds, ms = cn(**kwargs) + down_samples_list.append(ds) + mid_samples_list.append(ms) + except Exception: + pass + # Build delta (handles empty gracefully below) + if not down_samples_list: + return UnetKwargsDelta() + return UnetKwargsDelta( + down_block_additional_residuals=down_samples_list[0], + mid_block_additional_residual=mid_samples_list[0], + ) + + # If any active CN is PyTorch (no engine.stream), prefer sequential for correctness on xformers + try: + all_trt = True + for cn in active_controlnets: + mid = getattr(cn, 'model_id', None) + if mid and mid in engines_by_id: + cn = engines_by_id[mid] + if not (hasattr(cn, 'engine') and hasattr(cn, 'stream')): + all_trt = False + break + except Exception: + all_trt = False + + if not all_trt: + seq_down, seq_mid = run_sequential() + down_samples_list.extend(seq_down) + mid_samples_list.extend(seq_mid) + if not down_samples_list: + return UnetKwargsDelta() + if len(down_samples_list) == 1: + return UnetKwargsDelta( + down_block_additional_residuals=down_samples_list[0], + mid_block_additional_residual=mid_samples_list[0], + ) + merged_down = down_samples_list[0] + merged_mid = mid_samples_list[0] + for ds, ms in zip(down_samples_list[1:], mid_samples_list[1:]): + for j in range(len(merged_down)): + merged_down[j] = merged_down[j] + ds[j] + merged_mid = merged_mid + ms + return UnetKwargsDelta( + down_block_additional_residuals=merged_down, + mid_block_additional_residual=merged_mid, + ) + + tasks = [] + results: List[Tuple[int, Optional[List[torch.Tensor]], Optional[torch.Tensor], Optional[torch.cuda.Stream]]] = [] + + def run_one(idx: int, global_idx: int, cn_model: Any, img_tensor: torch.Tensor, scale_val: float): + cn_local = cn_model + # Swap to TRT engine if compiled and available for this model_id + try: + model_id_local = getattr(cn_local, 'model_id', None) + if model_id_local and model_id_local in engines_by_id: + cn_local = engines_by_id[model_id_local] + except Exception: + pass + + current_img_local = img_tensor + if current_img_local is None: + return (idx, None, None) + + # Ensure control image batch matches latent batch; match device/dtype + try: + main_batch = x_t.shape[0] + if current_img_local.dim() == 4 and current_img_local.shape[0] != main_batch: + if current_img_local.shape[0] == 1: + current_img_local = current_img_local.repeat(main_batch, 1, 1, 1) + else: + repeat_factor = max(1, main_batch // current_img_local.shape[0]) + current_img_local = current_img_local.repeat(repeat_factor, 1, 1, 1) + current_img_local = current_img_local.to(device=x_t.device, dtype=x_t.dtype) + except Exception: + pass + + local_kwargs = base_kwargs.copy() + local_kwargs['controlnet_cond'] = current_img_local + local_kwargs['conditioning_scale'] = float(scale_val) + try: + if getattr(self._stream, 'is_sdxl', False) and ctx.sdxl_cond is not None: + local_kwargs['added_cond_kwargs'] = ctx.sdxl_cond + except Exception: + pass + + try: + if hasattr(cn_local, 'engine') and hasattr(cn_local, 'stream'): + # TRT engine path: engine has its own CUDA stream; just call + down_s, mid_s = cn_local( + sample=local_kwargs['sample'], + timestep=local_kwargs['timestep'], + encoder_hidden_states=local_kwargs['encoder_hidden_states'], + controlnet_cond=local_kwargs['controlnet_cond'], + conditioning_scale=float(scale_val), + **({} if 'added_cond_kwargs' not in local_kwargs else local_kwargs['added_cond_kwargs']) + ) + # Engine call synchronizes internally; no stream to wait on + return (idx, down_s, mid_s, None) + else: + # PyTorch path: use a per-call CUDA stream for concurrency + # Lazily create/reuse a dedicated stream for this controlnet index + stream_obj = self._pt_cn_streams[global_idx] + if stream_obj is None: + stream_obj = torch.cuda.Stream(device=x_t.device) + self._pt_cn_streams[global_idx] = stream_obj + with torch.cuda.stream(stream_obj): + down_s, mid_s = cn_local(**local_kwargs) + # Do not synchronize here; main thread will wait on stream before use + return (idx, down_s, mid_s, stream_obj) + except Exception as e: + import traceback + __import__('logging').getLogger(__name__).error("ControlNetModule: run_one forward failed: %s", e) + __import__('logging').getLogger(__name__).error(traceback.format_exc()) + return (idx, None, None, None) + + # Submit tasks in bounded thread pool + desired_workers = max(1, min(max_par, len(active_controlnets))) + # (Re)create persistent executor only when worker count changes + if self._executor is None or self._executor_workers != desired_workers: + if self._executor is not None: + try: + self._executor.shutdown(wait=False, cancel_futures=True) + except Exception: + pass + self._executor = ThreadPoolExecutor(max_workers=desired_workers) + self._executor_workers = desired_workers + + ex = self._executor + # Map active sub-index to global controlnet index to reuse per-cn streams + for sub_i, (cn_i, img_i, sc_i) in enumerate(zip(active_controlnets, active_images, active_scales)): + global_i = active_indices[sub_i] + tasks.append(ex.submit(run_one, sub_i, global_i, cn_i, img_i, sc_i)) + for fut in as_completed(tasks): + idx, ds, ms, s = fut.result() + if ds is not None and ms is not None: + results.append((idx, ds, ms, s)) + + if not results: + return UnetKwargsDelta() + + # Restore original order + results.sort(key=lambda x: x[0]) + # Ensure default stream waits on any per-CN PyTorch streams before using tensors + default_stream = torch.cuda.current_stream(device=x_t.device) + for _, ds, ms, s in results: + if isinstance(s, torch.cuda.Stream): + default_stream.wait_stream(s) + down_samples_list.append(ds) # type: ignore[arg-type] + mid_samples_list.append(ms) # type: ignore[arg-type] + + if not down_samples_list: + return UnetKwargsDelta() + + if len(down_samples_list) == 1: + return UnetKwargsDelta( + down_block_additional_residuals=down_samples_list[0], + mid_block_additional_residual=mid_samples_list[0], + ) + + # Merge multiple ControlNet residuals + merged_down = down_samples_list[0] + merged_mid = mid_samples_list[0] + for ds, ms in zip(down_samples_list[1:], mid_samples_list[1:]): + for j in range(len(merged_down)): + merged_down[j] = merged_down[j] + ds[j] + merged_mid = merged_mid + ms + + return UnetKwargsDelta( + down_block_additional_residuals=merged_down, + mid_block_additional_residual=merged_mid, + ) + + return _unet_hook + + def _prepare_control_image(self, control_image: Union[str, Any, torch.Tensor], preprocessor: Optional[Any]) -> torch.Tensor: + if self._preprocessing_orchestrator is None: + raise RuntimeError("ControlNetModule: preprocessing orchestrator is not initialized") + # Reuse orchestrator API used by BaseControlNetPipeline + images = self._preprocessing_orchestrator.process_control_images_sync( + control_image=control_image, + preprocessors=[preprocessor], + scales=[1.0], + stream_width=self._stream.width, + stream_height=self._stream.height, + index=0, + ) + # API returns a list; pick first if present + return images[0] if images else None + + def _load_pytorch_controlnet_model(self, model_id: str) -> ControlNetModel: + from pathlib import Path + try: + if Path(model_id).exists(): + controlnet = ControlNetModel.from_pretrained( + model_id, torch_dtype=self.dtype, local_files_only=True + ) + else: + if "/" in model_id and model_id.count("/") > 1: + parts = model_id.split("/") + repo_id = "/".join(parts[:2]) + subfolder = "/".join(parts[2:]) + controlnet = ControlNetModel.from_pretrained( + repo_id, subfolder=subfolder, torch_dtype=self.dtype + ) + else: + controlnet = ControlNetModel.from_pretrained( + model_id, torch_dtype=self.dtype + ) + controlnet = controlnet.to(device=self.device, dtype=self.dtype) + # Track model_id for updater diffing + try: + setattr(controlnet, 'model_id', model_id) + except Exception: + pass + return controlnet + except Exception as e: + import logging, traceback + logger = logging.getLogger(__name__) + logger.error(f"ControlNetModule: failed to load model '{model_id}': {e}") + logger.error(traceback.format_exc()) + raise + diff --git a/src/streamdiffusion/modules/ipadapter_module.py b/src/streamdiffusion/modules/ipadapter_module.py new file mode 100644 index 00000000..399dcdfb --- /dev/null +++ b/src/streamdiffusion/modules/ipadapter_module.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Any +import torch + +from streamdiffusion.hooks import EmbedsCtx, EmbeddingHook, StepCtx, UnetKwargsDelta, UnetHook +import os +from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser + + +@dataclass +class IPAdapterConfig: + """Minimal config for constructing an IP-Adapter module instance. + + This module focuses only on embedding composition (step 2 of migration). + Runtime installation and wrapper wiring will come in later steps. + """ + style_image_key: Optional[str] = None + num_image_tokens: int = 4 # e.g., 4 for standard, 16 for plus + ipadapter_model_path: Optional[str] = None + image_encoder_path: Optional[str] = None + style_image: Optional[Any] = None + scale: float = 1.0 + + +class IPAdapterModule(OrchestratorUser): + """IP-Adapter embedding hook provider. + + Produces an embedding hook that concatenates cached image tokens (from + StreamParameterUpdater's embedding cache) to the current text embeddings. + """ + + def __init__(self, config: IPAdapterConfig) -> None: + self.config = config + self.ipadapter: Optional[Any] = None + + def build_embedding_hook(self, stream) -> EmbeddingHook: + style_key = self.config.style_image_key or "default" + num_tokens = int(self.config.num_image_tokens) + + def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx: + # Fetch cached image token embeddings (prompt, negative) + cached: Optional[Tuple[torch.Tensor, torch.Tensor]] = stream._param_updater.get_cached_embeddings(style_key) + image_prompt_tokens: Optional[torch.Tensor] = None + image_negative_tokens: Optional[torch.Tensor] = None + if cached is not None: + image_prompt_tokens, image_negative_tokens = cached + + # Validate or synthesize tokens when missing to satisfy engine shape (e.g., TRT expects 77+num_tokens) + hidden_dim = ctx.prompt_embeds.shape[2] + batch_size = ctx.prompt_embeds.shape[0] + if image_prompt_tokens is None: + image_prompt_tokens = torch.zeros( + (batch_size, num_tokens, hidden_dim), dtype=ctx.prompt_embeds.dtype, device=ctx.prompt_embeds.device + ) + else: + if image_prompt_tokens.shape[1] != num_tokens: + raise ValueError( + f"IPAdapterModule: Expected {num_tokens} image tokens, got {image_prompt_tokens.shape[1]}" + ) + + # Concatenate image tokens to the right of text tokens + prompt_with_image = ctx.prompt_embeds + if image_prompt_tokens is not None: + # Repeat to match batch size if needed + if image_prompt_tokens.shape[0] != prompt_with_image.shape[0]: + image_prompt_tokens = image_prompt_tokens.repeat_interleave( + repeats=prompt_with_image.shape[0] // max(image_prompt_tokens.shape[0], 1), dim=0 + ) + prompt_with_image = torch.cat([prompt_with_image, image_prompt_tokens], dim=1) + + neg_with_image = ctx.negative_prompt_embeds + if neg_with_image is not None: + if image_negative_tokens is None: + image_negative_tokens = torch.zeros( + (neg_with_image.shape[0], num_tokens, hidden_dim), dtype=neg_with_image.dtype, device=neg_with_image.device + ) + else: + if image_negative_tokens.shape[0] != neg_with_image.shape[0]: + image_negative_tokens = image_negative_tokens.repeat_interleave( + repeats=neg_with_image.shape[0] // max(image_negative_tokens.shape[0], 1), dim=0 + ) + neg_with_image = torch.cat([neg_with_image, image_negative_tokens], dim=1) + + return EmbedsCtx(prompt_embeds=prompt_with_image, negative_prompt_embeds=neg_with_image) + + return _embedding_hook + + def install(self, stream) -> None: + """Install IP-Adapter processors and register embedding hook and preprocessor. + + - Instantiates IP-Adapter with model and encoder paths + - Registers IPAdapterEmbeddingPreprocessor with StreamParameterUpdater using style_image_key + - Optionally processes provided style image to populate the embedding cache + - Registers the embedding hook onto stream.embedding_hooks + - Sets the initial scale and mirrors it onto stream.ipadapter_scale + """ + logger = __import__('logging').getLogger(__name__) + style_key = self.config.style_image_key or "ipadapter_main" + + # Attach shared orchestrator to ensure consistent reuse across modules + self.attach_orchestrator(stream) + + # Validate required paths + if not self.config.ipadapter_model_path or not self.config.image_encoder_path: + raise ValueError("IPAdapterModule.install: ipadapter_model_path and image_encoder_path are required") + + # Lazy import to avoid hard dependency unless used + try: + from diffusers_ipadapter import IPAdapter # type: ignore + except Exception as e: + logger.error(f"IPAdapterModule.install: Failed to import IPAdapter: {e}") + raise + try: + from streamdiffusion.preprocessing.processors.ipadapter_embedding import IPAdapterEmbeddingPreprocessor + except Exception as e: + logger.error(f"IPAdapterModule.install: Failed to import IPAdapterEmbeddingPreprocessor: {e}") + raise + + # Resolve model paths (HF repo file or local path) + resolved_ip_path = self._resolve_model_path(self.config.ipadapter_model_path) + resolved_encoder_path = self._resolve_model_path(self.config.image_encoder_path) + + # Create IP-Adapter and install processors into UNet + ipadapter = IPAdapter( + pipe=stream.pipe, + ipadapter_ckpt_path=resolved_ip_path, + image_encoder_path=resolved_encoder_path, + device=stream.device, + dtype=stream.dtype, + ) + self.ipadapter = ipadapter + + # Register embedding preprocessor for this style key + embedding_preprocessor = IPAdapterEmbeddingPreprocessor( + ipadapter=ipadapter, + device=stream.device, + dtype=stream.dtype, + ) + stream._param_updater.register_embedding_preprocessor(embedding_preprocessor, style_key) + + # Process initial style image if provided + if self.config.style_image is not None: + try: + stream._param_updater.update_style_image(style_key, self.config.style_image, is_stream=False) + except Exception as e: + logger.error(f"IPAdapterModule.install: Failed to process style image: {e}") + raise + + # Set initial scale and mirror onto stream for TRT runtime vector if needed + try: + ipadapter.set_scale(float(self.config.scale)) + setattr(stream, 'ipadapter_scale', float(self.config.scale)) + except Exception: + pass + + # Compatibility: expose expected attributes/methods used by StreamParameterUpdater + try: + setattr(stream, 'ipadapter', ipadapter) + setattr(stream, 'scale', float(self.config.scale)) + def _update_scale(new_scale: float) -> None: + ipadapter.set_scale(float(new_scale)) + setattr(stream, 'ipadapter_scale', float(new_scale)) + try: + setattr(stream, 'scale', float(new_scale)) + except Exception: + pass + def _update_style_image(style_image) -> None: + stream._param_updater.update_style_image(style_key, style_image, is_stream=False) + setattr(stream, 'update_scale', _update_scale) + setattr(stream, 'update_style_image', _update_style_image) + except Exception: + pass + + # Register embedding hook for concatenation of image tokens + stream.embedding_hooks.append(self.build_embedding_hook(stream)) + + # Register UNet hook to supply per-step IP-Adapter scale via extra kwargs + stream.unet_hooks.append(self.build_unet_hook(stream)) + + def _resolve_model_path(self, model_path: Optional[str]) -> str: + """Resolve a model path. + + Accepts either a local filesystem path or a Hugging Face repo/file spec like + "h94/IP-Adapter/models/ip-adapter-plus_sd15.safetensors" or a directory path + such as "h94/IP-Adapter/models/image_encoder". + """ + if not model_path: + raise ValueError("IPAdapterModule._resolve_model_path: model_path is required") + + if os.path.exists(model_path): + return model_path + + # Treat as HF repo path + try: + from huggingface_hub import hf_hub_download, snapshot_download + except Exception as e: + import logging + logging.getLogger(__name__).error(f"IPAdapterModule: huggingface_hub required to resolve '{model_path}': {e}") + raise + + parts = model_path.split("/") + if len(parts) < 3: + raise ValueError(f"IPAdapterModule._resolve_model_path: Invalid Hugging Face spec: '{model_path}'") + + repo_id = "/".join(parts[:2]) + subpath = "/".join(parts[2:]) + + # File if last component has an extension; otherwise treat as directory + if "." in parts[-1]: + # File download + local_path = hf_hub_download(repo_id=repo_id, filename=subpath) + return local_path + else: + # Directory download + repo_root = snapshot_download(repo_id=repo_id, allow_patterns=[f"{subpath}/*"]) + full_path = os.path.join(repo_root, subpath) + if not os.path.exists(full_path): + raise FileNotFoundError(f"IPAdapterModule._resolve_model_path: Downloaded path not found: {full_path}") + return full_path + + def build_unet_hook(self, stream) -> UnetHook: + """Provide per-step ipadapter_scale vector via UNet hook extra kwargs. + + - For TensorRT UNet engines compiled with IP-Adapter, pass a per-layer vector in extra kwargs + - For PyTorch UNet with installed IP processors, modulate per-layer processor scale by time factor + """ + def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: + # If no IP-Adapter installed, do nothing + if not hasattr(stream, 'ipadapter') or stream.ipadapter is None: + return UnetKwargsDelta() + + # Read base weight and weight type from stream + try: + base_weight = float(getattr(stream, 'ipadapter_scale', getattr(self, 'config', None).scale if hasattr(self, 'config') else 1.0)) + except Exception: + base_weight = 1.0 + weight_type = getattr(stream, 'ipadapter_weight_type', None) + + # Determine total steps and current step index for time scheduling + total_steps = None + try: + if hasattr(stream, 'denoising_steps_num') and isinstance(stream.denoising_steps_num, int): + total_steps = int(stream.denoising_steps_num) + elif hasattr(stream, 't_list') and stream.t_list is not None: + total_steps = len(stream.t_list) + except Exception: + total_steps = None + + time_factor = 1.0 + if total_steps is not None and ctx.step_index is not None: + try: + from diffusers_ipadapter.ip_adapter.attention_processor import build_time_weight_factor + time_factor = float(build_time_weight_factor(weight_type, int(ctx.step_index), int(total_steps))) + except Exception: + # Do not add fallback mechanisms + pass + + # TensorRT engine path: supply ipadapter_scale vector via extra kwargs + try: + is_trt_unet = hasattr(stream, 'unet') and hasattr(stream.unet, 'engine') and hasattr(stream.unet, 'stream') + except Exception: + is_trt_unet = False + + if is_trt_unet and getattr(stream.unet, 'use_ipadapter', False): + try: + from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights + except Exception: + # If helper unavailable, do not construct weights here + build_layer_weights = None # type: ignore + + num_ip_layers = getattr(stream.unet, 'num_ip_layers', None) + if isinstance(num_ip_layers, int) and num_ip_layers > 0: + weights_tensor = None + try: + if build_layer_weights is not None: + weights_tensor = build_layer_weights(num_ip_layers, float(base_weight), weight_type) + except Exception: + weights_tensor = None + if weights_tensor is None: + import torch as _torch + weights_tensor = _torch.full((num_ip_layers,), float(base_weight), dtype=_torch.float32, device=stream.device) + # Apply per-step time factor + try: + weights_tensor = weights_tensor * float(time_factor) + except Exception: + pass + return UnetKwargsDelta(extra_unet_kwargs={'ipadapter_scale': weights_tensor}) + + # PyTorch UNet path: modulate installed processor scales by time factor + try: + if time_factor != 1.0 and hasattr(stream.pipe, 'unet') and hasattr(stream.pipe.unet, 'attn_processors'): + for proc in stream.pipe.unet.attn_processors.values(): + if hasattr(proc, 'scale') and hasattr(proc, '_ip_layer_index'): + base_val = getattr(proc, '_base_scale', proc.scale) + proc.scale = float(base_val) * float(time_factor) + except Exception: + pass + + return UnetKwargsDelta() + + return _unet_hook + diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index e954b33c..ec2df2f6 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -11,6 +11,7 @@ ) from streamdiffusion.model_detection import detect_model +from streamdiffusion.hooks import EmbedsCtx, StepCtx, UnetKwargsDelta, EmbeddingHook, UnetHook from streamdiffusion.image_filter import SimilarImageFilter from streamdiffusion.stream_parameter_updater import StreamParameterUpdater @@ -98,6 +99,13 @@ def __init__( # Initialize parameter updater self._param_updater = StreamParameterUpdater(self, normalize_prompt_weights, normalize_seed_weights) + # Default IP-Adapter runtime weight mode (None = uniform). Can be set to strings like + # "ease in", "ease out", "ease in-out", "reverse in-out", "style transfer precise", "composition precise". + self.ipadapter_weight_type = None + + # Hook containers (step 1: introduced but initially no-op) + self.embedding_hooks: List[EmbeddingHook] = [] + self.unet_hooks: List[UnetHook] = [] def load_lcm_lora( self, @@ -209,8 +217,8 @@ def prepare( if len(encoder_output) >= 4: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = encoder_output[:4] - # Set up prompt embeddings for the UNet - self.prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) + # Set up prompt embeddings for the UNet (base before hooks) + base_prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) # Handle CFG for prompt embeddings if self.use_denoising_batch and self.cfg_type == "full": @@ -221,8 +229,8 @@ def prepare( if self.guidance_scale > 1.0 and ( self.cfg_type == "initialize" or self.cfg_type == "full" ): - self.prompt_embeds = torch.cat( - [uncond_prompt_embeds, self.prompt_embeds], dim=0 + base_prompt_embeds = torch.cat( + [uncond_prompt_embeds, base_prompt_embeds], dim=0 ) # Set up SDXL-specific conditioning (added_cond_kwargs) @@ -245,6 +253,15 @@ def prepare( self.add_time_ids = add_time_ids else: raise ValueError(f"SDXL encode_prompt returned {len(encoder_output)} outputs, expected at least 4") + # Run embedding hooks (no-op unless modules register) + embeds_ctx = EmbedsCtx(prompt_embeds=base_prompt_embeds, negative_prompt_embeds=None) + for hook in self.embedding_hooks: + try: + embeds_ctx = hook(embeds_ctx) + except Exception as e: + logger.error(f"prepare: embedding hook failed: {e}") + raise + self.prompt_embeds = embeds_ctx.prompt_embeds else: # SD1.5/SD2.1 encode_prompt returns 2 values: (prompt_embeds, negative_prompt_embeds) encoder_output = self.pipe.encode_prompt( @@ -254,7 +271,7 @@ def prepare( do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, ) - self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) + base_prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) if self.use_denoising_batch and self.cfg_type == "full": uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1) @@ -264,10 +281,20 @@ def prepare( if self.guidance_scale > 1.0 and ( self.cfg_type == "initialize" or self.cfg_type == "full" ): - self.prompt_embeds = torch.cat( - [uncond_prompt_embeds, self.prompt_embeds], dim=0 + base_prompt_embeds = torch.cat( + [uncond_prompt_embeds, base_prompt_embeds], dim=0 ) + # Run embedding hooks (no-op unless modules register) + embeds_ctx = EmbedsCtx(prompt_embeds=base_prompt_embeds, negative_prompt_embeds=None) + for hook in self.embedding_hooks: + try: + embeds_ctx = hook(embeds_ctx) + except Exception as e: + logger.error(f"prepare: embedding hook failed: {e}") + raise + self.prompt_embeds = embeds_ctx.prompt_embeds + self.scheduler.set_timesteps(num_inference_steps, self.device) self.timesteps = self.scheduler.timesteps.to(self.device) @@ -358,75 +385,7 @@ def update_prompt(self, prompt: str) -> None: prompt_interpolation_method="linear" ) - @torch.no_grad() - def update_stream_params( - self, - num_inference_steps: Optional[int] = None, - guidance_scale: Optional[float] = None, - delta: Optional[float] = None, - t_index_list: Optional[List[int]] = None, - seed: Optional[int] = None, - # Prompt blending parameters - prompt_list: Optional[List[Tuple[str, float]]] = None, - negative_prompt: Optional[str] = None, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp", - normalize_prompt_weights: Optional[bool] = None, - # Seed blending parameters - seed_list: Optional[List[Tuple[int, float]]] = None, - seed_interpolation_method: Literal["linear", "slerp"] = "linear", - normalize_seed_weights: Optional[bool] = None, - # IPAdapter parameters - ipadapter_config: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Update streaming parameters efficiently in a single call. - - Parameters - ---------- - num_inference_steps : Optional[int] - The number of inference steps to perform. - guidance_scale : Optional[float] - The guidance scale to use for CFG. - delta : Optional[float] - The delta multiplier of virtual residual noise. - t_index_list : Optional[List[int]] - The t_index_list to use for inference. - seed : Optional[int] - The random seed to use for noise generation. - prompt_list : Optional[List[Tuple[str, float]]] - List of prompts with weights for blending. - negative_prompt : Optional[str] - The negative prompt to apply to all blended prompts. - prompt_interpolation_method : Literal["linear", "slerp"] - Method for interpolating between prompt embeddings. - normalize_prompt_weights : Optional[bool] - Whether to normalize prompt weights in blending to sum to 1, by default None (no change). - When False, weights > 1 will amplify embeddings. - seed_list : Optional[List[Tuple[int, float]]] - List of seeds with weights for blending. - seed_interpolation_method : Literal["linear", "slerp"] - Method for interpolating between seed noise tensors. - normalize_seed_weights : Optional[bool] - Whether to normalize seed weights in blending to sum to 1, by default None (no change). - When False, weights > 1 will amplify noise. - ipadapter_config : Optional[Dict[str, Any]] - IPAdapter configuration dict containing scale, style_image, etc. - """ - self._param_updater.update_stream_params( - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - delta=delta, - t_index_list=t_index_list, - seed=seed, - prompt_list=prompt_list, - negative_prompt=negative_prompt, - prompt_interpolation_method=prompt_interpolation_method, - seed_list=seed_list, - seed_interpolation_method=seed_interpolation_method, - normalize_prompt_weights=normalize_prompt_weights, - normalize_seed_weights=normalize_seed_weights, - ipadapter_config=ipadapter_config, - ) + @@ -529,6 +488,46 @@ def unet_step( 'time_ids': add_time_ids } + # Allow modules to contribute additional UNet kwargs via hooks + try: + step_ctx = StepCtx( + x_t_latent=x_t_latent_plus_uc, + t_list=t_list, + step_index=idx if isinstance(idx, int) else (int(idx) if idx is not None else None), + guidance_mode=self.cfg_type if self.guidance_scale > 1.0 else "none", + sdxl_cond=unet_kwargs.get('added_cond_kwargs', None) + ) + extra_from_hooks = {} + for hook in self.unet_hooks: + delta: UnetKwargsDelta = hook(step_ctx) + if delta is None: + continue + if delta.down_block_additional_residuals is not None: + unet_kwargs['down_block_additional_residuals'] = delta.down_block_additional_residuals + if delta.mid_block_additional_residual is not None: + unet_kwargs['mid_block_additional_residual'] = delta.mid_block_additional_residual + if delta.added_cond_kwargs is not None: + # Merge SDXL cond if both exist + base_added = unet_kwargs.get('added_cond_kwargs', {}) + base_added.update(delta.added_cond_kwargs) + unet_kwargs['added_cond_kwargs'] = base_added + if getattr(delta, 'extra_unet_kwargs', None): + # Merge extra kwargs from hooks (e.g., ipadapter_scale) + try: + extra_from_hooks.update(delta.extra_unet_kwargs) + except Exception: + pass + if extra_from_hooks: + unet_kwargs['extra_unet_kwargs'] = extra_from_hooks + except Exception as e: + logger.error(f"unet_step: unet hook failed: {e}") + raise + + # Extract potential ControlNet residual kwargs and generic extra kwargs (e.g., ipadapter_scale) + hook_down_res = unet_kwargs.get('down_block_additional_residuals', None) + hook_mid_res = unet_kwargs.get('mid_block_additional_residual', None) + hook_extra_kwargs = unet_kwargs.get('extra_unet_kwargs', None) if 'extra_unet_kwargs' in unet_kwargs else None + # Call UNet with appropriate conditioning if self.is_sdxl: try: @@ -543,22 +542,43 @@ def unet_step( is_tensorrt_engine = hasattr(self.unet, 'engine') and hasattr(self.unet, 'stream') if is_tensorrt_engine: - # TensorRT engine expects positional args + kwargs + # TensorRT engine expects positional args + kwargs. IP-Adapter scale vector, if any, is provided by hooks via extra_unet_kwargs + extra_kwargs = {} + if isinstance(hook_extra_kwargs, dict): + extra_kwargs.update(hook_extra_kwargs) + + # Include ControlNet residuals if provided by hooks + if hook_down_res is not None: + extra_kwargs['down_block_additional_residuals'] = hook_down_res + if hook_mid_res is not None: + extra_kwargs['mid_block_additional_residual'] = hook_mid_res + + logger.debug(f"pipeline.unet_step: Calling TRT SDXL UNet with extra_kwargs keys={list(extra_kwargs.keys())}") model_pred = self.unet( unet_kwargs['sample'], # latent_model_input (positional) unet_kwargs['timestep'], # timestep (positional) unet_kwargs['encoder_hidden_states'], # encoder_hidden_states (positional) + **extra_kwargs, + # For TRT engines, ensure SDXL cond shapes match engine builds; if engine expects 81 tokens (77+4), append dummy image tokens when none **added_cond_kwargs # SDXL conditioning as kwargs )[0] else: - # PyTorch UNet expects diffusers-style named arguments - model_pred = self.unet( + # PyTorch UNet expects diffusers-style named arguments. Any processor scaling is handled by IP-Adapter hook + + call_kwargs = dict( sample=unet_kwargs['sample'], timestep=unet_kwargs['timestep'], encoder_hidden_states=unet_kwargs['encoder_hidden_states'], added_cond_kwargs=added_cond_kwargs, return_dict=False, - )[0] + ) + # Include ControlNet residuals if present + if hook_down_res is not None: + call_kwargs['down_block_additional_residuals'] = hook_down_res + if hook_mid_res is not None: + call_kwargs['mid_block_additional_residual'] = hook_mid_res + model_pred = self.unet(**call_kwargs)[0] + # No restoration for per-layer scale; next step will set again via updater/time factor except Exception as e: logger.error(f"[PIPELINE] unet_step: *** ERROR: SDXL UNet call failed: {e} ***") @@ -567,11 +587,27 @@ def unet_step( raise else: # For SD1.5/SD2.1, use the old calling convention for compatibility + # Build kwargs from hooks and include residuals + ip_scale_kw = {} + is_tensorrt_engine = hasattr(self.unet, 'engine') and hasattr(self.unet, 'stream') + if isinstance(hook_extra_kwargs, dict): + ip_scale_kw.update(hook_extra_kwargs) + + # PyTorch processor time scaling is handled by the IP-Adapter hook + + # Include ControlNet residuals if present + if hook_down_res is not None: + ip_scale_kw['down_block_additional_residuals'] = hook_down_res + if hook_mid_res is not None: + ip_scale_kw['mid_block_additional_residual'] = hook_mid_res + + logger.debug(f"pipeline.unet_step: Calling TRT SD1.5 UNet with keys={list(ip_scale_kw.keys())}") model_pred = self.unet( x_t_latent_plus_uc, t_list, encoder_hidden_states=self.prompt_embeds, return_dict=False, + **ip_scale_kw, )[0] # Check for problematic values in model prediction @@ -583,6 +619,7 @@ def unet_step( logger.error(f"[PIPELINE] unet_step: *** ERROR: {inf_count} Inf values in model_pred! ***") if (model_pred == 0).all(): logger.error(f"[PIPELINE] unet_step: *** ERROR: All model_pred values are zero! ***") + if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): noise_pred_text = model_pred[1:] diff --git a/src/streamdiffusion/preprocessing/orchestrator_user.py b/src/streamdiffusion/preprocessing/orchestrator_user.py new file mode 100644 index 00000000..7457cd67 --- /dev/null +++ b/src/streamdiffusion/preprocessing/orchestrator_user.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Optional + +from .preprocessing_orchestrator import PreprocessingOrchestrator + + +class OrchestratorUser: + """ + Minimal base class to attach a shared PreprocessingOrchestrator from the stream. + No convenience methods; strictly enforces presence of a shared orchestrator on stream. + """ + + _preprocessing_orchestrator: Optional[PreprocessingOrchestrator] = None + + def attach_orchestrator(self, stream) -> None: + orchestrator = getattr(stream, 'preprocessing_orchestrator', None) + if orchestrator is None: + # Lazy-create on stream once, on first user that needs it + orchestrator = PreprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4) + setattr(stream, 'preprocessing_orchestrator', orchestrator) + self._preprocessing_orchestrator = orchestrator + + diff --git a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py index 9422c43d..4605def8 100644 --- a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py +++ b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py @@ -35,6 +35,10 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max # Pipeline state for embedding preprocessing self._next_embedding_future = None self._next_embedding_result = None + + # Cache pipelining decision to avoid hot path checks + self._preprocessors_cache_key = None + self._has_feedback_cache = False def cleanup(self) -> None: """Cleanup thread pool resources""" @@ -85,20 +89,31 @@ def process_control_images_pipelined(self, stream_width: int, stream_height: int) -> List[Optional[torch.Tensor]]: """ - Process control images with inter-frame pipelining for improved performance. + Process control images with intelligent pipelining. + + Automatically falls back to sync processing when feedback preprocessors are detected + to avoid temporal artifacts, otherwise uses pipelined processing for performance. Returns: List of processed tensors for each ControlNet """ - # Wait for previous frame preprocessing + # Check for feedback preprocessors that require sync processing (cached) + has_feedback = self._check_feedback_cached(preprocessors) + if has_feedback: + return self.process_control_images_sync( + control_image, preprocessors, scales, stream_width, stream_height + ) + + # No feedback preprocessors detected - use pipelined processing + # Wait for previous frame preprocessing; non-blocking with short timeout self._wait_for_previous_preprocessing() - # Start next frame preprocessing in background + # Start next frame preprocessing in background using intraframe parallelism self._start_next_frame_preprocessing( control_image, preprocessors, scales, stream_width, stream_height ) - # Apply current frame preprocessing results + # Apply current frame preprocessing results if available; otherwise signal no update return self._apply_current_frame_preprocessing(preprocessors, scales) def prepare_control_image(self, @@ -493,8 +508,13 @@ def _process_tensor_input(self, if preprocessor is not None and hasattr(preprocessor, 'process_tensor'): try: processed_tensor = preprocessor.process_tensor(control_tensor) + # Ensure NCHW shape if processed_tensor.dim() == 3: processed_tensor = processed_tensor.unsqueeze(0) + # Resize to target spatial resolution if needed to match stream dimensions + processed_tensor = self._resize_tensor_if_needed( + processed_tensor, target_width, target_height + ) return processed_tensor.to(device=self.device, dtype=self.dtype) except Exception: pass # Fall through to standard processing @@ -734,14 +754,46 @@ def _process_single_preprocessor_optimized(self, logger.error(f"PreprocessingOrchestrator: Preprocessor {preprocessor_key} failed: {e}") return None + def _check_feedback_cached(self, preprocessors: List[Optional[Any]]) -> bool: + """ + _check_feedback_cached: Efficiently check for feedback preprocessors using caching + + Only performs expensive isinstance checks when preprocessor list actually changes. + """ + # Create cache key from preprocessor identities + cache_key = tuple(id(p) for p in preprocessors) + + # Return cached result if preprocessors haven't changed + if cache_key == self._preprocessors_cache_key: + return self._has_feedback_cache + + # Preprocessors changed - recompute and cache + self._preprocessors_cache_key = cache_key + self._has_feedback_cache = False + + try: + from .processors.feedback import FeedbackPreprocessor + for prep in preprocessors: + if isinstance(prep, FeedbackPreprocessor): + self._has_feedback_cache = True + break + except Exception: + # Fallback on class name check without importing + for prep in preprocessors: + if prep is not None and prep.__class__.__name__.lower().startswith('feedback'): + self._has_feedback_cache = True + break + + return self._has_feedback_cache + def _wait_for_previous_preprocessing(self) -> None: """Wait for previous frame preprocessing with optimized timeout""" if hasattr(self, '_next_frame_future') and self._next_frame_future is not None: try: - # Reduced timeout: 50ms for real-time performance - self._next_frame_result = self._next_frame_future.result(timeout=0.05) + # Reduced timeout: 10ms for lower latency in real-time + self._next_frame_result = self._next_frame_future.result(timeout=0.01) except concurrent.futures.TimeoutError: - logger.warning("PreprocessingOrchestrator: Preprocessing timeout - using previous results") + # Non-blocking: skip applying results this frame self._next_frame_result = None except Exception as e: logger.error(f"PreprocessingOrchestrator: Preprocessing error: {e}") diff --git a/src/streamdiffusion/preprocessing/processors/canny.py b/src/streamdiffusion/preprocessing/processors/canny.py index 7bb9edbf..7c25e9ab 100644 --- a/src/streamdiffusion/preprocessing/processors/canny.py +++ b/src/streamdiffusion/preprocessing/processors/canny.py @@ -5,7 +5,7 @@ from typing import Union from .base import BasePreprocessor - +#TODO provide gpu native edge detection class CannyPreprocessor(BasePreprocessor): """ Canny edge detection preprocessor for ControlNet diff --git a/src/streamdiffusion/preprocessing/processors/lineart.py b/src/streamdiffusion/preprocessing/processors/lineart.py index 7bab15d6..3af6011b 100644 --- a/src/streamdiffusion/preprocessing/processors/lineart.py +++ b/src/streamdiffusion/preprocessing/processors/lineart.py @@ -12,7 +12,7 @@ CONTROLNET_AUX_AVAILABLE = False raise ImportError("LineartPreprocessor: controlnet_aux is required for real-time optimization. Install with: pip install controlnet_aux") - +#TODO provide gpu native lineart detection class LineartPreprocessor(BasePreprocessor): """ Real-time optimized Lineart detection preprocessor for ControlNet diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index d58eea92..fc7e9f4e 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -1,10 +1,12 @@ from typing import List, Optional, Dict, Tuple, Literal, Any, Callable +import threading import torch import torch.nn.functional as F import gc import logging logger = logging.getLogger(__name__) +from .preprocessing.orchestrator_user import OrchestratorUser class CacheStats: """Helper class to track cache statistics""" @@ -19,12 +21,14 @@ def record_miss(self): self.misses += 1 -class StreamParameterUpdater: +class StreamParameterUpdater(OrchestratorUser): def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True): self.stream = stream_diffusion self.wrapper = wrapper # Reference to wrapper for accessing pipeline structure self.normalize_prompt_weights = normalize_prompt_weights self.normalize_seed_weights = normalize_seed_weights + # Atomic update lock for deterministic, thread-safe runtime updates + self._update_lock = threading.RLock() # Prompt blending caches self._prompt_cache: Dict[int, Dict] = {} self._current_prompt_list: List[Tuple[str, float]] = [] @@ -36,14 +40,16 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo self._current_seed_list: List[Tuple[int, float]] = [] self._seed_cache_stats = CacheStats() - # Enhancement hooks (e.g., for IPAdapter) - self._embedding_enhancers = [] + # Attach shared orchestrator once (lazy-creates on stream if absent) + self.attach_orchestrator(self.stream) + # IPAdapter embedding preprocessing self._embedding_preprocessors = [] self._embedding_cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} self._current_style_images: Dict[str, Any] = {} - self._embedding_orchestrator = None + # Use the shared orchestrator attached via OrchestratorUser + self._embedding_orchestrator = self._preprocessing_orchestrator def get_cache_info(self) -> Dict: """Get cache statistics for monitoring performance.""" total_requests = self._prompt_cache_stats.hits + self._prompt_cache_stats.misses @@ -88,44 +94,7 @@ def get_normalize_seed_weights(self) -> bool: """Get the current seed weight normalization setting.""" return self.normalize_seed_weights - def register_embedding_enhancer(self, enhancer_func, name: str = "unknown") -> None: - """ - Register an embedding enhancer function that will be called after prompt blending. - - The enhancer function should have signature: - enhancer_func(prompt_embeds: torch.Tensor, negative_prompt_embeds: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] - - Args: - enhancer_func: Function that takes (prompt_embeds, negative_prompt_embeds) and returns enhanced versions - name: Optional name for the enhancer (for debugging) - """ - self._embedding_enhancers.append((enhancer_func, name)) - # IMMEDIATELY apply enhancer to existing embeddings if they exist (fixes TensorRT timing issue) - if hasattr(self.stream, 'prompt_embeds') and self.stream.prompt_embeds is not None: - try: - current_negative_embeds = getattr(self.stream, 'negative_prompt_embeds', None) - enhanced_prompt_embeds, enhanced_negative_embeds = enhancer_func( - self.stream.prompt_embeds, current_negative_embeds - ) - self.stream.prompt_embeds = enhanced_prompt_embeds - if enhanced_negative_embeds is not None: - self.stream.negative_prompt_embeds = enhanced_negative_embeds - except Exception as e: - print(f"register_embedding_enhancer: Error applying '{name}' enhancer immediately: {e}") - import traceback - traceback.print_exc() - - def unregister_embedding_enhancer(self, enhancer_func) -> None: - """Unregister a specific embedding enhancer function.""" - original_length = len(self._embedding_enhancers) - self._embedding_enhancers = [(func, name) for func, name in self._embedding_enhancers if func != enhancer_func] - removed_count = original_length - len(self._embedding_enhancers) - - - def clear_embedding_enhancers(self) -> None: - """Clear all embedding enhancers.""" - enhancer_count = len(self._embedding_enhancers) - self._embedding_enhancers.clear() + # Deprecated enhancer registration removed; embedding composition is handled via stream.embedding_hooks def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: str) -> None: """ @@ -136,12 +105,9 @@ def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: st style_image_key: Unique key for the style image this preprocessor handles """ if self._embedding_orchestrator is None: - from .preprocessing.preprocessing_orchestrator import PreprocessingOrchestrator - self._embedding_orchestrator = PreprocessingOrchestrator( - device=self.stream.device, - dtype=self.stream.dtype, - max_workers=4 - ) + # Ensure orchestrator is present + self.attach_orchestrator(self.stream) + self._embedding_orchestrator = self._preprocessing_orchestrator self._embedding_preprocessors.append((preprocessor, style_image_key)) @@ -282,60 +248,61 @@ def update_stream_params( ) -> None: """Update streaming parameters efficiently in a single call.""" - if num_inference_steps is not None: - self.stream.scheduler.set_timesteps(num_inference_steps, self.stream.device) - self.stream.timesteps = self.stream.scheduler.timesteps.to(self.stream.device) + with self._update_lock: + if num_inference_steps is not None: + self.stream.scheduler.set_timesteps(num_inference_steps, self.stream.device) + self.stream.timesteps = self.stream.scheduler.timesteps.to(self.stream.device) - if num_inference_steps is not None and t_index_list is None: - max_step = num_inference_steps - 1 - t_index_list = [min(t, max_step) for t in self.stream.t_list] + if num_inference_steps is not None and t_index_list is None: + max_step = num_inference_steps - 1 + t_index_list = [min(t, max_step) for t in self.stream.t_list] - if guidance_scale is not None: - if self.stream.cfg_type == "none" and guidance_scale > 1.0: - logger.warning("update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect") - self.stream.guidance_scale = guidance_scale + if guidance_scale is not None: + if self.stream.cfg_type == "none" and guidance_scale > 1.0: + logger.warning("update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect") + self.stream.guidance_scale = guidance_scale - if delta is not None: - self.stream.delta = delta + if delta is not None: + self.stream.delta = delta - if seed is not None: - self._update_seed(seed) - - if normalize_prompt_weights is not None: - self.normalize_prompt_weights = normalize_prompt_weights - logger.info(f"update_stream_params: Prompt weight normalization set to {normalize_prompt_weights}") - - if normalize_seed_weights is not None: - self.normalize_seed_weights = normalize_seed_weights - logger.info(f"update_stream_params: Seed weight normalization set to {normalize_seed_weights}") - - # Handle prompt blending if prompt_list is provided - if prompt_list is not None: - self._update_blended_prompts( - prompt_list=prompt_list, - negative_prompt=negative_prompt or self._current_negative_prompt, - prompt_interpolation_method=prompt_interpolation_method - ) + if seed is not None: + self._update_seed(seed) + + if normalize_prompt_weights is not None: + self.normalize_prompt_weights = normalize_prompt_weights + logger.info(f"update_stream_params: Prompt weight normalization set to {normalize_prompt_weights}") + + if normalize_seed_weights is not None: + self.normalize_seed_weights = normalize_seed_weights + logger.info(f"update_stream_params: Seed weight normalization set to {normalize_seed_weights}") + + # Handle prompt blending if prompt_list is provided + if prompt_list is not None: + self._update_blended_prompts( + prompt_list=prompt_list, + negative_prompt=negative_prompt or self._current_negative_prompt, + prompt_interpolation_method=prompt_interpolation_method + ) - # Handle seed blending if seed_list is provided - if seed_list is not None: - self._update_blended_seeds( - seed_list=seed_list, - interpolation_method=seed_interpolation_method - ) + # Handle seed blending if seed_list is provided + if seed_list is not None: + self._update_blended_seeds( + seed_list=seed_list, + interpolation_method=seed_interpolation_method + ) - if t_index_list is not None: - self._recalculate_timestep_dependent_params(t_index_list) + if t_index_list is not None: + self._recalculate_timestep_dependent_params(t_index_list) - # Handle ControlNet configuration updates - if controlnet_config is not None: - logger.info(f"update_stream_params: Updating ControlNet configuration with {len(controlnet_config)} controlnets") - self._update_controlnet_config(controlnet_config) - - # Handle IPAdapter configuration updates - if ipadapter_config is not None: - logger.info(f"update_stream_params: Updating IPAdapter configuration") - self._update_ipadapter_config(ipadapter_config) + # Handle ControlNet configuration updates + if controlnet_config is not None: + #TODO: happy path for control images + self._update_controlnet_config(controlnet_config) + + # Handle IPAdapter configuration updates + if ipadapter_config is not None: + logger.info(f"update_stream_params: Updating IPAdapter configuration") + self._update_ipadapter_config(ipadapter_config) @torch.no_grad() def update_prompt_weights( @@ -486,20 +453,23 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", final_prompt_embeds = combined_embeds.repeat(self.stream.batch_size, 1, 1) final_negative_embeds = None # Will be set by enhancers if needed - # Apply embedding enhancers (e.g., IPAdapter) - if self._embedding_enhancers: - for enhancer_func, enhancer_name in self._embedding_enhancers: - try: - enhanced_prompt_embeds, enhanced_negative_embeds = enhancer_func( - final_prompt_embeds, final_negative_embeds - ) - final_prompt_embeds = enhanced_prompt_embeds - if enhanced_negative_embeds is not None: - final_negative_embeds = enhanced_negative_embeds - except Exception as e: - print(f"_apply_prompt_blending: Error in enhancer '{enhancer_name}': {e}") - import traceback - traceback.print_exc() + # Enhancer mechanism removed in favor of embedding_hooks + + # Run embedding hooks to compose final embeddings (e.g., append IP-Adapter tokens) + try: + if hasattr(self.stream, 'embedding_hooks') and self.stream.embedding_hooks: + from .hooks import EmbedsCtx # local import to avoid cycles + embeds_ctx = EmbedsCtx( + prompt_embeds=final_prompt_embeds, + negative_prompt_embeds=final_negative_embeds, + ) + for hook in self.stream.embedding_hooks: + embeds_ctx = hook(embeds_ctx) + final_prompt_embeds = embeds_ctx.prompt_embeds + final_negative_embeds = embeds_ctx.negative_prompt_embeds + except Exception as e: + import logging + logging.getLogger(__name__).error(f"_apply_prompt_blending: embedding hook failed: {e}") # Set final embeddings on stream self.stream.prompt_embeds = final_prompt_embeds @@ -994,7 +964,7 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non desired_config: Complete ControlNet configuration list defining the desired state. Each dict contains: model_id, preprocessor, conditioning_scale, enabled, etc. """ - # Find the ControlNet pipeline (might be nested in IPAdapter) + # Find the ControlNet pipeline/module (module-aware) controlnet_pipeline = self._get_controlnet_pipeline() if not controlnet_pipeline: logger.warning(f"_update_controlnet_config: No ControlNet pipeline found") @@ -1006,12 +976,26 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)} desired_models = {cfg['model_id']: cfg for cfg in desired_config} + # Reorder to match desired order (module supports stable reordering) + try: + desired_order = [cfg['model_id'] for cfg in desired_config if 'model_id' in cfg] + if hasattr(controlnet_pipeline, 'reorder_controlnets_by_model_ids'): + controlnet_pipeline.reorder_controlnets_by_model_ids(desired_order) + except Exception: + pass + + # Recompute current models after potential reorder + current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)} + # Remove controlnets not in desired config for i in reversed(range(len(controlnet_pipeline.controlnets))): model_id = current_models.get(i, f'controlnet_{i}') if model_id not in desired_models: logger.info(f"_update_controlnet_config: Removing ControlNet {model_id}") - controlnet_pipeline.remove_controlnet(i, immediate=False) + try: + controlnet_pipeline.remove_controlnet(i) + except Exception: + raise # Add new controlnets and update existing ones for desired_cfg in desired_config: @@ -1020,8 +1004,40 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non if existing_index is None: # Add new controlnet + # Respect wrapper/init configuration: block adds when parallel enabled + try: + block_add = bool(getattr(self.stream, 'controlnet_block_add_when_parallel', True)) + except Exception: + block_add = True + concurrency_active = False + try: + cn_module = getattr(self.stream, '_controlnet_module', None) + if cn_module is not None: + max_par = int(getattr(cn_module, '_max_parallel_controlnets', 0)) + concurrency_active = max_par > 1 + except Exception: + concurrency_active = False + if block_add and concurrency_active: + logger.warning(f"_update_controlnet_config: Add blocked by configuration while parallel ControlNet is active; skipping add for {model_id}") + continue logger.info(f"_update_controlnet_config: Adding ControlNet {model_id}") - controlnet_pipeline.add_controlnet(desired_cfg, desired_cfg.get('control_image'), immediate=False) + try: + # Prefer module path: construct ControlNetConfig + try: + from .modules.controlnet_module import ControlNetConfig # type: ignore + cn_cfg = ControlNetConfig( + model_id=desired_cfg.get('model_id'), + preprocessor=desired_cfg.get('preprocessor'), + conditioning_scale=desired_cfg.get('conditioning_scale', 1.0), + enabled=desired_cfg.get('enabled', True), + preprocessor_params=desired_cfg.get('preprocessor_params'), + ) + controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get('control_image')) + except Exception: + # No fallback + raise + except Exception as e: + logger.error(f"_update_controlnet_config: add_controlnet failed for {model_id}: {e}") else: # Update existing controlnet if 'conditioning_scale' in desired_cfg: @@ -1030,8 +1046,14 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non if current_scale != desired_scale: logger.info(f"_update_controlnet_config: Updating {model_id} scale: {current_scale} → {desired_scale}") - controlnet_pipeline.update_controlnet_scale(existing_index, desired_scale) + if hasattr(controlnet_pipeline, 'controlnet_scales') and 0 <= existing_index < len(controlnet_pipeline.controlnet_scales): + controlnet_pipeline.controlnet_scales[existing_index] = float(desired_scale) + # Enable/disable toggle + if 'enabled' in desired_cfg and hasattr(controlnet_pipeline, 'enabled_list'): + if 0 <= existing_index < len(controlnet_pipeline.enabled_list): + controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg['enabled']) + if 'preprocessor_params' in desired_cfg and hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[existing_index]: preprocessor = controlnet_pipeline.preprocessors[existing_index] preprocessor.params.update(desired_cfg['preprocessor_params']) @@ -1039,28 +1061,41 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non if hasattr(preprocessor, param_name): setattr(preprocessor, param_name, param_value) + # Efficient control image update when provided + if 'control_image' in desired_cfg and desired_cfg['control_image'] is not None: + try: + # Route through module helper if available + if hasattr(controlnet_pipeline, 'update_control_image_efficient'): + controlnet_pipeline.update_control_image_efficient(desired_cfg['control_image'], index=existing_index) + else: + # Fallback to orchestrator-based processing if present on module + if hasattr(controlnet_pipeline, '_prepare_control_image') and hasattr(controlnet_pipeline, 'preprocessors') and hasattr(controlnet_pipeline, 'controlnet_images'): + preproc = controlnet_pipeline.preprocessors[existing_index] if existing_index < len(controlnet_pipeline.preprocessors) else None + processed = controlnet_pipeline._prepare_control_image(desired_cfg['control_image'], preproc) + if existing_index < len(controlnet_pipeline.controlnet_images): + controlnet_pipeline.controlnet_images[existing_index] = processed + except Exception: + raise + def _get_controlnet_pipeline(self): """ - Get the ControlNet pipeline from the pipeline structure (handles IPAdapter wrapping). - - Returns: - ControlNet pipeline object or None if not found + Get the ControlNet module or legacy pipeline from the structure (module-aware). """ - # Check if stream is ControlNet pipeline directly + # Module-installed path + if hasattr(self.stream, '_controlnet_module'): + return self.stream._controlnet_module + # Legacy paths if hasattr(self.stream, 'controlnets'): return self.stream - - # Check if stream has nested stream (IPAdapter wrapper) if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'controlnets'): return self.stream.stream - - # Check if we have a wrapper reference and can access through it if self.wrapper and hasattr(self.wrapper, 'stream'): + if hasattr(self.wrapper.stream, '_controlnet_module'): + return self.wrapper.stream._controlnet_module if hasattr(self.wrapper.stream, 'controlnets'): return self.wrapper.stream - elif hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'controlnets'): + if hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'controlnets'): return self.wrapper.stream.stream - return None def _get_current_controlnet_config(self) -> List[Dict[str, Any]]: @@ -1078,11 +1113,17 @@ def _get_current_controlnet_config(self) -> List[Dict[str, Any]]: for i, controlnet in enumerate(controlnet_pipeline.controlnets): model_id = getattr(controlnet, 'model_id', f'controlnet_{i}') scale = controlnet_pipeline.controlnet_scales[i] if hasattr(controlnet_pipeline, 'controlnet_scales') and i < len(controlnet_pipeline.controlnet_scales) else 1.0 - + enabled_val = True + try: + if hasattr(controlnet_pipeline, 'enabled_list') and i < len(controlnet_pipeline.enabled_list): + enabled_val = bool(controlnet_pipeline.enabled_list[i]) + except Exception: + enabled_val = True config = { 'model_id': model_id, 'conditioning_scale': scale, - 'preprocessor_params': getattr(controlnet_pipeline.preprocessors[i], 'params', {}) if hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[i] else {} + 'preprocessor_params': getattr(controlnet_pipeline.preprocessors[i], 'params', {}) if hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[i] else {}, + 'enabled': enabled_val, } current_config.append(config) @@ -1110,7 +1151,30 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: if current_scale != desired_scale: logger.info(f"_update_ipadapter_config: Updating scale: {current_scale} → {desired_scale}") - ipadapter_pipeline.update_scale(desired_scale) + # If a weight_type is active, apply per-layer vector at the new base scale + try: + weight_type = getattr(self.stream, 'ipadapter_weight_type', None) + if weight_type is not None and hasattr(ipadapter_pipeline, 'ipadapter') and ipadapter_pipeline.ipadapter is not None: + from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights + ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")] + num_layers = len(ip_procs) + weights = build_layer_weights(num_layers, float(desired_scale), weight_type) + if weights is not None: + ipadapter_pipeline.ipadapter.set_scale(weights) + else: + ipadapter_pipeline.ipadapter.set_scale(float(desired_scale)) + # Keep pipeline/stream scales in sync + ipadapter_pipeline.scale = float(desired_scale) + try: + setattr(self.stream, 'ipadapter_scale', float(desired_scale)) + except Exception: + pass + else: + # No weight_type: uniform scale + ipadapter_pipeline.update_scale(desired_scale) + except Exception: + # Do not introduce fallback mechanisms + raise # Update style image if provided if 'style_image' in desired_config: @@ -1119,6 +1183,33 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: logger.info(f"_update_ipadapter_config: Updating style image") ipadapter_pipeline.update_style_image(style_image) + # Update weight type if provided (affects per-layer distribution and/or per-step factor) + if 'weight_type' in desired_config: + weight_type = desired_config['weight_type'] + try: + setattr(self.stream, 'ipadapter_weight_type', weight_type) + except Exception: + pass + # For PyTorch UNet, immediately apply a per-layer scale vector so layers reflect selection types + try: + is_tensorrt_engine = hasattr(self.stream.unet, 'engine') and hasattr(self.stream.unet, 'stream') + if not is_tensorrt_engine and hasattr(ipadapter_pipeline, 'ipadapter') and ipadapter_pipeline.ipadapter is not None: + # Compute per-layer vector using Diffusers_IPAdapter helper + from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights + # Count installed IP layers by scanning processors with _ip_layer_index + ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")] + num_layers = len(ip_procs) + base_weight = float(getattr(self.stream, 'ipadapter_scale', getattr(ipadapter_pipeline, 'scale', 1.0))) + weights = build_layer_weights(num_layers, base_weight, weight_type) + # If None, keep uniform base scale; else set per-layer vector + if weights is not None: + ipadapter_pipeline.ipadapter.set_scale(weights) + else: + ipadapter_pipeline.ipadapter.set_scale(base_weight) + except Exception: + # Do not add fallback mechanisms + raise + def _get_ipadapter_pipeline(self): """ Get the IPAdapter pipeline from the pipeline structure (following ControlNet pattern). diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index c0bcb973..45bd8478 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -110,6 +110,9 @@ def __init__( # IPAdapter options use_ipadapter: bool = False, ipadapter_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + # Concurrency options + controlnet_max_parallel: Optional[int] = None, + controlnet_block_add_when_parallel: bool = True, ): """ Initializes the StreamDiffusionWrapper. @@ -198,6 +201,9 @@ def __init__( self.enable_pytorch_fallback = enable_pytorch_fallback self.use_ipadapter = use_ipadapter self.ipadapter_config = ipadapter_config + # Concurrency settings + self.controlnet_max_parallel = controlnet_max_parallel + self.controlnet_block_add_when_parallel = controlnet_block_add_when_parallel if mode == "txt2img": if cfg_type != "none": @@ -375,6 +381,8 @@ def update_prompt( Supports both single prompts and prompt blending based on the prompt parameter type. + This is for legacy compatibility, use update_stream_params instead + Parameters ---------- prompt : Union[str, List[Tuple[str, float]]] @@ -508,16 +516,6 @@ def update_stream_params( ipadapter_config=ipadapter_config, ) - def get_normalize_prompt_weights(self) -> bool: - """Get the current prompt weight normalization setting.""" - return self.stream.get_normalize_prompt_weights() - - def get_normalize_seed_weights(self) -> bool: - """Get the current seed weight normalization setting.""" - return self.stream.get_normalize_seed_weights() - - - def __call__( self, image: Optional[Union[str, Image.Image, torch.Tensor]] = None, @@ -700,7 +698,6 @@ def postprocess_image( else: return postprocess_image(image_tensor.cpu(), output_type=output_type)[0] - def _denormalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Denormalize image tensor on GPU for efficiency @@ -715,7 +712,6 @@ def _denormalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: """ return (image_tensor / 2 + 0.5).clamp(0, 1) - def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Image]: """ Optimized tensor to PIL conversion with minimal CPU transfers @@ -763,8 +759,6 @@ def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Ima return pil_images - - def _load_model( self, model_id_or_path: str, @@ -1016,20 +1010,9 @@ def _load_model( # Use the explicit use_ipadapter parameter has_ipadapter = use_ipadapter - # Create IPAdapter pipeline and pre-load models for TensorRT if needed - ipadapter_pipeline = None - if has_ipadapter: - try: - from streamdiffusion.ipadapter import BaseIPAdapterPipeline - ipadapter_pipeline = BaseIPAdapterPipeline( - stream_diffusion=stream, - device=self.device, - dtype=self.dtype - ) - ipadapter_pipeline.preload_models_for_tensorrt(ipadapter_config) - except Exception as e: - print(f"_load_model: Error creating IPAdapter pipeline: {e}") - has_ipadapter = False + # Determine IP-Adapter presence and token count directly from config (no legacy pipeline) + if has_ipadapter and not ipadapter_config: + has_ipadapter = False try: # Use model detection results already computed during model loading @@ -1102,18 +1085,14 @@ def _load_model( # Use the engine_dir parameter passed to this function, with fallback to instance variable engine_dir = engine_dir if engine_dir else getattr(self, '_engine_dir', 'engines') - # Get IPAdapter information from pipeline if available + # Resolve IP-Adapter runtime params from config + # Strength is now a runtime input, so we do NOT bake scale into engine identity ipadapter_scale = None ipadapter_tokens = None - if use_ipadapter_trt and ipadapter_pipeline: - tensorrt_info = ipadapter_pipeline.get_tensorrt_info() - ipadapter_scale = tensorrt_info.get('scale', 1.0) - - # Read token count from loaded IPAdapter instance - if hasattr(ipadapter_pipeline, 'ipadapter') and ipadapter_pipeline.ipadapter: - ipadapter_tokens = getattr(ipadapter_pipeline.ipadapter, 'num_tokens', 4) - else: - ipadapter_tokens = 4 # Default fallback + if use_ipadapter_trt and has_ipadapter and ipadapter_config: + cfg0 = ipadapter_config[0] if isinstance(ipadapter_config, list) else ipadapter_config + # scale omitted from engine naming; runtime will pass ipadapter_scale vector + ipadapter_tokens = cfg0.get('num_image_tokens', 4) # Generate engine paths using EngineManager unet_path = engine_manager.get_engine_path( EngineType.UNET, @@ -1185,13 +1164,67 @@ def _load_model( num_tokens = 4 # Default for non-IPAdapter mode if use_ipadapter_trt: - if not (ipadapter_pipeline and hasattr(ipadapter_pipeline, 'ipadapter') and ipadapter_pipeline.ipadapter): - raise RuntimeError("IPAdapter TensorRT enabled but IPAdapter failed to load. Cannot proceed without proper IPAdapter instance.") - num_tokens = getattr(ipadapter_pipeline.ipadapter, 'num_tokens', 4) + # Use token count resolved from configuration (default to 4) + num_tokens = ipadapter_tokens if isinstance(ipadapter_tokens, int) else 4 # Compile UNet engine using EngineManager logger.info(f"compile_and_load_engine: Compiling UNet engine for image size: {self.width}x{self.height}") + try: + logger.debug(f"compile_and_load_engine: use_ipadapter_trt={use_ipadapter_trt}, num_ip_layers={num_ip_layers}, tokens={num_tokens}") + except Exception: + pass + # If using TensorRT with IP-Adapter, ensure processors and weights are installed BEFORE export + if use_ipadapter_trt and has_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): + try: + from streamdiffusion.modules.ipadapter_module import IPAdapterModule, IPAdapterConfig + cfg = ipadapter_config[0] if isinstance(ipadapter_config, list) else ipadapter_config + ip_cfg = IPAdapterConfig( + style_image_key=cfg.get('style_image_key') or 'ipadapter_main', + num_image_tokens=cfg.get('num_image_tokens', 4), + ipadapter_model_path=cfg['ipadapter_model_path'], + image_encoder_path=cfg['image_encoder_path'], + style_image=cfg.get('style_image'), + scale=cfg.get('scale', 1.0), + ) + ip_module_for_export = IPAdapterModule(ip_cfg) + ip_module_for_export.install(stream) + setattr(stream, '_ipadapter_module', ip_module_for_export) + try: + logger.info("Installed IP-Adapter processors prior to TensorRT export") + except Exception: + pass + except Exception: + import traceback + traceback.print_exc() + logger.error("Failed to pre-install IP-Adapter prior to TensorRT export") + + # NOTE: When IPAdapter is enabled, we must pass num_ip_layers. We cannot know it until after + # installing processors in the export wrapper. We construct the wrapper first to discover it, + # then construct UNet model with that value. + + # Build a temporary unified wrapper to install processors and discover num_ip_layers + from streamdiffusion.acceleration.tensorrt.export_wrappers.unet_unified_export import UnifiedExportWrapper + temp_wrapped_unet = UnifiedExportWrapper( + stream.unet, + use_controlnet=use_controlnet_trt, + use_ipadapter=use_ipadapter_trt, + control_input_names=None, + num_tokens=num_tokens + ) + + num_ip_layers = None + if use_ipadapter_trt: + # Access underlying IPAdapter wrapper + if hasattr(temp_wrapped_unet, 'ipadapter_wrapper') and temp_wrapped_unet.ipadapter_wrapper: + num_ip_layers = getattr(temp_wrapped_unet.ipadapter_wrapper, 'num_ip_layers', None) + if not isinstance(num_ip_layers, int) or num_ip_layers <= 0: + raise RuntimeError("Failed to determine num_ip_layers for IP-Adapter") + try: + logger.info(f"compile_and_load_engine: discovered num_ip_layers={num_ip_layers}") + except Exception: + pass + unet_model = UNet( fp16=True, device=stream.device, @@ -1203,17 +1236,19 @@ def _load_model( unet_arch=unet_arch if use_controlnet_trt else None, use_ipadapter=use_ipadapter_trt, num_image_tokens=num_tokens, + num_ip_layers=num_ip_layers if use_ipadapter_trt else None, image_height=self.height, image_width=self.width, ) # Use ControlNet wrapper if ControlNet support is enabled if use_controlnet_trt: - control_input_names = unet_model.get_input_names() - - # Unified compilation path - from streamdiffusion.acceleration.tensorrt.export_wrappers.unet_unified_export import UnifiedExportWrapper + # Build control_input_names excluding ipadapter_scale so indices align to 3-base offset + all_input_names = unet_model.get_input_names() + control_input_names = [name for name in all_input_names if name != 'ipadapter_scale'] + # Unified compilation path + # Recreate wrapped_unet with control input names if needed (after unet_model is ready) wrapped_unet = UnifiedExportWrapper( stream.unet, use_controlnet=use_controlnet_trt, @@ -1290,6 +1325,7 @@ def _load_model( use_controlnet_trt=use_controlnet_trt, use_ipadapter_trt=use_ipadapter_trt, unet_arch=unet_arch, + num_ip_layers=num_ip_layers if use_ipadapter_trt else None, engine_build_options={ 'opt_image_height': self.height, 'opt_image_width': self.width, @@ -1446,146 +1482,112 @@ def _load_model( # Use stream's current resolution for fallback image self.nsfw_fallback_img = Image.new("RGB", (stream.height, stream.width), (0, 0, 0)) - # Apply ControlNet patch if needed + # Install modules via hooks instead of patching (wrapper keeps forwarding updates only) if use_controlnet and controlnet_config: - # Pass engine_manager and cuda_stream if TensorRT is being used - if acceleration == "tensorrt": - stream = self._apply_controlnet_patch(stream, controlnet_config, acceleration, engine_dir, self._detected_model_type, self._is_sdxl, engine_manager, cuda_stream) - else: - stream = self._apply_controlnet_patch(stream, controlnet_config, acceleration, engine_dir, self._detected_model_type, self._is_sdxl) - - # Apply IPAdapter patch if needed (after ControlNet) - if use_ipadapter and ipadapter_config: - self._apply_ipadapter_patch(stream, ipadapter_config) - - return stream - - def _apply_controlnet_patch(self, stream: StreamDiffusion, controlnet_config: Union[Dict[str, Any], List[Dict[str, Any]]], acceleration: str = "none", engine_dir: str = "engines", model_type: str = "SD15", is_sdxl: bool = False, engine_manager = None, cuda_stream = None) -> Any: - """ - Apply ControlNet patch to StreamDiffusion using detected model type - - Args: - stream: Base StreamDiffusion instance - controlnet_config: ControlNet configuration(s) - model_type: Detected model type from original UNet - - Returns: - ControlNet-enabled pipeline (ControlNetPipeline or SDXLTurboControlNetPipeline) - """ - # Use provided model type (detected before TensorRT conversion) - if is_sdxl: - from streamdiffusion.controlnet.controlnet_sdxlturbo_pipeline import SDXLTurboControlNetPipeline - controlnet_pipeline = SDXLTurboControlNetPipeline(stream, self.device, self.dtype) - else: # SD15, SD21, etc. all use same ControlNet pipeline - from streamdiffusion.controlnet.controlnet_pipeline import ControlNetPipeline - controlnet_pipeline = ControlNetPipeline(stream, self.device, self.dtype) - - # Check if we should use TensorRT ControlNet acceleration - use_controlnet_tensorrt = (acceleration == "tensorrt") - - # Set the detected model type to avoid re-detection from TensorRT engine - controlnet_pipeline._detected_model_type = model_type - controlnet_pipeline._is_sdxl = is_sdxl - - # Initialize ControlNet engine management if using TensorRT acceleration - if use_controlnet_tensorrt and engine_manager is not None: - from streamdiffusion.acceleration.tensorrt.engine_manager import EngineType - # Use the same unified EngineManager for ControlNet engines - # Create a ControlNet-specific subdirectory for organization - controlnet_engine_dir = os.path.join(engine_dir, "controlnet") - - # Store unified engine manager on the pipeline for later use - controlnet_pipeline._engine_manager = engine_manager - controlnet_pipeline._controlnet_engine_dir = controlnet_engine_dir - controlnet_pipeline._use_tensorrt = True - - # Also set engine manager on stream where ControlNet pipeline expects to find it - # Create a wrapper that provides the old interface but uses EngineManager internally - class ControlNetEnginePoolWrapper: - def __init__(self, engine_manager, controlnet_engine_dir, cuda_stream): - self.engine_manager = engine_manager - self.engine_dir = controlnet_engine_dir - self.cuda_stream = cuda_stream - - def load_engine(self, model_id, model_type="sd15", batch_size=1): - engine_path = self.engine_manager.get_engine_path( - EngineType.CONTROLNET, - model_id_or_path="", # Not used for ControlNet - max_batch=batch_size, - min_batch_size=1, - mode="", # Not used for ControlNet - use_lcm_lora=False, # Not used for ControlNet - use_tiny_vae=False, # Not used for ControlNet - controlnet_model_id=model_id - ) - if not os.path.exists(engine_path): - raise FileNotFoundError(f"ControlNet engine not found at {engine_path}") - return self.engine_manager.load_engine( - EngineType.CONTROLNET, - engine_path, - model_type=model_type, - batch_size=batch_size, - cuda_stream=self.cuda_stream, - use_cuda_graph=True - ) - - def get_or_load_engine(self, model_id, pytorch_model, model_type="sd15", batch_size=1): - """get_or_load_engine: Compatibility wrapper for ControlNet pipeline""" - return self.engine_manager.get_or_load_controlnet_engine( - model_id=model_id, - pytorch_model=pytorch_model, - model_type=model_type, - batch_size=batch_size, - cuda_stream=self.cuda_stream, - use_cuda_graph=True + try: + from streamdiffusion.modules.controlnet_module import ControlNetModule, ControlNetConfig + cn_module = ControlNetModule(device=self.device, dtype=self.dtype) + cn_module.install(stream) + # Apply configured max parallel if provided + try: + if self.controlnet_max_parallel is not None: + setattr(cn_module, '_max_parallel_controlnets', int(self.controlnet_max_parallel)) + except Exception: + pass + # Expose add-blocking policy on stream + try: + setattr(stream, 'controlnet_block_add_when_parallel', bool(self.controlnet_block_add_when_parallel)) + except Exception: + pass + # Normalize to list of configs + configs = controlnet_config if isinstance(controlnet_config, list) else [controlnet_config] + for cfg in configs: + if not cfg.get('model_id'): + continue + cn_cfg = ControlNetConfig( + model_id=cfg['model_id'], + preprocessor=cfg.get('preprocessor'), + conditioning_scale=cfg.get('conditioning_scale', 1.0), + enabled=cfg.get('enabled', True), + preprocessor_params=cfg.get('preprocessor_params'), ) - - stream.controlnet_engine_pool = ControlNetEnginePoolWrapper(engine_manager, controlnet_engine_dir, cuda_stream) - logger.info("get_or_load_controlnet_engine: Initialized ControlNet TensorRT engine management with unified EngineManager") - else: - controlnet_pipeline._use_tensorrt = False - logger.info("Loading ControlNet in PyTorch mode (no TensorRT acceleration)") - - - # Setup ControlNets from config - if not isinstance(controlnet_config, list): - controlnet_config = [controlnet_config] - + cn_module.add_controlnet(cn_cfg, control_image=cfg.get('control_image')) + # Expose for later updates if needed by caller code + stream._controlnet_module = cn_module - for config in controlnet_config: - model_id = config.get('model_id') - if not model_id: - continue - - preprocessor = config.get('preprocessor', None) - conditioning_scale = config.get('conditioning_scale', 1.0) - enabled = config.get('enabled', True) - preprocessor_params = config.get('preprocessor_params', None) - control_image = config.get('control_image', None) # Extract control image from config + # If TensorRT UNet is active, proactively compile/load ControlNet TRT engines for each model + #TODO: make unet cnet trt acceleration independent and configurable + try: + use_trt_unet = hasattr(stream, 'unet') and hasattr(stream.unet, 'engine') + except Exception: + use_trt_unet = False + if use_trt_unet: + try: + compiled_cn_engines = [] + for cfg, cn_model in zip(configs, cn_module.controlnets): + if not cfg or not cfg.get('model_id') or cn_model is None: + continue + try: + # Assign a unique CUDA stream per ControlNet engine to enable concurrent inference + cn_cuda_stream = cuda.Stream() + engine = engine_manager.get_or_load_controlnet_engine( + model_id=cfg['model_id'], + pytorch_model=cn_model, + model_type=model_type, + batch_size=stream.trt_unet_batch_size, + cuda_stream=cn_cuda_stream, + use_cuda_graph=False, + unet=None, + model_path=cfg['model_id'] + ) + try: + setattr(engine, 'model_id', cfg['model_id']) + except Exception: + pass + compiled_cn_engines.append(engine) + except Exception as e: + logger.warning(f"Failed to compile/load ControlNet engine for {cfg.get('model_id')}: {e}") + if compiled_cn_engines: + # Replace existing engines atomically to avoid mixed state + setattr(stream, 'controlnet_engines', compiled_cn_engines) + try: + logger.info(f"Compiled/loaded {len(compiled_cn_engines)} ControlNet TensorRT engine(s)") + except Exception: + pass + except Exception: + import traceback + traceback.print_exc() + logger.warning("ControlNet TensorRT engine build step encountered an issue; continuing with PyTorch ControlNet") + except Exception: + import traceback + traceback.print_exc() + logger.error("Failed to install ControlNetModule") + raise + if use_ipadapter and ipadapter_config and not hasattr(stream, '_ipadapter_module'): try: - # Pass config dictionary directly - cn_config = { - 'model_id': model_id, - 'preprocessor': preprocessor, - 'conditioning_scale': conditioning_scale, - 'enabled': enabled, - 'preprocessor_params': preprocessor_params or {} - } - - # Add ControlNet with control image if provided (immediate during initialization) - controlnet_pipeline.add_controlnet(cn_config, control_image, immediate=True) - logger.info(f"_apply_controlnet_patch: Successfully added ControlNet: {model_id}") - except Exception as e: - logger.error(f"_apply_controlnet_patch: Failed to add ControlNet {model_id}: {e}") + from streamdiffusion.modules.ipadapter_module import IPAdapterModule, IPAdapterConfig + # Use first config if list provided + cfg = ipadapter_config[0] if isinstance(ipadapter_config, list) else ipadapter_config + ip_cfg = IPAdapterConfig( + style_image_key=cfg.get('style_image_key') or 'ipadapter_main', + num_image_tokens=cfg.get('num_image_tokens', 4), + ipadapter_model_path=cfg['ipadapter_model_path'], + image_encoder_path=cfg['image_encoder_path'], + style_image=cfg.get('style_image'), + scale=cfg.get('scale', 1.0), + ) + ip_module = IPAdapterModule(ip_cfg) + ip_module.install(stream) + # Expose for later updates + stream._ipadapter_module = ip_module + except Exception: import traceback traceback.print_exc() + logger.error("Failed to install IPAdapterModule") + raise - return controlnet_pipeline - - - - + return stream def get_last_processed_image(self, index: int) -> Optional[Image.Image]: """Forward get_last_processed_image call to the underlying ControlNet pipeline""" @@ -1593,15 +1595,7 @@ def get_last_processed_image(self, index: int) -> Optional[Image.Image]: raise RuntimeError("get_last_processed_image: ControlNet support not enabled. Set use_controlnet=True in constructor.") return self.stream.get_last_processed_image(index) - - def update_control_image_efficient(self, control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], index: Optional[int] = None) -> None: - """Forward update_control_image_efficient call to the underlying ControlNet pipeline""" - if not self.use_controlnet: - raise RuntimeError("update_control_image_efficient: ControlNet support not enabled. Set use_controlnet=True in constructor.") - - return self.stream.update_control_image_efficient(control_image, index) - - + def cleanup_controlnets(self) -> None: """Cleanup ControlNet resources including background threads and VRAM""" if not self.use_controlnet: @@ -1610,97 +1604,99 @@ def cleanup_controlnets(self) -> None: if hasattr(self, 'stream') and self.stream and hasattr(self.stream, 'cleanup'): self.stream.cleanup_controlnets() - def update_seed_blending( - self, - seed_list: List[Tuple[int, float]], - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: - """ - Update seed blending with multiple weighted seeds. - - Parameters - ---------- - seed_list : List[Tuple[int, float]] - List of seeds with weights. Each tuple contains (seed_value, weight). - Example: [(123, 0.6), (456, 0.4)] - interpolation_method : Literal["linear", "slerp"] - Method for interpolating between seed noise tensors, by default "linear". - """ - self.stream._param_updater.update_stream_params( - seed_list=seed_list, - seed_interpolation_method=interpolation_method - ) - - def update_prompt_weights( - self, - prompt_weights: List[float], - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" - ) -> None: - """ - Update weights for current prompt list without re-encoding prompts. - - Parameters - ---------- - prompt_weights : List[float] - New weights for the current prompt list. - prompt_interpolation_method : Literal["linear", "slerp"] - Method for interpolating between prompt embeddings, by default "slerp". - """ - self.stream._param_updater.update_prompt_weights(prompt_weights, prompt_interpolation_method) - - def update_seed_weights( - self, - seed_weights: List[float], - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: - """ - Update weights for current seed list without regenerating noise. - - Parameters - ---------- - seed_weights : List[float] - New weights for the current seed list. - interpolation_method : Literal["linear", "slerp"] - Method for interpolating between seed noise tensors, by default "linear". - """ - self.stream._param_updater.update_seed_weights(seed_weights, interpolation_method) - - def get_current_prompts(self) -> List[Tuple[str, float]]: - """ - Get the current prompt list with weights. - - Returns - ------- - List[Tuple[str, float]] - Current prompt list with weights. - """ - return self.stream._param_updater.get_current_prompts() + def clear_caches(self) -> None: + """Clear all cached prompt embeddings and seed noise tensors.""" + self.stream._param_updater.clear_caches() - def get_current_seeds(self) -> List[Tuple[int, float]]: - """ - Get the current seed list with weights. + def get_stream_state(self, include_caches: bool = False) -> Dict[str, Any]: + """Get a unified snapshot of the current stream state. - Returns - ------- - List[Tuple[int, float]] - Current seed list with weights. - """ - return self.stream._param_updater.get_current_seeds() - - def get_cache_info(self) -> Dict: - """ - Get cache statistics for prompt and seed blending. + Args: + include_caches: When True, include cache statistics in the response - Returns - ------- - Dict - Cache information including hits, misses, and cache sizes. + Returns: + Dict[str, Any]: Consolidated state including prompts/seeds, runtime settings, + module configs, and basic pipeline info. """ - return self.stream._param_updater.get_cache_info() + stream = self.stream + updater = stream._param_updater + + # Prompts / Seeds + prompts = updater.get_current_prompts() + seeds = updater.get_current_seeds() + + # Normalization flags + normalize_prompt_weights = updater.get_normalize_prompt_weights() + normalize_seed_weights = updater.get_normalize_seed_weights() + + # Core runtime params + guidance_scale = getattr(stream, 'guidance_scale', None) + delta = getattr(stream, 'delta', None) + t_index_list = list(getattr(stream, 't_list', [])) + current_seed = getattr(stream, 'current_seed', None) + num_inference_steps = None + try: + if hasattr(stream, 'timesteps') and stream.timesteps is not None: + num_inference_steps = int(len(stream.timesteps)) + except Exception: + pass + + # Resolution and model/pipeline info + state: Dict[str, Any] = { + 'width': getattr(stream, 'width', None), + 'height': getattr(stream, 'height', None), + 'latent_width': getattr(stream, 'latent_width', None), + 'latent_height': getattr(stream, 'latent_height', None), + 'device': getattr(stream, 'device', None).type if hasattr(getattr(stream, 'device', None), 'type') else getattr(stream, 'device', None), + 'dtype': str(getattr(stream, 'dtype', None)), + 'model_type': getattr(stream, 'model_type', None), + 'is_sdxl': getattr(stream, 'is_sdxl', None), + 'is_turbo': getattr(stream, 'is_turbo', None), + 'cfg_type': getattr(stream, 'cfg_type', None), + 'use_denoising_batch': getattr(stream, 'use_denoising_batch', None), + 'batch_size': getattr(stream, 'batch_size', None), + } + + # Blending state + state.update({ + 'prompt_list': prompts, + 'seed_list': seeds, + 'normalize_prompt_weights': normalize_prompt_weights, + 'normalize_seed_weights': normalize_seed_weights, + 'negative_prompt': getattr(updater, '_current_negative_prompt', ""), + }) + + # Core runtime knobs + state.update({ + 'guidance_scale': guidance_scale, + 'delta': delta, + 't_index_list': t_index_list, + 'current_seed': current_seed, + 'num_inference_steps': num_inference_steps, + }) + + # Module configs (ControlNet, IP-Adapter) + try: + controlnet_config = updater._get_current_controlnet_config() + except Exception: + controlnet_config = [] + try: + ipadapter_config = updater._get_current_ipadapter_config() + except Exception: + ipadapter_config = None + state.update({ + 'controlnet_config': controlnet_config, + 'ipadapter_config': ipadapter_config, + }) + + # Optional caches + if include_caches: + try: + state['caches'] = updater.get_cache_info() + except Exception: + state['caches'] = None - def clear_caches(self) -> None: - """Clear all cached prompt embeddings and seed noise tensors.""" - self.stream._param_updater.clear_caches() + return state def cleanup_gpu_memory(self) -> None: """Comprehensive GPU memory cleanup for model switching.""" @@ -1908,162 +1904,3 @@ def cleanup_engines_and_rebuild(self, reduce_batch_size: bool = True, reduce_res logger.info(f" Reduced resolution: {old_width}x{old_height} -> {self.width}x{self.height}") logger.info(" Next model load will rebuild engines with these smaller settings") - - def update_prompt_at_index( - self, - index: int, - new_prompt: str, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" - ) -> None: - """ - Update a specific prompt by index without changing other prompts. - - Parameters - ---------- - index : int - Index of the prompt to update. - new_prompt : str - New prompt text. - prompt_interpolation_method : Literal["linear", "slerp"] - Method for interpolating between prompt embeddings, by default "slerp". - """ - self.stream._param_updater.update_prompt_at_index(index, new_prompt, prompt_interpolation_method) - - def add_prompt( - self, - prompt: str, - weight: float = 1.0, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" - ) -> None: - """ - Add a new prompt to the current blending configuration. - - Parameters - ---------- - prompt : str - Prompt text to add. - weight : float - Weight for the new prompt, by default 1.0. - prompt_interpolation_method : Literal["linear", "slerp"] - Method for interpolating between prompt embeddings, by default "slerp". - """ - self.stream._param_updater.add_prompt(prompt, weight, prompt_interpolation_method) - - def remove_prompt_at_index( - self, - index: int, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" - ) -> None: - """ - Remove a prompt from the current blending configuration by index. - - Parameters - ---------- - index : int - Index of the prompt to remove. - prompt_interpolation_method : Literal["linear", "slerp"] - Method for interpolating between remaining prompt embeddings, by default "slerp". - """ - self.stream._param_updater.remove_prompt_at_index(index, prompt_interpolation_method) - - def update_seed_at_index( - self, - index: int, - new_seed: int, - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: - """ - Update a specific seed by index without changing other seeds. - - Parameters - ---------- - index : int - Index of the seed to update. - new_seed : int - New seed value. - interpolation_method : Literal["linear", "slerp"] - Method for interpolating between seed noise tensors, by default "linear". - """ - self.stream._param_updater.update_seed_at_index(index, new_seed, interpolation_method) - - def add_seed( - self, - seed: int, - weight: float = 1.0, - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: - """ - Add a new seed to the current blending configuration. - - Parameters - ---------- - seed : int - Seed value to add. - weight : float - Weight for the new seed, by default 1.0. - interpolation_method : Literal["linear", "slerp"] - Method for interpolating between seed noise tensors, by default "linear". - """ - self.stream._param_updater.add_seed(seed, weight, interpolation_method) - - def remove_seed_at_index( - self, - index: int, - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: - """ - Remove a seed from the current blending configuration by index. - - Parameters - ---------- - index : int - Index of the seed to remove. - interpolation_method : Literal["linear", "slerp"] - Method for interpolating between remaining seed noise tensors, by default "linear". - """ - self.stream._param_updater.remove_seed_at_index(index, interpolation_method) - - def _apply_ipadapter_patch(self, stream, ipadapter_config: Union[Dict[str, Any], List[Dict[str, Any]]]): - """ - Apply IPAdapter functionality to existing stream (add attributes instead of wrapping) - - Args: - stream: Existing StreamDiffusion or ControlNet pipeline - ipadapter_config: IPAdapter configuration - """ - from streamdiffusion.ipadapter import BaseIPAdapterPipeline - - # Get the underlying StreamDiffusion object - underlying_stream = stream.stream if hasattr(stream, 'stream') else stream - - # Create IPAdapter pipeline for the functionality - ipadapter_pipeline = BaseIPAdapterPipeline( - stream_diffusion=underlying_stream, - device=self.device, - dtype=self.dtype - ) - - # Add IPAdapter functionality to the existing stream by setting attributes - stream.ipadapter = None # Will be set when configured - stream.update_scale = ipadapter_pipeline.update_scale - stream.update_style_image = ipadapter_pipeline.update_style_image - - # Configure the IPAdapter if config provided - if ipadapter_config: - if isinstance(ipadapter_config, list): - # Use first config if multiple provided - config = ipadapter_config[0] - else: - config = ipadapter_config - - if config.get('enabled', True): - ipadapter_pipeline.set_ipadapter( - ipadapter_model_path=config['ipadapter_model_path'], - image_encoder_path=config['image_encoder_path'], - style_image=config.get('style_image'), - scale=config.get('scale', 1.0) - ) - # Copy the configured IPAdapter to the stream - stream.ipadapter = ipadapter_pipeline.ipadapter - stream.scale = ipadapter_pipeline.scale - stream.style_image = ipadapter_pipeline.style_image