From 090849619cf314ce3f39826c7c6e6f120fb0b440 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Thu, 17 Nov 2022 01:20:25 +0800 Subject: [PATCH] update weights loader --- hifigan/utils.py | 21 +++++++++++++++++++-- save_state.py | 14 +++++++++----- train.py | 6 +++++- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/hifigan/utils.py b/hifigan/utils.py index cb54b9d..92239da 100644 --- a/hifigan/utils.py +++ b/hifigan/utils.py @@ -3,7 +3,7 @@ import sys import torch import torchaudio -from typing import List, Tuple +from typing import Any, Dict, List, Tuple MATPLOTLIB_FLAG = False @@ -55,4 +55,21 @@ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios= for k, v in images.items(): writer.add_image(k, v, global_step, dataformats='HWC') for k, v in audios.items(): - writer.add_audio(k, v, global_step, audio_sampling_rate) \ No newline at end of file + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def load_state_dict(model_state_dict, state_dict: Dict[str, Any]) -> None: + is_changed = False + for k in state_dict: + if k in model_state_dict: + if state_dict[k].shape != model_state_dict[k].shape: + logging.info(f"Skip loading parameter: {k}, " + f"required shape: {model_state_dict[k].shape}, " + f"loaded shape: {state_dict[k].shape}") + # state_dict[k] = model_state_dict[k] + del state_dict[k] + is_changed = True + else: + logging.info(f"Dropping parameter {k}") + is_changed = True + return state_dict \ No newline at end of file diff --git a/save_state.py b/save_state.py index ddfc8b7..6bb0904 100644 --- a/save_state.py +++ b/save_state.py @@ -2,6 +2,14 @@ import glob import torch from hifigan.model.hifigan import HifiGAN + +def save(ckpt_path: str): + model = HifiGAN.load_from_checkpoint(checkpoint_path=ckpt_path, strict=True) + # print(model.net_g.state_dict()) + torch.save(model.net_g.state_dict(), "net_g.pt") + torch.save(model.net_period_d.state_dict(), "net_period_d.pt") + torch.save(model.net_scale_d.state_dict(), "net_scale_d.pt") + def main(): ckpt_path = None if os.path.exists("logs/lightning_logs"): @@ -13,10 +21,6 @@ def main(): ckpt_path = last_ckpt print(ckpt_path) - - model = HifiGAN.load_from_checkpoint(checkpoint_path=ckpt_path, strict=False) - - torch.save(model.net_g.state_dict(), "out.pt") - + save("logs/lightning_logs_v1/version_8/checkpoints/last.ckpt") if __name__ == "__main__": main() \ No newline at end of file diff --git a/train.py b/train.py index aa8f85f..37e5fc8 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ from hifigan.hparams import HParams from hifigan.data.dataset import MelDataset, MelDataset +from hifigan.utils import load_state_dict def get_hparams(config_path: str) -> HParams: with open(config_path, "r") as f: @@ -47,7 +48,10 @@ def main(): valid_loader = DataLoader(valid_dataset, batch_size=1, num_workers=16, shuffle=False, pin_memory=True, collate_fn=collate_fn) model = HifiGAN(**hparams) - # model = HifiGAN.load_from_checkpoint(checkpoint_path="logs/lightning_logs/version_10/checkpoints/last.ckpt", strict=False) + # model = HifiGAN.load_from_checkpoint(checkpoint_path="logs/lightning_logs_v1/version_8/checkpoints/last.ckpt", strict=True) + # model.net_g._load_from_state_dict(torch.load("net_g.pt"), strict=False) + # model.net_period_d._load_from_state_dict(torch.load("net_period_d.pt"), strict=False) + # model.net_scale_d._load_from_state_dict(torch.load("net_scale_d.pt"), strict=False) checkpoint_callback = ModelCheckpoint(dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False)