Skip to content

Commit 24563ca

Browse files
bghirabghira
and
bghira
authored
SNR gamma fixes for v_prediction training (#5106)
Co-authored-by: bghira <[email protected]>
1 parent 914586f commit 24563ca

File tree

5 files changed

+15
-0
lines changed

5 files changed

+15
-0
lines changed

examples/controlnet/train_controlnet_flax.py

+3
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,9 @@ def compute_loss(params, minibatch, sample_rng):
908908
if args.snr_gamma is not None:
909909
snr = jnp.array(compute_snr(timesteps))
910910
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
911+
if noise_scheduler.config.prediction_type == "v_prediction":
912+
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
913+
snr_loss_weights = snr_loss_weights + 1
911914
loss = loss * snr_loss_weights
912915

913916
loss = loss.mean()

examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

+3
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,9 @@ def collate_fn(examples):
875875
mse_loss_weights = (
876876
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
877877
)
878+
if noise_scheduler.config.prediction_type == "v_prediction":
879+
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
880+
mse_loss_weights = mse_loss_weights + 1
878881
# We first calculate the original loss. Then we mean over the non-batch dimensions and
879882
# rebalance the sample-wise losses with their respective loss weights.
880883
# Finally, we take the mean of the rebalanced loss.

examples/text_to_image/train_text_to_image.py

+3
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,9 @@ def collate_fn(examples):
955955
mse_loss_weights = (
956956
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
957957
)
958+
if noise_scheduler.config.prediction_type == "v_prediction":
959+
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
960+
mse_loss_weights = mse_loss_weights + 1
958961
# We first calculate the original loss. Then we mean over the non-batch dimensions and
959962
# rebalance the sample-wise losses with their respective loss weights.
960963
# Finally, we take the mean of the rebalanced loss.

examples/text_to_image/train_text_to_image_lora.py

+3
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,9 @@ def collate_fn(examples):
786786
mse_loss_weights = (
787787
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
788788
)
789+
if noise_scheduler.config.prediction_type == "v_prediction":
790+
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
791+
mse_loss_weights = mse_loss_weights + 1
789792
# We first calculate the original loss. Then we mean over the non-batch dimensions and
790793
# rebalance the sample-wise losses with their respective loss weights.
791794
# Finally, we take the mean of the rebalanced loss.

examples/text_to_image/train_text_to_image_lora_sdxl.py

+3
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
10751075
mse_loss_weights = (
10761076
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
10771077
)
1078+
if noise_scheduler.config.prediction_type == "v_prediction":
1079+
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
1080+
mse_loss_weights = mse_loss_weights + 1
10781081
# We first calculate the original loss. Then we mean over the non-batch dimensions and
10791082
# rebalance the sample-wise losses with their respective loss weights.
10801083
# Finally, we take the mean of the rebalanced loss.

0 commit comments

Comments
 (0)