Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Add LoRA support to AuraFlow #10216

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
4 changes: 4 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.

Expand Down Expand Up @@ -52,6 +53,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## Mochi1LoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
## AuraFlowLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin

## AmusedLoraLoaderMixin

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def text_encoder_attn_modules(text_encoder):
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
Expand All @@ -88,6 +89,7 @@ def text_encoder_attn_modules(text_encoder):
from .ip_adapter import IPAdapterMixin
from .lora_pipeline import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
FluxLoraLoaderMixin,
LoraLoaderMixin,
Expand Down
481 changes: 481 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
}


Expand Down
26 changes: 22 additions & 4 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.


from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Attention,
Expand Down Expand Up @@ -253,7 +254,7 @@ def forward(
return encoder_hidden_states, hidden_states


class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).

Expand Down Expand Up @@ -451,6 +452,7 @@ def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
height, width = hidden_states.shape[-2:]
Expand All @@ -463,7 +465,19 @@ def forward(
encoder_hidden_states = torch.cat(
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
)

if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down Expand Up @@ -538,6 +552,10 @@ def custom_forward(*inputs):
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)

Expand Down
47 changes: 44 additions & 3 deletions src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from transformers import T5Tokenizer, UMT5EncoderModel

from ...image_processor import VaeImageProcessor
from ...loaders import AuraFlowLoraLoaderMixin
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput

Expand Down Expand Up @@ -104,7 +112,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class AuraFlowPipeline(DiffusionPipeline):
class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
r"""
Args:
tokenizer (`T5TokenizerFast`):
Expand Down Expand Up @@ -214,6 +222,7 @@ def encode_prompt(
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand All @@ -240,10 +249,21 @@ def encode_prompt(
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
if device is None:
device = self._execution_device

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
Expand Down Expand Up @@ -401,6 +421,7 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -456,6 +477,10 @@ def __call__(
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).

Examples:

Expand All @@ -478,6 +503,8 @@ def __call__(
negative_prompt_attention_mask,
)

self._joint_attention_kwargs = joint_attention_kwargs

# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -487,6 +514,9 @@ def __call__(
batch_size = prompt_embeds.shape[0]

device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand All @@ -510,6 +540,7 @@ def __call__(
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
Expand Down Expand Up @@ -550,6 +581,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
joint_attention_kwargs=self.joint_attention_kwargs,
)[0]

# perform guidance
Expand Down Expand Up @@ -578,7 +610,16 @@ def __call__(
# Offload all models
self.maybe_free_model_hooks()

if self.text_encoder is not None:
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

if not return_dict:
return (image,)

return ImagePipelineOutput(images=image)

@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
114 changes: 114 additions & 0 deletions tests/lora/test_lora_layers_af.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest

import torch
from transformers import AutoTokenizer, UMT5EncoderModel

from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
require_peft_backend,
)


if is_peft_available():
pass

sys.path.append(".")

from utils import PeftLoraLoaderMixinTests # noqa: E402


@require_peft_backend
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
"sample_size": 64,
"patch_size": 1,
"in_channels": 4,
"num_mmdit_layers": 1,
"num_single_dit_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"caption_projection_dim": 32,
"out_channels": 4,
"pos_embed_max_size": 64,
}
transformer_cls = AuraFlowTransformer2DModel
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5"
attention_kwargs_name = "joint_attention_kwargs"

text_encoder_target_modules = ["q", "k", "v", "o"]

vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": (4,),
"layers_per_block": 1,
"latent_channels": 4,
"norm_num_groups": 1,
"use_quant_conv": False,
"use_post_quant_conv": False,
"shift_factor": None,
"scaling_factor": 0.13025,
}

@property
def output_shape(self):
return (1, 64, 64, 3)

def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)

generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})

return noise, input_ids, pipeline_inputs

@unittest.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass