From a95b4c478137c893a041e4fd1a4696b5edebd685 Mon Sep 17 00:00:00 2001 From: Anatoly Belikov Date: Tue, 30 Jul 2024 10:48:45 +0300 Subject: [PATCH 1/2] handle clip_skip and lora scale in lpw pipelines --- multigen/lpw_stable_diffusion.py | 46 ++++++++++++++++++----------- multigen/lpw_stable_diffusion_xl.py | 33 ++++++++++++--------- 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/multigen/lpw_stable_diffusion.py b/multigen/lpw_stable_diffusion.py index 30bd6d8..ec27acd 100644 --- a/multigen/lpw_stable_diffusion.py +++ b/multigen/lpw_stable_diffusion.py @@ -11,25 +11,21 @@ from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, - deprecate, - logging, -) -from diffusers.utils.torch_utils import randn_tensor -from diffusers.utils import ( USE_PEFT_BACKEND, deprecate, logging, - replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) +from diffusers.utils.torch_utils import randn_tensor # ------------------------------------------------------------------------------ @@ -227,8 +223,7 @@ def get_unweighted_text_embeddings( prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device)) text_embedding = prompt_embeds[0] else: - prompt_embeds = pipe.text_encoder( - text_input_chunk.to(pipe.device), output_hidden_states=True) + prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. @@ -298,7 +293,7 @@ def get_weighted_text_embeddings( """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(pipe, LoraLoaderMixin): + if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin): pipe._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -372,11 +367,7 @@ def get_weighted_text_embeddings( # get the embeddings text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.tokenizer.model_max_length, - no_boseos_middle=no_boseos_middle, - clip_skip=clip_skip + pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip ) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) if uncond_prompt is not None: @@ -385,7 +376,7 @@ def get_weighted_text_embeddings( uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, - clip_skip=clip_skip + clip_skip=clip_skip, ) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) @@ -403,7 +394,7 @@ def get_weighted_text_embeddings( uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if pipe.text_encoder is not None: - if isinstance(pipe, LoraLoaderMixin) and USE_PEFT_BACKEND: + if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(pipe.text_encoder, lora_scale) @@ -454,7 +445,11 @@ def preprocess_mask(mask, batch_size, scale_factor=8): class StableDiffusionLongPromptWeightingPipeline( - DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing @@ -590,6 +585,8 @@ def _encode_prompt( max_embeddings_multiples=3, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + clip_skip: Optional[int] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -639,6 +636,7 @@ def _encode_prompt( uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, clip_skip=clip_skip, + lora_scale=lora_scale, ) if prompt_embeds is None: prompt_embeds = prompt_embeds1 @@ -832,6 +830,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, + clip_skip: Optional[int] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): @@ -907,6 +906,9 @@ def __call__( is_cancelled_callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. If the function returns `True`, the inference will be cancelled. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -945,6 +947,7 @@ def __call__( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None # 3. Encode input prompt prompt_embeds = self._encode_prompt( @@ -956,6 +959,8 @@ def __call__( max_embeddings_multiples, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + lora_scale=lora_scale, ) dtype = prompt_embeds.dtype @@ -1086,6 +1091,7 @@ def text2img( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, + clip_skip=None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): @@ -1143,6 +1149,9 @@ def text2img( is_cancelled_callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. If the function returns `True`, the inference will be cancelled. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -1177,6 +1186,7 @@ def text2img( return_dict=return_dict, callback=callback, is_cancelled_callback=is_cancelled_callback, + clip_skip=clip_skip, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, ) diff --git a/multigen/lpw_stable_diffusion_xl.py b/multigen/lpw_stable_diffusion_xl.py index a2a6c1a..13d1e2a 100644 --- a/multigen/lpw_stable_diffusion_xl.py +++ b/multigen/lpw_stable_diffusion_xl.py @@ -22,29 +22,31 @@ from diffusers import DiffusionPipeline, StableDiffusionXLPipeline from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( + USE_PEFT_BACKEND, deprecate, is_accelerate_available, is_accelerate_version, is_invisible_watermark_available, logging, replace_example_docstring, -) -from diffusers.utils.torch_utils import randn_tensor -from diffusers.utils import ( - USE_PEFT_BACKEND, - deprecate, - logging, - replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) +from diffusers.utils.torch_utils import randn_tensor + if is_invisible_watermark_available(): from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker @@ -263,7 +265,7 @@ def get_weighted_text_embeddings_sdxl( num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, - lora_scale: Optional[int] = None + lora_scale: Optional[int] = None, ): """ This function can process long prompt with weights, no length limitation @@ -580,7 +582,7 @@ class SDXLLongPromptWeightingPipeline( StableDiffusionMixin, FromSingleFileMixin, IPAdapterMixin, - LoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ): r""" @@ -592,8 +594,8 @@ class SDXLLongPromptWeightingPipeline( The pipeline also inherits the following loading methods: - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings Args: @@ -774,7 +776,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): @@ -1643,7 +1645,9 @@ def __call__( image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 3. Encode input prompt - (self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None) + lora_scale = ( + self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None + ) negative_prompt = negative_prompt if negative_prompt is not None else "" @@ -1658,6 +1662,7 @@ def __call__( neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, + lora_scale=lora_scale, ) dtype = prompt_embeds.dtype From fe9644a42f1f873181b808de32492caa035224f6 Mon Sep 17 00:00:00 2001 From: Anatoly Belikov Date: Tue, 30 Jul 2024 11:16:37 +0300 Subject: [PATCH 2/2] StableDiffusionLoraLoaderMixin -> LoraLoaderMixin --- multigen/lpw_stable_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/multigen/lpw_stable_diffusion.py b/multigen/lpw_stable_diffusion.py index ec27acd..035a668 100644 --- a/multigen/lpw_stable_diffusion.py +++ b/multigen/lpw_stable_diffusion.py @@ -11,7 +11,7 @@ from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.pipelines.pipeline_utils import StableDiffusionMixin @@ -293,7 +293,7 @@ def get_weighted_text_embeddings( """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin): + if lora_scale is not None and isinstance(pipe, LoraLoaderMixin): pipe._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -394,7 +394,7 @@ def get_weighted_text_embeddings( uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if pipe.text_encoder is not None: - if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + if isinstance(pipe, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(pipe.text_encoder, lora_scale) @@ -448,7 +448,7 @@ class StableDiffusionLongPromptWeightingPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, - StableDiffusionLoraLoaderMixin, + LoraLoaderMixin, FromSingleFileMixin, ): r"""