Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Flux IP Adapter #10261

Merged
merged 31 commits into from
Dec 21, 2024
Merged

Support Flux IP Adapter #10261

merged 31 commits into from
Dec 21, 2024

Conversation

hlky
Copy link
Collaborator

@hlky hlky commented Dec 17, 2024

What does this PR do?

Adds support for XLabs Flux IP Adapter.

Example

import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image

pipe: FluxPipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

image = load_image("assets_statue.jpg").resize((1024, 1024))

pipe.load_ip_adapter("XLabs-AI/flux-ip-adapter", weight_name="ip_adapter.safetensors", image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14")
pipe.set_ip_adapter_scale(1.0)

image = pipe(
    width=1024,
    height=1024,
    prompt='wearing sunglasses',
    negative_prompt="",
    true_cfg=4.0,
    generator=torch.Generator().manual_seed(4444),
    ip_adapter_image=image,
).images[0]

image.save('flux_ipadapter_4444.jpg')

Input Output
assets_statue flux_ipadapter_4444

flux-ip-adapter-v2

Details

Note: true_cfg=1.0 is important, and strength is sensitive, fixed strength may not work, see here for more strength schedules, good results will require experimentation with strength schedules and the start/stop values. Results also vary with input image, I had no success with the statue image used for v1 test.

Multiple input images is not yet supported (dev note: apply torch.mean to the batch of image_embeds and to ip_attention)

import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image

pipe: FluxPipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

image = load_image("monalisa.jpg").resize((1024, 1024))

pipe.load_ip_adapter("XLabs-AI/flux-ip-adapter-v2", weight_name="ip_adapter.safetensors", image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14")

def LinearStrengthModel(start, finish, size):
    return [
        (start + (finish - start) * (i / (size - 1))) for i in range(size)
    ]

ip_strengths = LinearStrengthModel(0.4, 1.0, 19)
pipe.set_ip_adapter_scale(ip_strengths)

image = pipe(
    width=1024,
    height=1024,
    prompt='wearing red sunglasses, golden chain and a green cap',
    negative_prompt="",
    true_cfg_scale=1.0,
    generator=torch.Generator().manual_seed(0),
    ip_adapter_image=image,
).images[0]

image.save('result.jpg')

Input Output
monalisa result (14)

Notes

  • XLabs Flux IP Adapter produces bad results when used without CFG
    • Verifiable in original codebase, set --timestep_to_start_cfg greater than the number of steps to disable CFG
  • XLabs Flux IP Adapter also produces bad results when used with CFG in a batch (negative and positive concat)
  • This PR copies (most) of the changes from our pipeline_flux_with_cfg community example, except we run positive and negative separately.
  • Conversion script is optional, original weights will be converted on-the-fly from load_ip_adapter.
  • load_ip_adapter supports image_encoder_pretrained_model_name_or_path e.g. "openai/clip-vit-large-patch14" rather than just image_encoder_folder, also supports image_encoder_dtype with default torch.float16.
  • This required some changes to FluxTransformerBlock because of where ip_attention is applied to the hidden_states, see here in the original codebase.
  • flux-ip-adapter-v2 will be fixed and tested shortly.

Fixes #9825
Fixes #9403

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @yiyixuxu @DN6

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hlky hlky added the roadmap Add to current release roadmap label Dec 17, 2024
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/flux/pipeline_flux.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nits. Looks good otherwise.

src/diffusers/pipelines/flux/pipeline_flux.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/flux/pipeline_flux.py Outdated Show resolved Hide resolved
tests/pipelines/flux/test_pipeline_flux.py Outdated Show resolved Hide resolved
@slow
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hlky Could we add a fast test using something similar to what's been done here

def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
if "image" in parameters.keys() and "strength" in parameters.keys():
inputs["num_inference_steps"] = 4
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs)[0]
else:
output_without_adapter = expected_pipe_slice

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left one comment, looks good otherwise!

@hlky hlky merged commit be20709 into huggingface:main Dec 21, 2024
12 checks passed
Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
* Flux IP-Adapter

* test cfg

* make style

* temp remove copied from

* fix test

* fix test

* v2

* fix

* make style

* temp remove copied from

* Apply suggestions from code review

Co-authored-by: YiYi Xu <[email protected]>

* Move encoder_hid_proj to inside FluxTransformer2DModel

* merge

* separate encode_prompt, add copied from, image_encoder offload

* make

* fix test

* fix

* Update src/diffusers/pipelines/flux/pipeline_flux.py

* test_flux_prompt_embeds change not needed

* true_cfg -> true_cfg_scale

* fix merge conflict

* test_flux_ip_adapter_inference

* add fast test

* FluxIPAdapterMixin not test mixin

* Update pipeline_flux.py

Co-authored-by: YiYi Xu <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Flux IP-Adapter

* test cfg

* make style

* temp remove copied from

* fix test

* fix test

* v2

* fix

* make style

* temp remove copied from

* Apply suggestions from code review

Co-authored-by: YiYi Xu <[email protected]>

* Move encoder_hid_proj to inside FluxTransformer2DModel

* merge

* separate encode_prompt, add copied from, image_encoder offload

* make

* fix test

* fix

* Update src/diffusers/pipelines/flux/pipeline_flux.py

* test_flux_prompt_embeds change not needed

* true_cfg -> true_cfg_scale

* fix merge conflict

* test_flux_ip_adapter_inference

* add fast test

* FluxIPAdapterMixin not test mixin

* Update pipeline_flux.py

Co-authored-by: YiYi Xu <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support IPAdapters for FLUX pipelines [Flux IPadapter] Support Xlabs IPadapter in diffusers
4 participants