From 493f9529d75b868da58bb16f711ed39c73f39b8c Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 22 Sep 2023 13:03:39 +0200 Subject: [PATCH] [`PEFT` / `LoRA`] PEFT integration - text encoder (#5058) * more fixes * up * up * style * add in setup * oops * more changes * v1 rzfactor CI * Apply suggestions from code review Co-authored-by: Patrick von Platen * few todos * protect torch import * style * fix fuse text encoder * Update src/diffusers/loaders.py Co-authored-by: Sayak Paul * replace with `recurse_replace_peft_layers` * keep old modules for BC * adjustments on `adjust_lora_scale_text_encoder` * nit * move tests * add conversion utils * remove unneeded methods * use class method instead * oops * use `base_version` * fix examples * fix CI * fix weird error with python 3.8 * fix * better fix * style * Apply suggestions from code review Co-authored-by: Sayak Paul Co-authored-by: Patrick von Platen * Apply suggestions from code review Co-authored-by: Patrick von Platen * add comment * Apply suggestions from code review Co-authored-by: Sayak Paul * conv2d support for recurse remove * added docstrings * more docstring * add deprecate * revert * try to fix merge conflicts * v1 tests * add new decorator * add saving utilities test * adapt tests a bit * add save / from_pretrained tests * add saving tests * add scale tests * fix deps tests * fix lora CI * fix tests * add comment * fix * style * add slow tests * slow tests pass * style * Update src/diffusers/utils/import_utils.py Co-authored-by: Benjamin Bossan * Apply suggestions from code review Co-authored-by: Benjamin Bossan * circumvents pattern finding issue * left a todo * Apply suggestions from code review Co-authored-by: Patrick von Platen * update hub path * add lora workflow * fix --------- Co-authored-by: Patrick von Platen Co-authored-by: Sayak Paul Co-authored-by: Benjamin Bossan --- .github/workflows/pr_test_peft_backend.yml | 67 +++ src/diffusers/loaders.py | 316 ++++++---- src/diffusers/models/lora.py | 31 +- .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_alt_diffusion_img2img.py | 2 +- .../controlnet/pipeline_controlnet.py | 2 +- .../controlnet/pipeline_controlnet_img2img.py | 2 +- .../controlnet/pipeline_controlnet_inpaint.py | 2 +- .../pipeline_controlnet_inpaint_sd_xl.py | 4 +- .../controlnet/pipeline_controlnet_sd_xl.py | 4 +- .../pipeline_controlnet_sd_xl_img2img.py | 4 +- .../pipeline_cycle_diffusion.py | 2 +- .../pipeline_stable_diffusion.py | 2 +- ...line_stable_diffusion_attend_and_excite.py | 2 +- .../pipeline_stable_diffusion_depth2img.py | 2 +- .../pipeline_stable_diffusion_diffedit.py | 2 +- .../pipeline_stable_diffusion_gligen.py | 2 +- ...line_stable_diffusion_gligen_text_image.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 2 +- .../pipeline_stable_diffusion_k_diffusion.py | 2 +- .../pipeline_stable_diffusion_ldm3d.py | 2 +- ...pipeline_stable_diffusion_model_editing.py | 2 +- .../pipeline_stable_diffusion_panorama.py | 2 +- .../pipeline_stable_diffusion_paradigms.py | 2 +- .../pipeline_stable_diffusion_pix2pix_zero.py | 2 +- .../pipeline_stable_diffusion_sag.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 2 +- .../pipeline_stable_unclip.py | 2 +- .../pipeline_stable_unclip_img2img.py | 2 +- .../pipeline_stable_diffusion_xl.py | 4 +- .../pipeline_stable_diffusion_xl_img2img.py | 4 +- .../pipeline_stable_diffusion_xl_inpaint.py | 4 +- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 4 +- .../pipeline_stable_diffusion_adapter.py | 2 +- .../pipeline_stable_diffusion_xl_adapter.py | 4 +- .../pipeline_text_to_video_synth.py | 2 +- .../pipeline_text_to_video_synth_img2img.py | 2 +- src/diffusers/utils/__init__.py | 3 + src/diffusers/utils/import_utils.py | 12 + src/diffusers/utils/peft_utils.py | 71 +++ src/diffusers/utils/state_dict_utils.py | 184 ++++++ src/diffusers/utils/testing_utils.py | 27 + ...ers.py => test_lora_layers_old_backend.py} | 12 +- tests/lora/test_lora_layers_peft.py | 557 ++++++++++++++++++ 46 files changed, 1193 insertions(+), 175 deletions(-) create mode 100644 .github/workflows/pr_test_peft_backend.yml create mode 100644 src/diffusers/utils/peft_utils.py create mode 100644 src/diffusers/utils/state_dict_utils.py rename tests/lora/{test_lora_layers.py => test_lora_layers_old_backend.py} (99%) create mode 100644 tests/lora/test_lora_layers_peft.py diff --git a/.github/workflows/pr_test_peft_backend.yml b/.github/workflows/pr_test_peft_backend.yml new file mode 100644 index 000000000000..f5ff3c4444ab --- /dev/null +++ b/.github/workflows/pr_test_peft_backend.yml @@ -0,0 +1,67 @@ +name: Fast tests for PRs - PEFT backend + +on: + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + DIFFUSERS_IS_CI: yes + OMP_NUM_THREADS: 4 + MKL_NUM_THREADS: 4 + PYTEST_TIMEOUT: 60 + +jobs: + run_fast_tests: + strategy: + fail-fast: false + matrix: + config: + - name: LoRA + framework: lora + runner: docker-cpu + image: diffusers/diffusers-pytorch-cpu + report: torch_cpu_lora + + + name: ${{ matrix.config.name }} + + runs-on: ${{ matrix.config.runner }} + + container: + image: ${{ matrix.config.image }} + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ + + defaults: + run: + shell: bash + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Install dependencies + run: | + apt-get update && apt-get install libsndfile1-dev libgl1 -y + python -m pip install -e .[quality,test] + python -m pip install git+https://github.com/huggingface/accelerate.git + python -m pip install -U git+https://github.com/huggingface/transformers.git + python -m pip install -U git+https://github.com/huggingface/peft.git + + - name: Environment + run: | + python utils/print_env.py + + - name: Run fast PyTorch LoRA CPU tests with PEFT backend + if: ${{ matrix.config.framework == 'lora' }} + run: | + python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ + -s -v \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/lora/test_lora_layers_peft.py \ No newline at end of file diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index bea6e21aa64a..bc40cf9a18ea 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os import re from collections import defaultdict @@ -23,6 +24,7 @@ import safetensors import torch from huggingface_hub import hf_hub_download, model_info +from packaging import version from torch import nn from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta @@ -30,11 +32,15 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file, + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, deprecate, is_accelerate_available, is_omegaconf_available, + is_peft_available, is_transformers_available, logging, + recurse_remove_peft_layers, ) from .utils.import_utils import BACKENDS_MAPPING @@ -61,6 +67,21 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" +# Below should be `True` if the current version of `peft` and `transformers` are compatible with +# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are +# available. +# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. +_required_peft_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("peft")).base_version +) > version.parse("0.5") +_required_transformers_version = version.parse( + version.parse(importlib.metadata.version("transformers")).base_version +) > version.parse("4.33") + +USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version +LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." + + class PatchedLoraProjection(nn.Module): def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): super().__init__() @@ -1077,6 +1098,7 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME num_fused_loras = 0 + use_peft_backend = USE_PEFT_BACKEND def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): """ @@ -1268,6 +1290,7 @@ def lora_state_dict( state_dict = pretrained_model_name_or_path_or_dict network_alphas = None + # TODO: replace it with a method from `state_dict_utils` if all( ( k.startswith("lora_te_") @@ -1520,55 +1543,35 @@ def load_lora_into_text_encoder( if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {prefix}.") rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): - # Convert from the old naming convention to the new naming convention. - # - # Previously, the old LoRA layers were stored on the state dict at the - # same level as the attention block i.e. - # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. - # - # This is no actual module at that point, they were monkey patched on to the - # existing module. We want to be able to load them via their actual state dict. - # They're in `PatchedLoraProjection.lora_linear_layer` now. + if cls.use_peft_backend: + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + rank_key = f"{name}.out_proj.lora_B.weight" + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) + if patch_mlp: + for name, _ in text_encoder_mlp_modules(text_encoder): + rank_key_fc1 = f"{name}.fc1.lora_B.weight" + rank_key_fc2 = f"{name}.fc2.lora_B.weight" + rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] + rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + else: for name, _ in text_encoder_attn_modules(text_encoder): - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") - - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") - - for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" - rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) - - patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - if patch_mlp: - for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" - rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" - rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]}) - rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]}) + rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" + rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) + + patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) + if patch_mlp: + for name, _ in text_encoder_mlp_modules(text_encoder): + rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" + rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" + rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] + rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] if network_alphas is not None: alpha_keys = [ @@ -1578,56 +1581,79 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - cls._modify_text_encoder( - text_encoder, - lora_scale, - network_alphas, - rank=rank, - patch_mlp=patch_mlp, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + if cls.use_peft_backend: + from peft import LoraConfig - is_pipeline_offloaded = _pipeline is not None and any( - isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values() - ) - if is_pipeline_offloaded and low_cpu_mem_usage: - low_cpu_mem_usage = True - logger.info( - f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced." - ) + lora_rank = list(rank.values())[0] + # By definition, the scale should be alpha divided by rank. + # https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71 + alpha = lora_scale * lora_rank - if low_cpu_mem_usage: - device = next(iter(text_encoder_lora_state_dict.values())).device - dtype = next(iter(text_encoder_lora_state_dict.values())).dtype - unexpected_keys = load_model_dict_into_meta( - text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype - ) + target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + if patch_mlp: + target_modules += ["fc1", "fc2"] + + # TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873 + lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) + + text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) + + is_model_cpu_offload = False + is_sequential_cpu_offload = False else: - load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) - unexpected_keys = load_state_dict_results.unexpected_keys + cls._modify_text_encoder( + text_encoder, + lora_scale, + network_alphas, + rank=rank, + patch_mlp=patch_mlp, + low_cpu_mem_usage=low_cpu_mem_usage, + ) - if len(unexpected_keys) != 0: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + is_pipeline_offloaded = _pipeline is not None and any( + isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") + for c in _pipeline.components.values() ) + if is_pipeline_offloaded and low_cpu_mem_usage: + low_cpu_mem_usage = True + logger.info( + f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced." + ) + + if low_cpu_mem_usage: + device = next(iter(text_encoder_lora_state_dict.values())).device + dtype = next(iter(text_encoder_lora_state_dict.values())).dtype + unexpected_keys = load_model_dict_into_meta( + text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype + ) + else: + load_state_dict_results = text_encoder.load_state_dict( + text_encoder_lora_state_dict, strict=False + ) + unexpected_keys = load_state_dict_results.unexpected_keys - # float: return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + if self.use_peft_backend: + remove_method = recurse_remove_peft_layers + else: + remove_method = self._remove_text_encoder_monkey_patch_classmethod + + if hasattr(self, "text_encoder"): + remove_method(self.text_encoder) + + if self.use_peft_backend: + del self.text_encoder.peft_config + self.text_encoder._hf_peft_config_loaded = None + if hasattr(self, "text_encoder_2"): + remove_method(self.text_encoder_2) + if self.use_peft_backend: + del self.text_encoder_2.peft_config + self.text_encoder_2._hf_peft_config_loaded = None @classmethod def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): + deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE) + for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): attn_module.q_proj.lora_linear_layer = None @@ -1675,6 +1718,7 @@ def _modify_text_encoder( r""" Monkey-patches the forward passes of attention modules of the text encoder. """ + deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE) def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model @@ -2049,24 +2093,38 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora if fuse_unet: self.unet.fuse_lora(lora_scale) - def fuse_text_encoder_lora(text_encoder): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._fuse_lora(lora_scale) - attn_module.k_proj._fuse_lora(lora_scale) - attn_module.v_proj._fuse_lora(lora_scale) - attn_module.out_proj._fuse_lora(lora_scale) + if self.use_peft_backend: + from peft.tuners.tuners_utils import BaseTunerLayer - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._fuse_lora(lora_scale) - mlp_module.fc2._fuse_lora(lora_scale) + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + if lora_scale != 1.0: + module.scale_layer(lora_scale) + + module.merge() + + else: + deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE) + + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj._fuse_lora(lora_scale) + attn_module.k_proj._fuse_lora(lora_scale) + attn_module.v_proj._fuse_lora(lora_scale) + attn_module.out_proj._fuse_lora(lora_scale) + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1._fuse_lora(lora_scale) + mlp_module.fc2._fuse_lora(lora_scale) if fuse_text_encoder: if hasattr(self, "text_encoder"): - fuse_text_encoder_lora(self.text_encoder) + fuse_text_encoder_lora(self.text_encoder, lora_scale) if hasattr(self, "text_encoder_2"): - fuse_text_encoder_lora(self.text_encoder_2) + fuse_text_encoder_lora(self.text_encoder_2, lora_scale) def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): r""" @@ -2088,18 +2146,29 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True if unfuse_unet: self.unet.unfuse_lora() - def unfuse_text_encoder_lora(text_encoder): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._unfuse_lora() - attn_module.k_proj._unfuse_lora() - attn_module.v_proj._unfuse_lora() - attn_module.out_proj._unfuse_lora() + if self.use_peft_backend: + from peft.tuners.tuner_utils import BaseTunerLayer + + def unfuse_text_encoder_lora(text_encoder): + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + else: + deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE) + + def unfuse_text_encoder_lora(text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj._unfuse_lora() + attn_module.k_proj._unfuse_lora() + attn_module.v_proj._unfuse_lora() + attn_module.out_proj._unfuse_lora() - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._unfuse_lora() - mlp_module.fc2._unfuse_lora() + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1._unfuse_lora() + mlp_module.fc2._unfuse_lora() if unfuse_text_encoder: if hasattr(self, "text_encoder"): @@ -2810,5 +2879,16 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) + if self.use_peft_backend: + recurse_remove_peft_layers(self.text_encoder) + # TODO: @younesbelkada handle this in transformers side + del self.text_encoder.peft_config + self.text_encoder._hf_peft_config_loaded = None + + recurse_remove_peft_layers(self.text_encoder_2) + + del self.text_encoder_2.peft_config + self.text_encoder_2._hf_peft_config_loaded = None + else: + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index cc8e3e231e2b..07eeae712f71 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -25,18 +25,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj.lora_scale = lora_scale - attn_module.k_proj.lora_scale = lora_scale - attn_module.v_proj.lora_scale = lora_scale - attn_module.out_proj.lora_scale = lora_scale - - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1.lora_scale = lora_scale - mlp_module.fc2.lora_scale = lora_scale +def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False): + if use_peft_backend: + from peft.tuners.lora import LoraLayer + + for module in text_encoder.modules(): + if isinstance(module, LoraLayer): + module.scaling[module.active_adapter] = lora_scale + else: + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj.lora_scale = lora_scale + attn_module.k_proj.lora_scale = lora_scale + attn_module.v_proj.lora_scale = lora_scale + attn_module.out_proj.lora_scale = lora_scale + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1.lora_scale = lora_scale + mlp_module.fc2.lora_scale = lora_scale class LoRALinearLayer(nn.Module): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 3464f8294a33..ea4ec3fbc623 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -303,7 +303,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index fee5b7ff2bf5..8c302f8d948f 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -301,7 +301,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index e238ac913f97..b899240b0c0e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -291,7 +291,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 9778f288b7ad..85dced1dc9c3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -315,7 +315,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 981cedbec9a5..2065343fe06c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -442,7 +442,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 336b2c20f551..bf1be8407394 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -315,8 +315,8 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) prompt = [prompt] if isinstance(prompt, str) else prompt diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index b7b8699d4af8..85c47f13b300 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -288,8 +288,8 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) prompt = [prompt] if isinstance(prompt, str) else prompt diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index f26e338b7f58..7a19a1862ddb 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -326,8 +326,8 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) prompt = [prompt] if isinstance(prompt, str) else prompt diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 5a3f73800bff..46c264f5210c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -308,7 +308,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c56fae622c79..61fb1620ac28 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -301,7 +301,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 79d7d9dffaf5..6cd5939d87a9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -332,7 +332,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index fd08b29c49bd..eee91028f6e8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -213,7 +213,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index f9389820d545..97278d06371d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -481,7 +481,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py index 737e6bdff67f..40c058e78001 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py @@ -278,7 +278,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py index 2a0453722bbc..6b9a6761bd34 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py @@ -309,7 +309,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index d46b9026f5f2..d8e7ba3e5f90 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -302,7 +302,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 5ef0b8320a07..290e1eb6fae5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -375,7 +375,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 994abbd2558f..c21ceb169bf3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -297,7 +297,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 2cddf819781d..ff85af14fb1d 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -211,7 +211,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index d264bb9abb12..eea5383f9029 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -272,7 +272,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index 532153b1a505..9da9fa046bcc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -244,7 +244,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 8a87ab4eca2a..a284c6a32408 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -221,7 +221,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py index 5c2f52a7540a..fb65e1494757 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py @@ -256,7 +256,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 2a04f195af11..d24db4526536 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -446,7 +446,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 340832bb539f..267fd394ce25 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -244,7 +244,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 1b3dfdcd83c6..2ea14292cccc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -240,7 +240,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 28016c935f15..62e36652c34f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -346,7 +346,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 96094dd7914b..6189f751514e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -296,7 +296,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 5c646e7869ff..9f012dcbf0b4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -264,8 +264,8 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) prompt = [prompt] if isinstance(prompt, str) else prompt diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 6ebd450849ab..4118d5c37e58 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -271,8 +271,8 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) prompt = [prompt] if isinstance(prompt, str) else prompt diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 26b62e7e508e..484d030d6dde 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -420,8 +420,8 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) prompt = [prompt] if isinstance(prompt, str) else prompt diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index fdf02c536f78..79a72ef6eb83 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -272,8 +272,8 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 9a6f09f8f5dc..76bb8e77814f 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -296,7 +296,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 5a0865e463cf..f526409eac68 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -288,8 +288,8 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) prompt = [prompt] if isinstance(prompt, str) else prompt diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index a6a04e056f59..e59070a4122b 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -228,7 +228,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 085f4d17d640..504273db86dc 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -290,7 +290,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 7390a2f69d23..3cd185e86325 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,7 @@ is_note_seq_available, is_omegaconf_available, is_onnx_available, + is_peft_available, is_scipy_available, is_tensorboard_available, is_torch_available, @@ -82,7 +83,9 @@ from .loading_utils import load_image from .logging import get_logger from .outputs import BaseOutput +from .peft_utils import recurse_remove_peft_layers from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil +from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft logger = get_logger(__name__) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 587949ab0c52..e4c8ffafbc8f 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -267,6 +267,14 @@ _invisible_watermark_available = False +_peft_available = importlib.util.find_spec("peft") is not None +try: + _peft_version = importlib_metadata.version("peft") + logger.debug(f"Successfully imported peft version {_peft_version}") +except importlib_metadata.PackageNotFoundError: + _peft_available = False + + def is_torch_available(): return _torch_available @@ -351,6 +359,10 @@ def is_invisible_watermark_available(): return _invisible_watermark_available +def is_peft_available(): + return _peft_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py new file mode 100644 index 000000000000..9b34183ffaac --- /dev/null +++ b/src/diffusers/utils/peft_utils.py @@ -0,0 +1,71 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PEFT utilities: Utilities related to peft library +""" +from .import_utils import is_torch_available + + +if is_torch_available(): + import torch + + +def recurse_remove_peft_layers(model): + r""" + Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. + """ + from peft.tuners.lora import LoraLayer + + for name, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + recurse_remove_peft_layers(module) + + module_replaced = False + + if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): + new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to( + module.weight.device + ) + new_module.weight = module.weight + if module.bias is not None: + new_module.bias = module.bias + + module_replaced = True + elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d): + new_module = torch.nn.Conv2d( + module.in_channels, + module.out_channels, + module.kernel_size, + module.stride, + module.padding, + module.dilation, + module.groups, + module.bias, + ).to(module.weight.device) + + new_module.weight = module.weight + if module.bias is not None: + new_module.bias = module.bias + + module_replaced = True + + if module_replaced: + setattr(model, name, new_module) + del module + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return model diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py new file mode 100644 index 000000000000..acc64a5034ec --- /dev/null +++ b/src/diffusers/utils/state_dict_utils.py @@ -0,0 +1,184 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +State dict utilities: utility methods for converting state dicts easily +""" +import enum + + +class StateDictType(enum.Enum): + """ + The mode to use when converting state dicts. + """ + + DIFFUSERS_OLD = "diffusers_old" + # KOHYA_SS = "kohya_ss" # TODO: implement this + PEFT = "peft" + DIFFUSERS = "diffusers" + + +DIFFUSERS_TO_PEFT = { + ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", + ".q_proj.lora_linear_layer.down": ".q_proj.lora_A", + ".k_proj.lora_linear_layer.up": ".k_proj.lora_B", + ".k_proj.lora_linear_layer.down": ".k_proj.lora_A", + ".v_proj.lora_linear_layer.up": ".v_proj.lora_B", + ".v_proj.lora_linear_layer.down": ".v_proj.lora_A", + ".out_proj.lora_linear_layer.up": ".out_proj.lora_B", + ".out_proj.lora_linear_layer.down": ".out_proj.lora_A", +} + +DIFFUSERS_OLD_TO_PEFT = { + ".to_q_lora.up": ".q_proj.lora_B", + ".to_q_lora.down": ".q_proj.lora_A", + ".to_k_lora.up": ".k_proj.lora_B", + ".to_k_lora.down": ".k_proj.lora_A", + ".to_v_lora.up": ".v_proj.lora_B", + ".to_v_lora.down": ".v_proj.lora_A", + ".to_out_lora.up": ".out_proj.lora_B", + ".to_out_lora.down": ".out_proj.lora_A", +} + +PEFT_TO_DIFFUSERS = { + ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", + ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", + ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", + ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", + ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", + ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", + ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", + ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", +} + +DIFFUSERS_OLD_TO_DIFFUSERS = { + ".to_q_lora.up": ".q_proj.lora_linear_layer.up", + ".to_q_lora.down": ".q_proj.lora_linear_layer.down", + ".to_k_lora.up": ".k_proj.lora_linear_layer.up", + ".to_k_lora.down": ".k_proj.lora_linear_layer.down", + ".to_v_lora.up": ".v_proj.lora_linear_layer.up", + ".to_v_lora.down": ".v_proj.lora_linear_layer.down", + ".to_out_lora.up": ".out_proj.lora_linear_layer.up", + ".to_out_lora.down": ".out_proj.lora_linear_layer.down", +} + +PEFT_STATE_DICT_MAPPINGS = { + StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT, + StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT, +} + +DIFFUSERS_STATE_DICT_MAPPINGS = { + StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, + StateDictType.PEFT: PEFT_TO_DIFFUSERS, +} + + +def convert_state_dict(state_dict, mapping): + r""" + Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + mapping (`dict[str, str]`): + The mapping to use for conversion, the mapping should be a dictionary with the following structure: + - key: the pattern to replace + - value: the pattern to replace with + + Returns: + converted_state_dict (`dict`) + The converted state dict. + """ + converted_state_dict = {} + for k, v in state_dict.items(): + for pattern in mapping.keys(): + if pattern in k: + new_pattern = mapping[pattern] + k = k.replace(pattern, new_pattern) + break + converted_state_dict[k] = v + return converted_state_dict + + +def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): + r""" + Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or + new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. + """ + if original_type is None: + # Old diffusers to PEFT + if any("to_out_lora" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS_OLD + elif any("lora_linear_layer" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS + else: + raise ValueError("Could not automatically infer state dict type") + + if original_type not in PEFT_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + + mapping = PEFT_STATE_DICT_MAPPINGS[original_type] + return convert_state_dict(state_dict, mapping) + + +def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): + r""" + Converts a state dict to new diffusers format. The state dict can be from previous diffusers format + (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will + return the state dict as is. + + The method only supports the conversion from diffusers old, PEFT to diffusers new for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. + kwargs (`dict`, *args*): + Additional arguments to pass to the method. + + - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended + with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in + `get_peft_model_state_dict` method: + https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 + but we add it here in case we don't want to rely on that method. + """ + peft_adapter_name = kwargs.pop("adapter_name", None) + if peft_adapter_name is not None: + peft_adapter_name = "." + peft_adapter_name + else: + peft_adapter_name = "" + + if original_type is None: + # Old diffusers to PEFT + if any("to_out_lora" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS_OLD + elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): + original_type = StateDictType.PEFT + elif any("lora_linear_layer" in k for k in state_dict.keys()): + # nothing to do + return state_dict + else: + raise ValueError("Could not automatically infer state dict type") + + if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + + mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] + return convert_state_dict(state_dict, mapping) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 555ed700d184..a2f4de439e11 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,3 +1,4 @@ +import importlib import inspect import io import logging @@ -29,9 +30,11 @@ is_note_seq_available, is_onnx_available, is_opencv_available, + is_peft_available, is_torch_available, is_torch_version, is_torchsde_available, + is_transformers_available, ) from .logging import get_logger @@ -40,6 +43,15 @@ logger = get_logger(__name__) +_required_peft_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("peft")).base_version +) > version.parse("0.5") +_required_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version +) > version.parse("4.33") + +USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version + if is_torch_available(): import torch @@ -236,6 +248,21 @@ def require_torchsde(test_case): return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) +def require_peft_backend(test_case): + """ + Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and + transformers. + """ + return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) + + +def deprecate_after_peft_backend(test_case): + """ + Decorator marking a test that will be skipped after PEFT backend + """ + return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case) + + def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: if isinstance(arry, str): # local_path = "/home/patrick_huggingface_co/" diff --git a/tests/lora/test_lora_layers.py b/tests/lora/test_lora_layers_old_backend.py similarity index 99% rename from tests/lora/test_lora_layers.py rename to tests/lora/test_lora_layers_old_backend.py index 20d44e0c07c4..22e33767007a 100644 --- a/tests/lora/test_lora_layers.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -52,7 +52,15 @@ XFormersAttnProcessor, ) from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import ( + deprecate_after_peft_backend, + floats_tensor, + load_image, + nightly, + require_torch_gpu, + slow, + torch_device, +) def create_lora_layers(model, mock_weights: bool = True): @@ -181,6 +189,7 @@ def state_dicts_almost_equal(sd1, sd2): return models_are_equal +@deprecate_after_peft_backend class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -773,6 +782,7 @@ def test_stable_diffusion_inpaint_lora(self): assert np.abs(image_slice - image_slice_2).max() > 1e-2 +@deprecate_after_peft_backend class SDXLLoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py new file mode 100644 index 000000000000..1862437fce88 --- /dev/null +++ b/tests/lora/test_lora_layers_peft.py @@ -0,0 +1,557 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + EulerDiscreteScheduler, + StableDiffusionPipeline, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import ( + LoRAAttnProcessor, + LoRAAttnProcessor2_0, +) +from diffusers.utils.import_utils import is_peft_available +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, require_torch_gpu, slow + + +if is_peft_available(): + from peft import LoraConfig + from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils import get_peft_model_state_dict + + +def create_unet_lora_layers(unet: nn.Module): + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + unet_lora_layers = AttnProcsLayers(lora_attn_procs) + return lora_attn_procs, unet_lora_layers + + +@require_peft_backend +class PeftLoraLoaderMixinTests: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + pipeline_class = None + scheduler_cls = None + scheduler_kwargs = None + has_two_text_encoders = False + unet_kwargs = None + vae_kwargs = None + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel(**self.unet_kwargs) + scheduler = self.scheduler_cls(**self.scheduler_kwargs) + torch.manual_seed(0) + vae = AutoencoderKL(**self.vae_kwargs) + text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") + tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") + + if self.has_two_text_encoders: + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2") + tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") + + text_lora_config = LoraConfig( + r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False + ) + + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) + + if self.has_two_text_encoders: + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + } + else: + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + lora_components = { + "unet_lora_layers": unet_lora_layers, + "unet_lora_attn_procs": unet_lora_attn_procs, + } + return pipeline_components, lora_components, text_lora_config + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs + + def check_if_lora_correctly_set(self, model) -> bool: + """ + Checks if the LoRA layers are correctly set with peft + """ + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + def test_simple_inference(self): + """ + Tests a simple inference and makes sure it works as expected + """ + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs() + output_no_lora = pipe(**inputs).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + def test_simple_inference_with_text_lora(self): + """ + Tests a simple inference with lora attached on the text encoder + and makes sure it works as expected + """ + components, _, text_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) + + def test_simple_inference_with_text_lora_and_scale(self): + """ + Tests a simple inference with lora attached on the text encoder + scale argument + and makes sure it works as expected + """ + components, _, text_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) + + output_lora_scale = pipe( + **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} + ).images + self.assertTrue( + not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) + + output_lora_0_scale = pipe( + **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} + ).images + self.assertTrue( + np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), + "Lora + 0 scale should lead to same result as no LoRA", + ) + + def test_simple_inference_with_text_lora_fused(self): + """ + Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model + and makes sure it works as expected + """ + components, _, text_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.fuse_lora() + # Fusing should still keep the LoRA layers + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.has_two_text_encoders: + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertFalse( + np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + ) + + def test_simple_inference_with_text_lora_unloaded(self): + """ + Tests a simple inference with lora attached to text encoder, then unloads the lora weights + and makes sure it works as expected + """ + components, _, text_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.unload_lora_weights() + # unloading should remove the LoRA layers + self.assertFalse( + self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" + ) + + if self.has_two_text_encoders: + self.assertFalse( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2" + ) + + ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + ) + + def test_simple_inference_with_text_lora_save_load(self): + """ + Tests a simple usecase where users could use saving utilities for LoRA. + """ + components, _, text_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) + if self.has_two_text_encoders: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + text_encoder_2_lora_layers=text_encoder_2_state_dict, + safe_serialization=False, + ) + else: + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + safe_serialization=False, + ) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + pipe.unload_lora_weights() + + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.has_two_text_encoders: + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + self.assertTrue( + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + def test_simple_inference_save_pretrained(self): + """ + Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained + """ + components, _, text_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(self.torch_device) + + self.assertTrue( + self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), + "Lora not correctly set in text encoder", + ) + + if self.has_two_text_encoders: + self.assertTrue( + self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), + "Lora not correctly set in text encoder 2", + ) + + images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + +class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): + pipeline_class = StableDiffusionPipeline + scheduler_cls = DDIMScheduler + scheduler_kwargs = { + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "clip_sample": False, + "set_alpha_to_one": False, + "steps_offset": 1, + } + unet_kwargs = { + "block_out_channels": (32, 64), + "layers_per_block": 2, + "sample_size": 32, + "in_channels": 4, + "out_channels": 4, + "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), + "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), + "cross_attention_dim": 32, + } + vae_kwargs = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + + @slow + @require_torch_gpu + def test_integration_logits_with_scale(self): + path = "runwayml/stable-diffusion-v1-5" + lora_id = "takuma104/lora-test-text-encoder-lora-target" + + pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) + pipe.load_lora_weights(lora_id) + pipe = pipe.to("cuda") + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder), + "Lora not correctly set in text encoder 2", + ) + + prompt = "a red sks dog" + + images = pipe( + prompt=prompt, + num_inference_steps=15, + cross_attention_kwargs={"scale": 0.5}, + generator=torch.manual_seed(0), + output_type="np", + ).images + + expected_slice_scale = np.array([0.307, 0.283, 0.310, 0.310, 0.300, 0.314, 0.336, 0.314, 0.321]) + + predicted_slice = images[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + + @slow + @require_torch_gpu + def test_integration_logits_no_scale(self): + path = "runwayml/stable-diffusion-v1-5" + lora_id = "takuma104/lora-test-text-encoder-lora-target" + + pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) + pipe.load_lora_weights(lora_id) + pipe = pipe.to("cuda") + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder), + "Lora not correctly set in text encoder", + ) + + prompt = "a red sks dog" + + images = pipe(prompt=prompt, num_inference_steps=30, generator=torch.manual_seed(0), output_type="np").images + + expected_slice_scale = np.array([0.074, 0.064, 0.073, 0.0842, 0.069, 0.0641, 0.0794, 0.076, 0.084]) + + predicted_slice = images[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + + +class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): + has_two_text_encoders = True + pipeline_class = StableDiffusionXLPipeline + scheduler_cls = EulerDiscreteScheduler + scheduler_kwargs = { + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "timestep_spacing": "leading", + "steps_offset": 1, + } + unet_kwargs = { + "block_out_channels": (32, 64), + "layers_per_block": 2, + "sample_size": 32, + "in_channels": 4, + "out_channels": 4, + "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), + "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), + "attention_head_dim": (2, 4), + "use_linear_projection": True, + "addition_embed_type": "text_time", + "addition_time_embed_dim": 8, + "transformer_layers_per_block": (1, 2), + "projection_class_embeddings_input_dim": 80, # 6 * 8 + 32 + "cross_attention_dim": 64, + } + vae_kwargs = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + "sample_size": 128, + }