Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux fill cannot use lora(flux turbo lora) #10184

Closed
Suprhimp opened this issue Dec 11, 2024 · 17 comments
Closed

flux fill cannot use lora(flux turbo lora) #10184

Suprhimp opened this issue Dec 11, 2024 · 17 comments
Labels
bug Something isn't working lora

Comments

@Suprhimp
Copy link

Suprhimp commented Dec 11, 2024

Describe the bug

I want to use flux fill pipeline with turbo lora, but when I load pipeline and load lora model, than gives error

Reproduction

from diffusers import FluxFillPipeline

def model_fn(model_dir: str) -> FluxFillPipeline:

    pipe = FluxFillPipeline.from_pretrained(
       "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
    ).to("cuda")
    
    pipe.load_lora_weights(f"alimama-creative/FLUX.1-Turbo-Alpha")
    pipe.fuse_lora()
    
    
    return pipe

Logs

NotImplementedError: Only LoRAs with input/output features higher than the current module's input/output features are currently supported. The provided LoRA contains in_features=64 and out_features=3072, which are lower than module_in_features=384 and module_out_features=3072. If you require support for this please open an issue at https://github.com/huggingface/diffusers/issues.

System Info

latest(github version diffusers), python3.10, ubuntu with nvidia gpu

Who can help?

@sayakpaul

@Suprhimp Suprhimp added the bug Something isn't working label Dec 11, 2024
@sayakpaul
Copy link
Member

I am not sure you can. You're trying to load a LoRA that was obtained on Flux.1-Dev into something different.

@Suprhimp
Copy link
Author

I thought that I can use it becuase I tried with comfyui first. it worked well with comfyui

@sayakpaul
Copy link
Member

Cc: @a-r-r-o-w could be nice to try the LoRA expansion idea here as well.

@a-r-r-o-w
Copy link
Member

Will have to take a look at what Comfy is doing to be sure they're not simply dropping the input lora layer (which requires expansion/shrinking) and using all the other layers - in this case it will work easily. Or perhaps they are not passing the channelwise-concatenated latents through the lora layer and just passing the denoising latents. I currently don't have the bandwidth to try this out, but will be sure to look into it soon!

@sayakpaul
Copy link
Member

Or perhaps they are not passing the channelwise-concatenated latents through the lora layer and just passing the denoising latents.

Valid.

@Suprhimp
Copy link
Author

I hope I can use turbo for flux fill soon :)

@kadirnar
Copy link
Contributor

kadirnar commented Dec 12, 2024

@sayakpaul
This code works. Is it useful for you?

https://github.com/nftblackmagic/catvton-flux/blob/main/tryon_inference_lora.py#L26

Code:

        transformer = FluxTransformer2DModel.from_pretrained(
            "xiaozaa/flux1-fill-dev-diffusers",   ## The official Flux-Fill weights
            torch_dtype=torch.bfloat16
        )
        print("Start loading LoRA weights")
        state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
            pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha",     ## The tryon Lora weights
            weight_name="pytorch_lora_weights.safetensors",
            return_alphas=True
        )
        is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
        if not is_correct_format:
            raise ValueError("Invalid LoRA checkpoint.")

        FluxFillPipeline.load_lora_into_transformer(
            state_dict=state_dict,
            network_alphas=network_alphas,
            transformer=transformer,
        )
        

@yiyixuxu yiyixuxu added the lora label Dec 13, 2024
@Suprhimp
Copy link
Author

got error ;)

def model_fn(model_dir: str) -> FluxFillPipeline:
    
    transformer = FluxTransformer2DModel.from_pretrained(
         "xiaozaa/flux1-fill-dev-diffusers",   ## The official Flux-Fill weights
        torch_dtype=torch.bfloat16
    )
    print("Start loading LoRA weights")
    state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
        pretrained_model_name_or_path_or_dict=f"alimama-creative/FLUX.1-Turbo-Alpha",     ## The tryon Lora weights
        return_alphas=True
    )
    is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
    if not is_correct_format:
        raise ValueError("Invalid LoRA checkpoint.")

    FluxFillPipeline.load_lora_into_transformer(
        state_dict=state_dict,
        network_alphas=network_alphas,
        transformer=transformer,
    )

    pipe = FluxFillPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", 
        torch_dtype=torch.bfloat16,
        transformer=transformer,
    ).to("cuda")

        
    # pipe.load_lora_weights(f"alimama-creative/FLUX.1-Turbo-Alpha")
    # pipe.fuse_lora()
    return pipe

error message:

RuntimeError: Error(s) in loading state_dict for FluxTransformer2DModel:
        size mismatch for x_embedder.lora_A.default_0.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([64, 384]).

But loading xiaozaa/catvton-flux-lora-alpha works.

Also I tried with this way but it is not work too.

def model_fn(model_dir: str) -> FluxFillPipeline:
    
    transformer = FluxTransformer2DModel.from_pretrained(
         "xiaozaa/flux1-fill-dev-diffusers",   ## The official Flux-Fill weights
        torch_dtype=torch.bfloat16
    )

    pipe = FluxFillPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", 
        torch_dtype=torch.bfloat16,
        transformer=transformer,
    ).to("cuda")

        
    pipe.load_lora_weights(f"alimama-creative/FLUX.1-Turbo-Alpha")
    pipe.fuse_lora()
    return pipe

@sayakpaul
Copy link
Member

@kadirnar it won't work as the LoRA you showed in the example was obtained on the Flux fill checkpoint itself.

@sayakpaul
Copy link
Member

I tried this huggingface:0d96a89...huggingface:ed91c533f but the outputs are pure garbage. Wondering what the strategy here is.

@sayakpaul
Copy link
Member

Oh we had to set the scales accordingly:

from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch

control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
control_pipe.enable_model_cpu_offload()

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")

image = control_pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=8,
    guidance_scale=10.0,
    generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")

^ should work:

image

Make sure to install from expand-flux-lora branch.

@sayakpaul
Copy link
Member

@Suprhimp #10227 (comment)

@Suprhimp
Copy link
Author

Thanks for handling this issue, I appriciate.

But when I run my example

from diffusers import FluxFillPipeline
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch

pipe = FluxFillPipeline.from_pretrained(
       "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
    ).to("cuda")
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")

pipe.set_adapters("hyper-sd", adapter_weights=1.0)
pipe.enable_model_cpu_offload()


image = load_image("my-image-url")
mask = load_image("my-mask-url")


image = pipe(
    prompt="star fish",
    image=image,
    mask_image=mask,
    height=1024,
    width=1024,
    guidance_scale=30,
    num_inference_steps=8,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("output.png")

It doesn't work.

Traceback (most recent call last):
  File "/home/ubuntu/flux_fill_test.py", line 121, in <module>
    pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
  File "/home/ubuntu/venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1856, in load_lora_weights
    has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
  File "/home/ubuntu/venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 2337, in _maybe_expand_transformer_param_shape_or_error_
    raise NotImplementedError(
NotImplementedError: Only LoRAs with input/output features higher than the current module's input/output features are currently supported. The provided LoRA contains in_features=64 and out_features=3072, which are lower than module_in_features=384 and module_out_features=3072. If you require support for this please open an issue at https://github.com/huggingface/diffusers/issues.

I tried this huggingface:0d96a89...huggingface:ed91c533f but the outputs are pure garbage. Wondering what the strategy here is.

I edit diffusers directly your commit inside my venv and run the example. But it failed. ;)

@sayakpaul
Copy link
Member

Install diffusers from expand-flux-lora branch.

@Suprhimp
Copy link
Author

Thanks for quick reply

I installed that branch and this is new error

Loading hyper-sd was unsucessful with the following error: 
Error(s) in loading state_dict for FluxTransformer2DModel:
        size mismatch for x_embedder.lora_A.hyper-sd.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([64, 384]).
Traceback (most recent call last):
  File "/home/ubuntu/flux_fill_test.py", line 121, in <module>
    pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
  File "/home/ubuntu/venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1871, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/ubuntu/venv/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1935, in load_lora_into_transformer
    transformer.load_lora_adapter(
  File "/home/ubuntu/venv/lib/python3.10/site-packages/diffusers/loaders/peft.py", line 325, in load_lora_adapter
    incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
  File "/home/ubuntu/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 445, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
  File "/home/ubuntu/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for FluxTransformer2DModel:
        size mismatch for x_embedder.lora_A.hyper-sd.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([64, 384]).

@sayakpaul
Copy link
Member

I don't know if it's an issue with your local installation but I just ran it and it worked.

Please check it rigorously if you have done the installation correctly:

pip uninstall diffusers -y
git clone https://github.com/huggingface/diffusers/
cd diffusers
git checkout expand-flux-lora
pip install -e .

@JakobLS also tried it out in #10227 (comment) and has confirmed it's working. So, I am not sure what's wrong in your case.

@Suprhimp
Copy link
Author

Oh, It worked!

I just installed with pip install git+https://github.com/huggingface/diffusers.git@expand-flux-lora but maby pip did not erased clear.

Thanks :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working lora
Projects
None yet
Development

No branches or pull requests

5 participants