From 4a06c74547cd88440d95468698cbf6baf91fe506 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Tue, 26 Sep 2023 05:44:52 -0700 Subject: [PATCH] Min-SNR Gamma: follow-up fix for zero-terminal SNR models on v-prediction or epsilon (#5177) * merge with main * fix flax example * fix onnx example --------- Co-authored-by: bghira Co-authored-by: Sayak Paul --- examples/controlnet/train_controlnet_flax.py | 13 ++++++++++--- .../text_to_image/train_text_to_image_decoder.py | 15 ++++++++++++++- .../train_text_to_image_lora_decoder.py | 15 ++++++++++++++- .../train_text_to_image_lora_prior.py | 15 ++++++++++++++- .../text_to_image/train_text_to_image_prior.py | 15 ++++++++++++++- .../text_to_image/train_text_to_image.py | 13 +++++++++++-- examples/text_to_image/train_text_to_image.py | 16 +++++++++++++--- .../text_to_image/train_text_to_image_lora.py | 16 +++++++++++++--- .../train_text_to_image_lora_sdxl.py | 16 +++++++++++++--- .../text_to_image/train_text_to_image_sdxl.py | 5 +++++ 10 files changed, 121 insertions(+), 18 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index d04c616c57eb..34e8c69ff64b 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -907,10 +907,17 @@ def compute_loss(params, minibatch, sample_rng): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) - snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - snr_loss_weights = snr_loss_weights + 1 + snr_loss_weights = base_weights + 1 + else: + # Epsilon and sample prediction use the base weights. + snr_loss_weights = base_weights + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + snr_loss_weights[snr == 0] = 1.0 + loss = loss * snr_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index 364ed7e03189..affc26101ba2 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -801,9 +801,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 9d96a936d0ca..0a38c98f51c8 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -654,9 +654,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index e4aec111b8f7..aaa8792af08a 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -685,9 +685,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index d451e1bfe40d..38aa5eee8f8a 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -833,9 +833,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index b89de5e001c5..7e4a93dc0381 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -872,12 +872,21 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) if noise_scheduler.config.prediction_type == "v_prediction": # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = mse_loss_weights + 1 + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample prediction use the base weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 0d14e6ccd548..82201b22919a 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -952,12 +952,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = mse_loss_weights + 1 + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 5845bda0e54f..0d562bc59dec 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -783,12 +783,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = mse_loss_weights + 1 + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 7a8c2c353eb0..6b870a3ab5f2 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1072,12 +1072,22 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = mse_loss_weights + 1 + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index fd37301d8f05..1c579ef1fb11 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1100,6 +1100,11 @@ def compute_time_ids(original_size, crops_coords_top_left): # Epsilon and sample both use the same loss weights. mse_loss_weights = base_weight + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss.