-
Notifications
You must be signed in to change notification settings - Fork 20
/
tacotron2torch.cpp
76 lines (42 loc) · 1.66 KB
/
tacotron2torch.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include "tacotron2torch.h"
Tacotron2Torch::Tacotron2Torch()
{
}
bool Tacotron2Torch::Initialize(const std::string &SavedModelFolder, ETTSRepo::Enum InTTSRepo)
{
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
Model = torch::jit::load(SavedModelFolder);
}
catch (const c10::Error& e) {
return false;
}
CurrentRepo = InTTSRepo;
return true;
}
TFTensor<float> Tacotron2Torch::DoInference(const std::vector<int32_t> &InputIDs, const std::vector<float> &ArgsFloat, const std::vector<int32_t> ArgsInt, int32_t SpeakerID, int32_t EmotionID)
{
// without this memory consumption is 4x
torch::NoGradGuard no_grad;
std::vector<int64_t> IInputIDs;
IInputIDs.reserve(InputIDs.size());
for (const int32_t& Id : InputIDs){
int64_t casted = (int64_t)Id;
IInputIDs.push_back(casted);
}
torch::TensorOptions Opts = torch::TensorOptions().requires_grad(false);
// This Tacotron2 always takes in speaker IDs
if (SpeakerID == -1)
SpeakerID = 0;
auto InSpkid = torch::tensor({SpeakerID},Opts);
auto InIDS = torch::tensor(IInputIDs, Opts).unsqueeze(0);
std::vector<torch::jit::IValue> inputs{ InSpkid,InIDS};
// Infer
c10::IValue Output = Model(inputs);
// Output = list (mel_outputs, mel_outputs_postnet, gate_outputs, alignments)
auto OutputL = Output.toList();
auto MelTens = OutputL[1].get().toTensor();
auto AttTens = OutputL[3].get().toTensor();//.transpose(1,2); // [1, dec_t, enc_t ] -> [1, enc_t, dec_t]
Attention = VoxUtil::CopyTensor<float>(AttTens);
return VoxUtil::CopyTensor<float>(MelTens);
}