Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load multiple loras when using Flux Control LoRA #10180

Open
jonathanyin12 opened this issue Dec 10, 2024 · 11 comments · Fixed by #10182 · May be fixed by #10259
Open

Can't load multiple loras when using Flux Control LoRA #10180

jonathanyin12 opened this issue Dec 10, 2024 · 11 comments · Fixed by #10182 · May be fixed by #10259
Labels
bug Something isn't working help wanted Extra attention is needed lora

Comments

@jonathanyin12
Copy link
Contributor

Describe the bug

I was trying out the FluxControlPipeline with the Control LoRA introduced in #9999 , but had issues loading in multiple loras.

For example, if I load the depth lora first and then the 8-step lora, it errors on the 8-step lora, and if I load the 8-step lora first and then the depth lora, it errors when loading the depth lora.

Reproduction

from diffusers import FluxControlPipeline
from huggingface_hub import hf_hub_download
import torch

control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora")
control_pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))

Logs

AttributeError                            Traceback (most recent call last)
Cell In[6], line 8
      5 control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
      7 control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora")
----> 8 control_pipe.load_lora_weights(
      9         hf_hub_download(
     10             "ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"
     11         ),
     12         adapter_name="HyperFlux",
     13     )

File ~/.venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py:1856, in FluxLoraLoaderMixin.load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
   1849 transformer_norm_state_dict = {
   1850     k: state_dict.pop(k)
   1851     for k in list(state_dict.keys())
   1852     if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
   1853 }
   1855 transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
-> 1856 has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
   1857     transformer, transformer_lora_state_dict, transformer_norm_state_dict
   1858 )
   1860 if has_param_with_expanded_shape:
   1861     logger.info(
   1862         "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
   1863         "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
   1864         "To get a comprehensive list of parameter names that were modified, enable debug logging."
   1865     )

File ~/.venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py:2316, in FluxLoraLoaderMixin._maybe_expand_transformer_param_shape_or_error_(cls, transformer, lora_state_dict, norm_state_dict, prefix)
   2314 if isinstance(module, torch.nn.Linear):
   2315     module_weight = module.weight.data
-> 2316     module_bias = module.bias.data if hasattr(module, "bias") else None
   2317     bias = module_bias is not None
   2319     lora_A_weight_name = f"{name}.lora_A.weight"

AttributeError: 'NoneType' object has no attribute 'data'

System Info

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-5.15.0-124-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.5
  • Transformers version: 4.47.0
  • Accelerate version: 1.2.0
  • PEFT version: 0.14.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@a-r-r-o-w @sayakpaul

@jonathanyin12 jonathanyin12 added the bug Something isn't working label Dec 10, 2024
@a-r-r-o-w
Copy link
Member

Oh, we should have anticipated this use case. I think the correct check should be module_bias = module.bias.data if module.bias is not None else None instead.

Even with the above fix, I don't think the weights would load as expected because the depth control lora would expand the input features of x_embedder to 128, but Hyper-SD LoRA will have input features of 64. Will try and respond back shortly

cc @yiyixuxu as well

@a-r-r-o-w
Copy link
Member

It does indeed error out with the corrected if-statement as well due to the explanation above.

trace
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump4.py", line 9, in <module>
    pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
  File "/home/aryan/work/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1868, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/aryan/work/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1932, in load_lora_into_transformer
    transformer.load_lora_adapter(
  File "/home/aryan/work/diffusers/src/diffusers/loaders/peft.py", line 320, in load_lora_adapter
    incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 445, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for FluxTransformer2DModel:
        size mismatch for x_embedder.lora_A.default_1.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([64, 128]).

I do believe that this should work as expected allowing for depth-control-lora to work with N-step hyper-sd-loras. This is a unique case that has probably never been investigated before. Not completely sure on how we would handle this either :/

My initial thoughts are to expand the lora shapes as well, and set the weights of the linear layer corresponding to the depth control input to 0. This should effectively remove the control latent from interfering with the effect of hyper-sd and it will operate only on the denoising latent. Will experiment and let the results speak for whether this would be something we should try to prioritize support for (as there are 10000+ available Flux loras that might be compatible), and will let YiYi and Sayak comment on how best to handle this situation if it works as expected

Are you facing any errors when trying to run inference with LoRAs, but without control LoRAs? Either way, I think above mentioned condition needs to be updated.

@a-r-r-o-w a-r-r-o-w added help wanted Extra attention is needed lora labels Dec 10, 2024
@jonathanyin12
Copy link
Contributor Author

I just tried using the normal FluxPipeline, which also encounters the same issue.

Repro script:

from diffusers import FluxPipeline
import torch
from huggingface_hub import hf_hub_download

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe.to("cuda")

pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"),
)
pipe.load_lora_weights(
    "strangerzonehf/Flux-Midjourney-Mix2-LoRA",
)

@sayakpaul
Copy link
Member

First of all, thanks for your PR (#10182), @jonathanyin12! We appreciate it!

@a-r-r-o-w

I echo your thoughts on why we should support this as it is quite a bit enablement!

My initial thoughts are to expand the lora shapes as well, and set the weights of the linear layer corresponding to the depth control input to 0. This should effectively remove the control latent from interfering with the effect of hyper-sd and it will operate only on the denoising latent.

Yeah I think that would be the way to go here. The expanded inputs channels are initialized to zero anyway, so, it won't interfere with the HyperSD LoRA.

Would love to get @BenjaminBossan's thoughts here too. LMK if anything is unclear from the issues and the comments above.

@sayakpaul
Copy link
Member

#10182 doesn't completely solve this issue, so opening.

@a-r-r-o-w
Copy link
Member

An important thing to note why our integration tests passed through despite the bug in expanding shapes (that is now fixed, thanks to @jonathanyin12!):

Most LoRAs on the hub don't train for the x_embedder layer and are limited to QKV/Out projections - so we can load different loras without issues unless they have x_embedder-specific lora layers too (because it is the only layer that undergoes shape expansion).

@BenjaminBossan
Copy link
Member

LMK if anything is unclear from the issues and the comments above.

because the depth control lora would expand the input features of x_embedder to 128

Yes, this part is unclear to me. Is the x_embedder a normal part of the base model? If yes, why does it need to be expanded for LoRA to work? It sounds a bit to me like the different LoRAs have been trained on different base models.

@sayakpaul
Copy link
Member

It needs to be expanded because:

  1. Otherwise, we won't be able to allow for additional inputs across the channel dimension.
  2. The Control LoRA was obtained through the following:
    a. The base model was fine-tuned with additional structural inputs with its x_embedder layer expanded from (3072, 64) to (3072, 128).
    b. A LoRA was derived from the fine-tuned copy using an approximation method. Now this derived LoRA becomes loadable into the base model (with the expanded layer in question).

In the base model (assume no expansion), before the LoRA is loaded the concerned layer is expanded like so:

slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
expanded_module.weight.data.copy_(new_weight)
if module_bias is not None:
expanded_module.bias.data.copy_(module_bias)
setattr(parent_module, current_module_name, expanded_module)

@sayakpaul
Copy link
Member

sayakpaul commented Dec 11, 2024

Additionally, I used @a-r-r-o-w's idea of expanding the LoRA state dict with zeros (minimal changes are in the expand-flux-lora branch) but it is probably wrong as I get garbage results.

test code
from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch

control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)

control_pipe.set_adapters(["depth", "hyper-sd"])

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")

image = control_pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=50, # when lowered the results are still garbage
    guidance_scale=10.0,
    generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")

Cc: @jonathanyin12

@BenjaminBossan
Copy link
Member

2. The base model was fine-tuned with additional structural inputs with its x_embedder layer expanded from (3072, 64) to (3072, 128).

Okay, so technically it is a different base model, although just this one layer was adjusted. Yes, I don't really see any better option than to pad with zeros in that case.

Additionally, I used @a-r-r-o-w idea of expanding the LoRA state dict with zeros (minimal changes are in the expand-flux-lora branch) but it is probably wrong as I get garbage results.

I did a quick check of the diff but at first glance I see no issue.

@sayakpaul
Copy link
Member

@jonathanyin12 can you check #10184 (comment)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed lora
Projects
None yet
4 participants