diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f63532b84e7c..429b1b8c7190 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -35,18 +35,23 @@ convert_state_dict_to_diffusers, convert_state_dict_to_peft, deprecate, + get_adapter_name, + get_peft_kwargs, is_accelerate_available, is_omegaconf_available, is_peft_available, is_transformers_available, logging, recurse_remove_peft_layers, + scale_lora_layers, + set_adapter_layers, + set_weights_and_activate_adapters, ) from .utils.import_utils import BACKENDS_MAPPING if is_transformers_available(): - from transformers import CLIPTextModel, CLIPTextModelWithProjection + from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel if is_accelerate_available(): from accelerate import init_empty_weights @@ -1100,7 +1105,9 @@ class LoraLoaderMixin: num_fused_loras = 0 use_peft_backend = USE_PEFT_BACKEND - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. @@ -1120,6 +1127,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di See [`~loaders.LoraLoaderMixin.lora_state_dict`]. kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. """ # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) @@ -1143,6 +1153,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di text_encoder=self.text_encoder, lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, + adapter_name=adapter_name, _pipeline=self, ) @@ -1500,6 +1511,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, low_cpu_mem_usage=None, + adapter_name=None, _pipeline=None, ): """ @@ -1523,6 +1535,9 @@ def load_lora_into_text_encoder( tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT @@ -1584,19 +1599,22 @@ def load_lora_into_text_encoder( if cls.use_peft_backend: from peft import LoraConfig - lora_rank = list(rank.values())[0] - # By definition, the scale should be alpha divided by rank. - # https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71 - alpha = lora_scale * lora_rank + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict) - target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] - if patch_mlp: - target_modules += ["fc1", "fc2"] + lora_config = LoraConfig(**lora_config_kwargs) - # TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873 - lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) - text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) + # inject LoRA layers and load the state dict + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) is_model_cpu_offload = False is_sequential_cpu_offload = False @@ -2178,6 +2196,81 @@ def unfuse_text_encoder_lora(text_encoder): self.num_fused_loras -= 1 + def set_adapter_for_text_encoder( + self, + adapter_names: Union[List[str], str], + text_encoder: Optional[PreTrainedModel] = None, + text_encoder_weights: List[float] = None, + ): + """ + Sets the adapter layers for the text encoder. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + text_encoder_weights (`List[float]`, *optional*): + The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. + """ + if not self.use_peft_backend: + raise ValueError("PEFT backend is required for this method.") + + def process_weights(adapter_names, weights): + if weights is None: + weights = [1.0] * len(adapter_names) + elif isinstance(weights, float): + weights = [weights] + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" + ) + return weights + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + text_encoder_weights = process_weights(adapter_names, text_encoder_weights) + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError( + "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." + ) + set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) + + def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None): + """ + Disables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to disable the LoRA layers for. If `None`, it will try to get the + `text_encoder` attribute. + """ + if not self.use_peft_backend: + raise ValueError("PEFT backend is required for this method.") + + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(text_encoder, enabled=False) + + def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None): + """ + Enables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ + if not self.use_peft_backend: + raise ValueError("PEFT backend is required for this method.") + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(self.text_encoder, enabled=True) + class FromSingleFileMixin: """ diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 07eeae712f71..fa8258fedc86 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -19,7 +19,7 @@ from torch import nn from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules -from ..utils import logging +from ..utils import logging, scale_lora_layers logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -27,11 +27,7 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False): if use_peft_backend: - from peft.tuners.lora import LoraLayer - - for module in text_encoder.modules(): - if isinstance(module, LoraLayer): - module.scaling[module.active_adapter] = lora_scale + scale_lora_layers(text_encoder, weight=lora_scale) else: for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3bc21759caae..2c9edbd58314 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -84,7 +84,14 @@ from .loading_utils import load_image from .logging import get_logger from .outputs import BaseOutput -from .peft_utils import recurse_remove_peft_layers +from .peft_utils import ( + get_adapter_name, + get_peft_kwargs, + recurse_remove_peft_layers, + scale_lora_layers, + set_adapter_layers, + set_weights_and_activate_adapters, +) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index dd9e8384c61b..253a57a2270e 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -14,6 +14,8 @@ """ PEFT utilities: Utilities related to peft library """ +import collections + from .import_utils import is_torch_available @@ -68,3 +70,98 @@ def recurse_remove_peft_layers(model): torch.cuda.empty_cache() return model + + +def scale_lora_layers(model, weight): + """ + Adjust the weightage given to the LoRA layers of the model. + + Args: + model (`torch.nn.Module`): + The model to scale. + weight (`float`): + The weight to be given to the LoRA layers. + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.scale_layer(weight) + + +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict): + rank_pattern = {} + alpha_pattern = {} + r = lora_alpha = list(rank_dict.values())[0] + if len(set(rank_dict.values())) > 1: + # get the rank occuring the most number of times + r = collections.Counter(rank_dict.values()).most_common()[0][0] + + # for modules with rank different from the most occuring rank, add it to the `rank_pattern` + rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) + rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} + + if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1: + # get the alpha occuring the most number of times + lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] + + # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` + alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) + alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} + + # layer names without the Diffusers specific + target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) + + lora_config_kwargs = { + "r": r, + "lora_alpha": lora_alpha, + "rank_pattern": rank_pattern, + "alpha_pattern": alpha_pattern, + "target_modules": target_modules, + } + return lora_config_kwargs + + +def get_adapter_name(model): + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return f"default_{len(module.r)}" + return "default_0" + + +def set_adapter_layers(model, enabled=True): + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + # The recent version of PEFT needs to call `enable_adapters` instead + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=False) + else: + module.disable_adapters = True + + +def set_weights_and_activate_adapters(model, adapter_names, weights): + from peft.tuners.tuners_utils import BaseTunerLayer + + # iterate over each adapter, make it active and set the corresponding scaling weight + for adapter_name, weight in zip(adapter_names, weights): + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + # For backward compatbility with previous PEFT versions + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + else: + module.active_adapter = adapter_name + module.scale_layer(weight) + + # set multiple active adapters + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + # For backward compatbility with previous PEFT versions + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_names) + else: + module.active_adapter = adapter_names