forked from alumae/kiirkirjutaja
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
120 lines (94 loc) · 4.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import sys
import argparse
import logging
message_format = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(format=message_format, stream=sys.stderr, level=logging.INFO)
import argparse
import re
import ray
import torch
# Needed for loading the speaker change detection model
from pytorch_lightning.utilities import argparse_utils
setattr(argparse_utils, "_gpus_arg_default", lambda x: 0)
from vad import SpeechSegmentGenerator
from turn import TurnGenerator
from asr import TurnDecoder
from lid import LanguageFilter
from online_scd.model import SCDModel
import vosk
from unk_decoder import UnkDecoder
from compound import CompoundReconstructor
from words2numbers import Words2Numbers
from punctuate import Punctuate
from confidence import confidence_filter
from presenters import *
import utils
import gc
import tracemalloc
#date_strftime_format = "%y-%b-%d %H:%M:%S"
ray.init(num_cpus=4)
RemotePunctuate = ray.remote(Punctuate)
RemoteWords2Numbers = ray.remote(Words2Numbers)
unk_decoder = UnkDecoder()
compound_reconstructor = CompoundReconstructor()
remote_words2numbers = RemoteWords2Numbers.remote()
remote_punctuate = RemotePunctuate.remote("models/punctuator/checkpoints/best.ckpt", "models/punctuator/tokenizer.json")
def process_result(result):
result = unk_decoder.post_process(result)
text = ""
if "result" in result:
text = " ".join([wi["word"] for wi in result["result"]])
text = compound_reconstructor.post_process(text)
text = ray.get(remote_words2numbers.post_process.remote(text))
text = ray.get(remote_punctuate.post_process.remote(text))
result = utils.reconstruct_full_result(result, text)
result = confidence_filter(result)
return result
else:
return result
def main(args):
if args.youtube_caption_url is not None:
presenter = YoutubeLivePresenter(captions_url=args.youtube_caption_url)
elif args.fab_speechinterface_url is not None:
presenter = FabLiveWordByWordPresenter(fab_speech_iterface_url=args.fab_speechinterface_url)
elif args.fab_bcast_url is not None:
presenter = FabBcastWordByWordPresenter(fab_bcast_url=args.fab_bcast_url)
elif args.zoom_caption_url is not None:
presenter = ZoomPresenter(captions_url=args.zoom_caption_url)
else:
presenter = WordByWordPresenter(args.word_output_file)
#presenter = TerminalPresenter()
scd_model = SCDModel.load_from_checkpoint("models/online-speaker-change-detector/checkpoints/epoch=102.ckpt")
vosk_model = vosk.Model("models/asr_model")
speech_segment_generator = SpeechSegmentGenerator(args.input_file)
language_filter = LanguageFilter()
def main_loop():
for speech_segment in speech_segment_generator.speech_segments():
presenter.segment_start()
speech_segment_start_time = speech_segment.start_sample / 16000
turn_generator = TurnGenerator(scd_model, speech_segment)
for i, turn in enumerate(turn_generator.turns()):
if i > 0:
presenter.new_turn()
turn_start_time = (speech_segment.start_sample + turn.start_sample) / 16000
turn_decoder = TurnDecoder(vosk_model, language_filter.filter(turn.chunks()))
for res in turn_decoder.decode_results():
if "result" in res:
processed_res = process_result(res)
if res["final"]:
presenter.final_result(processed_res["result"])
else:
presenter.partial_result(processed_res["result"])
presenter.segment_end()
gc.collect()
main_loop()
if __name__ == '__main__':
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--youtube-caption-url', type=str)
parser.add_argument('--fab-speechinterface-url', type=str)
parser.add_argument('--fab-bcast-url', type=str)
parser.add_argument('--zoom-caption-url', type=str)
parser.add_argument('--word-output-file', type=argparse.FileType('w'), default=sys.stdout)
parser.add_argument('input_file')
args = parser.parse_args()
main(args)