From e50a10820874e4ff4fcf90adfb1d0e14cb7fb5c8 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 17:22:20 +0530 Subject: [PATCH 01/18] Add AuraFlowLoraLoaderMixin --- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 334 ++++++++++++++++++ .../transformers/auraflow_transformer_2d.py | 25 +- .../pipelines/aura_flow/pipeline_aura_flow.py | 3 +- 4 files changed, 360 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index b59150376599..92c701936eb6 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -64,6 +64,7 @@ def text_encoder_attn_modules(text_encoder): "AmusedLoraLoaderMixin", "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", + "AuraFlowLoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "LTXVideoLoraLoaderMixin", "LoraLoaderMixin", @@ -95,6 +96,7 @@ def text_encoder_attn_modules(text_encoder): Mochi1LoraLoaderMixin, SanaLoraLoaderMixin, SD3LoraLoaderMixin, + AuraFlowLoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, ) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b8c44e480093..79ad45bd3b81 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1641,6 +1641,337 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t super().unfuse_lora(components=components) +class AuraFlowLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`AuraFlowTransformer2DModel`] + Specific to [`AuraFlowPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) + + + @classmethod + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers): + raise ValueError( + "You must pass `transformer_lora_layers`." + ) + + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components) + class FluxLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`FluxTransformer2DModel`], @@ -1649,6 +1980,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin): Specific to [`StableDiffusion3Pipeline`]. """ +# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially +# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. +class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index b3f29e6b6224..ad03f98338b1 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn @@ -32,6 +32,8 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormZero, FP32LayerNorm +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -253,7 +255,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). @@ -451,6 +453,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -463,7 +466,19 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -538,6 +553,10 @@ def custom_forward(*inputs): shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) ) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 8737b219c833..72a7e08f4fb6 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -24,6 +24,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...loaders import AuraFlowLoraLoaderMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -104,7 +105,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AuraFlowPipeline(DiffusionPipeline): +class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): r""" Args: tokenizer (`T5TokenizerFast`): From 658d0586866e2920002d942b097ef96a9b6a70f9 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:46:20 +0530 Subject: [PATCH 02/18] Add comments, remove qkv fusion --- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 7 ++++++- .../transformers/auraflow_transformer_2d.py | 19 +++++++++---------- .../pipelines/aura_flow/pipeline_aura_flow.py | 2 +- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 92c701936eb6..ec233ecf4a70 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -91,12 +91,12 @@ def text_encoder_attn_modules(text_encoder): AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, + AuraFlowLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, SanaLoraLoaderMixin, SD3LoraLoaderMixin, - AuraFlowLoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, ) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 79ad45bd3b81..1fd0bdc9c395 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1649,9 +1649,9 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict @validate_hf_hub_args def lora_state_dict( cls, @@ -1742,6 +1742,7 @@ def lora_state_dict( return state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -1788,6 +1789,7 @@ def load_lora_weights( @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1866,6 +1868,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # Unsafe code /> @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -1913,6 +1916,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -1956,6 +1960,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index ad03f98338b1..13fe6309ce6e 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -20,7 +20,8 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( Attention, @@ -32,8 +33,6 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormZero, FP32LayerNorm -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -344,8 +343,8 @@ def __init__( self.gradient_checkpointing = False - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -453,7 +452,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -466,18 +465,18 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 72a7e08f4fb6..5c9c803e36e3 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -18,13 +18,13 @@ from transformers import T5Tokenizer, UMT5EncoderModel from ...image_processor import VaeImageProcessor +from ...loaders import AuraFlowLoraLoaderMixin from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...loaders import AuraFlowLoraLoaderMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 4208d09601afd4afdce33f23d3f6921822703f67 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:50:50 +0530 Subject: [PATCH 03/18] Add Tests --- tests/lora/test_lora_layers_af.py | 90 +++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/lora/test_lora_layers_af.py diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py new file mode 100644 index 000000000000..2b050aa74da9 --- /dev/null +++ b/tests/lora/test_lora_layers_af.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +from diffusers import ( + FlowMatchEulerDiscreteScheduler, + AuraFlowPipeline, +) +from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device + + +if is_peft_available(): + pass + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = AuraFlowPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_kwargs = {} + transformer_kwargs = { + "sample_size": 64, + "patch_size": 2, + "in_channels": 4, + "num_mmdit_layers": 4, + "num_single_dit_layers": 32, + "attention_head_dim": 256, + "num_attention_heads": 12, + "joint_attention_dim": 2048, + "caption_projection_dim": 3072, + "out_channels": 4, + "pos_embed_max_size": 1024, + } + vae_kwargs = { + "sample_size": 1024, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "layers_per_block": 2, + "latent_channels": 4, + "norm_num_groups": 32, + "use_quant_conv": True, + "use_post_quant_conv": True, + "shift_factor": None, + "scaling_factor": 0.13025, + } + has_three_text_encoders = False + + @require_torch_gpu + def test_af_lora(self): + """ + Test loading the loras that are saved with the diffusers and peft formats. + Related PR: https://github.com/huggingface/diffusers/pull/8584 + """ + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + lora_model_id = "Warlord-K/gorkem-auraflow-lora" + + lora_filename = "pytorch_lora_weights.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.unload_lora_weights() + + lora_filename = "lora_peft_format.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) From 98b19f66aa67bc193e491a85b6ebb4c59587db8c Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:53:49 +0530 Subject: [PATCH 04/18] Add AuraFlowLoraLoaderMixin to documentation --- docs/source/en/api/loaders/lora.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 5dde55ada562..3c1ab7e5b6eb 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux). - [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox). - [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi). +- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow). - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. @@ -41,6 +42,7 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin +<<<<<<< HEAD ## FluxLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin @@ -52,6 +54,11 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## Mochi1LoraLoaderMixin [[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin +======= +## AuraFlowLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin +>>>>>>> a45d48a56 (Add AuraFlowLoraLoaderMixin to documentation) ## AmusedLoraLoaderMixin From 71f8bace8b616fdbeada27d0e71c0522ad52aa61 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Mon, 12 Aug 2024 01:44:56 +0530 Subject: [PATCH 05/18] Add Suggested changes --- src/diffusers/loaders/lora_pipeline.py | 471 +++++++++++++++++- src/diffusers/loaders/peft.py | 1 + .../transformers/auraflow_transformer_2d.py | 2 +- tests/lora/test_lora_layers_af.py | 65 +-- 4 files changed, 496 insertions(+), 43 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1fd0bdc9c395..4d13134e47b1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -22,6 +22,7 @@ USE_PEFT_BACKEND, convert_state_dict_to_diffusers, convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, deprecate, get_adapter_name, get_peft_kwargs, @@ -1747,7 +1748,8 @@ def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. @@ -1787,6 +1789,473 @@ def load_lora_weights( _pipeline=self, ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components) + +class AuraFlowLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`AuraFlowTransformer2DModel`] + Specific to [`AuraFlowPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for transformer + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a791a250af08..7a8332bb289a 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -55,6 +55,7 @@ "MochiTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, + "AuraFlowTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 13fe6309ce6e..5c8e6720ecc2 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -254,7 +254,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 2b050aa74da9..9615249633cd 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -15,6 +15,8 @@ import sys import unittest +from transformers import AutoTokenizer, T5EncoderModel + from diffusers import ( FlowMatchEulerDiscreteScheduler, AuraFlowPipeline, @@ -31,60 +33,41 @@ @require_peft_backend -class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_kwargs = {} + uses_flow_matching = True transformer_kwargs = { "sample_size": 64, - "patch_size": 2, + "patch_size": 1, "in_channels": 4, - "num_mmdit_layers": 4, - "num_single_dit_layers": 32, - "attention_head_dim": 256, - "num_attention_heads": 12, - "joint_attention_dim": 2048, - "caption_projection_dim": 3072, + "num_mmdit_layers": 1, + "num_single_dit_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "caption_projection_dim": 32, "out_channels": 4, - "pos_embed_max_size": 1024, + "pos_embed_max_size": 32, } + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + vae_kwargs = { - "sample_size": 1024, + "sample_size": 32, "in_channels": 3, "out_channels": 3, - "block_out_channels": [ - 128, - 256, - 512, - 512 - ], - "layers_per_block": 2, + "block_out_channels": (4,), + "layers_per_block": 1, "latent_channels": 4, - "norm_num_groups": 32, - "use_quant_conv": True, - "use_post_quant_conv": True, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, "shift_factor": None, "scaling_factor": 0.13025, } - has_three_text_encoders = False - - @require_torch_gpu - def test_af_lora(self): - """ - Test loading the loras that are saved with the diffusers and peft formats. - Related PR: https://github.com/huggingface/diffusers/pull/8584 - """ - components = self.get_dummy_components() - - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - lora_model_id = "Warlord-K/gorkem-auraflow-lora" - - lora_filename = "pytorch_lora_weights.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.unload_lora_weights() - lora_filename = "lora_peft_format.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + @property + def output_shape(self): + return (1, 64, 64, 3) \ No newline at end of file From 0eee03eefea320d35b7876eefb6df05994d4290d Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Mon, 12 Aug 2024 11:06:37 +0530 Subject: [PATCH 06/18] Change attention_kwargs->joint_attention_kwargs --- src/diffusers/loaders/lora_pipeline.py | 2 +- .../models/transformers/auraflow_transformer_2d.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4d13134e47b1..3ed3ffc71b5c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2125,6 +2125,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -2429,7 +2430,6 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 5c8e6720ecc2..90f4f3fd5d8e 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -452,7 +452,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, - attention_kwargs: Optional[Dict[str, Any]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -465,18 +465,18 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): From 4e4f780e5466c67c1286fed049055490a7dd2cf6 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:17:39 +0100 Subject: [PATCH 07/18] Rebasing derp. --- src/diffusers/loaders/lora_pipeline.py | 476 +------------------------ 1 file changed, 1 insertion(+), 475 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3ed3ffc71b5c..3377ccf10f95 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1641,479 +1641,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t """ super().unfuse_lora(components=components) - -class AuraFlowLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`AuraFlowTransformer2DModel`] - Specific to [`AuraFlowPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = cls._fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - ) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=None, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`SD3Transformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) - - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - See `LoRALinearLayer` for more details. - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - keys = list(state_dict.keys()) - prefix = cls.text_encoder_name if prefix is None else prefix - - # Safe prefix to check with. - if any(cls.text_encoder_name in key for key in keys): - # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [ - k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix - ] - network_alphas = { - k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not (transformer_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer - def fuse_lora( - self, - components: List[str] = ["transformer", "text_encoder"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names - ) - - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - """ - super().unfuse_lora(components=components) - class AuraFlowLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`AuraFlowTransformer2DModel`] @@ -2338,7 +1865,6 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # Unsafe code /> @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -2386,7 +1912,6 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -2446,6 +1971,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components) + class FluxLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`FluxTransformer2DModel`], From c07d1f51858cf9114b0880eff27a9becfd6d0fa8 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 16:26:26 +0000 Subject: [PATCH 08/18] fix --- src/diffusers/loaders/lora_pipeline.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3377ccf10f95..1171ea5c76d3 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1979,10 +1979,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin): Specific to [`StableDiffusion3Pipeline`]. """ - -# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially -# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. -class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME From 1b7f99fb841445fa3c4e5ad91f25f29267aa8c6b Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 16:27:08 +0000 Subject: [PATCH 09/18] fix --- src/diffusers/loaders/lora_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1171ea5c76d3..376596fe7cb2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1979,6 +1979,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): Specific to [`StableDiffusion3Pipeline`]. """ + _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME From 875a3e0f908fb286014fea6908c582fed488e820 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:27:28 +0100 Subject: [PATCH 10/18] Quality fixes. --- src/diffusers/loaders/lora_pipeline.py | 2 +- tests/lora/test_lora_layers_af.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 376596fe7cb2..b08ed96e99eb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1652,7 +1652,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 9615249633cd..b61e33f7eba5 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -18,10 +18,10 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( - FlowMatchEulerDiscreteScheduler, AuraFlowPipeline, + FlowMatchEulerDiscreteScheduler, ) -from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import is_peft_available, require_peft_backend if is_peft_available(): @@ -70,4 +70,4 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @property def output_shape(self): - return (1, 64, 64, 3) \ No newline at end of file + return (1, 64, 64, 3) From a242d7ae6915d0de32c42437f22820f502f99cfd Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 16:29:05 +0000 Subject: [PATCH 11/18] make style --- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 8 +++----- .../models/transformers/auraflow_transformer_2d.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index ec233ecf4a70..70612cdea292 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -89,9 +89,9 @@ def text_encoder_attn_modules(text_encoder): from .ip_adapter import IPAdapterMixin from .lora_pipeline import ( AmusedLoraLoaderMixin, + AuraFlowLoraLoaderMixin, CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, - AuraFlowLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b08ed96e99eb..a0ae31575327 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1641,10 +1641,10 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t """ super().unfuse_lora(components=components) + class AuraFlowLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`AuraFlowTransformer2DModel`] - Specific to [`AuraFlowPipeline`]. + Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. """ _lora_loadable_modules = ["transformer"] @@ -1896,9 +1896,7 @@ def save_lora_weights( state_dict = {} if not (transformer_lora_layers): - raise ValueError( - "You must pass `transformer_lora_layers`." - ) + raise ValueError("You must pass `transformer_lora_layers`.") state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 90f4f3fd5d8e..2fa615a78bb4 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( From a73df6b94929988f45229efe9c2691d43de1d4ed Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:33:32 +0100 Subject: [PATCH 12/18] `make fix-copies` --- src/diffusers/loaders/lora_pipeline.py | 91 ++++++------------- .../transformers/auraflow_transformer_2d.py | 2 +- 2 files changed, 29 insertions(+), 64 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a0ae31575327..e570dc349a4c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1703,7 +1703,8 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. """ - # Load the main state dict first which has the LoRA layers for transformer + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1724,7 +1725,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1739,6 +1740,12 @@ def lora_state_dict( allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + return state_dict def load_lora_weights( @@ -1787,7 +1794,9 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1801,68 +1810,24 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) - - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def save_lora_weights( diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 2fa615a78bb4..1f93a1fc0d87 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -343,8 +343,8 @@ def __init__( self.gradient_checkpointing = False - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: From 894eac0aa81f27274902e86b9d4f0122cebd3fe3 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:39:08 +0100 Subject: [PATCH 13/18] `ruff check --fix` --- src/diffusers/loaders/lora_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e570dc349a4c..64eb72dd860b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -22,7 +22,6 @@ USE_PEFT_BACKEND, convert_state_dict_to_diffusers, convert_state_dict_to_peft, - convert_unet_state_dict_to_peft, deprecate, get_adapter_name, get_peft_kwargs, From 2b364161929981b71349119f891ec5ef8845c16d Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 15 Dec 2024 05:08:54 +0100 Subject: [PATCH 14/18] Attept 1 to fix tests. --- tests/lora/test_lora_layers_af.py | 34 +++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index b61e33f7eba5..1b2835e96427 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -15,13 +15,19 @@ import sys import unittest +import torch from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( AuraFlowPipeline, + AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler, ) -from diffusers.utils.testing_utils import is_peft_available, require_peft_backend +from diffusers.utils.testing_utils import ( + floats_tensor, + is_peft_available, + require_peft_backend, +) if is_peft_available(): @@ -49,8 +55,9 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "joint_attention_dim": 32, "caption_projection_dim": 32, "out_channels": 4, - "pos_embed_max_size": 32, + "pos_embed_max_size": 64, } + transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" @@ -71,3 +78,26 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @property def output_shape(self): return (1, 64, 64, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs From 6b762b800c155dc8cb29822c06973c72d8f7e4d2 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 15 Dec 2024 05:57:02 +0100 Subject: [PATCH 15/18] Attept 2 to fix tests. --- src/diffusers/loaders/lora_pipeline.py | 199 +++++++++++++++++++++++-- tests/lora/test_lora_layers_af.py | 6 +- 2 files changed, 193 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 64eb72dd860b..9d2773293780 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1646,7 +1646,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. """ - _lora_loadable_modules = ["transformer"] + _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME @classmethod @@ -1747,48 +1747,75 @@ def lora_state_dict( return state_dict + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( + self.load_lora_into_unet( state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + network_alphas=network_alphas, + unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=getattr(self, self.text_encoder_name) + if not hasattr(self, "text_encoder") + else self.text_encoder, + lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -1828,6 +1855,158 @@ def load_lora_into_transformer( low_cpu_mem_usage=low_cpu_mem_usage, ) + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + **peft_kwargs, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + @classmethod def save_lora_weights( cls, diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 1b2835e96427..62364380eb83 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -16,7 +16,7 @@ import unittest import torch -from transformers import AutoTokenizer, T5EncoderModel +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, UMT5EncoderModel from diffusers import ( AuraFlowPipeline, @@ -59,7 +59,9 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" - text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" + + text_encoder_target_modules = ["q", "k", "v", "o"] vae_kwargs = { "sample_size": 32, From bc2a4663fd13e304b48b177b8d364186e3e9d527 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 15 Dec 2024 06:45:30 +0100 Subject: [PATCH 16/18] Attept 3 to fix tests. --- src/diffusers/loaders/lora_pipeline.py | 35 ++++++++++---- .../pipelines/aura_flow/pipeline_aura_flow.py | 46 ++++++++++++++++++- tests/lora/test_lora_layers_af.py | 13 +++++- 3 files changed, 81 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9d2773293780..25efcbbc7964 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1651,7 +1651,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1700,10 +1700,11 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. """ # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. + # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1712,6 +1713,7 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False @@ -1738,16 +1740,32 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + network_alphas = None + # TODO: replace it with a method from `state_dict_utils` + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -1798,10 +1816,9 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_unet( + self.load_lora_into_transformer( state_dict, - network_alphas=network_alphas, - unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 5c9c803e36e3..cfb002ea7764 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers import T5Tokenizer, UMT5EncoderModel @@ -22,7 +22,14 @@ from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging, replace_example_docstring +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -125,6 +132,8 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" + transformer_name = "transformer" + text_encoder_name = "text_encoder" def __init__( self, @@ -215,6 +224,7 @@ def encode_prompt( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 256, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -241,10 +251,21 @@ def encode_prompt( negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ if device is None: device = self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -402,6 +423,7 @@ def __call__( max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. @@ -457,6 +479,10 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: @@ -479,6 +505,8 @@ def __call__( negative_prompt_attention_mask, ) + self._joint_attention_kwargs = joint_attention_kwargs + # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -488,6 +516,9 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -511,6 +542,7 @@ def __call__( prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -551,6 +583,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, return_dict=False, + joint_attention_kwargs=self.joint_attention_kwargs, )[0] # perform guidance @@ -579,7 +612,16 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + if self.text_encoder is not None: + if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + if not return_dict: return (image,) return ImagePipelineOutput(images=image) + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 62364380eb83..247e8f83aab2 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -16,7 +16,7 @@ import unittest import torch -from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, UMT5EncoderModel +from transformers import AutoTokenizer, UMT5EncoderModel from diffusers import ( AuraFlowPipeline, @@ -41,7 +41,7 @@ @require_peft_backend class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} uses_flow_matching = True transformer_kwargs = { @@ -60,6 +60,7 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" + attention_kwargs_name = "joint_attention_kwargs" text_encoder_target_modules = ["q", "k", "v", "o"] @@ -103,3 +104,11 @@ def get_dummy_inputs(self, with_generator=True): pipeline_inputs.update({"generator": generator}) return noise, input_ids, pipeline_inputs + + @unittest.skip("Not supported in AuraFlow.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Not supported in AuraFlow.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass From 1c7909565a8525b4891e01e2ad7feb6fc301e915 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:23:21 +0100 Subject: [PATCH 17/18] Address review comments. --- src/diffusers/loaders/lora_pipeline.py | 46 ++++++++----------- .../pipelines/aura_flow/pipeline_aura_flow.py | 2 - 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 25efcbbc7964..8632e9c8c106 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1648,10 +1648,11 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1700,11 +1701,10 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - weight_name (`str`, *optional*, defaults to None): - Name of the serialized state dict file. + """ # Load the main state dict first which has the LoRA layers for either of - # UNet and text encoder or both. + # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1713,7 +1713,6 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False @@ -1740,30 +1739,14 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - network_alphas = None - # TODO: replace it with a method from `state_dict_utils` - if all( - ( - k.startswith("lora_te_") - or k.startswith("lora_unet_") - or k.startswith("lora_te1_") - or k.startswith("lora_te2_") - ) - for k in state_dict.keys() - ): - # Map SDXL blocks correctly. - if unet_config is not None: - # use unet config to remap block numbers - state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - - return state_dict, network_alphas + return state_dict # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer def load_lora_weights( @@ -2025,10 +2008,12 @@ def load_lora_into_text_encoder( # Unsafe code /> @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -2042,6 +2027,9 @@ def save_lora_weights( Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -2055,10 +2043,14 @@ def save_lora_weights( """ state_dict = {} - if not (transformer_lora_layers): - raise ValueError("You must pass `transformer_lora_layers`.") + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index cfb002ea7764..49c89227a193 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -132,8 +132,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" - transformer_name = "transformer" - text_encoder_name = "text_encoder" def __init__( self, From 9454e845e6aeb828b9c4560946d1b34d8ec761d7 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:25:49 +0100 Subject: [PATCH 18/18] Rebasing derp. --- docs/source/en/api/loaders/lora.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 3c1ab7e5b6eb..34a5416b1ccc 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -42,7 +42,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin -<<<<<<< HEAD ## FluxLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin @@ -54,11 +53,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## Mochi1LoraLoaderMixin [[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin -======= ## AuraFlowLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin ->>>>>>> a45d48a56 (Add AuraFlowLoraLoaderMixin to documentation) ## AmusedLoraLoaderMixin