forked from KinglittleQ/GST-Tacotron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
80 lines (63 loc) · 2 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from utils import *
from Data import get_eval_data
from Hyperparameters import Hyperparameters as hp
import torch
from scipy.io.wavfile import write
from Network import *
from pypinyin import lazy_pinyin, Style
device = torch.device('cpu')
def synthesis(model, eval_text):
eval_text = _pinyin(eval_text)
model.eval()
# ref_wavs = [
# 'ref_wav/nannan.wav', 'ref_wav/xiaofeng.wav', 'ref_wav/donaldduck.wav'
# ]
ref_wavs = [
'ref_wav/nannan.wav',
'ref_wav/xiaofeng.wav',
'ref_wav/donaldduck.wav'
]
speakers = ['nannan', 'xiaofeng', 'donaldduck']
wavs = {}
for ref_wav, speaker in zip(ref_wavs, speakers):
text, GO, ref_mels = get_eval_data(eval_text, ref_wav)
text = text.to(device)
GO = GO.to(device)
ref_mels = ref_mels.to(device)
mel_hat, mag_hat, attn = model(text, GO, ref_mels)
mag_hat = mag_hat.squeeze().detach().cpu().numpy()
attn = attn.squeeze().detach().cpu().numpy()
wav_hat = spectrogram2wav(mag_hat)
wavs[speaker] = wav_hat
return wavs
def load_model(checkpoint_path):
model = Tacotron().to(device)
model.load_state_dict(
torch.load(
checkpoint_path, map_location=lambda storage, location: storage))
return model
def _pinyin(s):
symbols = '0123456789abcdefghijklmnopqrstuvwxyz '
s = lazy_pinyin(s, style=Style.TONE2)
yin = []
for token in s:
if token != ' ':
a = ''
for c in token:
if c in symbols:
a += c
yin.append(a)
a = ''
s = ' '.join(yin)
for i in range(len(s)):
if s[i] == ' ' and i < len(s) - 1 and s[i + 1] == ' ':
continue
a += s[i]
return a
if __name__ == '__main__':
text = '''毛主席是中国的红太阳'''
model = load_model('checkpoint/epoch100.pt')
wavs = synthesis(model, text)
for k in wavs:
wav = wavs[k]
write('samples/{}.wav'.format(k), hp.sr, wav)