Skip to content

Commit

Permalink
update model params
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Nov 16, 2022
1 parent ac52d28 commit b5414eb
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
6 changes: 3 additions & 3 deletions configs/48k.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
"eval_interval": 1000,
"seed": 1234,
"max_epochs": 20000,
"learning_rate": 2e-4,
"learning_rate": 8e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 16,
"batch_size": 32,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 16384,
Expand Down Expand Up @@ -49,7 +49,7 @@
],
"upsample_rates": [8,8,4,2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16,4,4],
"upsample_kernel_sizes": [16,16,8,4],
"use_spectral_norm": false
}
}
2 changes: 1 addition & 1 deletion hifigan/hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def hifigan_48k(
],
upsample_rates=[8,8,4,2],
upsample_initial_channel=512,
upsample_kernel_sizes=[16,16,4,4]
upsample_kernel_sizes=[16,16,8,4]
)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
Expand Down
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def load_local():
if os.path.exists("logs/lightning_logs"):
versions = glob.glob("logs/lightning_logs/version_*")
if len(list(versions)) > 0:
last_ver = sorted(list(versions))[-1]
last_ver = sorted(list(versions), key=lambda p: int(p.split("_")[-1]))[-1]
last_ckpt = os.path.join(last_ver, "checkpoints/last.ckpt")
if os.path.exists(last_ckpt):
ckpt_path = last_ckpt
Expand All @@ -32,7 +32,7 @@ def load_remote():
hifigan = load_local().to(device)

# Load audio
wav, sr = torchaudio.load("test.wav")
wav, sr = torchaudio.load("zszy_48k.wav")
assert sr == 48000

mel = mel_spectrogram_torch(wav, 2048, 256, 48000, 512, 2048, 0, None, False)
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def main():
valid_loader = DataLoader(valid_dataset, batch_size=1, num_workers=16, shuffle=False, pin_memory=True, collate_fn=collate_fn)

model = HifiGAN(**hparams)
# model = HifiGAN.load_from_checkpoint(checkpoint_path="logs/lightning_logs/version_10/checkpoints/last.ckpt", strict=False)

checkpoint_callback = ModelCheckpoint(dirpath=None, save_last=True, every_n_train_steps=2000)
checkpoint_callback = ModelCheckpoint(dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False)

devices = [int(n.strip()) for n in args.device.split(",")]
trainer_params = {
Expand Down Expand Up @@ -76,7 +77,7 @@ def main():
if os.path.exists("logs/lightning_logs"):
versions = glob.glob("logs/lightning_logs/version_*")
if len(list(versions)) > 0:
last_ver = sorted(list(versions))[-1]
last_ver = sorted(list(versions), key=lambda p: int(p.split("_")[-1]))[-1]
last_ckpt = os.path.join(last_ver, "checkpoints/last.ckpt")
if os.path.exists(last_ckpt):
ckpt_path = last_ckpt
Expand Down

0 comments on commit b5414eb

Please sign in to comment.