-
Notifications
You must be signed in to change notification settings - Fork 3
/
recognize.py
192 lines (151 loc) · 6.3 KB
/
recognize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import torch
import argparse
import soundfile as sf
import torch.nn.functional as F
import itertools as it
from fairseq import utils
from fairseq.models import BaseFairseqModel
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
from fairseq.data import Dictionary
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
parser = argparse.ArgumentParser(description='Wav2vec-2.0 Recognize')
parser.add_argument('--wav_path', type=str,
default='~/xxx.wav',
help='path of wave file')
parser.add_argument('--w2v_path', type=str,
default='~/wav2vec2_vox_960h.pt',
help='path of pre-trained wav2vec-2.0 model')
parser.add_argument('--target_dict_path', type=str,
default='dict.ltr.txt',
help='path of target dict (dict.ltr.txt)')
class Wav2VecCtc(BaseFairseqModel):
def __init__(self, w2v_encoder, args):
super().__init__()
self.w2v_encoder = w2v_encoder
self.args = args
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, args, target_dict):
"""Build a new model instance."""
base_architecture(args)
w2v_encoder = Wav2VecEncoder(args, target_dict)
return cls(w2v_encoder, args)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out"]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
class W2lDecoder(object):
def __init__(self, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = 1
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.asg_transitions = None
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
# encoder_out = models[0].encoder(**encoder_input)
encoder_out = models[0](**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B)
]
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == 'wordpiece':
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == 'letter':
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol is not None and symbol != 'none':
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
def get_feature(filepath):
def postprocess(feats, sample_rate):
if feats.dim == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
wav, sample_rate = sf.read(filepath)
feats = torch.from_numpy(wav).float()
feats = postprocess(feats, sample_rate)
return feats
def load_model(model_path, target_dict):
w2v = torch.load(model_path)
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
return [model]
def main():
args = parser.parse_args()
sample = dict()
net_input = dict()
feature = get_feature(args.wav_path)
target_dict = Dictionary.load(args.target_dict_path)
model = load_model(args.w2v_path, target_dict)
model[0].eval()
generator = W2lViterbiDecoder(target_dict)
net_input["source"] = feature.unsqueeze(0)
padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
net_input["padding_mask"] = padding_mask
sample["net_input"] = net_input
with torch.no_grad():
hypo = generator.generate(model, sample, prefix_tokens=None)
hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
print(post_process(hyp_pieces, 'letter'))
if __name__ == '__main__':
main()