-
Notifications
You must be signed in to change notification settings - Fork 31
/
ctc_decoder.py
112 lines (90 loc) · 3.81 KB
/
ctc_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Author: Awni Hannun
This is an example CTC decoder written in Python. The code is
intended to be a simple example and is not designed to be
especially efficient.
The algorithm is a prefix beam search for a model trained
with the CTC loss function.
For more details checkout either of these references:
https://distill.pub/2017/ctc/#inference
https://arxiv.org/abs/1408.2873
"""
import numpy as np
import math
import collections
NEG_INF = -float("inf")
def make_new_beam():
fn = lambda : (NEG_INF, NEG_INF)
return collections.defaultdict(fn)
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def decode(probs, beam_size=10, blank=0):
"""
Performs inference for the given output probabilities.
Arguments:
probs: The output probabilities (e.g. log post-softmax) for each
time step. Should be an array of shape (time x output dim).
beam_size (int): Size of the beam to use during inference.
blank (int): Index of the CTC blank label.
Returns the output label sequence and the corresponding negative
log-likelihood estimated by the decoder.
"""
T, S = probs.shape
# Elements in the beam are (prefix, (p_blank, p_no_blank))
# Initialize the beam with the empty sequence, a probability of
# 1 for ending in blank and zero for ending in non-blank
# (in log space).
beam = [(tuple(), (0.0, NEG_INF))]
for t in range(T): # Loop over time
# A default dictionary to store the next step candidates.
next_beam = make_new_beam()
for s in range(S): # Loop over vocab
p = probs[t, s]
# The variables p_b and p_nb are respectively the
# probabilities for the prefix given that it ends in a
# blank and does not end in a blank at this time step.
for prefix, (p_b, p_nb) in beam: # Loop over beam
# If we propose a blank the prefix doesn't change.
# Only the probability of ending in blank gets updated.
if s == blank:
n_p_b, n_p_nb = next_beam[prefix]
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
continue
# Extend the prefix by the new character s and add it to
# the beam. Only the probability of not ending in blank
# gets updated.
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,)
n_p_b, n_p_nb = next_beam[n_prefix]
if s != end_t:
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
else:
# We don't include the previous probability of not ending
# in blank (p_nb) if s is repeated at the end. The CTC
# algorithm merges characters not separated by a blank.
n_p_nb = logsumexp(n_p_nb, p_b + p)
# *NB* this would be a good place to include an LM score.
next_beam[n_prefix] = (n_p_b, n_p_nb)
# If s is repeated at the end we also update the unchanged
# prefix. This is the merging case.
if s == end_t:
n_p_b, n_p_nb = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
# Sort and trim the beam before moving on to the
# next time-step.
beam = sorted(next_beam.items(),
key=lambda x : logsumexp(*x[1]),
reverse=True)
beam = beam[:beam_size]
best = beam[0]
return best[0], -logsumexp(*best[1])