This repository has been archived by the owner on Feb 11, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathapi-server.py
155 lines (123 loc) · 4.12 KB
/
api-server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from concurrent import futures
import time
import math
import logging
from pathlib import Path
import itertools as it
import grpc
import interfaces.libreasr_pb2 as ap
import interfaces.libreasr_pb2_grpc as apg
from libreasr.lib.inference import *
from libreasr.lib.utils import tensorize
WORKERS = 4
PORTS = {
"en": "[::]:50051",
"de": "[::]:50052",
"fr": "[::]:50053",
}
# streaming
# reset threshold
THRESH = 4000
BUFFER_N_FRAMES = 3
def log_print(*args, **kwargs):
print("[api-server]", *args, **kwargs)
def get_settings(conf):
downsample = None
n_buffer = None
for tfm in conf["transforms"]["stream"]:
if tfm["name"] == "StackDownsample":
downsample = tfm["args"]["downsample"]
if tfm["name"] == "Buffer":
n_buffer = tfm["args"]["n_buffer"]
return downsample, n_buffer
def should_reset(steps, downsample, n_buffer):
# one step length
steps = int(10.0 * downsample * n_buffer * steps)
if steps >= THRESH:
log_print("reset")
return True
return False
class ASRServicer(apg.ASRServicer):
def __init__(self, lang):
self.lang_name = lang
conf, lang, m, x_tfm, x_tfm_stream = load_stuff(lang)
self.conf = conf
self.downsample, self.n_buffer = get_settings(conf)
self.lang = lang
self.model = m
self.x_tfm = x_tfm
self.x_tfm_stream = x_tfm_stream
def Transcribe(self, request, context):
# tensorize
aud, sr = request.data, request.sr
aud = tensorize(aud)
# print
log_print(f"Transcribe(lang={self.lang_name}, sr={sr}, shape={aud.shape})")
# tfms
aud = AudioTensor(aud, sr)
aud = self.x_tfm(aud)[0]
# inference
out = self.model.transcribe(aud)
return ap.Transcript(data=out[0])
def TranscribeStream(self, request_iterator, context):
def stream():
started = False
frames = []
counter = 0
printed = False
for i, frame in enumerate(request_iterator):
# fill up frames
t = tensorize(frame.data)
frames.append(t)
counter += 1
# may continue?
if not len(frames) == BUFFER_N_FRAMES:
continue
# cat all frames
aud = torch.cat(frames, dim=1)
# clear first
del frames[0]
# convert to AudioTensor
aud = AudioTensor(aud, frame.sr)
# print
if not printed:
log_print(
f"TranscribeStream(lang={self.lang_name}, sr={frame.sr}, shape={aud.shape})"
)
printed = True
aud = self.x_tfm_stream(aud)
yield aud
# inference
outputs = self.model.transcribe_stream(stream(), self.lang.denumericalize)
last = ""
last_diff = ""
steps = 0
for i, (y, y_one, reset_fn) in enumerate(outputs):
steps += 1
if y_one != "":
now = self.lang.denumericalize(y)
diff = "".join(y for x, y in it.zip_longest(last, now) if x != y)
last = now
# bail if we just output the same thing twice
if diff == last_diff:
continue
last_diff = diff
yield ap.Transcript(data=diff)
elif should_reset(steps, self.downsample, self.n_buffer):
reset_fn()
steps = 0
def serve(lang):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=WORKERS))
apg.add_ASRServicer_to_server(ASRServicer(lang), server)
port = PORTS[lang]
server.add_insecure_port(port)
server.start()
log_print("gRPC server running on", port, "language", lang)
server.wait_for_termination()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("lang", help="language to serve")
args = parser.parse_args()
logging.basicConfig()
serve(args.lang)