Skip to content

Commit

Permalink
Attept 3 to fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Dec 19, 2024
1 parent 6b762b8 commit bc2a466
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 13 deletions.
35 changes: 26 additions & 9 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):

@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
Expand Down Expand Up @@ -1700,10 +1700,11 @@ def lora_state_dict(
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
# UNet and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
Expand All @@ -1712,6 +1713,7 @@ def lora_state_dict(
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)

allow_pickle = False
Expand All @@ -1738,16 +1740,32 @@ def lora_state_dict(
user_agent=user_agent,
allow_pickle=allow_pickle,
)

is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

return state_dict
network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)

return state_dict, network_alphas

# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
Expand Down Expand Up @@ -1798,10 +1816,9 @@ def load_lora_weights(
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

self.load_lora_into_unet(
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
Expand Down
46 changes: 44 additions & 2 deletions src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from transformers import T5Tokenizer, UMT5EncoderModel
Expand All @@ -22,7 +22,14 @@
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput

Expand Down Expand Up @@ -125,6 +132,8 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):

_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
transformer_name = "transformer"
text_encoder_name = "text_encoder"

def __init__(
self,
Expand Down Expand Up @@ -215,6 +224,7 @@ def encode_prompt(
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand All @@ -241,10 +251,21 @@ def encode_prompt(
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
if device is None:
device = self._execution_device

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
Expand Down Expand Up @@ -402,6 +423,7 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -457,6 +479,10 @@ def __call__(
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Expand All @@ -479,6 +505,8 @@ def __call__(
negative_prompt_attention_mask,
)

self._joint_attention_kwargs = joint_attention_kwargs

# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -488,6 +516,9 @@ def __call__(
batch_size = prompt_embeds.shape[0]

device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand All @@ -511,6 +542,7 @@ def __call__(
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
Expand Down Expand Up @@ -551,6 +583,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
joint_attention_kwargs=self.joint_attention_kwargs,
)[0]

# perform guidance
Expand Down Expand Up @@ -579,7 +612,16 @@ def __call__(
# Offload all models
self.maybe_free_model_hooks()

if self.text_encoder is not None:
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

if not return_dict:
return (image,)

return ImagePipelineOutput(images=image)

@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
13 changes: 11 additions & 2 deletions tests/lora/test_lora_layers_af.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest

import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, UMT5EncoderModel
from transformers import AutoTokenizer, UMT5EncoderModel

from diffusers import (
AuraFlowPipeline,
Expand All @@ -41,7 +41,7 @@
@require_peft_backend
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
Expand All @@ -60,6 +60,7 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
transformer_cls = AuraFlowTransformer2DModel
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5"
attention_kwargs_name = "joint_attention_kwargs"

text_encoder_target_modules = ["q", "k", "v", "o"]

Expand Down Expand Up @@ -103,3 +104,11 @@ def get_dummy_inputs(self, with_generator=True):
pipeline_inputs.update({"generator": generator})

return noise, input_ids, pipeline_inputs

@unittest.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass

0 comments on commit bc2a466

Please sign in to comment.