From 2356621059184b53fc10c338f8a46a17d923fdd4 Mon Sep 17 00:00:00 2001 From: JinZr Date: Wed, 9 Oct 2024 14:04:21 +0800 Subject: [PATCH] minor updates --- egs/libritts/CODEC/encodec/encodec.py | 31 +---- egs/libritts/CODEC/encodec/loss.py | 168 -------------------------- egs/libritts/CODEC/encodec/train.py | 25 ++-- 3 files changed, 19 insertions(+), 205 deletions(-) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index aa0373bfab..e1b646d725 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -157,27 +157,7 @@ def _forward_generator( x=speech, x_hat=speech_hat ) - # loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g( - # commit_loss, - # speech, - # speech_hat, - # fmap, - # fmap_hat, - # y, - # y_hat, - # y_p, - # y_p_hat, - # y_s, - # y_s_hat, - # fmap_p, - # fmap_p_hat, - # fmap_s, - # fmap_s_hat, - # args=self.params, - # ) - stats = dict( - # generator_loss=loss.item(), generator_wav_reconstruction_loss=wav_reconstruction_loss.item(), generator_mel_reconstruction_loss=mel_reconstruction_loss.item(), generator_feature_stft_loss=feature_stft_loss.item(), @@ -187,7 +167,6 @@ def _forward_generator( generator_period_adv_loss=gen_period_adv_loss.item(), generator_scale_adv_loss=gen_scale_adv_loss.item(), generator_commit_loss=commit_loss.item(), - # d_weight=d_weight.item(), ) if return_sample: @@ -260,18 +239,16 @@ def _forward_discriminator( speech_hat.contiguous().detach() ) - disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor( - 0.0 - ), torch.tensor(0.0) + disc_period_real_adv_loss = torch.tensor(0.0) + disc_period_fake_adv_loss = torch.tensor(0.0) if self.multi_period_discriminator is not None: y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( speech.contiguous(), speech_hat.contiguous().detach(), ) - disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor( - 0.0 - ), torch.tensor(0.0) + disc_scale_real_adv_loss = torch.tensor(0.0) + disc_scale_fake_adv_loss = torch.tensor(0.0) if self.multi_scale_discriminator is not None: y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( speech.contiguous(), diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index ae1e34bddf..8ec80bb9c9 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -317,171 +317,3 @@ def forward( wav_loss = F.l1_loss(x, x_hat) return wav_loss - - -def adversarial_g_loss(y_disc_gen): - """Hinge loss""" - loss = 0.0 - for i in range(len(y_disc_gen)): - stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze() - loss += stft_loss - return loss / len(y_disc_gen) - - -def feature_loss(fmap_r, fmap_gen): - loss = 0.0 - for i in range(len(fmap_r)): - for j in range(len(fmap_r[i])): - stft_loss = ( - (fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean()) - ).mean() - loss += stft_loss - return loss / (len(fmap_r) * len(fmap_r[0])) - - -def sim_loss(y_disc_r, y_disc_gen): - loss = 0.0 - for i in range(len(y_disc_r)): - loss += F.mse_loss(y_disc_r[i], y_disc_gen[i]) - return loss / len(y_disc_r) - - -def reconstruction_loss(x, x_hat, args, eps=1e-7): - # NOTE (lsx): hard-coded now - L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss - # loss_sisnr = sisnr_loss(G_x, x) # - # L += 0.01*loss_sisnr - # 2^6=64 -> 2^10=1024 - # NOTE (lsx): add 2^11 - for i in range(6, 12): - # for i in range(5, 12): # Encodec setting - s = 2**i - melspec = MelSpectrogram( - sample_rate=args.sampling_rate, - n_fft=max(s, 512), - win_length=s, - hop_length=s // 4, - n_mels=64, - wkwargs={"device": x_hat.device}, - ).to(x_hat.device) - S_x = melspec(x) - S_x_hat = melspec(x_hat) - l1_loss = (S_x - S_x_hat).abs().mean() - l2_loss = ( - ((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean( - dim=-2 - ) - ** 0.5 - ).mean() - - alpha = (s / 2) ** 0.5 - L += l1_loss + alpha * l2_loss - return L - - -def adopt_weight(weight, global_step, threshold=0, value=0.0): - if global_step < threshold: - weight = value - return weight - - -def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - print("last_layer cannot be none") - assert 1 == 2 - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 1.0, 1.0).detach() - d_weight = d_weight * args.lambda_adv - return d_weight - - -def loss_g( - codebook_loss, - speech, - speech_hat, - fmap, - fmap_hat, - y, - y_hat, - y_df, - y_df_hat, - y_ds, - y_ds_hat, - fmap_f, - fmap_f_hat, - fmap_s, - fmap_s_hat, - args=None, -): - """ - args: - codebook_loss: commit loss. - speech: ground-truth wav. - speech_hat: reconstructed wav. - fmap: real stft-D feature map. - fmap_hat: fake stft-D feature map. - y: real stft-D logits. - y_hat: fake stft-D logits. - global_step: global training step. - y_df: real MPD logits. - y_df_hat: fake MPD logits. - y_ds: real MSD logits. - y_ds_hat: fake MSD logits. - fmap_f: real MPD feature map. - fmap_f_hat: fake MPD feature map. - fmap_s: real MSD feature map. - fmap_s_hat: fake MSD feature map. - """ - rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args) - adv_g_loss = adversarial_g_loss(y_hat) - adv_mpd_loss = adversarial_g_loss(y_df_hat) - adv_msd_loss = adversarial_g_loss(y_ds_hat) - adv_loss = ( - adv_g_loss + adv_mpd_loss + adv_msd_loss - ) / 3.0 # NOTE(lsx): need to divide by 3? - feat_loss = feature_loss( - fmap, fmap_hat - ) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits? - feat_loss_mpd = feature_loss( - fmap_f, fmap_f_hat - ) # + sim_loss(y_df_hat_r, y_df_hat_g) - feat_loss_msd = feature_loss( - fmap_s, fmap_s_hat - ) # + sim_loss(y_ds_hat_r, y_ds_hat_g) - feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 - d_weight = torch.tensor(1.0) - - # disc_factor = adopt_weight( - # args.lambda_adv, global_step, threshold=args.discriminator_iter_start - # ) - disc_factor = 1 - if disc_factor == 0.0: - fm_loss_wt = 0 - else: - fm_loss_wt = args.lambda_feat - - loss = ( - rec_loss - + d_weight * disc_factor * adv_loss - + fm_loss_wt * feat_loss_tot - + args.lambda_com * codebook_loss - ) - return loss, rec_loss, adv_loss, feat_loss_tot, d_weight - - -if __name__ == "__main__": - # la = FeatureLoss(average_by_layers=True, average_by_discriminators=True) - # aa = [torch.rand(192, 192) for _ in range(3)] - # bb = [torch.rand(192, 192) for _ in range(3)] - # print(la(bb, aa)) - # print(feature_loss(aa, bb)) - la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge") - aa = torch.Tensor([0.1, 0.2, 0.3, 0.4]) - bb = torch.Tensor([0.4, 0.3, 0.2, 0.1]) - print(la(aa)) - print(adversarial_g_loss(aa)) - print(la(bb)) - print(adversarial_g_loss(bb)) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 8475ab6e86..11f352911f 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -14,7 +14,6 @@ from codec_datamodule import LibriTTSCodecDataModule from encodec import Encodec from lhotse.utils import fix_random_seed -from loss import adopt_weight from scheduler import WarmupCosineLrScheduler from torch import nn from torch.cuda.amp import GradScaler, autocast @@ -189,10 +188,10 @@ def get_params() -> AttributeDict: "audio_normalization": False, "chunk_size": 1.0, # in seconds "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss - "lambda_wav": 1.0, # loss scaling coefficient for waveform loss - "lambda_feat": 3.0, # loss scaling coefficient for feat loss + "lambda_wav": 0.1, # loss scaling coefficient for waveform loss + "lambda_feat": 4.0, # loss scaling coefficient for feat loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss - "lambda_com": 100.0, # loss scaling coefficient for commitment loss + "lambda_com": 1.0, # loss scaling coefficient for commitment loss } ) @@ -361,6 +360,12 @@ def prepare_input( return audio, audio_lens, features, features_lens +def train_discriminator(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], @@ -447,7 +452,7 @@ def save_bad_model(suffix: str = ""): try: with autocast(enabled=params.use_fp16): - d_weight = adopt_weight( + d_weight = train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_epoch_start, @@ -483,7 +488,7 @@ def save_bad_model(suffix: str = ""): scaler.step(optimizer_d) with autocast(enabled=params.use_fp16): - g_weight = adopt_weight( + g_weight = train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_epoch_start, @@ -702,7 +707,7 @@ def compute_validation_loss( loss_info = MetricsTracker() loss_info["samples"] = batch_size - d_weight = adopt_weight( + d_weight = train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_epoch_start, @@ -735,7 +740,7 @@ def compute_validation_loss( for k, v in stats_d.items(): loss_info[k] = v * batch_size - g_weight = adopt_weight( + g_weight = train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_epoch_start, @@ -845,7 +850,7 @@ def scan_pessimistic_batches_for_oom( + disc_period_fake_adv_loss + disc_scale_real_adv_loss + disc_scale_fake_adv_loss - ) * adopt_weight( + ) * train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_train_start, @@ -873,7 +878,7 @@ def scan_pessimistic_batches_for_oom( ) loss_g = ( (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) - * adopt_weight( + * train_discriminator( params.lambda_adv, 0, threshold=params.discriminator_epoch_start,