Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Dec 19, 2024
1 parent bc2a466 commit 1c79095
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 29 deletions.
46 changes: 19 additions & 27 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,10 +1648,11 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):

_lora_loadable_modules = ["transformer", "text_encoder"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME

@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
Expand Down Expand Up @@ -1700,11 +1701,10 @@ def lora_state_dict(
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
Expand All @@ -1713,7 +1713,6 @@ def lora_state_dict(
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)

allow_pickle = False
Expand All @@ -1740,30 +1739,14 @@ def lora_state_dict(
user_agent=user_agent,
allow_pickle=allow_pickle,
)

is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)

return state_dict, network_alphas
return state_dict

# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer
def load_lora_weights(
Expand Down Expand Up @@ -2025,10 +2008,12 @@ def load_lora_into_text_encoder(
# Unsafe code />

@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
Expand All @@ -2042,6 +2027,9 @@ def save_lora_weights(
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
Expand All @@ -2055,10 +2043,14 @@ def save_lora_weights(
"""
state_dict = {}

if not (transformer_lora_layers):
raise ValueError("You must pass `transformer_lora_layers`.")
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")

state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))

if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))

# Save the model
cls.write_lora_layers(
Expand Down
2 changes: 0 additions & 2 deletions src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):

_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
transformer_name = "transformer"
text_encoder_name = "text_encoder"

def __init__(
self,
Expand Down

0 comments on commit 1c79095

Please sign in to comment.