From a584d42ce5853d160c3c1bfb5ff0f0ee65c301e6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 27 Sep 2023 18:16:32 +0200 Subject: [PATCH] [LoRA, Xformers] Fix xformers lora (#5201) * fix xformers lora * improve * fix --- src/diffusers/models/attention_processor.py | 11 ++++------- src/diffusers/models/autoencoder_kl.py | 10 ++++++---- src/diffusers/models/controlnet.py | 10 ++++++---- src/diffusers/models/prior_transformer.py | 10 ++++++---- src/diffusers/models/unet_2d_condition.py | 10 ++++++---- src/diffusers/models/unet_3d_condition.py | 10 ++++++---- .../pipelines/audioldm2/modeling_audioldm2.py | 10 ++++++---- .../versatile_diffusion/modeling_text_unet.py | 10 ++++++---- tests/lora/test_lora_layers_old_backend.py | 2 +- 9 files changed, 47 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fba5bddb5def..53f6ca019094 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -310,19 +310,16 @@ def set_attention_slice(self, slice_size): self.set_processor(processor) - def set_processor(self, processor: "AttnProcessor"): - if ( - hasattr(self, "processor") - and not isinstance(processor, LORA_ATTENTION_PROCESSORS) - and self.to_q.lora_layer is not None - ): + def set_processor(self, processor: "AttnProcessor", _remove_lora=False): + if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: deprecate( "set_processor to offload LoRA", "0.26.0", - "In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", + "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", ) # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete # We need to remove all LoRA layers + # Don't forget to remove ALL `_remove_lora` from the codebase for module in self.modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 76666a4cc295..21c8f64fd916 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -196,7 +196,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): r""" Sets the attention processor to use to compute attention. @@ -220,9 +222,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor) + module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor(processor.pop(f"{name}.processor")) + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -244,7 +246,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor) + self.set_attn_processor(processor, _remove_lora=True) @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index db05b0689cff..1a82b0421f88 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -517,7 +517,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): r""" Sets the attention processor to use to compute attention. @@ -541,9 +543,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor) + module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor(processor.pop(f"{name}.processor")) + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -565,7 +567,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor) + self.set_attn_processor(processor, _remove_lora=True) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 8ada0a7c08a5..6c5e406ad378 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -192,7 +192,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): r""" Sets the attention processor to use to compute attention. @@ -216,9 +218,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor) + module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor(processor.pop(f"{name}.processor")) + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -240,7 +242,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor) + self.set_attn_processor(processor, _remove_lora=True) def forward( self, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 385f0a42c598..866254a89545 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -613,7 +613,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): r""" Sets the attention processor to use to compute attention. @@ -637,9 +639,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor) + module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor(processor.pop(f"{name}.processor")) + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -660,7 +662,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor) + self.set_attn_processor(processor, _remove_lora=True) def set_attention_slice(self, slice_size): r""" diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 58c848fdb97f..01af31061d10 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -366,7 +366,9 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i fn_recursive_set_attention_slice(module, reversed_slice_size) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): r""" Sets the attention processor to use to compute attention. @@ -390,9 +392,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor) + module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor(processor.pop(f"{name}.processor")) + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -454,7 +456,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor) + self.set_attn_processor(processor, _remove_lora=True) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index d39b2c99ddd0..e855c2f0d6f1 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -538,7 +538,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): r""" Sets the attention processor to use to compute attention. @@ -562,9 +564,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor) + module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor(processor.pop(f"{name}.processor")) + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -586,7 +588,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor) + self.set_attn_processor(processor, _remove_lora=True) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 99bf1d22ee91..f2b191496aaa 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -820,7 +820,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): r""" Sets the attention processor to use to compute attention. @@ -844,9 +846,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor) + module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor(processor.pop(f"{name}.processor")) + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -868,7 +870,7 @@ def set_default_attn_processor(self): f" {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor) + self.set_attn_processor(processor, _remove_lora=True) def set_attention_slice(self, slice_size): r""" diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index 391f2f8449fc..ae90f8b6a4b8 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -1786,7 +1786,7 @@ def test_lora_on_off(self): with torch.no_grad(): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - model.set_attn_processor(AttnProcessor()) + model.set_default_attn_processor() with torch.no_grad(): new_sample = model(**inputs_dict).sample