From 1c7909565a8525b4891e01e2ad7feb6fc301e915 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:23:21 +0100 Subject: [PATCH] Address review comments. --- src/diffusers/loaders/lora_pipeline.py | 46 ++++++++----------- .../pipelines/aura_flow/pipeline_aura_flow.py | 2 - 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 25efcbbc7964..8632e9c8c106 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1648,10 +1648,11 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1700,11 +1701,10 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - weight_name (`str`, *optional*, defaults to None): - Name of the serialized state dict file. + """ # Load the main state dict first which has the LoRA layers for either of - # UNet and text encoder or both. + # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1713,7 +1713,6 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False @@ -1740,30 +1739,14 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - network_alphas = None - # TODO: replace it with a method from `state_dict_utils` - if all( - ( - k.startswith("lora_te_") - or k.startswith("lora_unet_") - or k.startswith("lora_te1_") - or k.startswith("lora_te2_") - ) - for k in state_dict.keys() - ): - # Map SDXL blocks correctly. - if unet_config is not None: - # use unet config to remap block numbers - state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - - return state_dict, network_alphas + return state_dict # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer def load_lora_weights( @@ -2025,10 +2008,12 @@ def load_lora_into_text_encoder( # Unsafe code /> @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -2042,6 +2027,9 @@ def save_lora_weights( Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -2055,10 +2043,14 @@ def save_lora_weights( """ state_dict = {} - if not (transformer_lora_layers): - raise ValueError("You must pass `transformer_lora_layers`.") + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index cfb002ea7764..49c89227a193 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -132,8 +132,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" - transformer_name = "transformer" - text_encoder_name = "text_encoder" def __init__( self,