Skip to content

Commit

Permalink
v0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Nov 22, 2022
1 parent 2900a6d commit a76cbc1
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class AudioPipeline(torch.nn.Module):

device = "cpu"

hifigan = torch.hub.load("vtuber-plan/hifi-gan:v0.2.1", "hifigan_48k", force_reload=True).to(device)
hifigan = torch.hub.load("vtuber-plan/hifi-gan:v0.3.0", "hifigan_48k", force_reload=True).to(device)

# Load audio
wav, sr = torchaudio.load("test.wav")
Expand Down
2 changes: 1 addition & 1 deletion hifigan/hub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
CKPT_URLS = {
"hifigan-48k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.2.1/hifigan-48k-C8FDBD55FE7700384955A6EC41AF1D84.pt",
"hifigan-48k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.0/hifigan-48k-B67C217083569F978E07EFD1AD7B1766.pt",
}
import torch
from ..model.generators.generator import Generator
Expand Down
9 changes: 4 additions & 5 deletions hifigan/model/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ def __init__(self, **kwargs):
n_fft=self.hparams.data.filter_length,
n_mel=self.hparams.data.n_mel_channels,
win_length=self.hparams.data.win_length,
hop_length=self.hparams.data.hop_length,
aug=True)
hop_length=self.hparams.data.hop_length)
for param in self.audio_pipeline.parameters():
param.requires_grad = False

# metrics
self.valid_mel_loss = torchmetrics.MeanMetric()

Expand All @@ -61,7 +60,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, optimize
y_wav, y_wav_lengths = batch["y_wav_values"], batch["y_wav_lengths"]

with torch.inference_mode():
x_mel = self.audio_pipeline(x_wav.squeeze(1))
x_mel = self.audio_pipeline(x_wav.squeeze(1), aug=True)
x_mel_lengths = (x_wav_lengths / self.hparams.data.hop_length).long()

x_mel, ids_slice = rand_slice_segments(x_mel, x_mel_lengths, self.hparams.train.segment_size // self.hparams.data.hop_length)
Expand Down Expand Up @@ -175,7 +174,7 @@ def validation_step(self, batch, batch_idx):
y_wav, y_wav_lengths = batch["y_wav_values"], batch["y_wav_lengths"]

with torch.inference_mode():
x_mel = self.audio_pipeline(x_wav.squeeze(1))
x_mel = self.audio_pipeline(x_wav.squeeze(1), aug=False)
x_mel_lengths = (x_wav_lengths / self.hparams.data.hop_length).long()

y_spec = spectrogram_torch_audio(y_wav.squeeze(1),
Expand Down
10 changes: 4 additions & 6 deletions hifigan/model/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@ def __init__(
n_fft=1024,
n_mel=128,
win_length=1024,
hop_length=256,
aug=False
hop_length=256
):
super().__init__()

self.freq=freq
self.aug=aug

pad = int((n_fft-hop_length)/2)
self.spec = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length,
Expand All @@ -32,18 +30,18 @@ def __init__(
# self.strech = T.TimeStretch(hop_length=hop_length, n_freq=freq)
self.spec_aug = torch.nn.Sequential(
T.FrequencyMasking(freq_mask_param=80),
T.TimeMasking(time_mask_param=80),
# T.TimeMasking(time_mask_param=80),
)

self.mel_scale = T.MelScale(n_mels=n_mel, sample_rate=freq, n_stft=n_fft // 2 + 1)

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
def forward(self, waveform: torch.Tensor, aug: bool=False) -> torch.Tensor:
shift_waveform = waveform
# Convert to power spectrogram
spec = self.spec(shift_waveform)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
# Apply SpecAugment
if self.aug:
if aug:
spec = self.spec_aug(spec)
# Convert to mel-scale
mel = self.mel_scale(spec)
Expand Down
9 changes: 6 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ def main():

devices = [int(n.strip()) for n in args.device.split(",")]

checkpoint_callback = ModelCheckpoint(dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False)
earlystop_callback = EarlyStopping(monitor="valid/loss_mel_epoch", mode="min", patience=7)
checkpoint_callback = ModelCheckpoint(
dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False,
monitor="valid/loss_mel_epoch", mode="min", save_top_k=3
)
earlystop_callback = EarlyStopping(monitor="valid/loss_mel_epoch", mode="min", patience=3)

trainer_params = {
"accelerator": args.accelerator,
Expand Down Expand Up @@ -84,7 +87,7 @@ def main():
batch_per_gpu = hparams.train.batch_size // len(devices)
else:
batch_per_gpu = hparams.train.batch_size
train_loader = DataLoader(train_dataset, batch_size=batch_per_gpu, num_workers=4, shuffle=True, pin_memory=True, collate_fn=collate_fn)
train_loader = DataLoader(train_dataset, batch_size=batch_per_gpu, num_workers=8, shuffle=True, pin_memory=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=4, num_workers=4, shuffle=False, pin_memory=True, collate_fn=collate_fn)

# model
Expand Down

0 comments on commit a76cbc1

Please sign in to comment.