forked from as-ideas/ForwardTacotron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_forward.py
103 lines (89 loc) · 4.58 KB
/
train_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import itertools
from pathlib import Path
import os
import torch
from torch import optim
from torch.utils.data.dataloader import DataLoader
from models.forward_tacotron import ForwardTacotron
from models.tacotron import Tacotron
from trainer.forward_trainer import ForwardTrainer
from utils import hparams as hp
from utils.checkpoints import restore_checkpoint
from utils.dataset import get_tts_datasets
from utils.display import *
from utils.paths import Paths
from utils.text.symbols import phonemes
def create_gta_features(model: Tacotron,
train_set: DataLoader,
val_set: DataLoader,
save_path: Path):
model.eval()
device = next(model.parameters()).device # use same device as model parameters
iters = len(train_set) + len(val_set)
dataset = itertools.chain(train_set, val_set)
for i, (x, mels, ids, mel_lens, dur) in enumerate(dataset, 1):
x, mels, dur = x.to(device), mels.to(device), dur.to(device)
with torch.no_grad():
_, gta, _ = model(x, mels, dur)
gta = gta.cpu().numpy()
for j, item_id in enumerate(ids):
mel = gta[j][:, :mel_lens[j]]
np.save(str(save_path/f'{item_id}.npy'), mel, allow_pickle=False)
bar = progbar(i, iters)
msg = f'{bar} {i}/{iters} Batches '
stream(msg)
if __name__ == '__main__':
# Parse Arguments
parser = argparse.ArgumentParser(description='Train Tacotron TTS')
parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
args = parser.parse_args()
hp.configure(args.hp_file) # Load hparams from file
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
assert len(os.listdir(paths.alg)) > 0, f'Could not find alignment files in {paths.alg}, please predict ' \
f'alignments first with python train_tacotron.py --force_align!'
force_gta = args.force_gta
if not args.force_cpu and torch.cuda.is_available():
device = torch.device('cuda')
for session in hp.forward_schedule:
_, _, batch_size = session
if batch_size % torch.cuda.device_count() != 0:
raise ValueError('`batch_size` must be evenly divisible by n_gpus!')
else:
device = torch.device('cpu')
print('Using device:', device)
# Instantiate Forward TTS Model
print('\nInitialising Forward TTS Model...\n')
model = ForwardTacotron(embed_dims=hp.forward_embed_dims,
num_chars=len(phonemes),
durpred_rnn_dims=hp.forward_durpred_rnn_dims,
durpred_conv_dims=hp.forward_durpred_conv_dims,
durpred_dropout=hp.forward_durpred_dropout,
pitch_rnn_dims=hp.forward_pitch_rnn_dims,
pitch_conv_dims=hp.forward_pitch_conv_dims,
pitch_dropout=hp.forward_pitch_dropout,
pitch_emb_dims=hp.forward_pitch_emb_dims,
pitch_proj_dropout=hp.forward_pitch_proj_dropout,
rnn_dim=hp.forward_rnn_dims,
postnet_k=hp.forward_postnet_K,
postnet_dims=hp.forward_postnet_dims,
prenet_k=hp.forward_prenet_K,
prenet_dims=hp.forward_prenet_dims,
highways=hp.forward_num_highways,
dropout=hp.forward_dropout,
n_mels=hp.num_mels).to(device)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f'num params {params}')
optimizer = optim.Adam(model.parameters())
restore_checkpoint('forward', paths, model, optimizer, create_if_missing=True)
if force_gta:
print('Creating Ground Truth Aligned Dataset...\n')
train_set, val_set = get_tts_datasets(paths.data, 8, r=1, model_type='forward')
create_gta_features(model, train_set, val_set, paths.gta)
print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n')
else:
trainer = ForwardTrainer(paths)
trainer.train(model, optimizer)