Skip to content

Commit

Permalink
Add vae_roundtrip.py example (#7104)
Browse files Browse the repository at this point in the history
* Add vae_roundtrip.py example

* Add cuda support to vae_roundtrip

* Move vae_roundtrip.py into research_projects/vae

* Fix channel scaling in vae roundrip and also support taesd.

* Apply ruff --fix for CI gatekeep check

---------

Co-authored-by: 脕lvaro Somoza <[email protected]>
  • Loading branch information
thomaseding and asomoza committed Jul 4, 2024
1 parent 31adeb4 commit 2e2684f
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 0 deletions.
11 changes: 11 additions & 0 deletions examples/research_projects/vae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# VAE

`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side.

```
cd examples/research_projects/vae
python vae_roundtrip.py \
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
--subfolder="vae" \
--input_image="/path/to/your/input.png"
```
282 changes: 282 additions & 0 deletions examples/research_projects/vae/vae_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import argparse
import typing
from typing import Optional, Union

import torch
from PIL import Image
from torchvision import transforms # type: ignore

from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.autoencoder_kl import (
AutoencoderKL,
AutoencoderKLOutput,
)
from diffusers.models.autoencoders.autoencoder_tiny import (
AutoencoderTiny,
AutoencoderTinyOutput,
)
from diffusers.models.autoencoders.vae import DecoderOutput


SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny]


def load_vae_model(
*,
device: torch.device,
model_name_or_path: str,
revision: Optional[str],
variant: Optional[str],
# NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE
subfolder: Optional[str],
use_tiny_nn: bool,
) -> SupportedAutoencoder:
if use_tiny_nn:
# NOTE: These scaling factors don't have to be the same as each other.
down_scale = 2
up_scale = 2
vae = AutoencoderTiny.from_pretrained( # type: ignore
model_name_or_path,
subfolder=subfolder,
revision=revision,
variant=variant,
downscaling_scaling_factor=down_scale,
upsampling_scaling_factor=up_scale,
)
assert isinstance(vae, AutoencoderTiny)
else:
vae = AutoencoderKL.from_pretrained( # type: ignore
model_name_or_path,
subfolder=subfolder,
revision=revision,
variant=variant,
)
assert isinstance(vae, AutoencoderKL)
vae = vae.to(device)
vae.eval() # Set the model to inference mode
return vae


def pil_to_nhwc(
*,
device: torch.device,
image: Image.Image,
) -> torch.Tensor:
assert image.mode == "RGB"
transform = transforms.ToTensor()
nhwc = transform(image).unsqueeze(0).to(device) # type: ignore
assert isinstance(nhwc, torch.Tensor)
return nhwc


def nhwc_to_pil(
*,
nhwc: torch.Tensor,
) -> Image.Image:
assert nhwc.shape[0] == 1
hwc = nhwc.squeeze(0).cpu()
return transforms.ToPILImage()(hwc) # type: ignore


def concatenate_images(
*,
left: Image.Image,
right: Image.Image,
vertical: bool = False,
) -> Image.Image:
width1, height1 = left.size
width2, height2 = right.size
if vertical:
total_height = height1 + height2
max_width = max(width1, width2)
new_image = Image.new("RGB", (max_width, total_height))
new_image.paste(left, (0, 0))
new_image.paste(right, (0, height1))
else:
total_width = width1 + width2
max_height = max(height1, height2)
new_image = Image.new("RGB", (total_width, max_height))
new_image.paste(left, (0, 0))
new_image.paste(right, (width1, 0))
return new_image


def to_latent(
*,
rgb_nchw: torch.Tensor,
vae: SupportedAutoencoder,
) -> torch.Tensor:
rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore
encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw))
if isinstance(encoding_nchw, AutoencoderKLOutput):
latent = encoding_nchw.latent_dist.sample() # type: ignore
assert isinstance(latent, torch.Tensor)
elif isinstance(encoding_nchw, AutoencoderTinyOutput):
latent = encoding_nchw.latents
do_internal_vae_scaling = False # Is this needed?
if do_internal_vae_scaling:
latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore
latent = vae.unscale_latents(latent / 255.0) # type: ignore
assert isinstance(latent, torch.Tensor)
else:
assert False, f"Unknown encoding type: {type(encoding_nchw)}"
return latent


def from_latent(
*,
latent_nchw: torch.Tensor,
vae: SupportedAutoencoder,
) -> torch.Tensor:
decoding_nchw = vae.decode(latent_nchw) # type: ignore
assert isinstance(decoding_nchw, DecoderOutput)
rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore
assert isinstance(rgb_nchw, torch.Tensor)
return rgb_nchw


def main_kwargs(
*,
device: torch.device,
input_image_path: str,
pretrained_model_name_or_path: str,
revision: Optional[str],
variant: Optional[str],
subfolder: Optional[str],
use_tiny_nn: bool,
) -> None:
vae = load_vae_model(
device=device,
model_name_or_path=pretrained_model_name_or_path,
revision=revision,
variant=variant,
subfolder=subfolder,
use_tiny_nn=use_tiny_nn,
)
original_pil = Image.open(input_image_path).convert("RGB")
original_image = pil_to_nhwc(
device=device,
image=original_pil,
)
print(f"Original image shape: {original_image.shape}")
reconstructed_image: Optional[torch.Tensor] = None

with torch.no_grad():
latent_image = to_latent(rgb_nchw=original_image, vae=vae)
print(f"Latent shape: {latent_image.shape}")
reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae)
reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image)
combined_image = concatenate_images(
left=original_pil,
right=reconstructed_pil,
vertical=False,
)
combined_image.show("Original | Reconstruction")
print(f"Reconstructed image shape: {reconstructed_image.shape}")


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Inference with VAE")
parser.add_argument(
"--input_image",
type=str,
required=True,
help="Path to the input image for inference.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
required=True,
help="Path to pretrained VAE model.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Model version.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Model file variant, e.g., 'fp16'.",
)
parser.add_argument(
"--subfolder",
type=str,
default=None,
help="Subfolder in the model file.",
)
parser.add_argument(
"--use_cuda",
action="store_true",
help="Use CUDA if available.",
)
parser.add_argument(
"--use_tiny_nn",
action="store_true",
help="Use tiny neural network.",
)
return parser.parse_args()


# EXAMPLE USAGE:
#
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
#
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
#
def main_cli() -> None:
args = parse_args()

input_image_path = args.input_image
assert isinstance(input_image_path, str)

pretrained_model_name_or_path = args.pretrained_model_name_or_path
assert isinstance(pretrained_model_name_or_path, str)

revision = args.revision
assert isinstance(revision, (str, type(None)))

variant = args.variant
assert isinstance(variant, (str, type(None)))

subfolder = args.subfolder
assert isinstance(subfolder, (str, type(None)))

use_cuda = args.use_cuda
assert isinstance(use_cuda, bool)

use_tiny_nn = args.use_tiny_nn
assert isinstance(use_tiny_nn, bool)

device = torch.device("cuda" if use_cuda else "cpu")

main_kwargs(
device=device,
input_image_path=input_image_path,
pretrained_model_name_or_path=pretrained_model_name_or_path,
revision=revision,
variant=variant,
subfolder=subfolder,
use_tiny_nn=use_tiny_nn,
)


if __name__ == "__main__":
main_cli()

0 comments on commit 2e2684f

Please sign in to comment.