-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill #10259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-authored-by: a-r-r-o-w <[email protected]>
Requesting for a review from @BenjaminBossan for initial stuff. Then will request reviews from others. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for extending the functionality of loading LoRA adapters when shapes need to be expanded. The PR LGTM, I only have a nit.
One question that came up (maybe it was already discussed and I just missed it or forgot): Right now, this type of expansion is permanent, right? I.e. even after unloading the LoRA that made the expansion necessary in the first place, the expansion is not undone. Probably that would be quite hard to add and not worth the effort, I'm just curious.
The LoRA state dict expansion is permanent. But model-level state dict expansion can be undone is being added in #10206 |
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) | ||
pipe = self.pipeline_class(**components) | ||
pipe = pipe.to(torch_device) | ||
pipe.set_progress_bar_config(disable=None) | ||
dummy_lora_A = torch.nn.Linear(1, rank, bias=False) | ||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) | ||
lora_state_dict = { | ||
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, | ||
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, | ||
} | ||
# We should error out because lora input features is less than original. We only | ||
# support expanding the module, not shrinking it | ||
with self.assertRaises(NotImplementedError): | ||
pipe.load_lora_weights(lora_state_dict, "adapter-1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing this part of the test because in case LoRA input feature dimensions are less than the original, we expand it.
This is tested below with test_lora_expanding_shape_with_normal_lora()
and test_load_regular_lora()
.
Yes, something similar for LoRA would be nice, but it's not as important, as the overhead for LoRA should be relatively small. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making this work @sayakpaul! Just some tiny nits and a question
@@ -478,21 +463,16 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): | |||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight, | |||
} | |||
|
|||
# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this error is now removed because expanding lora state dicts is now supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly!
Co-authored-by: Aryan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Sayak! LGTM
The case for the following can be looked into in a separate PR because it is not too important.
- Step 1: Load N "normal" loras
- Step 2: Load shape expander lora
- Step 3; Run inference
This should currently raise an input size mismatch error on the normal loras when the inputs pass through the normal lora computation path - which is okay. This is because we don't handle the case of rejigging and expanding all the previous N loras if a shape expander lora is loaded after them.
Maybe if someone reports an issue, it would be worth looking into. But otherwise, the recommendation would be to load shape expander lora first, and then all the others loras. Thanks again for making this possible!
lora_A_weight_name = f"{name}.lora_A.weight" | ||
lora_B_weight_name = f"{name}.lora_B.weight" | ||
if lora_A_weight_name not in state_dict.keys(): | ||
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight, | ||
} | ||
|
||
pipe.load_lora_weights(lora_state_dict, "adapter-2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool that this would now be possible! 🔥
Just completed running ALL the integration tests for Flux LoRA ( |
Merging. |
…d Fill (huggingface#10259) * lora expansion with dummy zeros. * updates * fix working 🥳 * working. * use torch.device meta for state dict expansion. * tests Co-authored-by: a-r-r-o-w <[email protected]> * fixes * fixes * switch to debug * fix * Apply suggestions from code review Co-authored-by: Aryan <[email protected]> * fix stuff * docs --------- Co-authored-by: a-r-r-o-w <[email protected]> Co-authored-by: Aryan <[email protected]>
…d Fill (#10259) * lora expansion with dummy zeros. * updates * fix working 🥳 * working. * use torch.device meta for state dict expansion. * tests Co-authored-by: a-r-r-o-w <[email protected]> * fixes * fixes * switch to debug * fix * Apply suggestions from code review Co-authored-by: Aryan <[email protected]> * fix stuff * docs --------- Co-authored-by: a-r-r-o-w <[email protected]> Co-authored-by: Aryan <[email protected]>
Hey, would like to chime in on this one, as we actually have a use case. |
What does this PR do?
Fixes #10180, #10227, #10184
In short, this PR enables few-steps inference for Flux Control, Fill, Redux, etc.
Fill + Turbo LoRA
Flux Control LoRA + Turbo LoRA (different from the previous one)
Todods