forked from r9y9/deepvoice3_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
synthesis.py
168 lines (138 loc) · 6.06 KB
/
synthesis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# coding: utf-8
"""
Synthesis waveform from trained model.
usage: synthesis.py [options] <checkpoint> <text_list_file> <dst_dir>
options:
--hparams=<parmas> Hyper parameters [default: ].
--preset=<json> Path of preset parameters (json).
--checkpoint-seq2seq=<path> Load seq2seq model from checkpoint path.
--checkpoint-postnet=<path> Load postnet model from checkpoint path.
--file-name-suffix=<s> File name suffix [default: ].
--max-decoder-steps=<N> Max decoder steps [default: 500].
--replace_pronunciation_prob=<N> Prob [default: 0.0].
--speaker_id=<id> Speaker ID (for multi-speaker model).
--output-html Output html for blog post.
-h, --help Show help message.
"""
from docopt import docopt
import sys
import os
from os.path import dirname, join, basename, splitext
import audio
import torch
import numpy as np
import nltk
# The deepvoice3 model
from deepvoice3_pytorch import frontend
from hparams import hparams, hparams_debug_string
from tqdm import tqdm
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
_frontend = None # to be set later
def tts(model, text, p=0, speaker_id=None, fast=False):
"""Convert text to speech waveform given a deepvoice3 model.
Args:
text (str) : Input text to be synthesized
p (float) : Replace word to pronounciation if p > 0. Default is 0.
"""
model = model.to(device)
model.eval()
if fast:
model.make_generation_fast_()
sequence = np.array(_frontend.text_to_sequence(text, p=p))
sequence = torch.from_numpy(sequence).unsqueeze(0).long().to(device)
text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long().to(device)
speaker_ids = None if speaker_id is None else torch.LongTensor([speaker_id]).to(device)
# Greedy decoding
with torch.no_grad():
mel_outputs, linear_outputs, alignments, done = model(
sequence, text_positions=text_positions, speaker_ids=speaker_ids)
linear_output = linear_outputs[0].cpu().data.numpy()
spectrogram = audio._denormalize(linear_output)
alignment = alignments[0].cpu().data.numpy()
mel = mel_outputs[0].cpu().data.numpy()
mel = audio._denormalize(mel)
# Predicted audio signal
waveform = audio.inv_spectrogram(linear_output.T)
return waveform, alignment, spectrogram, mel
def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
if __name__ == "__main__":
args = docopt(__doc__)
print("Command line args:\n", args)
checkpoint_path = args["<checkpoint>"]
text_list_file_path = args["<text_list_file>"]
dst_dir = args["<dst_dir>"]
checkpoint_seq2seq_path = args["--checkpoint-seq2seq"]
checkpoint_postnet_path = args["--checkpoint-postnet"]
max_decoder_steps = int(args["--max-decoder-steps"])
file_name_suffix = args["--file-name-suffix"]
replace_pronunciation_prob = float(args["--replace_pronunciation_prob"])
output_html = args["--output-html"]
speaker_id = args["--speaker_id"]
if speaker_id is not None:
speaker_id = int(speaker_id)
preset = args["--preset"]
# Load preset if specified
if preset is not None:
with open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args["--hparams"])
assert hparams.name == "deepvoice3"
_frontend = getattr(frontend, hparams.frontend)
import train
train._frontend = _frontend
from train import plot_alignment, build_model
# Model
model = build_model()
# Load checkpoints separately
if checkpoint_postnet_path is not None and checkpoint_seq2seq_path is not None:
checkpoint = _load(checkpoint_seq2seq_path)
model.seq2seq.load_state_dict(checkpoint["state_dict"])
checkpoint = _load(checkpoint_postnet_path)
model.postnet.load_state_dict(checkpoint["state_dict"])
checkpoint_name = splitext(basename(checkpoint_seq2seq_path))[0]
else:
checkpoint = _load(checkpoint_path)
model.load_state_dict(checkpoint["state_dict"])
checkpoint_name = splitext(basename(checkpoint_path))[0]
model.seq2seq.decoder.max_decoder_steps = max_decoder_steps
os.makedirs(dst_dir, exist_ok=True)
with open(text_list_file_path, "rb") as f:
lines = f.readlines()
for idx, line in enumerate(lines):
text = line.decode("utf-8")[:-1]
words = nltk.word_tokenize(text)
waveform, alignment, _, _ = tts(
model, text, p=replace_pronunciation_prob, speaker_id=speaker_id, fast=True)
dst_wav_path = join(dst_dir, "{}_{}{}.wav".format(
idx, checkpoint_name, file_name_suffix))
dst_alignment_path = join(
dst_dir, "{}_{}{}_alignment.png".format(idx, checkpoint_name,
file_name_suffix))
plot_alignment(alignment.T, dst_alignment_path,
info="{}, {}".format(hparams.builder, basename(checkpoint_path)))
audio.save_wav(waveform, dst_wav_path)
name = splitext(basename(text_list_file_path))[0]
if output_html:
print("""
{}
({} chars, {} words)
<audio controls="controls" >
<source src="/audio/{}/{}/{}" autoplay/>
Your browser does not support the audio element.
</audio>
<div align="center"><img src="/audio/{}/{}/{}" /></div>
""".format(text, len(text), len(words),
hparams.builder, name, basename(dst_wav_path),
hparams.builder, name, basename(dst_alignment_path)))
else:
print(idx, ": {}\n ({} chars, {} words)".format(text, len(text), len(words)))
print("Finished! Check out {} for generated audio samples.".format(dst_dir))
sys.exit(0)