diff --git a/configs/config.json b/configs/config.json index 650bbf5..45e6db1 100644 --- a/configs/config.json +++ b/configs/config.json @@ -35,9 +35,9 @@ "filter_length": 1024, "hop_length": 256, "win_length": 1024, - "n_mel_channels": 128, - "mel_fmin": 0.0, - "mel_fmax": null, + "n_mel_channels": 80, + "mel_fmin": 0, + "mel_fmax": 8000, "min_file_length": 0.3, "max_file_length": 10.0, "unit_interpolate_mode": "nearest" diff --git a/data_utils.py b/data_utils.py index f5f778a..7098cfb 100644 --- a/data_utils.py +++ b/data_utils.py @@ -35,13 +35,11 @@ def __init__(self, audiopaths, hparams, all_in_mem: bool = False): self.sampling_rate = hparams.data.sampling_rate self.use_sr = hparams.train.use_sr self.spec_len = hparams.train.max_speclen - # self.spk_map = hparams.spk self.num_mels = hparams.data.n_mel_channels self.mel_fmin = hparams.data.mel_fmin self.mel_fmax = hparams.data.mel_fmax self.min_file_length = hparams.data.min_file_length * self.sampling_rate self.max_file_length = hparams.data.max_file_length * self.sampling_rate - # self.spk_map_inv = {v: k for k, v in self.spk_map.items()} random.seed(1234) random.shuffle(self.audiopaths) @@ -56,7 +54,11 @@ def _filter_long_files(self, audio_paths): filtered = [] for p, speaker in audio_paths: - if self.min_file_length <(Path(p).stat().st_size // 2) < self.max_file_length: + if ( + self.min_file_length + < (Path(p).stat().st_size // 2) + < self.max_file_length + ): filtered.append([p, speaker]) print("Audiopaths before filtering:", len(audio_paths)) @@ -90,16 +92,19 @@ def get_audio(self, filename): ) spec = torch.squeeze(spec, 0) - # load ppgs - ppg_path = filename.replace(".wav", ".ppg.pt") - ppg = torch.load(ppg_path) - # load f0 and uv f0_path = filename.replace(".wav", ".rmvpe.pt") loaded_data = torch.load(f0_path) f0 = loaded_data["f0"].unsqueeze(0) uv = loaded_data["uv"] + # load ppgs + ppg_path = filename.replace(".wav", ".ppg.pt") + ppg = torch.load(ppg_path) + ppg = utils.repeat_expand_2d( + ppg.squeeze(0), f0.shape[1], mode=self.unit_interpolate_mode + ) + # load hubert hubert_path = filename.replace(".wav", ".soft.pt") c = torch.load(hubert_path) diff --git a/inference/infer_tool.py b/inference/infer_tool.py index 0a92104..e96b627 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -4,26 +4,22 @@ import json import logging import os -import pickle import time from pathlib import Path import librosa import numpy as np +import ppgs # import onnxruntime import soundfile import torch import torchaudio -import cluster import utils -from inference import slicer from models import SynthesizerTrn -from modules.mel_processing import mel_spectrogram_torch # from models_cf import SynthesizerTrn -from modules.speaker_encoder import ResNetSpeakerEncoder logging.getLogger("matplotlib").setLevel(logging.WARNING) @@ -132,12 +128,8 @@ def __init__( net_g_path, config_path, device=None, - cluster_model_path="logs/44k/kmeans_10000.pt", - speaker_encoder_path="logs/44k/speaker_encoder.pt", - feature_retrieval=False, ): self.net_g_path = net_g_path - self.feature_retrieval = feature_retrieval if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: @@ -152,43 +144,18 @@ def __init__( if self.hps_ms.data.unit_interpolate_mode is not None else "left" ) - self.vol_embedding = ( - self.hps_ms.model.vol_embedding - if self.hps_ms.model.vol_embedding is not None - else False - ) + + # contentvec encoder self.speech_encoder = ( self.hps_ms.model.speech_encoder if self.hps_ms.model.speech_encoder is not None else "vec768l12" ) - # load hubert and model self.load_model() self.hubert_model = utils.get_speech_encoder( self.speech_encoder, device=self.dev ) - self.volume_extractor = utils.Volume_Extractor(self.hop_size) - - if os.path.exists(cluster_model_path): - if self.feature_retrieval: - with open(cluster_model_path, "rb") as f: - self.cluster_model = pickle.load(f) - self.big_npy = None - else: - self.cluster_model = cluster.get_cluster_model(cluster_model_path) - else: - self.feature_retrieval = False - - self.speaker_encoder = ResNetSpeakerEncoder( - input_dim=80, proj_dim=512, log_input=True - ) - checkpoint = torch.load( - speaker_encoder_path, - map_location="cpu", - ) - self.speaker_encoder.load_state_dict(checkpoint) - self.speaker_encoder.eval() if not hasattr(self, "audio16k_resample_transform"): self.audio16k_resample_transform = torchaudio.transforms.Resample( @@ -199,8 +166,6 @@ def load_model(self): # get model configuration self.net_g_ms = SynthesizerTrn( self.hps_ms.data.n_mel_channels, - self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, - num_mel_channels=self.hps_ms.data.n_mel_channels, **self.hps_ms.model, ) _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None) @@ -245,18 +210,6 @@ def get_unit_f0( wav = torch.from_numpy(wav).unsqueeze(0).to(self.dev) - # compute energy - energy = utils.audio_to_energy( - wav, - filter_length=self.hps_ms.data.filter_length, - n_mel_channels=self.hps_ms.data.n_mel_channels, - hop_length=self.hps_ms.data.hop_length, - win_length=self.hps_ms.data.win_length, - sampling_rate=self.hps_ms.data.sampling_rate, - mel_fmin=self.hps_ms.data.mel_fmin, - mel_fmax=self.hps_ms.data.mel_fmax, - ) - wav16k = self.audio16k_resample_transform(wav)[0] c = self.hubert_model.encoder(wav16k) @@ -279,16 +232,14 @@ def get_unit_f0( c = c.unsqueeze(0) f0 = f0.unsqueeze(0) - energy = energy.unsqueeze(0) - return c, f0, uv, energy + return c, f0, uv def infer( self, tran, raw_path, target_audio_path, - target_speaker, cluster_infer_ratio=0, n_timesteps=2, f0_filter=False, @@ -311,25 +262,6 @@ def infer( # resample to target sample rate wav = self.audio_resample_transform(wav).squeeze(0).numpy() - # wav_tgt = self.audio_resample_transform(wav_tgt) - - # speaker_embeddings = glob(f"/mnt/datasets/VC_Dataset/{speaker}/*.emb.pt")[ - # :20 - # ] - # speaker_embeddings = [ - # torch.FloatTensor(torch.load(speaker_embedding)) - # for speaker_embedding in speaker_embeddings - # ] - # sid = torch.mean(torch.stack(speaker_embeddings), dim=0).to(self.dev) - - speaker_embeddings = [] - for f in target_audio_path: - speaker_embeddings.append( - self.speaker_encoder.compute_embedding(f).to(self.dev) - ) - - sid = torch.mean(torch.stack(speaker_embeddings), dim=0).to(self.dev) - # get the root path of the file c_targets = [] mels = [] @@ -363,11 +295,11 @@ def infer( mels.append(mel_spec_tgt) mel_lengths.append(torch.LongTensor([mel_spec_tgt.shape[2]]).to(self.dev)) - style_cond = self.net_g_ms.compute_conditional_latent(mels, mel_lengths, sid) + # compute cond latent and speaker embedding + speaker_embedding = self.net_g_ms.compute_conditional_latent(mels, mel_lengths) - # sid = self.avg_speaker_embeddings[speaker].unsqueeze(0).to(self.dev) - # sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0) - c, f0, uv, energy = self.get_unit_f0( + # get contentvec, f0, uv and ppg + c, f0, uv = self.get_unit_f0( wav, c_targets, tran, @@ -377,18 +309,25 @@ def infer( cr_threshold=cr_threshold, ) + # Infer PPGs + audio = ppgs.load.audio(raw_path) + ppg = ppgs.from_audio(audio, ppgs.SAMPLE_RATE, gpu=0).to(self.dev) + ppg = utils.repeat_expand_2d( + ppg.squeeze(0), f0.shape[-1], self.unit_interpolate_mode + ).unsqueeze(0) + c = c.to(self.dtype) f0 = f0.to(self.dtype) uv = uv.to(self.dtype) + ppg = ppg.to(self.dtype) with torch.no_grad(): o, _ = self.net_g_ms.vc( c, - style_cond=style_cond, f0=f0, - g=sid, uv=uv, - energy=energy, + ppgs=ppg, + g=speaker_embedding, n_timesteps=n_timesteps, temperature=temperature, guidance_scale=guidance_scale, @@ -405,21 +344,16 @@ def unload_model(self): # unload model self.net_g_ms = self.net_g_ms.to("cpu") del self.net_g_ms - if hasattr(self, "enhancer"): - self.enhancer.enhancer = self.enhancer.enhancer.to("cpu") - del self.enhancer.enhancer - del self.enhancer gc.collect() def slice_inference( self, raw_audio_path, raw_target_audio_path, - target_speaker, tran, cluster_infer_ratio, n_timesteps=2, - f0_predictor="pm", + f0_predictor="rmvpe", cr_threshold=0.05, temperature=1.0, guidance_scale=0.0, @@ -429,7 +363,6 @@ def slice_inference( tran, raw_audio_path, target_audio_path=raw_target_audio_path, - target_speaker=target_speaker, cluster_infer_ratio=cluster_infer_ratio, n_timesteps=n_timesteps, f0_predictor=f0_predictor, @@ -441,81 +374,6 @@ def slice_inference( return out_audio - # global_frame = 0 - # audio = [] - # for slice_tag, data in audio_data: - # # padd - # length = int(np.ceil(len(data) / audio_sr * self.target_sample)) - # if slice_tag: - # _audio = np.zeros(length) - # audio.extend(list(pad_array(_audio, length))) - # global_frame += length // self.hop_size - # continue - # if per_size != 0: - # datas = split_list_by_n(data, per_size, lg_size) - # else: - # datas = [data] - # for k, dat in enumerate(datas): - # per_length = ( - # int(np.ceil(len(dat) / audio_sr * self.target_sample)) - # if clip_seconds != 0 - # else length - # ) - # if clip_seconds != 0: - # print( - # f"###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======" - # ) - # # padd - # pad_len = int(audio_sr * pad_seconds) - # dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) - # raw_path = io.BytesIO() - # soundfile.write(raw_path, dat, audio_sr, format="wav") - # raw_path.seek(0) - - # out_audio, out_sr, out_frame = self.infer( - # tran, - # raw_path, - # target_audio_path=raw_target_audio_path, - # cluster_infer_ratio=cluster_infer_ratio, - # auto_predict_f0=auto_predict_f0, - # noice_scale=noice_scale, - # f0_predictor=f0_predictor, - # cr_threshold=cr_threshold, - # f0_adain_alpha=f0_adain_alpha, - # loudness_envelope_adjustment=loudness_envelope_adjustment, - # ) - - # global_frame += out_frame - # _audio = out_audio.cpu().numpy() - # pad_len = int(self.target_sample * pad_seconds) - # _audio = _audio[pad_len:-pad_len] - # _audio = pad_array(_audio, per_length) - # if lg_size != 0 and k != 0: - # lg1 = ( - # audio[-(lg_size_r + lg_size_c_r) : -lg_size_c_r] - # if lgr_num != 1 - # else audio[-lg_size:] - # ) - # lg2 = ( - # _audio[lg_size_c_l : lg_size_c_l + lg_size_r] - # if lgr_num != 1 - # else _audio[0:lg_size] - # ) - # lg_pre = lg1 * (1 - lg) + lg2 * lg - # audio = ( - # audio[0 : -(lg_size_r + lg_size_c_r)] - # if lgr_num != 1 - # else audio[0:-lg_size] - # ) - # audio.extend(lg_pre) - # _audio = ( - # _audio[lg_size_c_l + lg_size_r :] - # if lgr_num != 1 - # else _audio[lg_size:] - # ) - # audio.extend(list(_audio)) - # return np.array(audio) - class RealTimeVC: def __init__(self): diff --git a/models.py b/models.py index 18920cd..8d71c2e 100644 --- a/models.py +++ b/models.py @@ -23,7 +23,7 @@ def __init__( n_layers, filter_channels=None, n_heads=None, - p_dropout=None, + p_dropout=0.0, utt_emb_dim=0, ): super().__init__() @@ -51,7 +51,7 @@ def __init__( n_layers=6, n_heads=2, p_dropout=0.1, - utt_emb_dim=512, + utt_emb_dim=utt_emb_dim, ) # ppg decoder @@ -63,14 +63,14 @@ def __init__( n_layers=6, n_heads=2, p_dropout=0.1, - utt_emb_dim=512, + utt_emb_dim=utt_emb_dim, ) - # decoder - self.encoder = attentions.Encoder( + # encoder + self.encoder = attentions.Decoder( hidden_channels=hidden_channels, filter_channels=filter_channels, - n_heads=2, + n_heads=n_heads, n_layers=6, kernel_size=3, p_dropout=p_dropout, @@ -90,8 +90,6 @@ def forward( self, x, x_mask, - cond=None, - cond_mask=None, f0=None, uv=None, ppgs=None, @@ -104,20 +102,20 @@ def forward( x = x + self.uv_emb(uv.long()).transpose(1, 2) # ppg decoder - ppg = self.ppg_decoder(x, x_mask, ppgs, cond, cond_mask, utt_emb) + ppg_pred = self.ppg_decoder(x, x_mask, ppgs, utt_emb) # pitch lf0 = 2595.0 * torch.log10(1.0 + f0 / 700.0) / 500 f0_norm = normalize_f0(lf0, x_mask, uv) - f0_pred = self.f0_decoder(x, x_mask, f0_norm, cond, cond_mask, utt_emb) + f0_pred = self.f0_decoder(x, x_mask, f0_norm, utt_emb) f0 = f0_to_coarse(f0.squeeze(1)) f0_emb = self.f0_emb(f0, x_mask, utt_emb) # add f0 and ppg to x - x = x + f0_emb + ppg + aux_embeddings = +f0_emb + ppg_pred # encode prosodic features - x = self.encoder(x, x_mask, utt_emb) + x = self.encoder(x, x_mask, aux_embeddings, x_mask, utt_emb) # # project to mu mu = self.proj_m(x) * x_mask @@ -128,8 +126,6 @@ def vc( self, x, x_mask, - cond=None, - cond_mask=None, f0=None, uv=None, ppgs=None, @@ -142,21 +138,21 @@ def vc( x = x + self.uv_emb(uv.long()).transpose(1, 2) # ppg decoder - ppg = self.ppg_decoder(x, x_mask, ppgs, cond, cond_mask, utt_emb) + ppg_pred = self.ppg_decoder(x, x_mask, ppgs, utt_emb) # pitch lf0 = 2595.0 * torch.log10(1.0 + f0 / 700.0) / 500 f0_norm = normalize_f0(lf0, x_mask, uv) - f0_pred = self.f0_decoder(x, x_mask, f0_norm, cond, cond_mask, utt_emb) + f0_pred = self.f0_decoder(x, x_mask, f0_norm, utt_emb) f0 = (700 * (torch.pow(10, f0_pred * 500 / 2595) - 1)).squeeze(1) f0 = f0_to_coarse(f0) f0_emb = self.f0_emb(f0, x_mask, utt_emb) # add f0 and ppg to x - x = x + f0_emb + ppg + aux_embeddings = f0_emb + ppg_pred # encode prosodic features - x = self.encoder(x, x_mask, utt_emb) + x = self.encoder(x, x_mask, aux_embeddings, x_mask, utt_emb) # # project to mu mu = self.proj_m(x) * x_mask @@ -215,18 +211,18 @@ def __init__( self.mel_encoder = MelStyleEncoder( in_channels=spec_channels, hidden_channels=256, - cond_channels=hidden_channels, utt_channels=speaker_embedding, kernel_size=5, + p_dropout=0.1, n_heads=8, dim_head=64, ) # conditional flow matching decoder self.decoder = ConditionalFlowMatching( - in_channels=self.spec_channels, - hidden_channels=self.hidden_channels, - out_channels=self.spec_channels, + in_channels=spec_channels, + hidden_channels=hidden_channels, + out_channels=spec_channels, spk_emb_dim=speaker_embedding, estimator="dit", ) @@ -241,14 +237,12 @@ def forward(self, c, f0, uv, spec, ppgs=None, c_lengths=None): ) # reference mel encoder - g, cond, cond_mask = self.mel_encoder(spec, x_mask) + g = self.mel_encoder(spec, x_mask) # content encoder mu_y, x_mask, f0_pred, lf0 = self.enc_p( c, x_mask, - cond=cond, - cond_mask=cond_mask, f0=f0, uv=uv, ppgs=ppgs, @@ -257,7 +251,7 @@ def forward(self, c, f0, uv, spec, ppgs=None, c_lengths=None): # Compute loss of score-based decoder diff_loss, _ = self.decoder.forward( - spec, None, x_mask, mu_y, spk=g, cond=cond, cond_mask=cond_mask + spec, None, x_mask, mu_y, spk=g, cond=None, cond_mask=None ) prior_loss = torch.sum( @@ -290,14 +284,12 @@ def infer( ) # reference mel encoder - g, cond, cond_mask = self.mel_encoder(spec, x_mask) + g = self.mel_encoder(spec, x_mask) # content encoder mu_y, x_mask, *_ = self.enc_p( c, x_mask, - cond=cond, - cond_mask=cond_mask, f0=f0, uv=uv, ppgs=ppgs, @@ -317,8 +309,8 @@ def infer( mu_y, n_timesteps, spk=g, - cond=cond, - cond_mask=cond_mask, + cond=None, + cond_mask=None, solver="euler", ) @@ -328,8 +320,6 @@ def infer( def vc( self, c, - cond, - cond_mask, f0, uv, ppgs, @@ -353,22 +343,14 @@ def vc( mu_y, x_mask = self.enc_p.vc( c, x_mask, - c_lengths, - cond=cond, - cond_mask=cond_mask, f0=f0, uv=uv, ppgs=ppgs, - utt_emb=g.squeeze(-1), + utt_emb=g, ) # fix length compatibility y_max_length = int(c_lengths.max()) - y_max_length_ = commons.fix_len_compatibility(y_max_length) - mu_y = commons.fix_y_by_max_length(mu_y, y_max_length_) - x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, y_max_length_), 1).to( - c.dtype - ) z = torch.randn_like(mu_y) * temperature decoder_outputs = self.decoder.inference( @@ -376,9 +358,9 @@ def vc( x_mask, mu_y, n_timesteps, - spk=g.squeeze(-1), - cond=cond, - cond_mask=cond_mask, + spk=g, + cond=None, + cond_mask=None, solver=solver, ) decoder_outputs = decoder_outputs[:, :, :y_max_length] @@ -388,22 +370,18 @@ def vc( @torch.no_grad() def compute_conditional_latent(self, mels, mel_lengths=None): speaker_embeddings = [] - cond_latents = [] for mel, length in zip(mels, mel_lengths): x_mask = torch.unsqueeze(commons.sequence_mask(length, mel.size(2)), 1).to( mel.dtype ) # reference mel encoder and perceiver latents - speaker_embedding, cond_latent, conds_mask = self.mel_encoder(mel, x_mask) + speaker_embedding = self.mel_encoder(mel, x_mask) speaker_embeddings.append(speaker_embedding.squeeze(0)) - cond_latents.append(cond_latent.squeeze(0)) - cond_latents = torch.stack(cond_latents, dim=0) speaker_embeddings = torch.stack(speaker_embeddings, dim=0) # mean pooling for cond_latents and speaker_embeddings speaker_embeddings = speaker_embeddings.mean(dim=0, keepdim=True) - conds = cond_latents.mean(dim=0, keepdim=True) - return conds, conds_mask, speaker_embeddings + return speaker_embeddings diff --git a/modules/attentions.py b/modules/attentions.py index 470cf1e..fb0cacc 100644 --- a/modules/attentions.py +++ b/modules/attentions.py @@ -6,22 +6,7 @@ from torch.nn import functional as F from modules import commons -from modules.modules import AdaIN1d, ConditionalLayerNorm, FiLM - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) +from modules.modules import AdaIN1d, ConditionalLayerNorm, FiLM, LayerNorm class ConditioningEncoder(nn.Module): @@ -108,7 +93,9 @@ def __init__( p_dropout=p_dropout, ) ) - self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.norm_layers_0.append( + ConditionalLayerNorm(hidden_channels, utt_emb_dim) + ) self.encdec_attn_layers.append( ConditioningEncoder( hidden_channels, @@ -145,7 +132,7 @@ def forward(self, x, x_mask, h, h_mask, g=None): for i in range(self.n_layers): y = self.self_attn_layers[i](x, x, self_attn_mask) y = self.drop(y) - x = self.norm_layers_0[i](x + y) + x = self.norm_layers_0[i](x + y, g) y = self.encdec_attn_layers[i](x, x_mask, h, h_mask) y = self.drop(y) @@ -168,8 +155,6 @@ def __init__( dim_head=None, kernel_size=1, p_dropout=0.0, - window_size=4, - utt_emb_dim=None, ): super().__init__() self.hidden_channels = hidden_channels @@ -178,7 +163,6 @@ def __init__( self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = p_dropout - self.window_size = window_size self.drop = nn.Dropout(p_dropout) self.attn_layers = nn.ModuleList() @@ -198,9 +182,7 @@ def __init__( p_dropout=p_dropout, ) ) - self.norm_layers_1.append( - ConditionalLayerNorm(hidden_channels, utt_emb_dim) - ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers_1.append( FFN( @@ -211,11 +193,9 @@ def __init__( p_dropout=p_dropout, ) ) - self.norm_layers_2.append( - ConditionalLayerNorm(hidden_channels, utt_emb_dim) - ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask, g=None): + def forward(self, x, x_mask): # attn mask attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask @@ -224,12 +204,12 @@ def forward(self, x, x_mask, g=None): # self-attention y = self.attn_layers[i](x, x, attn_mask) y = self.drop(y) - x = self.norm_layers_1[i](x + y, g) + x = self.norm_layers_1[i](x + y) # feed-forward y = self.ffn_layers_1[i](x, x_mask) y = self.drop(y) - x = self.norm_layers_2[i](x + y, g) + x = self.norm_layers_2[i](x + y) x = x * x_mask return x diff --git a/modules/cfm/cfm_neuralode.py b/modules/cfm/cfm_neuralode.py index ebe3727..62eb4f3 100644 --- a/modules/cfm/cfm_neuralode.py +++ b/modules/cfm/cfm_neuralode.py @@ -87,9 +87,9 @@ def __init__( hidden_channels=hidden_channels, out_channels=out_channel, filter_channels=hidden_channels * 4, - dropout=0.05, - n_layers=4, - n_heads=4, + dropout=0.00, + n_layers=6, + n_heads=2, kernel_size=3, utt_emb_dim=spk_emb_dim, ) diff --git a/modules/cfm/dit.py b/modules/cfm/dit.py index 6efb1d9..18fe605 100644 --- a/modules/cfm/dit.py +++ b/modules/cfm/dit.py @@ -2,117 +2,144 @@ import torch import torch.nn as nn +from einops import rearrange -from modules.attentions import FFN, ConditioningEncoder, MultiHeadAttention +from modules.attentions import FFN, MultiHeadAttention +from modules.modules import ConditionalLayerNorm, LayerNorm -# modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390 -class DiTConVBlock(nn.Module): - """ - A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. - """ +class ConditionalGroupNorm(nn.Module): + def __init__(self, groups, normalized_shape, context_dim): + super().__init__() + self.norm = nn.GroupNorm(groups, normalized_shape, affine=False) + self.context_mlp = nn.Sequential( + nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape) + ) + self.context_mlp[1].weight.data.zero_() + self.context_mlp[1].bias.data.zero_() + + def forward(self, x, context): + context = self.context_mlp(context) + ndims = " 1" * len(x.shape[2:]) + context = rearrange(context, f"b c -> b c{ndims}") + + scale, shift = context.chunk(2, dim=1) + x = self.norm(x) * (scale + 1.0) + shift + return x + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8, context_dim=None): + super().__init__() + self.conv1d = torch.nn.Conv1d(dim, dim_out, 3, padding=1) + if context_dim is None: + self.norm = torch.nn.GroupNorm(groups, dim_out) + else: + self.norm = ConditionalGroupNorm(groups, dim_out, context_dim) + self.mish = nn.Mish() + + def forward(self, x, mask, utt_emb=None): + output = self.conv1d(x * mask) + if utt_emb is not None: + output = self.norm(output, utt_emb) + else: + output = self.norm(output) + output = self.mish(output) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8, context_dim=512): + super().__init__() + self.mlp = torch.nn.Sequential( + nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out) + ) + self.block1 = Block1D(dim, dim_out, groups=groups, context_dim=context_dim) + self.block2 = Block1D(dim_out, dim_out, groups=groups, context_dim=context_dim) + + self.res_conv = ( + torch.nn.Conv1d(dim, dim_out, 1) if dim != dim_out else torch.nn.Identity() + ) + + def forward(self, x, mask, time_emb, utt_emb=None): + h = self.block1(x, mask, utt_emb) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask, utt_emb) + output = h + self.res_conv(x * mask) + return output + + +class Encoder(nn.Module): def __init__( self, hidden_channels, filter_channels, - num_heads, - kernel_size=3, - p_dropout=0.1, - utt_emb_dim=0, + time_channels, + n_heads, + n_layers, + dim_head=None, + kernel_size=1, + p_dropout=0.0, ): super().__init__() - self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6) - self.attn = MultiHeadAttention( - hidden_channels, hidden_channels, num_heads, p_dropout=p_dropout - ) - self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6) - self.cross_attn = ConditioningEncoder( - hidden_channels=hidden_channels, - n_heads=num_heads, - dim_head=None, - p_dropout=p_dropout, - cond_emb_dim=192, - ) - self.norm3 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6) - self.mlp = FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - ) - self.adaLN_modulation = nn.Sequential( - nn.Linear(utt_emb_dim, hidden_channels), - nn.SiLU(), - nn.Linear(hidden_channels, 9 * hidden_channels, bias=True), - ) + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size - def forward(self, x, c, x_mask, cond, cond_mask): - """ - Args: - x : [batch_size, channel, time] - c : [batch_size, channel] - x_mask : [batch_size, 1, time] - return the same shape as x - """ - x = x * x_mask + self.attn_layers = nn.ModuleList() - # attn mask - attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + self.norm_layers_1 = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() - ( - shift_msa, - scale_msa, - gate_msa, - shift_mca, - scale_mca, - gate_mca, - shift_mlp, - scale_mlp, - gate_mlp, - ) = ( - self.adaLN_modulation(c).unsqueeze(2).chunk(9, dim=1) - ) # shape: [batch_size, channel, 1] - - # self attention - modulated_x = self.modulate( - self.norm1(x.transpose(1, 2)).transpose(1, 2), shift_msa, scale_msa - ) - x = ( - x - + gate_msa - * self.attn( - modulated_x, - c=modulated_x, - attn_mask=attn_mask, + self.ffn_layers_1 = nn.ModuleList() + + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + dim_head=dim_head, + p_dropout=p_dropout, + ) + ) + self.norm_layers_1.append( + ConditionalLayerNorm(hidden_channels, time_channels) ) - * x_mask - ) - # cross attention - modulated_cross_x = self.modulate( - self.norm2(x.transpose(1, 2)).transpose(1, 2), shift_mca, scale_mca - ) - x = ( - x - + gate_mca - * self.cross_attn(modulated_cross_x, x_mask, cond, cond_mask) - * x_mask - ) + self.ffn_layers_1.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size=kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append( + ConditionalLayerNorm(hidden_channels, time_channels) + ) - x = x + gate_mlp * self.mlp( - self.modulate( - self.norm3(x.transpose(1, 2)).transpose(1, 2), shift_mlp, scale_mlp - ), - x_mask, - ) + self.norm = LayerNorm(hidden_channels) - return x + def forward(self, x, x_mask, t): + # attn mask + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask - @staticmethod - def modulate(x, shift, scale): - return x * (1 + scale) + shift + for i in range(self.n_layers): + # self-attention + attn_input = self.norm_layers_1[i](x, t) + x = self.attn_layers[i](attn_input, attn_input, attn_mask) + x + + # feed-forward + ffn_input = self.norm_layers_2[i](x, t) + x = self.ffn_layers_1[i](ffn_input, x_mask) + x + + return self.norm(x * x_mask) class DitWrapper(nn.Module): @@ -120,6 +147,7 @@ class DitWrapper(nn.Module): def __init__( self, + in_channels, hidden_channels, filter_channels, num_heads, @@ -130,92 +158,36 @@ def __init__( time_channels=0, ): super().__init__() - self.time_fusion = FiLMLayer(hidden_channels, time_channels) - self.conv_layers = nn.ModuleList( - [ - ConvNeXtBlock(hidden_channels, filter_channels, utt_emb_dim) - for _ in range(conv_layers) - ] - ) - self.block = DiTConVBlock( - hidden_channels, - hidden_channels, - num_heads, - kernel_size, - p_dropout, - utt_emb_dim, - ) - - def forward(self, x, c, t, x_mask, cond, cond_mask): - x = self.time_fusion(x, t) * x_mask - for layer in self.conv_layers: - x = layer(x, c, x_mask) - x = self.block(x, c, x_mask, cond, cond_mask) - return x + self.conv_layers = nn.ModuleList([]) -class FiLMLayer(nn.Module): - """ - Feature-wise Linear Modulation (FiLM) layer - Reference: https://arxiv.org/abs/1709.07871 - """ - - def __init__(self, in_channels, cond_channels): - super(FiLMLayer, self).__init__() - self.in_channels = in_channels - self.film = nn.Conv1d(cond_channels, in_channels * 2, 1) - - def forward(self, x, c): - gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1) - return gamma * x + beta - + for _ in range(conv_layers): + self.conv_layers.append( + ResnetBlock1D( + dim=in_channels, + dim_out=hidden_channels, + time_emb_dim=time_channels, + groups=8, + context_dim=utt_emb_dim, + ) + ) + in_channels = hidden_channels -class ConvNeXtBlock(nn.Module): - def __init__(self, in_channels, filter_channels, gin_channels): - super().__init__() - self.dwconv = nn.Conv1d( - in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels - ) - self.norm = StyleAdaptiveLayerNorm(in_channels, gin_channels) - self.pwconv = nn.Sequential( - nn.Linear(in_channels, filter_channels), - nn.GELU(), - nn.Linear(filter_channels, in_channels), + self.block = Encoder( + hidden_channels=hidden_channels, + filter_channels=filter_channels, + time_channels=time_channels, + n_heads=num_heads, + n_layers=1, + kernel_size=kernel_size, + p_dropout=p_dropout, ) - def forward(self, x, c, x_mask) -> torch.Tensor: - residual = x - x = self.dwconv(x) * x_mask - x = self.norm(x.transpose(1, 2), c) - x = self.pwconv(x).transpose(1, 2) - x = residual + x - return x * x_mask - - -class StyleAdaptiveLayerNorm(nn.Module): - def __init__(self, in_channels, cond_channels): - """ - Style Adaptive Layer Normalization (SALN) module. - - Parameters: - in_channels: The number of channels in the input feature maps. - cond_channels: The number of channels in the conditioning input. - """ - super(StyleAdaptiveLayerNorm, self).__init__() - self.in_channels = in_channels - - self.saln = nn.Linear(cond_channels, in_channels * 2, 1) - self.norm = nn.LayerNorm(in_channels, elementwise_affine=False) - - self.reset_parameters() - - def reset_parameters(self): - nn.init.constant_(self.saln.bias.data[: self.in_channels], 1) - nn.init.constant_(self.saln.bias.data[self.in_channels :], 0) - - def forward(self, x, c): - gamma, beta = torch.chunk(self.saln(c.unsqueeze(1)), chunks=2, dim=-1) - return gamma * self.norm(x) + beta + def forward(self, x, c, t, x_mask, cond, cond_mask): + for layer in self.conv_layers: + x = layer(x, x_mask, t, c) + x = self.block(x, x_mask, t) + return x class SinusoidalPosEmb(nn.Module): @@ -241,7 +213,7 @@ def __init__(self, in_channels, out_channels, filter_channels): self.layer = nn.Sequential( nn.Linear(in_channels, filter_channels), - nn.SiLU(inplace=True), + nn.SiLU(), nn.Linear(filter_channels, out_channels), ) @@ -267,38 +239,55 @@ def __init__( self.hidden_channels = hidden_channels self.out_channels = out_channels self.filter_channels = filter_channels + self.n_layers = n_layers self.time_embeddings = SinusoidalPosEmb(hidden_channels) self.time_mlp = TimestepEmbedding( hidden_channels, hidden_channels, filter_channels ) - # in projection - self.in_proj = nn.Conv1d(in_channels, hidden_channels, 1) + self.down_blocks = nn.ModuleList() + self.up_blocks = nn.ModuleList() + for idx in range(n_layers // 2): + self.down_blocks.append( + DitWrapper( + in_channels=in_channels, + hidden_channels=hidden_channels, + filter_channels=filter_channels, + num_heads=n_heads, + kernel_size=kernel_size, + p_dropout=dropout, + utt_emb_dim=utt_emb_dim, + conv_layers=2, + time_channels=hidden_channels, + ) + ) + in_channels = hidden_channels - self.blocks = nn.ModuleList( - [ + for idx in range(n_layers // 2): + self.up_blocks.append( DitWrapper( + in_channels=hidden_channels * 2, hidden_channels=hidden_channels, filter_channels=filter_channels, num_heads=n_heads, kernel_size=kernel_size, p_dropout=dropout, utt_emb_dim=utt_emb_dim, - conv_layers=3, + conv_layers=2, time_channels=hidden_channels, ) - for _ in range(n_layers) - ] - ) + ) + + self.final_block = Block1D(hidden_channels, hidden_channels) self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.initialize_weights() + # self.initialize_weights() - def initialize_weights(self): - for block in self.blocks: - nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0) + # def initialize_weights(self): + # for block in self.blocks: + # nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0) def forward(self, x, mask, mu, t, spks=None, cond=None, cond_mask=None): """Forward pass of the UNet1DConditional model. @@ -320,11 +309,17 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, cond_mask=None): t = self.time_mlp(self.time_embeddings(t)) x = torch.cat((x, mu), dim=1) - x = self.in_proj(x) * mask + skip_connections = [] + for idx, block in enumerate(self.down_blocks): + x = block(x, spks, t, mask, cond, cond_mask) + skip_connections.append(x) - for block in self.blocks: + for idx, block in enumerate(self.up_blocks): + skip_x = skip_connections.pop() + x = torch.cat([x, skip_x], dim=1) x = block(x, spks, t, mask, cond, cond_mask) + x = self.final_block(x, mask) output = self.final_proj(x * mask) return output * mask diff --git a/modules/perceiver_encoder.py b/modules/perceiver_encoder.py index eaea499..e413b69 100644 --- a/modules/perceiver_encoder.py +++ b/modules/perceiver_encoder.py @@ -252,13 +252,13 @@ def __init__( use_flash=False, cross_attn_include_queries=True, ), - RMSNorm(hidden_channels), FeedForward(dim=hidden_channels, mult=ff_mult), - RMSNorm(hidden_channels), ] ) ) + self.norm = RMSNorm(hidden_channels) + def forward(self, x, x_mask=None): batch = x.shape[0] @@ -274,11 +274,11 @@ def forward(self, x, x_mask=None): device=x_mask.device, ) - for attn, norm_attn, ff, norm_ff in self.layers: - y = attn(latents, x, mask=None) - latents = norm_attn(y + latents) - y = ff(latents) - latents = norm_ff(y + latents) + for attn, ff in self.layers: + latents = attn(latents, x, mask=None) + latents + latents = ff(latents) + latents + + latents = self.norm(latents) latents = rearrange(latents, "b n d -> b d n") diff --git a/modules/reference_encoder.py b/modules/reference_encoder.py index 62a4454..168b461 100644 --- a/modules/reference_encoder.py +++ b/modules/reference_encoder.py @@ -1,8 +1,31 @@ import torch import torch.nn as nn +from modules.attentions import Encoder from modules.mel_encoder import MelEncoder -from modules.perceiver_encoder import PerceiverResampler + + +class Conv1dGLU(nn.Module): + """ + Conv1d + GLU(Gated Linear Unit) with residual connection. + For GLU refer to https://arxiv.org/abs/1612.08083 paper. + """ + + def __init__(self, in_channels, out_channels, kernel_size, dropout): + super(Conv1dGLU, self).__init__() + self.out_channels = out_channels + self.conv1 = nn.Conv1d( + in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2 + ) + self.p_dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.conv1(x) + x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) + x = x1 * torch.sigmoid(x2) + x = residual + self.p_dropout(x) + return x class MelStyleEncoder(nn.Module): @@ -11,56 +34,63 @@ class MelStyleEncoder(nn.Module): def __init__( self, in_channels=80, - hidden_channels=256, - cond_channels=192, + hidden_channels=192, utt_channels=512, kernel_size=5, + p_dropout=0.0, n_heads=2, - dim_head=None, + dim_head=64, ): super(MelStyleEncoder, self).__init__() - # encoder - self.encoder = MelEncoder( - in_channels=in_channels, - out_channels=hidden_channels, - hidden_channels=cond_channels, - kernel_size=kernel_size, - dilation_rate=1, - n_layers=16, + # encode + self.spectral = nn.Sequential( + nn.Conv1d(in_channels, hidden_channels, 1), + nn.Mish(), + nn.Dropout(p_dropout), + nn.Conv1d(hidden_channels, hidden_channels, 1), + nn.Mish(), + nn.Dropout(p_dropout), + ) + + self.temporal = nn.Sequential( + Conv1dGLU(hidden_channels, hidden_channels, kernel_size, p_dropout), + Conv1dGLU(hidden_channels, hidden_channels, kernel_size, p_dropout), ) - # perceiver encoder - self.perceiver_encoder = PerceiverResampler( + # attn + self.attn = Encoder( hidden_channels=hidden_channels, - depth=2, - num_latents=32, + filter_channels=hidden_channels * 4, + n_layers=2, + n_heads=n_heads, dim_head=dim_head, - heads=n_heads, - ff_mult=4, + kernel_size=3, + p_dropout=0.1, ) - self.cond_proj = nn.Conv1d(hidden_channels, cond_channels, kernel_size=1) - self.utt_proj = nn.Conv1d(hidden_channels, utt_channels, kernel_size=1) + self.fc = nn.Conv1d(hidden_channels, utt_channels, kernel_size=1) def temporal_avg_pool(self, x, mask=None): # avg pooling - len_ = mask.sum(dim=-1) - x = torch.sum(x * mask, dim=-1) / len_ - return x + len_ = mask.sum(dim=2) + x = x.sum(dim=2) + out = torch.div(x, len_) + return out def forward(self, x, x_mask=None): - # encode mel (x) - encoded_mel = self.encoder(x, x_mask) + # spectral + x = self.spectral(x) * x_mask + # temporal + x = self.temporal(x) * x_mask - # perceiver encoder - cond, cond_mask = self.perceiver_encoder(encoded_mel, x_mask) + # attention + x = self.attn(x, x_mask) - # project to cond and utt embeddings - utt_emb = self.utt_proj(cond) * cond_mask - cond = self.cond_proj(cond) * cond_mask + # fc + x = self.fc(x) * x_mask # temoral average pooling for utterance embedding - utt_emb = self.temporal_avg_pool(utt_emb, mask=cond_mask) + utt_emb = self.temporal_avg_pool(x, mask=x_mask) - return utt_emb, cond, cond_mask + return utt_emb diff --git a/modules/variance_decoder.py b/modules/variance_decoder.py index b9fe5ca..34215aa 100644 --- a/modules/variance_decoder.py +++ b/modules/variance_decoder.py @@ -5,6 +5,82 @@ from modules.modules import AdainResBlk1d, ConditionalLayerNorm +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + dim_head=None, + kernel_size=1, + p_dropout=0.0, + utt_emb_dim=512, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + + self.norm_layers_1 = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + self.ffn_layers_1 = nn.ModuleList() + + for i in range(self.n_layers): + self.attn_layers.append( + attentions.MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + dim_head=dim_head, + p_dropout=p_dropout, + ) + ) + self.norm_layers_1.append( + ConditionalLayerNorm(hidden_channels, utt_emb_dim) + ) + + self.ffn_layers_1.append( + attentions.FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size=kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append( + ConditionalLayerNorm(hidden_channels, utt_emb_dim) + ) + + def forward(self, x, x_mask, g=None): + # attn mask + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + + for i in range(self.n_layers): + # self-attention + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y, g) + + # feed-forward + y = self.ffn_layers_1[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y, g) + + x = x * x_mask + return x + + class ConditionalEmbedding(nn.Module): def __init__(self, num_embeddings, d_model, style_dim=512): super().__init__() @@ -43,7 +119,7 @@ def __init__( self.drop = nn.Dropout(p_dropout) - self.aux_decoder = attentions.Decoder( + self.aux_decoder = Encoder( hidden_channels=hidden_channels, filter_channels=hidden_channels * 4, n_heads=n_heads, @@ -51,13 +127,12 @@ def __init__( kernel_size=kernel_size, dim_head=dim_head, utt_emb_dim=utt_emb_dim, - causal_ffn=True, p_dropout=0.1, ) self.proj = nn.Conv1d(hidden_channels, output_channels, 1) - def forward(self, x, x_mask, aux, h, h_mask, utt_emb): + def forward(self, x, x_mask, aux, utt_emb): # detach x x = torch.detach(x) @@ -65,7 +140,8 @@ def forward(self, x, x_mask, aux, h, h_mask, utt_emb): x = x + self.aux_prenet(aux) * x_mask x = self.prenet(x) * x_mask - x = self.aux_decoder(x, x_mask, h, h_mask, utt_emb) + # attention + x = self.aux_decoder(x, x_mask, utt_emb) # out projection x = self.proj(x) * x_mask diff --git a/preprocess.ipynb b/preprocess.ipynb deleted file mode 100644 index 744d808..0000000 --- a/preprocess.ipynb +++ /dev/null @@ -1,439 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preprocess Config" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "from random import shuffle\n", - "\n", - "from loguru import logger\n", - "from tqdm import tqdm\n", - "from pathlib import Path\n", - "\n", - "min_duration = 22050 * 2.0\n", - "max_duration = 22050 * 10.0\n", - "\n", - "config_template = json.load(open(\"configs_template/config_template.json\"))\n", - "\n", - "training_files = [\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_borderlands2_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_baldursgate3_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_worldofwarcraft_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_mario_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/de_gametts_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/pl_archolos_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/de_borderlands2_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_warcraft_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_sqnarrator_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_emotional_train_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/de_emotional_train_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/ru_witcher3_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_witcher3_skyrim_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_fallout4_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_naruto_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/de_kcd_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/pl_witcher3_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/de_diablo4_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/en_diablo4_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/fr_diablo4_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/pl_diablo4_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/ru_diablo4_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/ru_skyrim_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/jp_one_piece_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/metadata/filelists/xphoneBERT/jp_skyrim_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/fr/Fallout4/fr_fallout4_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/de/Fallout4/de_fallout4_xphone.csv\",\n", - " \"/mnt/datasets/TTS_Data/en/Fallout4/en_fallout4_xphone.csv\",\n", - "]\n", - "\n", - "all_lines = []\n", - "\n", - "for file in training_files:\n", - " with open(file) as f:\n", - " lines = f.readlines()\n", - " all_lines.extend(lines)\n", - "\n", - "len(all_lines)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wavs = []\n", - "train = []\n", - "val = []\n", - "spk_dict = {}\n", - "spk_id = 0\n", - "speaker_items_count = {}\n", - "duplicate_wavs = set()\n", - "\n", - "shuffle(all_lines)\n", - "\n", - "for line in tqdm(all_lines):\n", - " cols = line.strip().split(\"|\")\n", - " speaker_name = cols[1]\n", - " wav_path = cols[0]\n", - "\n", - " if not os.path.exists(wav_path):\n", - " continue\n", - "\n", - " if not (max_duration >= (Path(wav_path).stat().st_size // 2) > min_duration):\n", - " continue\n", - "\n", - " if speaker_name not in spk_dict:\n", - " speaker_items_count[speaker_name] = 0\n", - " spk_dict[speaker_name] = spk_id\n", - " spk_id += 1\n", - " else:\n", - " speaker_items_count[speaker_name] += 1\n", - "\n", - " if (wav_path, speaker_name) in duplicate_wavs:\n", - " continue\n", - "\n", - " if speaker_items_count[speaker_name] < 150:\n", - " duplicate_wavs.add(wav_path)\n", - " wavs.append((wav_path, speaker_name))\n", - "\n", - "shuffle(wavs)\n", - "\n", - "with open(\"/home/alexander/Projekte/so-vits-svc/filelists/voice_conversion_train.txt\", \"w\") as f:\n", - " for wav_path, speaker_name in wavs:\n", - " speaker_id = spk_dict[speaker_name]\n", - " f.write(f\"{wav_path}|{speaker_id}\\n\")\n", - "\n", - "\n", - "config_template[\"spk\"] = spk_dict\n", - "config_template[\"model\"][\"n_speakers\"] = spk_id\n", - "config_template[\"model\"][\"speech_encoder\"] = \"vec768l12\"\n", - "\n", - "\n", - "logger.info(\"Writing to configs/config_vc.json\")\n", - "with open(\"configs/config_vc.json\", \"w\") as f:\n", - " json.dump(config_template, f, indent=2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "from random import shuffle\n", - "\n", - "from loguru import logger\n", - "from tqdm import tqdm\n", - "from pathlib import Path\n", - "from glob import glob\n", - "import wave \n", - "\n", - "min_duration = 22050 * 1.0\n", - "max_duration = 22050 * 8.0\n", - "wavs = []\n", - "train = []\n", - "val = []\n", - "spk_dict = {}\n", - "spk_id = 6860\n", - "speaker_items_count = {}\n", - "duplicate_wavs = set()\n", - "\n", - "all_wavs = glob(\"/mnt/datasets/TTS_Data/en/FF7/**/*.wav\", recursive=True)\n", - "\n", - "shuffle(all_wavs)\n", - "\n", - "for file_path in tqdm(all_wavs):\n", - "\n", - " wav_path = file_path\n", - " speaker_name = file_path.split(\"/\")[-2]\n", - " \n", - " if \"announcer\" in speaker_name:\n", - " continue\n", - "\n", - " if not os.path.exists(wav_path):\n", - " continue\n", - "\n", - " # Open the WAV file\n", - " with wave.open(wav_path, 'r') as wav_file:\n", - " # Get the number of frames and the frame rate\n", - " frames = wav_file.getnframes()\n", - " frame_rate = wav_file.getframerate()\n", - "\n", - " # Calculate the duration in seconds\n", - " duration_seconds = frames / float(frame_rate)\n", - "\n", - " if not (8.0 >= duration_seconds > 1.0):\n", - " continue\n", - "\n", - " if speaker_name not in spk_dict:\n", - " speaker_items_count[speaker_name] = 0\n", - " spk_dict[speaker_name] = spk_id\n", - " spk_id += 1\n", - " else:\n", - " speaker_items_count[speaker_name] += 1\n", - "\n", - " if (wav_path, speaker_name) in duplicate_wavs:\n", - " continue\n", - "\n", - " if speaker_items_count[speaker_name] < 200:\n", - " duplicate_wavs.add(wav_path)\n", - " wavs.append((wav_path, speaker_name))\n", - "\n", - "shuffle(wavs)\n", - "\n", - "with open(\"/home/alexander/Projekte/so-vits-svc/filelists/voice_conversion_train_ff7.txt\", \"w\") as f:\n", - " for wav_path, speaker_name in wavs:\n", - " speaker_id = spk_dict[speaker_name]\n", - " f.write(f\"{wav_path}|{speaker_id}\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(\"/home/alexander/Projekte/so-vits-svc/filelists/voice_conversion_train.txt\", \"r\") as f:\n", - " lines = f.readlines()\n", - " \n", - "# get max speaker id\n", - "max_speaker_id = 0\n", - "for line in lines:\n", - " cols = line.strip().split(\"|\")\n", - " wav_path = cols[0]\n", - " speaker_id = cols[1]\n", - " \n", - " if int(speaker_id) > max_speaker_id:\n", - " max_speaker_id = int(speaker_id)\n", - "\n", - "max_speaker_id" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preprocess F0 and Hubert" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "file_paths = []\n", - "for file, speaker in wavs:\n", - " file_paths.append(file)\n", - "\n", - "print(len(file_paths))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def find_duplicates(lst):\n", - " seen = set()\n", - " duplicates = set()\n", - "\n", - " for sublist in lst:\n", - " # Convert the list into a tuple to make it hashable\n", - " t = tuple(sublist)\n", - "\n", - " if t in seen:\n", - " duplicates.add(t)\n", - " seen.add(t)\n", - "\n", - " return list(duplicates)\n", - "\n", - "\n", - "dups = find_duplicates(wavs)\n", - "dups" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "de_duped = []\n", - "with open(\n", - " \"/home/alexander/Projekte/so-vits-svc/filelists/voice_conversion_train.txt\", \"r\"\n", - ") as rf:\n", - " for line in rf:\n", - " fil = line.strip().split(\"|\")[0]\n", - " if not any(fil in dup for dup, speak in dups):\n", - " de_duped.append(line)\n", - "\n", - "with open(\n", - " \"/home/alexander/Projekte/so-vits-svc/filelists/voice_conversion_train.txt\", \"w\"\n", - ") as wf:\n", - " for line in de_duped:\n", - " wf.write(line)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import utils\n", - "import librosa\n", - "import torch\n", - "import numpy as np\n", - "from concurrent.futures import ProcessPoolExecutor\n", - "import torch.multiprocessing as mp\n", - "from tqdm import tqdm\n", - "\n", - "sampling_rate = 22050\n", - "hop_length = 256\n", - "speech_encoder = \"vec768l12\"\n", - "device = \"cuda:0\"\n", - "f0p = \"crepe\"\n", - "\n", - "save_path = \"/mnt/datasets/VC_Data\"\n", - "\n", - "\n", - "def process_one(filename, hmodel, f0p, device, diff=False, mel_extractor=None):\n", - " filename, speaker = filename\n", - " wav, sr = librosa.load(filename, sr=sampling_rate)\n", - " audio_norm = torch.FloatTensor(wav)\n", - " audio_norm = audio_norm.unsqueeze(0)\n", - "\n", - " # get only the filename without path and without the extension\n", - " only_filename = os.path.splitext(os.path.basename(filename))[0]\n", - "\n", - " soft_path = os.path.join(save_path, only_filename + f\"_{speaker}_.soft.pt\")\n", - " if not os.path.exists(soft_path):\n", - " wav16k = librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000)\n", - " wav16k = torch.from_numpy(wav16k).to(device)\n", - " c = hmodel.encoder(wav16k)\n", - " torch.save(c.cpu(), soft_path)\n", - "\n", - " f0_path = filename.replace(\".wav\", \".pitch.pt\")\n", - " if not os.path.exists(f0_path):\n", - " f0_predictor = utils.get_f0_predictor(\n", - " f0p,\n", - " sampling_rate=sampling_rate,\n", - " hop_length=hop_length,\n", - " device=device,\n", - " threshold=0.05,\n", - " )\n", - "\n", - " f0, uv = f0_predictor.compute_f0_uv(wav)\n", - "\n", - " # Assuming f0 and uv are numpy arrays\n", - " f0_tensor = torch.from_numpy(f0)\n", - " uv_tensor = torch.from_numpy(uv)\n", - "\n", - " # Save as a dictionary for clarity\n", - " data_to_save = {\"f0\": f0_tensor, \"uv\": uv_tensor}\n", - " torch.save(data_to_save, f0_path)\n", - "\n", - " # np.save(f0_path, np.asanyarray((f0, uv), dtype=object))\n", - "\n", - "\n", - "def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device=\"cpu\"):\n", - " logger.info(\"Loading speech encoder for content...\")\n", - " rank = mp.current_process()._identity\n", - " rank = rank[0] if len(rank) > 0 else 0\n", - " if torch.cuda.is_available():\n", - " gpu_id = rank % torch.cuda.device_count()\n", - " device = torch.device(f\"cuda:{gpu_id}\")\n", - " logger.info(f\"Rank {rank} uses device {device}\")\n", - " hmodel = utils.get_speech_encoder(speech_encoder, device=device)\n", - " logger.info(f\"Loaded speech encoder for rank {rank}\")\n", - " for filename in tqdm(file_chunk, position=rank):\n", - " process_one(filename, hmodel, f0p, device, diff, mel_extractor)\n", - "\n", - "\n", - "def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device):\n", - " with ProcessPoolExecutor(max_workers=num_processes) as executor:\n", - " tasks = []\n", - " for i in range(num_processes):\n", - " start = int(i * len(filenames) / num_processes)\n", - " end = int((i + 1) * len(filenames) / num_processes)\n", - " file_chunk = filenames[start:end]\n", - " tasks.append(\n", - " executor.submit(\n", - " process_batch, file_chunk, f0p, diff, mel_extractor, device=device\n", - " )\n", - " )\n", - " for task in tqdm(tasks, position=0):\n", - " task.result()\n", - "\n", - "\n", - "parallel_process(wavs, 7, f0p, False, None, device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from glob import glob\n", - "import wave\n", - "from tqdm import tqdm\n", - "import os\n", - "\n", - "all_wavs = glob(\"/mnt/datasets/TTS_Data/en/FF7/**/*.wav\", recursive=True)\n", - "\n", - "for wav_path in tqdm(all_wavs):\n", - " # Open the WAV file\n", - " with wave.open(wav_path, 'r') as wav_file:\n", - " # Get the number of frames and the frame rate\n", - " frames = wav_file.getnframes()\n", - " frame_rate = wav_file.getframerate()\n", - "\n", - " # Calculate the duration in seconds\n", - " duration_seconds = frames / float(frame_rate)\n", - "\n", - "\n", - " if not (7.0 >= duration_seconds > 1.0):\n", - " os.remove(wav_path)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/preprocess_f0_hubert.py b/preprocess_f0_hubert.py index 1f9c176..3ca78fc 100644 --- a/preprocess_f0_hubert.py +++ b/preprocess_f0_hubert.py @@ -84,7 +84,7 @@ def parallel_process(filenames, num_processes, f0p, device): soft_path = file_path.replace(".wav", ".soft.pt") f0_path = file_path.replace(".wav", ".rmvpe.pt") - if not os.path.exists(soft_path) and not os.path.exists(f0_path): + if not os.path.exists(soft_path) or not os.path.exists(f0_path): wav_paths.append(file_path.strip()) # preprocess f0 and hubert diff --git a/preprocess_ppgs.py b/preprocess_ppgs.py index 9bb0c39..8ae9bda 100644 --- a/preprocess_ppgs.py +++ b/preprocess_ppgs.py @@ -1,12 +1,13 @@ import ppgs if __name__ == "__main__": + # build paths wav_paths = [] with open("/workspace/vc_train.csv", "r") as f: for line in f: file_path = line.split("|")[0] wav_paths.append(file_path.strip()) - ppgs_paths = [path.replace(".wav", ".ppg.pt") for path in wav_paths] - ppgs.from_files_to_files(wav_paths, ppgs_paths, gpu=0) \ No newline at end of file + # compute ppgs + ppgs.from_files_to_files(wav_paths, ppgs_paths, gpu=0) diff --git a/testing.ipynb b/testing.ipynb index d912289..80d222a 100644 --- a/testing.ipynb +++ b/testing.ipynb @@ -40,13 +40,13 @@ "import torch\n", "\n", "dit = DiT(\n", - " in_channels=100 * 2,\n", - " hidden_channels=256,\n", - " out_channels=100,\n", - " filter_channels=256 * 4,\n", + " in_channels=80 * 2,\n", + " hidden_channels=192,\n", + " out_channels=80,\n", + " filter_channels=192 * 4,\n", " dropout=0.05,\n", - " n_layers=4,\n", - " n_heads=4,\n", + " n_layers=6,\n", + " n_heads=2,\n", " kernel_size=3,\n", " utt_emb_dim=512,\n", ")\n", @@ -54,9 +54,9 @@ "# print dit parameter count\n", "print(f\"DiT parameter count: {sum(p.numel() for p in dit.parameters())}\")\n", "\n", - "x = torch.randn(1, 100, 128)\n", + "x = torch.randn(1, 80, 128)\n", "x_mask = torch.ones(1, 1, 128)\n", - "mu = torch.randn(1, 100, 128)\n", + "mu = torch.randn(1, 80, 128)\n", "t = torch.Tensor([0.2])\n", "spks = torch.randn(1, 512)\n", "cond = torch.randn(1, 192, 32)\n", @@ -75,7 +75,7 @@ "from models import SynthesizerTrn\n", "\n", "vc_model = SynthesizerTrn(\n", - " spec_channels=128,\n", + " spec_channels=80,\n", " hidden_channels=192,\n", " filter_channels=768,\n", " n_heads=2,\n", @@ -91,7 +91,7 @@ "c = torch.randn(1, 768, 56)\n", "c_lengths = torch.Tensor([56])\n", "ppgs = torch.randn(1, 40, 56)\n", - "spec = torch.randn(1, 128, 56)\n", + "spec = torch.randn(1, 80, 56)\n", "f0 = torch.randn(1, 1, 56)\n", "uv = torch.ones(1, 56)\n", "g = torch.randn(1, 512)\n", @@ -101,7 +101,17 @@ "print(f\"VC model parameter count: {sum(p.numel() for p in vc_model.parameters())}\")\n", "\n", "# (prior_loss, diff_loss, f0_pred, lf0)\n", - "vc_model(c=c, f0=f0, uv=uv, spec=spec, ppgs=ppgs, c_lengths=c_lengths)" + "vc_model(c=c, f0=f0, uv=uv, spec=spec, ppgs=ppgs, c_lengths=c_lengths)\n", + "vc_model.mel_encoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vc_model.decoder.estimator.blocks" ] }, { @@ -171,91 +181,93 @@ "metadata": {}, "outputs": [], "source": [ - "from glob import glob\n", - "\n", - "wavs_paths = glob(\"/home/alex/Projekte/TestData/Extended/Dexter/**/*.wav\", recursive=True)\n", + "import random\n", + "from pathlib import Path\n", "\n", - "with open(\"/home/alex/Projekte/so-vits-svc/filelists/gametts_train.txt\", \"w\") as f:\n", - " for path in wavs_paths:\n", - " f.write(f\"{path}|0\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import wget\n", + "train_all = [\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_borderlands2_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_baldursgate3_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_worldofwarcraft_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_mario_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/de_gametts_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/pl_archolos_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/de_borderlands2_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_warcraft_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_sqnarrator_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_emotional_train_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/de_emotional_train_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/ru_witcher3_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_witcher3_skyrim_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_fallout4_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_naruto_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/de_kcd_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/pl_witcher3_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/de_diablo4_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/en_diablo4_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/fr_diablo4_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/pl_diablo4_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/ru_diablo4_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/ru_skyrim_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/jp_one_piece_xphone.csv\",\n", + " \"/workspace/metadata/filelists/xphoneBERT/jp_skyrim_xphone.csv\",\n", + " \"/workspace/dataset/fr/Fallout4/fr_fallout4_xphone.csv\",\n", + " \"/workspace/dataset/de/Fallout4/de_fallout4_xphone.csv\",\n", + " \"/workspace/dataset/en/Fallout4/en_fallout4_xphone.csv\",\n", + "]\n", "\n", - "if not os.path.exists(\"/content/cfm-vc//workspace/pretrained_models/checkpoint_best_legacy_500.pt\"):\n", - " wget.download(\n", - " \"https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt\", out=\"/home/cfm-vc//workspace/pretrained_models/checkpoint_best_legacy_500.pt\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torchaudio\n", - "from tqdm import tqdm\n", + "all_lines = []\n", "\n", - "wav_paths = []\n", - "with open(\"/workspace/vc_test.csv\", \"r\") as f:\n", - " for line in f:\n", - " file_path = line.split(\"|\")[0]\n", - " wav_paths.append(file_path.strip())\n", + "for file in train_all:\n", + " with open(file, \"r\") as f:\n", + " lines = f.readlines()\n", + " all_lines.extend(lines)\n", "\n", - "bad_files = []\n", - "for path in tqdm(wav_paths):\n", - " try:\n", - " audio, sr = torchaudio.load(path)\n", - " except:\n", - " bad_files.append(path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", + "random.shuffle(all_lines)\n", "\n", - "for f in bad_files:\n", - " if os.path.exists(f):\n", - " os.remove(f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "all_lines = []\n", + "files_max_per_speaker = 300\n", + "min_audio_length = 0.3 * 22050\n", + "max_audio_length = 12.0 * 22050\n", "\n", + "speaker_files_dict = {}\n", "\n", - "with open(\"/workspace/vc_train.csv\", \"r\") as f:\n", - " for line in f:\n", - " file_path = line.split(\"|\")[0]\n", - " if os.path.exists(file_path):\n", - " all_lines.append(line)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(\"/workspace/vc_train_2.csv\", \"w\") as wf:\n", + "with open(\n", + " \"/workspace/vocoder_train.csv\", \"w\"\n", + ") as wf:\n", " for line in all_lines:\n", - " wf.write(line)" + " cols = line.split(\"|\")\n", + " filename = cols[0]\n", + " speaker = cols[1]\n", + " text = cols[-2]\n", + " text_orig = cols[-1]\n", + " \n", + " filename = filename.replace(\"/mnt/datasets/TTS_Data\", \"/workspace/dataset\")\n", + "\n", + " if any(\n", + " v in text_orig\n", + " for v in [\"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\", \"v8\", \"v9\", \"v10\"]\n", + " ):\n", + " continue\n", + "\n", + " if not Path(filename).exists():\n", + " continue\n", + "\n", + " if max_audio_length < Path(filename).stat().st_size // 2 < min_audio_length:\n", + " continue\n", + "\n", + " if any(char in \"#[]{}*\" for char in text_orig):\n", + " continue\n", + "\n", + " # if len(text) < 4 or len(text) > 350:\n", + " # continue\n", + "\n", + " if speaker not in speaker_files_dict:\n", + " speaker_files_dict[speaker] = []\n", + " speaker_files_dict[speaker].append(line)\n", + " wf.write(f\"{filename}|{speaker}\\n\")\n", + " else:\n", + " if len(speaker_files_dict[speaker]) < files_max_per_speaker:\n", + " speaker_files_dict[speaker].append(line)\n", + " wf.write(f\"{filename}|{speaker}\\n\")" ] } ], diff --git a/train.py b/train.py index f4c2d29..3b20054 100644 --- a/train.py +++ b/train.py @@ -25,6 +25,8 @@ global_step = 0 start_time = time.time() +# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging. + class ModelEmaV2(torch.nn.Module): def __init__(self, model, decay=0.9999, device=None): @@ -249,8 +251,8 @@ def train_and_evaluate( ) with autocast(enabled=False, dtype=half_type): - f0_loss = F.smooth_l1_loss(f0_pred, lf0.detach()) - loss_gen_all = diff_loss + prior_loss + f0_loss # + reversal_loss + f0_loss = F.smooth_l1_loss(f0_pred, lf0.detach()) # f0 loss + loss_gen_all = diff_loss + prior_loss + f0_loss optim_g.zero_grad() scaler.scale(loss_gen_all).backward() @@ -263,7 +265,7 @@ def train_and_evaluate( if rank == 0: if global_step % hps.train.log_interval == 0: lr = optim_g.param_groups[0]["lr"] - losses = [diff_loss, prior_loss] + losses = [diff_loss, prior_loss, f0_loss] reference_loss = 0 for i in losses: reference_loss += i diff --git a/vdecoder/IstftGenerator/model.py b/vdecoder/IstftGenerator/model.py index ce6531b..af371bf 100644 --- a/vdecoder/IstftGenerator/model.py +++ b/vdecoder/IstftGenerator/model.py @@ -110,8 +110,8 @@ class ISTFTHead(nn.Module): def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): super().__init__() - out_dim = n_fft + 2 - # self.out = torch.nn.Linear(dim, out_dim) + + # self.out = nn.Linear(dim, out_dim) self.istft = ISTFT( n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding ) @@ -127,7 +127,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ - # x = self.out(x).transpose(1, 2) + x = x.float() mag, p = x.chunk(2, dim=1) mag = torch.exp(mag) mag = torch.clip( @@ -135,7 +135,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # safeguard to prevent excessively large magnitudes S = torch.polar(mag, p) audio = self.istft(S) - return audio + return audio.unsqueeze(1) class Generator(nn.Module): @@ -151,36 +151,37 @@ class Generator(nn.Module): def __init__( self, - input_channels: int, - num_layers: int, + hparams, ): super().__init__() - self.num_layers = num_layers - self.input_channels = input_channels - - upsample_rates = [8, 8] - upsample_kernel_sizes = [16, 16] - upsample_initial_channel = 512 - resblock_kernel_sizes = [3, 7, 11] - resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + self.input_channels = hparams.num_mels + self.upsample_rates = hparams.upsample_rates + self.upsample_kernel_sizes = hparams.upsample_kernel_sizes + self.upsample_initial_channel = hparams.upsample_initial_channel + self.resblock_kernel_sizes = hparams.resblock_kernel_sizes + self.resblock_dilation_sizes = hparams.resblock_dilation_sizes + self.post_n_fft = hparams.post_n_fft + self.post_hop_size = hparams.post_hop_size - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.post_n_fft = 32 - self.post_hop_size = 8 + self.num_kernels = len(self.resblock_kernel_sizes) + self.num_upsamples = len(self.upsample_rates) self.conv_pre = weight_norm( - nn.Conv1d(input_channels, upsample_initial_channel, 7, 1, padding=3) + nn.Conv1d( + self.input_channels, self.upsample_initial_channel, 7, 1, padding=3 + ) ) self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + for i, (u, k) in enumerate( + zip(self.upsample_rates, self.upsample_kernel_sizes) + ): self.ups.append( weight_norm( nn.ConvTranspose1d( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), + self.upsample_initial_channel // (2**i), + self.upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2, @@ -190,9 +191,9 @@ def __init__( self.resblocks = nn.ModuleList() for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) + ch = self.upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes) + zip(self.resblock_kernel_sizes, self.resblock_dilation_sizes) ): self.resblocks.append(ResBlock1(ch, k, d)) @@ -201,7 +202,6 @@ def __init__( ) self.ups.apply(init_weights) self.conv_post.apply(init_weights) - # self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) self.head = ISTFTHead( dim=self.post_n_fft + 2, @@ -218,18 +218,13 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: - # x = self.embed(x) - # x = self.norm(x.transpose(1, 2)) - # x = x.transpose(1, 2) - # for conv_block in self.convnext: - # x = conv_block(x) - - # x = self.final_layer_norm(x.transpose(1, 2)) - + # prenet x = self.conv_pre(x) + for i in range(self.num_upsamples): x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) + xs = None for j in range(self.num_kernels): if xs is None: @@ -239,7 +234,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = xs / self.num_kernels x = F.leaky_relu(x) - # x = self.reflection_pad(x) x = self.conv_post(x) audio_output = self.head(x) @@ -248,9 +242,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def remove_weight_norm(self): print("Removing weight norm...") - for l in self.ups: - remove_parametrizations(l, "weight") - for l in self.resblocks: - l.remove_weight_norm() + for conv in self.ups: + remove_parametrizations(conv, "weight") + for conv in self.resblocks: + conv.remove_weight_norm() remove_parametrizations(self.conv_pre, "weight") remove_parametrizations(self.conv_post, "weight") diff --git a/vdecoder/IstftGenerator/module.py b/vdecoder/IstftGenerator/module.py index 93f9595..36c0740 100644 --- a/vdecoder/IstftGenerator/module.py +++ b/vdecoder/IstftGenerator/module.py @@ -100,7 +100,7 @@ def forward(self, x): return x def remove_weight_norm(self): - for l in self.convs1: - remove_parametrizations(l, "weight") - for l in self.convs2: - remove_parametrizations(l, "weight") + for conv in self.convs1: + remove_parametrizations(conv, "weight") + for conv in self.convs2: + remove_parametrizations(conv, "weight")