Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipelines] allow models to run with a user-provided dtype map instead of a single dtype #10108

Open
sayakpaul opened this issue Dec 4, 2024 · 0 comments · May be fixed by #10301
Open

[pipelines] allow models to run with a user-provided dtype map instead of a single dtype #10108

sayakpaul opened this issue Dec 4, 2024 · 0 comments · May be fixed by #10301
Labels
enhancement New feature or request

Comments

@sayakpaul
Copy link
Member

The newer models like Mochi-1 run the text encoder and VAE decoding in FP32 while keeping the denoising process in torch.bfloat16 autocast.

Currently, it's not possible for our pipelines to run the different models involved as we set a global torch_dtype while initializing the pipeline.

We have some pipelines like SDXL where the VAE has a config attribute called force_upcast and it's handled within the pipeline implementation like so:

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)

Another way to achieve this could be to decouple the major computation stages of the pipeline and users can choose whatever supported torch_dtype they want. Here is an example.

But this an involved process and is a power-user thing, IMO. What if we could allow the users to pass a torch_dtype map like so:

{"unet": torch.bfloat16, "vae": torch.float32, "text_encoder": torch.float32}

This along with @a-r-r-o-w's idea of an upcast marker could really benefit the pipelines that are not resilient to precision changes.

Cc: @DN6 @yiyixuxu @hlky

@hlky hlky added the enhancement New feature or request label Dec 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants