Skip to content

Commit

Permalink
PEFT Integration for Text Encoder to handle multiple alphas/ranks, di…
Browse files Browse the repository at this point in the history
…sable/enable adapters and support for multiple adapters (#5147)

* more fixes

* up

* up

* style

* add in setup

* oops

* more changes

* v1 rzfactor CI

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* few todos

* protect torch import

* style

* fix fuse text encoder

* Update src/diffusers/loaders.py

Co-authored-by: Sayak Paul <[email protected]>

* replace with `recurse_replace_peft_layers`

* keep old modules for BC

* adjustments on `adjust_lora_scale_text_encoder`

* nit

* move tests

* add conversion utils

* remove unneeded methods

* use class method instead

* oops

* use `base_version`

* fix examples

* fix CI

* fix weird error with python 3.8

* fix

* better fix

* style

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* add comment

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* conv2d support for recurse remove

* added docstrings

* more docstring

* add deprecate

* revert

* try to fix merge conflicts

* peft integration features for text encoder

1. support multiple rank/alpha values
2. support multiple active adapters
3. support disabling and enabling adapters

* fix bug

* fix code quality

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* fix bugs

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* address comments

Co-Authored-By: Benjamin Bossan <[email protected]>
Co-Authored-By: Patrick von Platen <[email protected]>

* fix code quality

* address comments

* address comments

* Apply suggestions from code review

* find and replace

---------

Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
  • Loading branch information
6 people authored Sep 27, 2023
1 parent 940f941 commit 02247d9
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 19 deletions.
117 changes: 105 additions & 12 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,23 @@
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .utils.import_utils import BACKENDS_MAPPING


if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel

if is_accelerate_available():
from accelerate import init_empty_weights
Expand Down Expand Up @@ -1100,7 +1105,9 @@ class LoraLoaderMixin:
num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND

def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
Expand All @@ -1120,6 +1127,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
Expand All @@ -1143,6 +1153,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self,
)

Expand Down Expand Up @@ -1500,6 +1511,7 @@ def load_lora_into_text_encoder(
prefix=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None,
_pipeline=None,
):
"""
Expand All @@ -1523,6 +1535,9 @@ def load_lora_into_text_encoder(
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT

Expand Down Expand Up @@ -1584,19 +1599,22 @@ def load_lora_into_text_encoder(
if cls.use_peft_backend:
from peft import LoraConfig

lora_rank = list(rank.values())[0]
# By definition, the scale should be alpha divided by rank.
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
alpha = lora_scale * lora_rank
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict)

target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
if patch_mlp:
target_modules += ["fc1", "fc2"]
lora_config = LoraConfig(**lora_config_kwargs)

# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)

text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
# inject LoRA layers and load the state dict
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)

is_model_cpu_offload = False
is_sequential_cpu_offload = False
Expand Down Expand Up @@ -2178,6 +2196,81 @@ def unfuse_text_encoder_lora(text_encoder):

self.num_fused_loras -= 1

def set_adapter_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional[PreTrainedModel] = None,
text_encoder_weights: List[float] = None,
):
"""
Sets the adapter layers for the text encoder.
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")

def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
return weights

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError(
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)

def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
"""
Disables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
`text_encoder` attribute.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")

text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)

def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
"""
Enables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(self.text_encoder, enabled=True)


class FromSingleFileMixin:
"""
Expand Down
8 changes: 2 additions & 6 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,15 @@
from torch import nn

from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging
from ..utils import logging, scale_lora_layers


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
from peft.tuners.lora import LoraLayer

for module in text_encoder.modules():
if isinstance(module, LoraLayer):
module.scaling[module.active_adapter] = lora_scale
scale_lora_layers(text_encoder, weight=lora_scale)
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,14 @@
from .loading_utils import load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import recurse_remove_peft_layers
from .peft_utils import (
get_adapter_name,
get_peft_kwargs,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft

Expand Down
97 changes: 97 additions & 0 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""
PEFT utilities: Utilities related to peft library
"""
import collections

from .import_utils import is_torch_available


Expand Down Expand Up @@ -68,3 +70,98 @@ def recurse_remove_peft_layers(model):
torch.cuda.empty_cache()

return model


def scale_lora_layers(model, weight):
"""
Adjust the weightage given to the LoRA layers of the model.
Args:
model (`torch.nn.Module`):
The model to scale.
weight (`float`):
The weight to be given to the LoRA layers.
"""
from peft.tuners.tuners_utils import BaseTunerLayer

for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.scale_layer(weight)


def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]

# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}

if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]

# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}

# layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})

lora_config_kwargs = {
"r": r,
"lora_alpha": lora_alpha,
"rank_pattern": rank_pattern,
"alpha_pattern": alpha_pattern,
"target_modules": target_modules,
}
return lora_config_kwargs


def get_adapter_name(model):
from peft.tuners.tuners_utils import BaseTunerLayer

for module in model.modules():
if isinstance(module, BaseTunerLayer):
return f"default_{len(module.r)}"
return "default_0"


def set_adapter_layers(model, enabled=True):
from peft.tuners.tuners_utils import BaseTunerLayer

for module in model.modules():
if isinstance(module, BaseTunerLayer):
# The recent version of PEFT needs to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True


def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer

# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
module.scale_layer(weight)

# set multiple active adapters
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_names)
else:
module.active_adapter = adapter_names

0 comments on commit 02247d9

Please sign in to comment.