From df87a0fe2c018f036c20ecb1d0aef3354fdc5365 Mon Sep 17 00:00:00 2001 From: JinZr Date: Wed, 9 Oct 2024 14:12:41 +0800 Subject: [PATCH] updated train.py --- egs/libritts/CODEC/encodec/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 11f352911f..934d480f59 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -527,6 +527,7 @@ def save_bad_model(suffix: str = ""): + params.lambda_feat * feature_loss + params.lambda_com * commit_loss ) + loss_info["generator_loss"] = gen_loss for k, v in stats_g.items(): if "returned_sample" not in k: loss_info[k] = v * batch_size @@ -737,6 +738,7 @@ def compute_validation_loss( + disc_scale_fake_adv_loss ) * d_weight assert disc_loss.requires_grad is False + loss_info["discriminator_loss"] = disc_loss for k, v in stats_d.items(): loss_info[k] = v * batch_size @@ -778,6 +780,7 @@ def compute_validation_loss( + params.lambda_com * commit_loss ) assert gen_loss.requires_grad is False + loss_info["generator_loss"] = gen_loss for k, v in stats_g.items(): if "returned_sample" not in k: loss_info[k] = v * batch_size