Skip to content

[LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill #10259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Dec 20, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 17, 2024

What does this PR do?

Fixes #10180, #10227, #10184

In short, this PR enables few-steps inference for Flux Control, Fill, Redux, etc.

Fill + Turbo LoRA
from diffusers import FluxFillPipeline
from diffusers.utils import load_image
import torch

pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")

adapter_id = "alimama-creative/FLUX.1-Turbo-Alpha"
pipe.load_lora_weights(adapter_id)

image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")

image = pipe(
    prompt="a white paper cup",
    image=image,
    mask_image=mask,
    height=1632,
    width=1232,
    guidance_scale=30,
    num_inference_steps=8,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-fill-dev.png")
Flux Control LoRA + Turbo LoRA (different from the previous one)
from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch

control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
control_pipe.enable_model_cpu_offload()

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")

image = control_pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=8,
    guidance_scale=10.0,
    generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")

Todods

  • Integration tests
  • Docs

@sayakpaul sayakpaul added the lora label Dec 17, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

Requesting for a review from @BenjaminBossan for initial stuff. Then will request reviews from others.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extending the functionality of loading LoRA adapters when shapes need to be expanded. The PR LGTM, I only have a nit.

One question that came up (maybe it was already discussed and I just missed it or forgot): Right now, this type of expansion is permanent, right? I.e. even after unloading the LoRA that made the expansion necessary in the first place, the expansion is not undone. Probably that would be quite hard to add and not worth the effort, I'm just curious.

@sayakpaul sayakpaul requested a review from a-r-r-o-w December 17, 2024 11:57
@sayakpaul sayakpaul marked this pull request as ready for review December 17, 2024 11:57
@sayakpaul sayakpaul changed the title [WIP][LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill [LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill Dec 17, 2024
@sayakpaul
Copy link
Member Author

@BenjaminBossan

Right now, this type of expansion is permanent, right? I.e. even after unloading the LoRA that made the expansion necessary in the first place, the expansion is not undone. Probably that would be quite hard to add and not worth the effort, I'm just curious.

The LoRA state dict expansion is permanent. But model-level state dict expansion can be undone is being added in #10206

Comment on lines -343 to -356
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
# We should error out because lora input features is less than original. We only
# support expanding the module, not shrinking it
with self.assertRaises(NotImplementedError):
pipe.load_lora_weights(lora_state_dict, "adapter-1")
Copy link
Member Author

@sayakpaul sayakpaul Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this part of the test because in case LoRA input feature dimensions are less than the original, we expand it.

This is tested below with test_lora_expanding_shape_with_normal_lora() and test_load_regular_lora().

@BenjaminBossan
Copy link
Member

The LoRA state dict expansion is permanent. But model-level state dict expansion can be undone is being added in #10206

Yes, something similar for LoRA would be nice, but it's not as important, as the overhead for LoRA should be relatively small.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this work @sayakpaul! Just some tiny nits and a question

@@ -478,21 +463,16 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}

# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this error is now removed because expanding lora state dicts is now supported

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!

@sayakpaul sayakpaul requested a review from a-r-r-o-w December 20, 2024 07:29
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Sayak! LGTM

The case for the following can be looked into in a separate PR because it is not too important.

  • Step 1: Load N "normal" loras
  • Step 2: Load shape expander lora
  • Step 3; Run inference

This should currently raise an input size mismatch error on the normal loras when the inputs pass through the normal lora computation path - which is okay. This is because we don't handle the case of rejigging and expanding all the previous N loras if a shape expander lora is loaded after them.

Maybe if someone reports an issue, it would be worth looking into. But otherwise, the recommendation would be to load shape expander lora first, and then all the others loras. Thanks again for making this possible!

lora_A_weight_name = f"{name}.lora_A.weight"
lora_B_weight_name = f"{name}.lora_B.weight"
if lora_A_weight_name not in state_dict.keys():
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}

pipe.load_lora_weights(lora_state_dict, "adapter-2")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool that this would now be possible! 🔥

@sayakpaul
Copy link
Member Author

Just completed running ALL the integration tests for Flux LoRA (test_lora_layers_flux.py). All green 🟢

@sayakpaul
Copy link
Member Author

Merging.

@sayakpaul sayakpaul merged commit 17128c4 into main Dec 20, 2024
15 checks passed
@sayakpaul sayakpaul deleted the expand-flux-lora branch December 20, 2024 09:00
Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
…d Fill (huggingface#10259)

* lora expansion with dummy zeros.

* updates

* fix working 🥳

* working.

* use torch.device meta for state dict expansion.

* tests

Co-authored-by: a-r-r-o-w <[email protected]>

* fixes

* fixes

* switch to debug

* fix

* Apply suggestions from code review

Co-authored-by: Aryan <[email protected]>

* fix stuff

* docs

---------

Co-authored-by: a-r-r-o-w <[email protected]>
Co-authored-by: Aryan <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
…d Fill (#10259)

* lora expansion with dummy zeros.

* updates

* fix working 🥳

* working.

* use torch.device meta for state dict expansion.

* tests

Co-authored-by: a-r-r-o-w <[email protected]>

* fixes

* fixes

* switch to debug

* fix

* Apply suggestions from code review

Co-authored-by: Aryan <[email protected]>

* fix stuff

* docs

---------

Co-authored-by: a-r-r-o-w <[email protected]>
Co-authored-by: Aryan <[email protected]>
@Cyb4Black
Copy link

Thanks for extending the functionality of loading LoRA adapters when shapes need to be expanded. The PR LGTM, I only have a nit.

One question that came up (maybe it was already discussed and I just missed it or forgot): Right now, this type of expansion is permanent, right? I.e. even after unloading the LoRA that made the expansion necessary in the first place, the expansion is not undone. Probably that would be quite hard to add and not worth the effort, I'm just curious.

Hey, would like to chime in on this one, as we actually have a use case.
We want to have the flux hosted with openAI like API and be able to use extra parameters to enable or disable preloaded LoRAs.
For now this means we can not use the control loras, as we would not be able to have the next users request work without the control lora enabled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Can't load multiple loras when using Flux Control LoRA
5 participants