Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stable Diffusion 3 Example #2558

Merged

Conversation

Czxck001
Copy link
Contributor

@Czxck001 Czxck001 commented Oct 13, 2024

It appears the focus of the community has been largely shifted to Flux.dev1. So the main purpose of this PR is to demonstrate the capability of Candle and serves a smoke-test to the MMDiT (#2397).

As such, I intend to minimize the intrusive change to the existing stable-diffusion codebase, such as using renaming function to adapt the VAE var-builder to the official safetensor weights of SD3 VAE. Still, there are some changes I have to make to candle_nn::stable_diffusion to support the CLIP and VAE of SD3, including:

  • Add a forward_until_encoder_layer to ClipTextTransformer. The Comfy implementation for SD3 uses the penultimate hidden layer of CLIP-l and CLIP-g instead of the final layer (see sd3_clip.py and sdxl_clip.py). This practice, although not mentioned in the SD3 tech report, is referred and specified in Chapter 2.1 of the SDXL tech report.
  • Add the use_quant_conv and use_post_quant_conv options to the AutoEncoderKL, as SD3's VAE does not have those layers. These changes might be considered unspecific to SD3, as diffusers has these options supported.
  • Uses get_qkv_linear to load the attention block in candle_nn::stable-diffusion::attention, as some weight of linear layer of VAE in official SD3 Medium safetensors follow the dimension convention of (channel, channel, 1, 1) instead of the regular (channel, channel) that is natually supported by nn::linear constructor.

These changes allows reusing existing CLIP and VAE implementations, but inevitably add complexity to existing codebase. @LaurentMazare Let me know if these intrusive changes are justified. We may consider alternatives like re-implementing VAE and CLIP from scratch.

On top of these changes, I added the support to flash-attention for MMDiT based on whether the feature flash-attn is enabled. Also done a simple performance benchmark on GPUs like 3090 Ti and 4090.

A side note is the T5 implementation on current main branch hasn't supported for FP16. I attempted to insert simple clampings within the FP16 dynamic range but it didn't work well on my GPUs. Looks like I need to wait for a more sophiscated implementation such as #2481. So for now, I use two different VarBuilders, one maps weights in safetensor into FP32 specifically for T5, the other for the rest compoents.

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test
Copy link
Collaborator

@LaurentMazare LaurentMazare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, thanks for adding this. Would you mind replacing the sample image with a jpg version? (the png version you attached takes almost 1MB which is not great for the repo size)

@Czxck001
Copy link
Contributor Author

@LaurentMazare Thank you for reminding me this. The sample image has been replaced by a JPG. The original PNG should be excluded from Git objects after squash-merging.

candle-examples/examples/stable-diffusion-3/main.rs Outdated Show resolved Hide resolved
candle-examples/Cargo.toml Show resolved Hide resolved
candle-examples/examples/stable-diffusion-3/vae.rs Outdated Show resolved Hide resolved
candle-transformers/src/models/mmdit/blocks.rs Outdated Show resolved Hide resolved
@LaurentMazare LaurentMazare merged commit ca7cf5c into huggingface:main Oct 13, 2024
10 checks passed
@Czxck001 Czxck001 deleted the add-stable-diffusion-3-example branch October 13, 2024 20:09
@LaurentMazare
Copy link
Collaborator

Merged, thanks a lot!

@super-fun-surf
Copy link

its so great to have this. thanks for the work. stable diffusion 3.5 Large is out now and it looks amazing. as its a full base model that we can train on. looks like its not working with candle yet though.

@Czxck001
Copy link
Contributor Author

@super-fun-surf working on it. SD3.5 changes the MMDiT archetecture a little bit (namely MMDiT-X). Needs to get that done first before implementing a working example.

@LaurentMazare
Copy link
Collaborator

@Czxck001 I actually already started adding 3.5 and just opened a PR for it #2578 , it seems to be working well with at least the turbo model so far and would be great if you can give a look at the PR.

@Czxck001
Copy link
Contributor Author

@LaurentMazare That's awesome! I'll take a look.

EricLBuehler pushed a commit to EricLBuehler/candle that referenced this pull request Nov 26, 2024
* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants