diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 22486298c984..d37621e50fcf 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -332,15 +332,6 @@ def parse_args(input_args=None): help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " "More details here: https://arxiv.org/abs/2303.09556.", ) - parser.add_argument( - "--force_snr_gamma", - action="store_true", - help=( - "When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN" - " condition when computing the SNR with a sigma value of zero. This parameter overrides the check," - " allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results." - ), - ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument( "--allow_tf32", @@ -554,18 +545,6 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Check for terminal SNR in combination with SNR Gamma - if ( - args.snr_gamma - and not args.force_snr_gamma - and ( - hasattr(noise_scheduler.config, "rescale_betas_zero_snr") and noise_scheduler.config.rescale_betas_zero_snr - ) - ): - raise ValueError( - f"The selected noise scheduler for the model {args.pretrained_model_name_or_path} uses rescaled betas for zero SNR.\n" - "When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n" - "This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero." - ) text_encoder_one = text_encoder_cls_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) @@ -1013,9 +992,17 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss.