Skip to content

Commit

Permalink
Fixed quantization of wav files. Avoid out of bounds.
Browse files Browse the repository at this point in the history
  • Loading branch information
geneing committed May 10, 2019
1 parent d57ff16 commit 7b317c4
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 39 deletions.
75 changes: 44 additions & 31 deletions audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ def _stft(y):
else:
return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=hparams.hop_size, win_length=hparams.win_size, pad_mode='constant')

# def melspectrogram(y):
# D = _stft(preemphasis(y))
# S = _amp_to_db(_linear_to_mel(np.abs(D)**hparams.magnitude_power)) - hparams.ref_level_db
# if not hparams.allow_clipping_in_normalization:
# assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
# return _normalize(S)

def melspectrogram(y):
D = _stft(preemphasis(y))
S = _amp_to_db(_linear_to_mel(np.abs(D)**hparams.magnitude_power)) - hparams.ref_level_db
if not hparams.allow_clipping_in_normalization:
assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db
return _normalize(S)


def _lws_processor():
return lws.lws(hparams.win_size, hparams.hop_size, mode="speech")

Expand Down Expand Up @@ -98,41 +102,50 @@ def _db_to_amp(x):
# return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
#

# def _normalize(S):
# if hparams.allow_clipping_in_normalization:
# if hparams.symmetric_mels:
# return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
# -hparams.max_abs_value, hparams.max_abs_value)
# else:
# return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
#
# assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
# if hparams.symmetric_mels:
# return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
# else:
# return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
#
# def _denormalize(D):
# if hparams.allow_clipping_in_normalization:
# if hparams.symmetric_mels:
# return (((np.clip(D, -hparams.max_abs_value,
# hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
# + hparams.min_level_db)
# else:
# return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
#
# if hparams.symmetric_mels:
# return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
# else:
# return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)

def _normalize(S):
if hparams.allow_clipping_in_normalization:
if hparams.symmetric_mels:
return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
-hparams.max_abs_value, hparams.max_abs_value)
else:
return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)

assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
if hparams.symmetric_mels:
return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
else:
return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))

def _denormalize(D):
if hparams.allow_clipping_in_normalization:
if hparams.symmetric_mels:
return (((np.clip(D, -hparams.max_abs_value,
hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
+ hparams.min_level_db)
else:
return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)

if hparams.symmetric_mels:
return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
else:
return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
# symmetric mels
return 2 * hparams.max_abs_value * ((S - hparams.min_level_db) / -hparams.min_level_db) - hparams.max_abs_value

def _denormalize(S):
# symmetric mels
return ((S + hparams.max_abs_value) * -hparams.min_level_db) / (2 * hparams.max_abs_value) + hparams.min_level_db


# Fatcord's preprocessing
def quantize(x):
"""quantize audio signal
"""
quant = (x + 1.) * (2**hparams.bits - 1) / 2
x = np.clip(x, -1., 1.)
quant = ((x + 1.)/2.) * (2**hparams.bits - 1)
return quant.astype(np.int)


Expand Down
32 changes: 32 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,38 @@ def __getitem__(self, index):
def __len__(self):
return len(self.metadata)

class Tacotron2Dataset(Dataset):
def __init__(self, data_path):
self.metadata=[]
self.path = os.path.join(data_path, "")
with open(os.path.join(self.path,'train.txt'), 'r', newline='') as f:
csvreader = csv.reader(f, delimiter='|')
for row in csvreader:
self.metadata.append(row)

self.mel_path = os.path.join(data_path, "mels")
self.wav_path = os.path.join(data_path, "audio")
self.test_path = os.path.join(data_path, "mels")

def __getitem__(self, index):
entry = self.metadata[index]
m = np.load(os.path.join(self.mel_path, entry[1])).T
wav = np.load(os.path.join(self.wav_path, entry[0]))

if hp.input_type == 'raw' or hp.input_type=='mixture':
wav = wav.astype(np.float32)
elif hp.input_type == 'mulaw':
wav = mulaw_quantize(wav, hp.mulaw_quantize_channels).astype(np.int)
elif hp.input_type == 'bits':
wav = quantize(wav).astype(np.int)
else:
raise ValueError("hp.input_type {} not recognized".format(hp.input_type))
return m, wav

def __len__(self):
return len(self.metadata)


class MozillaTTS(Dataset):
def __init__(self, data_path):
self.metadata=[]
Expand Down
9 changes: 5 additions & 4 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
preemphasize = False, #whether to apply filter
preemphasis = 0.97, #filter coefficient.

magnitude_power=1., #The power of the spectrogram magnitude (1. for energy, 2. for power)
magnitude_power=2., #The power of the spectrogram magnitude (1. for energy, 2. for power)

# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
# It's preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
Expand Down Expand Up @@ -80,7 +80,7 @@
save_every_step=10000,
evaluate_every_step=10000,
# seq_len_factor can be adjusted to increase training sequence length (will increase GPU usage)
seq_len_factor=5,
seq_len_factor=7,

grad_norm=10,
# learning rate parameters
Expand All @@ -92,9 +92,10 @@
lr_step_interval=15000,

# sparsification
start_prune=40000,
start_prune=80000,
prune_steps=80000, # 20000
sparsity_target=0.85,
sparsity_target=0.90,
sparsity_target_rnn=0.90,
sparse_group=4,

adam_beta1=0.9,
Expand Down
10 changes: 7 additions & 3 deletions synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@
if mel.shape[0] > mel.shape[1]: #ugly hack for transposed mels
mel = mel.T

flist = glob.glob(f'{checkpoint_dir}/checkpoint_*.pth')
latest_checkpoint = max(flist, key=os.path.getctime)
if checkpoint_path is None:
flist = glob.glob(f'{checkpoint_dir}/checkpoint_*.pth')
latest_checkpoint = max(flist, key=os.path.getctime)
else:
latest_checkpoint = checkpoint_path
print('Loading: %s'%latest_checkpoint)
# build model, create optimizer
model = build_model().to(device)
Expand All @@ -64,7 +67,7 @@
#print("rnn2: %.3f million"%(num_params_count(model.rnn2)))
print("fc1: %.3f million"%(num_params_count(model.fc1)))
#print("fc2: %.3f million"%(num_params_count(model.fc2)))
#print("fc3: %.3f million"%(num_params_count(model.fc3)))
print("fc3: %.3f million"%(num_params_count(model.fc3)))


#onnx export
Expand All @@ -81,6 +84,7 @@


mel0 = mel.copy()
mel0=np.hstack([np.ones([80,40])*(-4), mel0, np.ones([80,40])*(-4)])
start = time.time()
output0 = model.generate(mel0, batched=False, target=2000, overlap=64)
total_time = time.time() - start
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def train_loop(device, model, data_loader, optimizer, checkpoint_dir):
raise ValueError("input_type:{} not supported".format(hp.input_type))

# Pruner for reducing memory footprint
layers = [(model.I,hp.sparsity_target), (model.rnn1,hp.sparsity_target), (model.fc1,hp.sparsity_target), (model.fc3,hp.sparsity_target)] #(model.fc2,hp.sparsity_target),
layers = [(model.I,hp.sparsity_target), (model.rnn1,hp.sparsity_target_rnn), (model.fc1,hp.sparsity_target), (model.fc3,hp.sparsity_target)] #(model.fc2,hp.sparsity_target),
pruner = Pruner(layers, hp.start_prune, hp.prune_steps, hp.sparsity_target)

global global_step, global_epoch, global_test_step
Expand Down Expand Up @@ -445,6 +445,8 @@ def test_prune(model):
except KeyboardInterrupt:
print("Interrupted!")
pass
except Exception as e:
print(e)
finally:
print("saving model....")
save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch)
Expand Down

0 comments on commit 7b317c4

Please sign in to comment.