Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add AuraFlow #8796

Merged
merged 42 commits into from
Jul 11, 2024
Merged

[Core] Add AuraFlow #8796

merged 42 commits into from
Jul 11, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jul 5, 2024

What does this PR do?

Adds Aura Flow from Fal.

Test code:

from diffusers import AuraFlowPipeline
from diffusers.utils import make_image_grid
import torch


pipeline = AuraFlowPipeline(
	"AuraDiffusion/auradiffusion-v0.1a0", 
	torch_dtype=torch.float16
).to("cuda")

images = pipeline(
    prompt="a cute cat with tiger like looks",
    height=512,
    width=512,
    num_inference_steps=50, 
    num_images_per_prompt=4,
    generator=torch.Generator().manual_seed(666),
    guidance_scale=3.5,
).images
make_image_grid(images, 1, 4).save("demo_hf.png")

Warning

To download the model you must be a member of the AuraDiffusion org. Follow this (internal) Slack message.

Gives:

image

TODOS

  • Docs
  • Tests
  • Scheduler (@yiyixuxu would you be able to help out here? I couldn't find a way to use our existing flow matching scheduler in this case)

Because of the last point above, the noise scheduling code is taken from the original codebase. But I think this PR is still ready for a first review.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu July 5, 2024 08:33
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -0,0 +1,401 @@
# Copyright 2024 Stability AI, Lavender Flow, The HuggingFace Team. All rights reserved.
Copy link
Member Author

@sayakpaul sayakpaul Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New model because it differs from the SD3 one (non-exhaustive list):

  • Uses register tokens
  • Mixes MMDiT and another kind of simple DiT block (that uses a concatenated encoder_hidden_states and hidden_states as its inputs)
  • The final layer norm is different
  • Position embeddings are different (uses learned positional embeddings)
  • The feedforward is different. We only support GeLU and its variants in the feedforward. It uses SwiGLU.
  • No pooled projections.

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have deliberately kept additional methods like feedforward chunking, QKV fusion, etc. out of this class because it helps with the initial review.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!
left some comments mainly on attention processor

src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

Looking into the test failures 👀

@bghira
Copy link
Contributor

bghira commented Jul 8, 2024

i've been testing a fork of this with the LoRA support and it works without any changes to just add the peft adapter to the Transformer2D model and the SD3 LoRA loader mixin to the pipeline.

@sayakpaul
Copy link
Member Author

sayakpaul commented Jul 9, 2024

@yiyixuxu @DN6 I have addressed the comments. Here are some extended comments from my end:

@bghira, I will add LoRA support in an immediate future PR once this PR is merged to keep the reviewing scope concrete and manageable. It's not just about adding those classes like you mentioned. We need to scale and unscale the layers appropriately for dealing with scale, add features like fuse_lora(), etc. So, keeping that out of this PR.

@sayakpaul sayakpaul marked this pull request as ready for review July 9, 2024 05:00
@sayakpaul
Copy link
Member Author

I have also added fast tests and decided to make the default value of negative prompt to be None instead of "This is watermark, jpeg image white background, web image". I think this is better aligned with our other pipelines. Will include this negative prompt in the docs once I start adding them.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu July 9, 2024 06:59
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!!

src/diffusers/models/normalization.py Outdated Show resolved Hide resolved
@@ -158,7 +158,12 @@ def scale_noise(
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm I don't think these changed are introduced in this PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu we merged #8799 into the current PR branch. So, the commits will come here. But they still belong to you, so I guess that is okay?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, I was confused 😅

padding="max_length",
return_tensors="pt",
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious what else is in text_inputs other than the text_input_ids?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@sayakpaul sayakpaul mentioned this pull request Jul 10, 2024
5 tasks
@yiyixuxu yiyixuxu merged commit 2261510 into main Jul 11, 2024
16 of 18 checks passed
@yiyixuxu yiyixuxu deleted the lavender-flow branch July 11, 2024 18:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants