-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add vae_roundtrip.py example (#7104)
* Add vae_roundtrip.py example * Add cuda support to vae_roundtrip * Move vae_roundtrip.py into research_projects/vae * Fix channel scaling in vae roundrip and also support taesd. * Apply ruff --fix for CI gatekeep check --------- Co-authored-by: 脕lvaro Somoza <[email protected]>
- Loading branch information
1 parent
31adeb4
commit 2e2684f
Showing
2 changed files
with
293 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# VAE | ||
|
||
`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side. | ||
|
||
``` | ||
cd examples/research_projects/vae | ||
python vae_roundtrip.py \ | ||
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ | ||
--subfolder="vae" \ | ||
--input_image="/path/to/your/input.png" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
|
||
import argparse | ||
import typing | ||
from typing import Optional, Union | ||
|
||
import torch | ||
from PIL import Image | ||
from torchvision import transforms # type: ignore | ||
|
||
from diffusers.image_processor import VaeImageProcessor | ||
from diffusers.models.autoencoders.autoencoder_kl import ( | ||
AutoencoderKL, | ||
AutoencoderKLOutput, | ||
) | ||
from diffusers.models.autoencoders.autoencoder_tiny import ( | ||
AutoencoderTiny, | ||
AutoencoderTinyOutput, | ||
) | ||
from diffusers.models.autoencoders.vae import DecoderOutput | ||
|
||
|
||
SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny] | ||
|
||
|
||
def load_vae_model( | ||
*, | ||
device: torch.device, | ||
model_name_or_path: str, | ||
revision: Optional[str], | ||
variant: Optional[str], | ||
# NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE | ||
subfolder: Optional[str], | ||
use_tiny_nn: bool, | ||
) -> SupportedAutoencoder: | ||
if use_tiny_nn: | ||
# NOTE: These scaling factors don't have to be the same as each other. | ||
down_scale = 2 | ||
up_scale = 2 | ||
vae = AutoencoderTiny.from_pretrained( # type: ignore | ||
model_name_or_path, | ||
subfolder=subfolder, | ||
revision=revision, | ||
variant=variant, | ||
downscaling_scaling_factor=down_scale, | ||
upsampling_scaling_factor=up_scale, | ||
) | ||
assert isinstance(vae, AutoencoderTiny) | ||
else: | ||
vae = AutoencoderKL.from_pretrained( # type: ignore | ||
model_name_or_path, | ||
subfolder=subfolder, | ||
revision=revision, | ||
variant=variant, | ||
) | ||
assert isinstance(vae, AutoencoderKL) | ||
vae = vae.to(device) | ||
vae.eval() # Set the model to inference mode | ||
return vae | ||
|
||
|
||
def pil_to_nhwc( | ||
*, | ||
device: torch.device, | ||
image: Image.Image, | ||
) -> torch.Tensor: | ||
assert image.mode == "RGB" | ||
transform = transforms.ToTensor() | ||
nhwc = transform(image).unsqueeze(0).to(device) # type: ignore | ||
assert isinstance(nhwc, torch.Tensor) | ||
return nhwc | ||
|
||
|
||
def nhwc_to_pil( | ||
*, | ||
nhwc: torch.Tensor, | ||
) -> Image.Image: | ||
assert nhwc.shape[0] == 1 | ||
hwc = nhwc.squeeze(0).cpu() | ||
return transforms.ToPILImage()(hwc) # type: ignore | ||
|
||
|
||
def concatenate_images( | ||
*, | ||
left: Image.Image, | ||
right: Image.Image, | ||
vertical: bool = False, | ||
) -> Image.Image: | ||
width1, height1 = left.size | ||
width2, height2 = right.size | ||
if vertical: | ||
total_height = height1 + height2 | ||
max_width = max(width1, width2) | ||
new_image = Image.new("RGB", (max_width, total_height)) | ||
new_image.paste(left, (0, 0)) | ||
new_image.paste(right, (0, height1)) | ||
else: | ||
total_width = width1 + width2 | ||
max_height = max(height1, height2) | ||
new_image = Image.new("RGB", (total_width, max_height)) | ||
new_image.paste(left, (0, 0)) | ||
new_image.paste(right, (width1, 0)) | ||
return new_image | ||
|
||
|
||
def to_latent( | ||
*, | ||
rgb_nchw: torch.Tensor, | ||
vae: SupportedAutoencoder, | ||
) -> torch.Tensor: | ||
rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore | ||
encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw)) | ||
if isinstance(encoding_nchw, AutoencoderKLOutput): | ||
latent = encoding_nchw.latent_dist.sample() # type: ignore | ||
assert isinstance(latent, torch.Tensor) | ||
elif isinstance(encoding_nchw, AutoencoderTinyOutput): | ||
latent = encoding_nchw.latents | ||
do_internal_vae_scaling = False # Is this needed? | ||
if do_internal_vae_scaling: | ||
latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore | ||
latent = vae.unscale_latents(latent / 255.0) # type: ignore | ||
assert isinstance(latent, torch.Tensor) | ||
else: | ||
assert False, f"Unknown encoding type: {type(encoding_nchw)}" | ||
return latent | ||
|
||
|
||
def from_latent( | ||
*, | ||
latent_nchw: torch.Tensor, | ||
vae: SupportedAutoencoder, | ||
) -> torch.Tensor: | ||
decoding_nchw = vae.decode(latent_nchw) # type: ignore | ||
assert isinstance(decoding_nchw, DecoderOutput) | ||
rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore | ||
assert isinstance(rgb_nchw, torch.Tensor) | ||
return rgb_nchw | ||
|
||
|
||
def main_kwargs( | ||
*, | ||
device: torch.device, | ||
input_image_path: str, | ||
pretrained_model_name_or_path: str, | ||
revision: Optional[str], | ||
variant: Optional[str], | ||
subfolder: Optional[str], | ||
use_tiny_nn: bool, | ||
) -> None: | ||
vae = load_vae_model( | ||
device=device, | ||
model_name_or_path=pretrained_model_name_or_path, | ||
revision=revision, | ||
variant=variant, | ||
subfolder=subfolder, | ||
use_tiny_nn=use_tiny_nn, | ||
) | ||
original_pil = Image.open(input_image_path).convert("RGB") | ||
original_image = pil_to_nhwc( | ||
device=device, | ||
image=original_pil, | ||
) | ||
print(f"Original image shape: {original_image.shape}") | ||
reconstructed_image: Optional[torch.Tensor] = None | ||
|
||
with torch.no_grad(): | ||
latent_image = to_latent(rgb_nchw=original_image, vae=vae) | ||
print(f"Latent shape: {latent_image.shape}") | ||
reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae) | ||
reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image) | ||
combined_image = concatenate_images( | ||
left=original_pil, | ||
right=reconstructed_pil, | ||
vertical=False, | ||
) | ||
combined_image.show("Original | Reconstruction") | ||
print(f"Reconstructed image shape: {reconstructed_image.shape}") | ||
|
||
|
||
def parse_args() -> argparse.Namespace: | ||
parser = argparse.ArgumentParser(description="Inference with VAE") | ||
parser.add_argument( | ||
"--input_image", | ||
type=str, | ||
required=True, | ||
help="Path to the input image for inference.", | ||
) | ||
parser.add_argument( | ||
"--pretrained_model_name_or_path", | ||
type=str, | ||
required=True, | ||
help="Path to pretrained VAE model.", | ||
) | ||
parser.add_argument( | ||
"--revision", | ||
type=str, | ||
default=None, | ||
help="Model version.", | ||
) | ||
parser.add_argument( | ||
"--variant", | ||
type=str, | ||
default=None, | ||
help="Model file variant, e.g., 'fp16'.", | ||
) | ||
parser.add_argument( | ||
"--subfolder", | ||
type=str, | ||
default=None, | ||
help="Subfolder in the model file.", | ||
) | ||
parser.add_argument( | ||
"--use_cuda", | ||
action="store_true", | ||
help="Use CUDA if available.", | ||
) | ||
parser.add_argument( | ||
"--use_tiny_nn", | ||
action="store_true", | ||
help="Use tiny neural network.", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
# EXAMPLE USAGE: | ||
# | ||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png" | ||
# | ||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png" | ||
# | ||
def main_cli() -> None: | ||
args = parse_args() | ||
|
||
input_image_path = args.input_image | ||
assert isinstance(input_image_path, str) | ||
|
||
pretrained_model_name_or_path = args.pretrained_model_name_or_path | ||
assert isinstance(pretrained_model_name_or_path, str) | ||
|
||
revision = args.revision | ||
assert isinstance(revision, (str, type(None))) | ||
|
||
variant = args.variant | ||
assert isinstance(variant, (str, type(None))) | ||
|
||
subfolder = args.subfolder | ||
assert isinstance(subfolder, (str, type(None))) | ||
|
||
use_cuda = args.use_cuda | ||
assert isinstance(use_cuda, bool) | ||
|
||
use_tiny_nn = args.use_tiny_nn | ||
assert isinstance(use_tiny_nn, bool) | ||
|
||
device = torch.device("cuda" if use_cuda else "cpu") | ||
|
||
main_kwargs( | ||
device=device, | ||
input_image_path=input_image_path, | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
revision=revision, | ||
variant=variant, | ||
subfolder=subfolder, | ||
use_tiny_nn=use_tiny_nn, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_cli() |