From bc2a4663fd13e304b48b177b8d364186e3e9d527 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 15 Dec 2024 06:45:30 +0100 Subject: [PATCH] Attept 3 to fix tests. --- src/diffusers/loaders/lora_pipeline.py | 35 ++++++++++---- .../pipelines/aura_flow/pipeline_aura_flow.py | 46 ++++++++++++++++++- tests/lora/test_lora_layers_af.py | 13 +++++- 3 files changed, 81 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9d2773293780..25efcbbc7964 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1651,7 +1651,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1700,10 +1700,11 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. """ # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. + # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1712,6 +1713,7 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False @@ -1738,16 +1740,32 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + network_alphas = None + # TODO: replace it with a method from `state_dict_utils` + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -1798,10 +1816,9 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_unet( + self.load_lora_into_transformer( state_dict, - network_alphas=network_alphas, - unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 5c9c803e36e3..cfb002ea7764 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers import T5Tokenizer, UMT5EncoderModel @@ -22,7 +22,14 @@ from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging, replace_example_docstring +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -125,6 +132,8 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" + transformer_name = "transformer" + text_encoder_name = "text_encoder" def __init__( self, @@ -215,6 +224,7 @@ def encode_prompt( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 256, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -241,10 +251,21 @@ def encode_prompt( negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ if device is None: device = self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -402,6 +423,7 @@ def __call__( max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. @@ -457,6 +479,10 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: @@ -479,6 +505,8 @@ def __call__( negative_prompt_attention_mask, ) + self._joint_attention_kwargs = joint_attention_kwargs + # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -488,6 +516,9 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -511,6 +542,7 @@ def __call__( prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -551,6 +583,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, return_dict=False, + joint_attention_kwargs=self.joint_attention_kwargs, )[0] # perform guidance @@ -579,7 +612,16 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + if self.text_encoder is not None: + if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + if not return_dict: return (image,) return ImagePipelineOutput(images=image) + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 62364380eb83..247e8f83aab2 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -16,7 +16,7 @@ import unittest import torch -from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, UMT5EncoderModel +from transformers import AutoTokenizer, UMT5EncoderModel from diffusers import ( AuraFlowPipeline, @@ -41,7 +41,7 @@ @require_peft_backend class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} uses_flow_matching = True transformer_kwargs = { @@ -60,6 +60,7 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" + attention_kwargs_name = "joint_attention_kwargs" text_encoder_target_modules = ["q", "k", "v", "o"] @@ -103,3 +104,11 @@ def get_dummy_inputs(self, with_generator=True): pipeline_inputs.update({"generator": generator}) return noise, input_ids, pipeline_inputs + + @unittest.skip("Not supported in AuraFlow.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Not supported in AuraFlow.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass