-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathctcgreedydecoder.py
63 lines (47 loc) · 1.31 KB
/
ctcgreedydecoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright huaqiaoz
import numpy as np
import math,sys
alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
kPadSymbol = '#'
kAlphabet = alphabet + kPadSymbol
conf = 1.0
argmax = None
prob = None
cache = None
def softmax(begin, end, argmax, prob):
max = None
if begin >= end:
max = begin
else:
max = end
distance = None
if cache.index(max) >= cache.index(begin):
distance = cache.index(max) - cache.index(begin)
else:
distance = cache.index(max) - cache.index(begin)
max_val = max
sum = 0
i = begin
while i != end:
sum += np.exp(i-max_val)
i+=1
if math.fabs(sum) < sys.float_info.min:
print("sum can't be equal to zero")
prob = 1.0 / float(sum)
def CTCGreedyDecoder(data, alphabet, pad_symbol, conf):
cache = data
res = ""
prev_pad = False
conf = 1
num_classes = len(alphabet)
for index in range(len(data)):
softmax(data[index],index+num_classes,argmax,prob)
conf *= prob
symbol = alphabet[argmax]
if symbol != pad_symbol:
if res == "" or prev_pad or (res != "" and symbol != res[-1]):
prev_pad = False
res += symbol
else:
prev_pad = True
return res