diff --git a/amt/audio.py b/amt/audio.py index 7bb2a47..d0c1a16 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -14,7 +14,7 @@ # hard-coded audio hyperparameters config = load_config()["audio"] SAMPLE_RATE = config["sample_rate"] -N_FFT = config["n_fft"] +N_FFT = config["n_fft_large"] HOP_LENGTH = config["hop_len"] CHUNK_LENGTH = config["chunk_len"] N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk @@ -82,10 +82,12 @@ def __init__( self.config = load_config()["audio"] self.sample_rate = self.config["sample_rate"] self.chunk_len = self.config["chunk_len"] - self.n_fft = self.config["n_fft"] - self.n_fft_reduced = self.config["n_fft_reduced"] - self.n_mels = self.config["n_mels"] - self.n_mels_reduced = self.config["n_mels_reduced"] + self.n_fft_large = self.config["n_fft_large"] + self.n_fft_med = self.config["n_fft_med"] + self.n_fft_small = self.config["n_fft_small"] + self.n_mels_large = self.config["n_mels_large"] + self.n_mels_med = self.config["n_mels_med"] + self.n_mels_small = self.config["n_mels_small"] self.num_samples = self.sample_rate * self.chunk_len self.noise_ratio = noise_ratio @@ -129,28 +131,39 @@ def __init__( self.register_buffer(f"applause_{i}", applause) self.num_applause += 1 - self.spec_transform = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft, + self.spec_transform_large = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft_large, hop_length=self.config["hop_len"], ) - self.mel_transform = torchaudio.transforms.MelScale( - n_mels=self.n_mels, + self.mel_transform_large = torchaudio.transforms.MelScale( + n_mels=self.n_mels_large, sample_rate=self.sample_rate, - n_stft=self.n_fft // 2 + 1, + n_stft=self.n_fft_large // 2 + 1, f_min=30, f_max=8000, ) - self.spec_transform_reduced = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft_reduced, + self.spec_transform_med = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft_med, hop_length=self.config["hop_len"], ) - self.mel_transform_reduced = torchaudio.transforms.MelScale( - n_mels=self.n_mels_reduced, + self.mel_transform_med = torchaudio.transforms.MelScale( + n_mels=self.n_mels_med, sample_rate=self.sample_rate, - n_stft=self.n_fft_reduced // 2 + 1, + n_stft=self.n_fft_med // 2 + 1, f_min=30, f_max=8000, ) + self.spec_transform_small = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft_small, + hop_length=self.config["hop_len"], + ) + self.mel_transform_small = torchaudio.transforms.MelScale( + n_mels=self.n_mels_small, + sample_rate=self.sample_rate, + n_stft=self.n_fft_small // 2 + 1, + f_min=30, + f_max=4000, + ) self.spec_aug = torch.nn.Sequential( torchaudio.transforms.TimeMasking( time_mask_param=self.time_mask_param, @@ -370,28 +383,43 @@ def norm_mel(self, mel_spec: torch.Tensor): def log_mel( self, wav: torch.Tensor, shift: int | None = None, detune: bool = False ): - spec = self.spec_transform(wav)[..., :-1] - spec_reduced = self.spec_transform_reduced(wav)[..., :-1] + spec_large = self.spec_transform_large(wav)[..., :-1] + spec_med = self.spec_transform_med(wav)[..., :-1] + spec_small = self.spec_transform_small(wav)[..., :-1] if shift is not None and shift != 0: - spec = self.shift_spec(spec, shift) - spec_reduced = self.shift_spec(spec_reduced, shift) + spec_large = self.shift_spec(spec_large, shift) + spec_med = self.shift_spec(spec_med, shift) + spec_small = self.shift_spec(spec_small, shift) elif detune is True: # Don't detune and spec shift at the same time if random.random() < self.detune_ratio: detune_shift = random.uniform( -self.detune_max_shift, self.detune_max_shift ) - spec = self.detune_spec(spec, detune_shift=detune_shift) - spec_reduced = self.detune_spec( - spec_reduced, detune_shift=detune_shift + spec_large = self.detune_spec( + spec_large, + detune_shift=detune_shift, + ) + spec_med = self.detune_spec( + spec_med, + detune_shift=detune_shift, + ) + spec_small = self.detune_spec( + spec_small, + detune_shift=detune_shift, ) - mel_spec = self.mel_transform(spec) - mel_spec_reduced = self.mel_transform_reduced(spec_reduced) + mel_spec_large = self.mel_transform_large(spec_large) + mel_spec_med = self.mel_transform_med(spec_med) + mel_spec_small = self.mel_transform_small(spec_small) # Norm - concat_mel = torch.cat((mel_spec, mel_spec_reduced), dim=1) + concat_mel = torch.cat( + (mel_spec_large, mel_spec_med, mel_spec_small), + # (mel_spec_large, mel_spec_small), + dim=1, + ) log_mel = self.norm_mel(concat_mel) return log_mel diff --git a/config/config.json b/config/config.json index 7503777..23c56d9 100644 --- a/config/config.json +++ b/config/config.json @@ -11,12 +11,14 @@ }, "audio": { "sample_rate": 16000, - "n_fft": 2048, - "n_fft_reduced": 800, + "n_fft_large": 4096, + "n_fft_med": 2048, + "n_fft_small": 768, "hop_len": 160, "chunk_len": 30, - "n_mels": 384, - "n_mels_reduced": 128 + "n_mels_large": 384, + "n_mels_med": 256, + "n_mels_small": 128 }, "data": { "stride_factor": 15, diff --git a/config/models/medium-triple.json b/config/models/medium-triple.json new file mode 100644 index 0000000..5f92924 --- /dev/null +++ b/config/models/medium-triple.json @@ -0,0 +1,11 @@ +{ + "n_mels": 768, + "n_audio_ctx": 1500, + "n_audio_state": 768, + "n_audio_head": 12, + "n_audio_layer": 4, + "n_text_ctx": 4096, + "n_text_state": 768, + "n_text_head": 12, + "n_text_layer": 4 +} \ No newline at end of file