Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Flux IP Adapter #10261

Merged
merged 31 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2ee946f
Flux IP-Adapter
hlky Oct 28, 2024
d794ab5
test cfg
hlky Dec 17, 2024
7167fc4
make style
hlky Dec 17, 2024
dc26e47
temp remove copied from
hlky Dec 17, 2024
09e1e58
fix test
hlky Dec 17, 2024
ce5558f
fix test
hlky Dec 17, 2024
84f08d7
Merge branch 'main' into ipadapter-flux
hlky Dec 17, 2024
12833b1
v2
hlky Dec 17, 2024
0eb3eb8
fix
hlky Dec 17, 2024
188a515
make style
hlky Dec 17, 2024
08b1aeb
Merge branch 'main' into ipadapter-flux
hlky Dec 17, 2024
45a2fb1
Merge branch 'main' into ipadapter-flux
hlky Dec 18, 2024
19b4d54
Merge branch 'main' into ipadapter-flux
hlky Dec 18, 2024
2537016
temp remove copied from
hlky Dec 18, 2024
5b0a88b
Apply suggestions from code review
hlky Dec 19, 2024
eb67b2c
Move encoder_hid_proj to inside FluxTransformer2DModel
hlky Dec 19, 2024
956e417
Merge branch 'main' into ipadapter-flux
hlky Dec 20, 2024
248bbd4
merge
hlky Dec 20, 2024
02edb0f
separate encode_prompt, add copied from, image_encoder offload
hlky Dec 20, 2024
a7bcf50
make
hlky Dec 20, 2024
3516159
fix test
hlky Dec 20, 2024
9059b37
fix
hlky Dec 20, 2024
7db9b44
Update src/diffusers/pipelines/flux/pipeline_flux.py
hlky Dec 20, 2024
786babb
test_flux_prompt_embeds change not needed
hlky Dec 20, 2024
0f229c4
true_cfg -> true_cfg_scale
hlky Dec 20, 2024
9276ced
fix merge conflict
hlky Dec 20, 2024
cab0dd8
test_flux_ip_adapter_inference
hlky Dec 20, 2024
7938b42
Merge branch 'main' into ipadapter-flux
hlky Dec 20, 2024
253ef7e
add fast test
hlky Dec 20, 2024
a3bf2a3
FluxIPAdapterMixin not test mixin
hlky Dec 20, 2024
a573e71
Update pipeline_flux.py
hlky Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions scripts/convert_flux_xlabs_ipadapter_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
from contextlib import nullcontext

import safetensors.torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download

from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available


if is_transformers_available():
from transformers import CLIPVisionModelWithProjection

vision = True
else:
vision = False

"""
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
--filename "flux-ip-adapter.safetensors"
--output_path "flux-ip-adapter-hf/"
"""


CTX = init_empty_weights if is_accelerate_available else nullcontext

parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)

args = parser.parse_args()


def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")

original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict


def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
converted_state_dict = {}

# image_proj
## norm
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
## proj
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")

# double transformer blocks
for i in range(num_layers):
block_prefix = f"ip_adapter.{i}."
# to_k_ip
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
)
# to_v_ip
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
)

return converted_state_dict


def main(args):
original_ckpt = load_original_checkpoint(args)

num_layers = 19
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)

print("Saving Flux IP-Adapter in Diffusers format.")
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")

if vision:
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
model.save_pretrained(f"{args.output_path}/image_encoder")


if __name__ == "__main__":
main(args)
6 changes: 5 additions & 1 deletion src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):

if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]

_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
Expand All @@ -75,8 +75,10 @@ def text_encoder_attn_modules(text_encoder):
"SanaLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
hlky marked this conversation as resolved.
Show resolved Hide resolved
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
]

Expand All @@ -86,12 +88,14 @@ def text_encoder_attn_modules(text_encoder):
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

if is_transformers_available():
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
)
Expand Down
Loading
Loading