Skip to content

Commit

Permalink
update weights loader
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Nov 16, 2022
1 parent b5414eb commit 0908496
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
21 changes: 19 additions & 2 deletions hifigan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import torch
import torchaudio
from typing import List, Tuple
from typing import Any, Dict, List, Tuple

MATPLOTLIB_FLAG = False

Expand Down Expand Up @@ -55,4 +55,21 @@ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios=
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats='HWC')
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
writer.add_audio(k, v, global_step, audio_sampling_rate)


def load_state_dict(model_state_dict, state_dict: Dict[str, Any]) -> None:
is_changed = False
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
logging.info(f"Skip loading parameter: {k}, "
f"required shape: {model_state_dict[k].shape}, "
f"loaded shape: {state_dict[k].shape}")
# state_dict[k] = model_state_dict[k]
del state_dict[k]
is_changed = True
else:
logging.info(f"Dropping parameter {k}")
is_changed = True
return state_dict
14 changes: 9 additions & 5 deletions save_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
import glob
import torch
from hifigan.model.hifigan import HifiGAN

def save(ckpt_path: str):
model = HifiGAN.load_from_checkpoint(checkpoint_path=ckpt_path, strict=True)
# print(model.net_g.state_dict())
torch.save(model.net_g.state_dict(), "net_g.pt")
torch.save(model.net_period_d.state_dict(), "net_period_d.pt")
torch.save(model.net_scale_d.state_dict(), "net_scale_d.pt")

def main():
ckpt_path = None
if os.path.exists("logs/lightning_logs"):
Expand All @@ -13,10 +21,6 @@ def main():
ckpt_path = last_ckpt

print(ckpt_path)

model = HifiGAN.load_from_checkpoint(checkpoint_path=ckpt_path, strict=False)

torch.save(model.net_g.state_dict(), "out.pt")

save("logs/lightning_logs_v1/version_8/checkpoints/last.ckpt")
if __name__ == "__main__":
main()
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from hifigan.hparams import HParams
from hifigan.data.dataset import MelDataset, MelDataset
from hifigan.utils import load_state_dict

def get_hparams(config_path: str) -> HParams:
with open(config_path, "r") as f:
Expand Down Expand Up @@ -47,7 +48,10 @@ def main():
valid_loader = DataLoader(valid_dataset, batch_size=1, num_workers=16, shuffle=False, pin_memory=True, collate_fn=collate_fn)

model = HifiGAN(**hparams)
# model = HifiGAN.load_from_checkpoint(checkpoint_path="logs/lightning_logs/version_10/checkpoints/last.ckpt", strict=False)
# model = HifiGAN.load_from_checkpoint(checkpoint_path="logs/lightning_logs_v1/version_8/checkpoints/last.ckpt", strict=True)
# model.net_g._load_from_state_dict(torch.load("net_g.pt"), strict=False)
# model.net_period_d._load_from_state_dict(torch.load("net_period_d.pt"), strict=False)
# model.net_scale_d._load_from_state_dict(torch.load("net_scale_d.pt"), strict=False)

checkpoint_callback = ModelCheckpoint(dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False)

Expand Down

0 comments on commit 0908496

Please sign in to comment.