Skip to content

Commit

Permalink
make scaling factor a config arg of vae/vqvae (huggingface#1860)
Browse files Browse the repository at this point in the history
* make scaling factor cnfig arg of vae

* fix

* make flake happy

* fix ldm

* fix upscaler

* qualirty

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* solve conflicts, addres some comments

* examples

* examples min version

* doc

* fix type

* typo

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Co-authored-by: Pedro Cuenca <[email protected]>

* remove duplicate line

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

Co-authored-by: Anton Lozhkov <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
4 people authored Jan 26, 2023
1 parent 915a563 commit 1e216be
Show file tree
Hide file tree
Showing 39 changed files with 95 additions and 55 deletions.
4 changes: 2 additions & 2 deletions examples/community/clip_guided_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def cond_fn(
else:
raise ValueError(f"scheduler type {type(self.scheduler)} not supported")

sample = 1 / 0.18215 * sample
sample = 1 / self.vae.config.scaling_factor * sample
image = self.vae.decode(sample).sample
image = (image / 2 + 0.5).clamp(0, 1)

Expand Down Expand Up @@ -336,7 +336,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def compute_loss(params):
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng)
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def collate_fn(examples):
optimizer.zero_grad()

latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")

logger = get_logger(__name__)

Expand Down Expand Up @@ -699,13 +699,13 @@ def collate_fn(examples):
# Convert images to latent space

latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Convert masked images to latent space
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215
masked_latents = masked_latents * vae.config.scaling_factor

masks = batch["masks"]
# resize the mask to latents shape as we concatenate the mask to the latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")


logger = get_logger(__name__)
Expand Down Expand Up @@ -555,7 +555,7 @@ def main():
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")

logger = get_logger(__name__)

Expand Down Expand Up @@ -788,7 +788,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def collate_fn(examples):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def compute_loss(params):
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng)
Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def collate_fn(examples):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down
2 changes: 1 addition & 1 deletion examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def main():
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down
2 changes: 1 addition & 1 deletion examples/textual_inversion/textual_inversion_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def compute_loss(params):
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor

noise_rng, timestep_rng = jax.random.split(sample_rng)
noise = jax.random.normal(noise_rng, latents.shape)
Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""

@register_to_config
Expand All @@ -71,6 +78,7 @@ def __init__(
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
):
super().__init__()

Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/models/vae_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,15 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
Latent space channels
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
Norm num group
sample_size (:obj:`int`, *optional*, defaults to `32`):
sample_size (:obj:`int`, *optional*, defaults to 32):
Sample input size
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype`
"""
Expand All @@ -767,6 +774,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
latent_channels: int = 4
norm_num_groups: int = 32
sample_size: int = 32
scaling_factor: float = 0.18215
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/models/vq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class VQModel(ModelMixin, ConfigMixin):
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""

@register_to_config
Expand All @@ -74,6 +81,7 @@ def __init__(
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215,
):
super().__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def run_safety_checker(self, image, device, dtype):
return image, has_nsfw_concept

def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def run_safety_checker(self, image, device, dtype):
return image, has_nsfw_concept

def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
Expand Down Expand Up @@ -490,7 +490,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)

init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents

if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __call__(
input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
generator=generator
)[0]
input_images = 0.18215 * input_images
input_images = self.vqvae.config.scaling_factor * input_images

if start_step > 0:
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
Expand Down Expand Up @@ -195,7 +195,7 @@ def __call__(

if self.vqvae is not None:
# 0.18215 was scaling factor used in training to ensure unit variance
images = 1 / 0.18215 * images
images = 1 / self.vqvae.config.scaling_factor * images
images = self.vqvae.decode(images)["sample"]

images = (images / 2 + 0.5).clamp(0, 1)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/dit/pipeline_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __call__(
else:
latents = latent_model_input

latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample

samples = (samples / 2 + 0.5).clamp(0, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vqvae.config.scaling_factor * latents
image = self.vqvae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def prepare_extra_step_kwargs(self, generator, eta):

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
Expand Down Expand Up @@ -328,7 +328,7 @@ def prepare_mask_latents(
masked_image_latents = torch.cat(masked_image_latents, dim=0)
else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents

# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def run_safety_checker(self, image, device, dtype):

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
Expand Down Expand Up @@ -509,7 +509,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)

init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents

if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def loop_body(step, args):
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample

image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _generate(
# Create init_latents
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents

def loop_body(step, args):
latents, scheduler_state = args
Expand Down Expand Up @@ -272,7 +272,7 @@ def loop_body(step, args):
latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state))

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample

image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _generate(
{"params": params["vae"]}, masked_image, method=self.vae.encode
).latent_dist
masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2))
masked_image_latents = 0.18215 * masked_image_latents
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
del mask_prng_seed

mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest")
Expand Down Expand Up @@ -327,7 +327,7 @@ def loop_body(step, args):
)

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample

image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def run_safety_checker(self, image, device, dtype):
return image, has_nsfw_concept

def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
Expand Down
Loading

0 comments on commit 1e216be

Please sign in to comment.