diff --git a/README.md b/README.md index 54ade1d..b6e661e 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ class AudioPipeline(torch.nn.Module): device = "cpu" -hifigan = torch.hub.load("vtuber-plan/hifi-gan:v0.2.1", "hifigan_48k", force_reload=True).to(device) +hifigan = torch.hub.load("vtuber-plan/hifi-gan:v0.3.0", "hifigan_48k", force_reload=True).to(device) # Load audio wav, sr = torchaudio.load("test.wav") diff --git a/hifigan/hub/__init__.py b/hifigan/hub/__init__.py index 6d070d5..b2400e2 100644 --- a/hifigan/hub/__init__.py +++ b/hifigan/hub/__init__.py @@ -1,5 +1,5 @@ CKPT_URLS = { - "hifigan-48k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.2.1/hifigan-48k-C8FDBD55FE7700384955A6EC41AF1D84.pt", + "hifigan-48k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.0/hifigan-48k-B67C217083569F978E07EFD1AD7B1766.pt", } import torch from ..model.generators.generator import Generator diff --git a/hifigan/model/hifigan.py b/hifigan/model/hifigan.py index f2c27e8..64208d1 100644 --- a/hifigan/model/hifigan.py +++ b/hifigan/model/hifigan.py @@ -48,11 +48,10 @@ def __init__(self, **kwargs): n_fft=self.hparams.data.filter_length, n_mel=self.hparams.data.n_mel_channels, win_length=self.hparams.data.win_length, - hop_length=self.hparams.data.hop_length, - aug=True) + hop_length=self.hparams.data.hop_length) for param in self.audio_pipeline.parameters(): param.requires_grad = False - + # metrics self.valid_mel_loss = torchmetrics.MeanMetric() @@ -61,7 +60,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, optimize y_wav, y_wav_lengths = batch["y_wav_values"], batch["y_wav_lengths"] with torch.inference_mode(): - x_mel = self.audio_pipeline(x_wav.squeeze(1)) + x_mel = self.audio_pipeline(x_wav.squeeze(1), aug=True) x_mel_lengths = (x_wav_lengths / self.hparams.data.hop_length).long() x_mel, ids_slice = rand_slice_segments(x_mel, x_mel_lengths, self.hparams.train.segment_size // self.hparams.data.hop_length) @@ -175,7 +174,7 @@ def validation_step(self, batch, batch_idx): y_wav, y_wav_lengths = batch["y_wav_values"], batch["y_wav_lengths"] with torch.inference_mode(): - x_mel = self.audio_pipeline(x_wav.squeeze(1)) + x_mel = self.audio_pipeline(x_wav.squeeze(1), aug=False) x_mel_lengths = (x_wav_lengths / self.hparams.data.hop_length).long() y_spec = spectrogram_torch_audio(y_wav.squeeze(1), diff --git a/hifigan/model/pipeline.py b/hifigan/model/pipeline.py index 1ff0efe..ac2377e 100644 --- a/hifigan/model/pipeline.py +++ b/hifigan/model/pipeline.py @@ -17,13 +17,11 @@ def __init__( n_fft=1024, n_mel=128, win_length=1024, - hop_length=256, - aug=False + hop_length=256 ): super().__init__() self.freq=freq - self.aug=aug pad = int((n_fft-hop_length)/2) self.spec = T.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length, @@ -32,18 +30,18 @@ def __init__( # self.strech = T.TimeStretch(hop_length=hop_length, n_freq=freq) self.spec_aug = torch.nn.Sequential( T.FrequencyMasking(freq_mask_param=80), - T.TimeMasking(time_mask_param=80), + # T.TimeMasking(time_mask_param=80), ) self.mel_scale = T.MelScale(n_mels=n_mel, sample_rate=freq, n_stft=n_fft // 2 + 1) - def forward(self, waveform: torch.Tensor) -> torch.Tensor: + def forward(self, waveform: torch.Tensor, aug: bool=False) -> torch.Tensor: shift_waveform = waveform # Convert to power spectrogram spec = self.spec(shift_waveform) spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) # Apply SpecAugment - if self.aug: + if aug: spec = self.spec_aug(spec) # Convert to mel-scale mel = self.mel_scale(spec) diff --git a/train.py b/train.py index c7cb8ea..024ca2e 100644 --- a/train.py +++ b/train.py @@ -54,8 +54,11 @@ def main(): devices = [int(n.strip()) for n in args.device.split(",")] - checkpoint_callback = ModelCheckpoint(dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False) - earlystop_callback = EarlyStopping(monitor="valid/loss_mel_epoch", mode="min", patience=7) + checkpoint_callback = ModelCheckpoint( + dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False, + monitor="valid/loss_mel_epoch", mode="min", save_top_k=3 + ) + earlystop_callback = EarlyStopping(monitor="valid/loss_mel_epoch", mode="min", patience=3) trainer_params = { "accelerator": args.accelerator, @@ -84,7 +87,7 @@ def main(): batch_per_gpu = hparams.train.batch_size // len(devices) else: batch_per_gpu = hparams.train.batch_size - train_loader = DataLoader(train_dataset, batch_size=batch_per_gpu, num_workers=4, shuffle=True, pin_memory=True, collate_fn=collate_fn) + train_loader = DataLoader(train_dataset, batch_size=batch_per_gpu, num_workers=8, shuffle=True, pin_memory=True, collate_fn=collate_fn) valid_loader = DataLoader(valid_dataset, batch_size=4, num_workers=4, shuffle=False, pin_memory=True, collate_fn=collate_fn) # model