Skip to content

Commit

Permalink
Min-SNR Gamma: follow-up fix for zero-terminal SNR models on v-predic…
Browse files Browse the repository at this point in the history
…tion or epsilon (#5177)

* merge with main

* fix flax example

* fix onnx example

---------

Co-authored-by: bghira <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Sep 26, 2023
1 parent 89d8f84 commit 4a06c74
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 18 deletions.
13 changes: 10 additions & 3 deletions examples/controlnet/train_controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,10 +907,17 @@ def compute_loss(params, minibatch, sample_rng):

if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps))
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
snr_loss_weights = snr_loss_weights + 1
snr_loss_weights = base_weights + 1
else:
# Epsilon and sample prediction use the base weights.
snr_loss_weights = base_weights
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
snr_loss_weights[snr == 0] = 1.0

loss = loss * snr_loss_weights

loss = loss.mean()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,22 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)

if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,9 +654,22 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)

if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,9 +685,22 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)

if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down
15 changes: 14 additions & 1 deletion examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,9 +833,22 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)

if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -872,12 +872,21 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample prediction use the base weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down
16 changes: 13 additions & 3 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,12 +952,22 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)

if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down
16 changes: 13 additions & 3 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,12 +783,22 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)

if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down
16 changes: 13 additions & 3 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,12 +1072,22 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)

if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down
5 changes: 5 additions & 0 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
Expand Down

0 comments on commit 4a06c74

Please sign in to comment.