From b5414eb92b2bfe6702e4516247adb3ea22c7a327 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Wed, 16 Nov 2022 23:02:41 +0800 Subject: [PATCH] update model params --- configs/48k.json | 6 +++--- hifigan/hub/__init__.py | 2 +- test.py | 4 ++-- train.py | 5 +++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/configs/48k.json b/configs/48k.json index e7cf95a..1f85321 100644 --- a/configs/48k.json +++ b/configs/48k.json @@ -11,10 +11,10 @@ "eval_interval": 1000, "seed": 1234, "max_epochs": 20000, - "learning_rate": 2e-4, + "learning_rate": 8e-4, "betas": [0.8, 0.99], "eps": 1e-9, - "batch_size": 16, + "batch_size": 32, "fp16_run": true, "lr_decay": 0.999875, "segment_size": 16384, @@ -49,7 +49,7 @@ ], "upsample_rates": [8,8,4,2], "upsample_initial_channel": 512, - "upsample_kernel_sizes": [16,16,4,4], + "upsample_kernel_sizes": [16,16,8,4], "use_spectral_norm": false } } \ No newline at end of file diff --git a/hifigan/hub/__init__.py b/hifigan/hub/__init__.py index d77b83c..7c9ccf7 100644 --- a/hifigan/hub/__init__.py +++ b/hifigan/hub/__init__.py @@ -19,7 +19,7 @@ def hifigan_48k( ], upsample_rates=[8,8,4,2], upsample_initial_channel=512, - upsample_kernel_sizes=[16,16,4,4] + upsample_kernel_sizes=[16,16,8,4] ) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( diff --git a/test.py b/test.py index 2369783..b990595 100644 --- a/test.py +++ b/test.py @@ -9,7 +9,7 @@ def load_local(): if os.path.exists("logs/lightning_logs"): versions = glob.glob("logs/lightning_logs/version_*") if len(list(versions)) > 0: - last_ver = sorted(list(versions))[-1] + last_ver = sorted(list(versions), key=lambda p: int(p.split("_")[-1]))[-1] last_ckpt = os.path.join(last_ver, "checkpoints/last.ckpt") if os.path.exists(last_ckpt): ckpt_path = last_ckpt @@ -32,7 +32,7 @@ def load_remote(): hifigan = load_local().to(device) # Load audio -wav, sr = torchaudio.load("test.wav") +wav, sr = torchaudio.load("zszy_48k.wav") assert sr == 48000 mel = mel_spectrogram_torch(wav, 2048, 256, 48000, 512, 2048, 0, None, False) diff --git a/train.py b/train.py index de6dc85..aa8f85f 100644 --- a/train.py +++ b/train.py @@ -47,8 +47,9 @@ def main(): valid_loader = DataLoader(valid_dataset, batch_size=1, num_workers=16, shuffle=False, pin_memory=True, collate_fn=collate_fn) model = HifiGAN(**hparams) + # model = HifiGAN.load_from_checkpoint(checkpoint_path="logs/lightning_logs/version_10/checkpoints/last.ckpt", strict=False) - checkpoint_callback = ModelCheckpoint(dirpath=None, save_last=True, every_n_train_steps=2000) + checkpoint_callback = ModelCheckpoint(dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False) devices = [int(n.strip()) for n in args.device.split(",")] trainer_params = { @@ -76,7 +77,7 @@ def main(): if os.path.exists("logs/lightning_logs"): versions = glob.glob("logs/lightning_logs/version_*") if len(list(versions)) > 0: - last_ver = sorted(list(versions))[-1] + last_ver = sorted(list(versions), key=lambda p: int(p.split("_")[-1]))[-1] last_ckpt = os.path.join(last_ver, "checkpoints/last.ckpt") if os.path.exists(last_ckpt): ckpt_path = last_ckpt