Skip to content

Commit

Permalink
[LoRA, Xformers] Fix xformers lora (#5201)
Browse files Browse the repository at this point in the history
* fix xformers lora

* improve

* fix
  • Loading branch information
patrickvonplaten authored Sep 27, 2023
1 parent cdcc01b commit a584d42
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 36 deletions.
11 changes: 4 additions & 7 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,16 @@ def set_attention_slice(self, slice_size):

self.set_processor(processor)

def set_processor(self, processor: "AttnProcessor"):
if (
hasattr(self, "processor")
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
and self.to_q.lora_layer is not None
):
def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
"In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
)
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
# We need to remove all LoRA layers
# Don't forget to remove ALL `_remove_lora` from the codebase
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -220,9 +222,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -244,7 +246,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)

@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -541,9 +543,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -565,7 +567,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/models/prior_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -216,9 +218,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -240,7 +242,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)

def forward(
self,
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:

return processors

def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -637,9 +639,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -660,7 +662,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)

def set_attention_slice(self, slice_size):
r"""
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/models/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
fn_recursive_set_attention_slice(module, reversed_slice_size)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -390,9 +392,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand Down Expand Up @@ -454,7 +456,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -562,9 +564,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -586,7 +588,7 @@ def set_default_attn_processor(self):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:

return processors

def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Expand All @@ -844,9 +846,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
Expand All @@ -868,7 +870,7 @@ def set_default_attn_processor(self):
f" {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)

def set_attention_slice(self, slice_size):
r"""
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_lora_layers_old_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,7 +1786,7 @@ def test_lora_on_off(self):
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample

model.set_attn_processor(AttnProcessor())
model.set_default_attn_processor()

with torch.no_grad():
new_sample = model(**inputs_dict).sample
Expand Down

0 comments on commit a584d42

Please sign in to comment.