Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated FastSpeech 2 architecture #24

Open
wants to merge 42 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
330c040
pitch cwt, std, and mean added
carankt Nov 25, 2020
f43ffd2
update dataloader
carankt Nov 25, 2020
180c0a1
:ship: Add training f0 modification code
rishikksh20 Nov 27, 2020
eb55497
update dataloader
carankt Nov 27, 2020
136bb02
Updated training code for version 2
rishikksh20 Nov 27, 2020
58284ec
:books: Update code for training
rishikksh20 Nov 29, 2020
df45948
add avg and p per utterance
carankt Nov 30, 2020
3e9ad15
:star: Training code
rishikksh20 Nov 30, 2020
8908d22
:bug: Detach the hidden space to stop gradient flow
rishikksh20 Dec 1, 2020
cb9d19d
add icwt
carankt Dec 1, 2020
54e7bca
bug fix for inference
carankt Dec 1, 2020
ff22962
:books: Fully runnable with inference code
rishikksh20 Dec 1, 2020
752e516
:bug: Update evaluate code
rishikksh20 Dec 1, 2020
a9a87e9
pitch preprocessing update
carankt Dec 2, 2020
71334b1
Update pitch.py
carankt Dec 2, 2020
bbaf968
Update pitch.py
carankt Dec 2, 2020
6f2316a
Update .gitignore
carankt Dec 3, 2020
6d7f2b4
cwt and icwt same lib
carankt Dec 3, 2020
a1ecd20
hs detach
carankt Dec 3, 2020
4e1f46c
Update pitch.py
carankt Dec 3, 2020
451c9ea
Update pitch.py
carankt Dec 3, 2020
e3b3a0e
pitch cwt update
carankt Dec 4, 2020
06830b2
inference fix
carankt Dec 4, 2020
b23e965
:books: Code clean-up
rishikksh20 Dec 5, 2020
07d072d
:bugs: Remove bugs, code clean up and update/add requirement
rishikksh20 Dec 5, 2020
b39385c
add cwt bins parameter
carankt Dec 14, 2020
86d3304
Update pitch_mod.py
carankt Dec 14, 2020
25e519c
Update fastspeech.py
carankt Dec 14, 2020
f0b977d
plot 2 waveform to writer
carankt Dec 15, 2020
4b01394
take predicted pitch and energy during inference
carankt Dec 15, 2020
b03dfa0
add pitch and energy predicted and GT plots
carankt Dec 15, 2020
09a9789
predict log energy
carankt Dec 15, 2020
077d2fe
do exp in inference for energy
carankt Dec 15, 2020
43c8962
output predicted p and e in inference
carankt Dec 15, 2020
f59a2f1
Update inference.py
carankt Dec 16, 2020
3fde1cd
Merge branch 'version2' of https://github.com/rishikksh20/FastSpeech2…
carankt Dec 16, 2020
65cea6d
Update nvidia_preprocessing.py
carankt Dec 16, 2020
7cf53b2
Merge branch 'version2' of https://github.com/rishikksh20/FastSpeech2…
carankt Dec 16, 2020
f130bd8
get predicted p and e
carankt Dec 16, 2020
69165d9
:books: Add try catch on pre-processing part
rishikksh20 Dec 17, 2020
e369ebc
Merge branch 'version2' of https://github.com/rishikksh20/FastSpeech2…
carankt Dec 21, 2020
081906c
line wise inference
carankt Dec 21, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,27 @@ idea/*
/trace_loss_nvidia.txt
/conf
/etc
.ipynb_checkpoints/Untitled-checkpoint.ipynb
dataset/audio/__pycache__/__init__.cpython-36.pyc
*.pyc
Untitled.ipynb
mel.npy
*.png
*.npy
Testing/2log_v2/no_exp_before_bins_fs2v2_2_31k_test_tts.wav
Testing/exp_log/test_tts.wav
Testing/exp_log_v2/exp_before_bins_fs2v2_2_31k_test_tts.wav
mel.png
mel.npy
Testing/v2_2/test_tts.wav
*.npy
*.png
mel.png
*.wav
*.npy
.ipynb_checkpoints/pitch_cwt-checkpoint.ipynb
pitch_cwt.ipynb
*.wav
Testing/test_tts.wav
*.wav
Testing/test_tts.wav
Binary file added Testing/test_tts.wav
Binary file not shown.
15 changes: 8 additions & 7 deletions configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
data:
data_dir: 'H:\Deepsync\backup\fastspeech\data\'
wav_dir: 'H:\Deepsync\backup\deepsync\LJSpeech-1.1\wavs\'
data_dir: './data/LJSpeech/good_file/'
wav_dir: '/mnt/Karan/LJSpeech-1.1/wavs/'
# Compute statistics
e_mean: 21.578571319580078
e_std: 18.916799545288086
Expand All @@ -10,7 +10,7 @@ data:
f0_mean: 206.5135564772342
f0_std: 53.633228905750336
p_min: 71.0
p_max: 676.2260946528305 # 799.8901977539062
p_max: 500.0 # 799.8901977539062
train_filelist: "./filelists/train_filelist.txt"
valid_filelist: "./filelists/valid_filelist.txt"
tts_cleaner_names: ['english_cleaners']
Expand All @@ -30,6 +30,7 @@ audio:
bits : 9 # bit depth of signal
mu_law : True # Recommended to suppress noise if using raw bits in hp.voc_mode below
peak_norm : False # Normalise to the peak of each wav file
cwt_bins : 10



Expand All @@ -46,7 +47,7 @@ model:
aheads: 2
elayers: 4
eunits: 1024
ddim: 384
ddim: 256
dlayers: 4
dunits: 1024
positionwise_layer_type : "conv1d" # linear
Expand Down Expand Up @@ -110,7 +111,7 @@ train:
# optimization related
eos: False #True
opt: 'noam'
accum_grad: 4
accum_grad: 1
grad_clip: 1.0
weight_decay: 0.001
patience: 0
Expand All @@ -126,7 +127,7 @@ train:
seed: 1 # random seed number
resume: "" # the snapshot path to resume (if set empty, no effect)
use_phonemes: True
batch_size : 16
batch_size : 48
# other
melgan_vocoder : True
save_interval : 1000
Expand All @@ -135,4 +136,4 @@ train:
summary_interval : 200
validation_step : 500
tts_max_mel_len : 870 # if you have a couple of extremely long spectrograms you might want to use this
tts_bin_lengths : True # bins the spectrogram lengths before sampling in data loader - speeds up training
tts_bin_lengths : True # bins the spectrogram lengths before sampling in data loader - speeds up training
100 changes: 92 additions & 8 deletions core/variance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import torch.nn.functional as F
from typing import Optional
from core.modules import LayerNorm

#import pycwt
import numpy as np
from sklearn import preprocessing

class VariancePredictor(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -149,7 +151,11 @@ def inference(self, xs: torch.Tensor, alpha: float = 1.0):

"""
out = self.predictor.inference(xs, False, alpha=alpha)
return self.to_one_hot(out) # Need to do One hot code
#print(out.shape, type(out))
#out = torch.from_numpy(np.load("/results/chkpts/LJ/Fastspeech2_V2/data/energy/LJ001-0001.npy")).cuda()
#print(out, "Energy Pricted")
out = torch.exp(out)
return self.to_one_hot(out), out # Need to do One hot code

def to_one_hot(self, x):
# e = de_norm_mean_std(e, hp.e_mean, hp.e_std)
Expand All @@ -171,6 +177,7 @@ def __init__(
min=0,
max=0,
n_bins=256,
out=5,
):
"""Initilize pitch predictor module.

Expand All @@ -195,9 +202,29 @@ def __init__(
)
),
)
self.predictor = VariancePredictor(idim)
self.offset = offset
self.conv = torch.nn.ModuleList()
for idx in range(n_layers):
in_chans = idim if idx == 0 else n_chans
self.conv += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chans,
n_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
),
torch.nn.ReLU(),
LayerNorm(n_chans),
torch.nn.Dropout(dropout_rate),
)
]
self.spectrogram_out = torch.nn.Linear(n_chans, out)
self.mean = torch.nn.Linear(n_chans, 1)
self.std = torch.nn.Linear(n_chans, 1)

def forward(self, xs: torch.Tensor, x_masks: torch.Tensor):
def forward(self, xs: torch.Tensor, olens: torch.Tensor, x_masks: torch.Tensor):
"""Calculate forward propagation.

Args:
Expand All @@ -208,9 +235,42 @@ def forward(self, xs: torch.Tensor, x_masks: torch.Tensor):
Tensor: Batch of predicted durations in log domain (B, Tmax).

"""
return self.predictor(xs, x_masks)
xs = xs.transpose(1, -1) # (B, idim, Tmax)
for f in self.conv:
xs = f(xs) # (B, C, Tmax)

def inference(self, xs: torch.Tensor, alpha: float = 1.0):
# NOTE: calculate in log domain
xs = xs.transpose(1, -1)
f0_spec = self.spectrogram_out(xs) # (B, Tmax, 10)

if x_masks is not None:
# print("olen:", olens)
#f0_spec = f0_spec.transpose(1, -1)
# print("F0 spec dimension:", f0_spec.shape)
# print("x_masks dimension:", x_masks.shape)
f0_spec = f0_spec.masked_fill(x_masks, 0.0)
#f0_spec = f0_spec.transpose(1, -1)
# print("F0 spec dimension:", f0_spec.shape)
#xs = xs.transpose(1, -1)
xs = xs.masked_fill(x_masks, 0.0)
#xs = xs.transpose(1, -1)
# print("xs dimension:", xs.shape)
x_avg = xs.sum(dim=1).squeeze(1)
# print(x_avg)
# print("xs dim :", x_avg.shape)
# print("olens ;", olens.shape)
if olens is not None:
x_avg = x_avg / olens.unsqueeze(1)
# print(x_avg)
f0_mean = self.mean(x_avg).squeeze(-1)
f0_std = self.std(x_avg).squeeze(-1)

# if x_masks is not None:
# f0_spec = f0_spec.masked_fill(x_masks, 0.0)

return f0_spec, f0_mean, f0_std

def inference(self, xs: torch.Tensor, olens = None, alpha: float = 1.0):
"""Inference duration.

Args:
Expand All @@ -221,8 +281,14 @@ def inference(self, xs: torch.Tensor, alpha: float = 1.0):
LongTensor: Batch of predicted durations in linear domain (B, Tmax).

"""
out = self.predictor.inference(xs, False, alpha=alpha)
return self.to_one_hot(out)
f0_spec, f0_mean, f0_std = self.forward(xs, olens, x_masks=None) # (B, Tmax, 10)
#print(f0_spec)
f0_reconstructed = self.inverse(f0_spec, f0_mean, f0_std)
#print(f0_reconstructed)
#f0_reconstructed = torch.from_numpy(np.load("/results/chkpts/LJ/Fastspeech2_V2/data/pitch/LJ001-0001.npy").reshape(1,-1)).cuda()
#print(f0_reconstructed, "Pitch coef output")

return self.to_one_hot(f0_reconstructed), f0_reconstructed

def to_one_hot(self, x: torch.Tensor):
# e = de_norm_mean_std(e, hp.e_mean, hp.e_std)
Expand All @@ -231,6 +297,24 @@ def to_one_hot(self, x: torch.Tensor):
quantize = torch.bucketize(x, self.pitch_bins).to(device=x.device) # .cuda()
return F.one_hot(quantize.long(), 256).float()

def inverse(self, Wavelet_lf0, f0_mean, f0_std):
scales = np.array([0.01, 0.02, 0.04, 0.08, 0.16]) #np.arange(1,11)
#print(Wavelet_lf0.shape)
Wavelet_lf0 = Wavelet_lf0.squeeze(0).cpu().numpy()
lf0_rec = np.zeros([Wavelet_lf0.shape[0], len(scales)])
for i in range(0,len(scales)):
lf0_rec[:,i] = Wavelet_lf0[:,i]*((i+200+2.5)**(-2.5))

lf0_rec_sum = np.sum(lf0_rec,axis = 1)
lf0_rec_sum_norm = preprocessing.scale(lf0_rec_sum)

f0_reconstructed = (torch.Tensor(lf0_rec_sum_norm).cuda()*f0_std) + f0_mean

f0_reconstructed = torch.exp(f0_reconstructed)
#print(f0_reconstructed.shape)
#print(f0_reconstructed.shape)
return f0_reconstructed.reshape(1,-1)


class PitchPredictorLoss(torch.nn.Module):
"""Loss function module for duration predictor.
Expand Down
Empty file added dataset/audio/__init__.py
Empty file.
Loading