-
Notifications
You must be signed in to change notification settings - Fork 3
/
speech_to_docx.py
169 lines (149 loc) · 7.01 KB
/
speech_to_docx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from argparse import ArgumentParser
import codecs
import logging
import os
import sys
import tempfile
from docx import Document
import numpy as np
from wav_io.wav_io import transform_to_wavpcm, load_sound
from wav_io.wav_io import TARGET_SAMPLING_FREQUENCY
from asr.asr import initialize_model_for_speech_recognition
from asr.asr import initialize_model_for_speech_classification
from asr.asr import initialize_model_for_speech_segmentation
from asr.asr import transcribe, check_language
from asr.asr import asr_logger
from utils.utils import time_to_str
speech_to_srt_logger = logging.getLogger(__name__)
def main():
parser = ArgumentParser()
parser.add_argument('--lang', dest='language', type=str, required=False, default='ru',
help='The language of input speech (Russian or English).')
parser.add_argument('-i', '--input', dest='input_name', type=str, required=True,
help='The input sound file name.')
parser.add_argument('-m', '--model', dest='model_dir', type=str, required=False, default=None,
help='The path to directory with Wav2Vec2, AudioTransformer and Whisper.')
parser.add_argument('-o', '--output', dest='output_name', type=str, required=True,
help='The output DocX file name.')
args = parser.parse_args()
language_name = check_language(args.language)
if args.model_dir is None:
wav2vec2_path = None
audiotransformer_path = None
whisper_path = None
else:
model_dir = os.path.normpath(args.model_dir)
if not os.path.isdir(model_dir):
err_msg = f'The directory "{model_dir}" does not exist!'
speech_to_srt_logger.error(err_msg)
raise IOError(err_msg)
wav2vec2_path = os.path.join(model_dir, 'wav2vec2')
if not os.path.isdir(wav2vec2_path):
err_msg = f'The directory "{wav2vec2_path}" does not exist!'
speech_to_srt_logger.error(err_msg)
raise IOError(err_msg)
audiotransformer_path = os.path.join(model_dir, 'ast')
if not os.path.isdir(audiotransformer_path):
err_msg = f'The directory "{audiotransformer_path}" does not exist!'
speech_to_srt_logger.error(err_msg)
raise IOError(err_msg)
whisper_path = os.path.join(model_dir, 'whisper')
if not os.path.isdir(whisper_path):
err_msg = f'The directory "{whisper_path}" does not exist!'
speech_to_srt_logger.error(err_msg)
raise IOError(err_msg)
audio_fname = os.path.normpath(args.input_name)
if not os.path.isfile(audio_fname):
err_msg = f'The file "{audio_fname}" does not exist!'
speech_to_srt_logger.error(err_msg)
raise IOError(err_msg)
output_docx_fname = os.path.normpath(args.output_name)
output_docx_dir = os.path.dirname(output_docx_fname)
if len(output_docx_dir) > 0:
if not os.path.isdir(output_docx_dir):
err_msg = f'The directory "{output_docx_dir}" does not exist!'
speech_to_srt_logger.error(err_msg)
raise IOError(err_msg)
if len(os.path.basename(output_docx_fname).strip()) == 0:
err_msg = f'The file name "{output_docx_fname}" is incorrect!'
speech_to_srt_logger.error(err_msg)
raise IOError(err_msg)
if os.path.basename(output_docx_fname) == os.path.basename(audio_fname):
err_msg = f'The input audio and the output DocX file have a same names! ' \
f'{os.path.basename(audio_fname)} = {os.path.basename(output_docx_fname)}'
speech_to_srt_logger.error(err_msg)
raise IOError(err_msg)
tmp_wav_name = ''
try:
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.wav') as fp:
tmp_wav_name = fp.name
try:
transform_to_wavpcm(audio_fname, tmp_wav_name)
except BaseException as ex:
err_msg = str(ex)
speech_to_srt_logger.error(err_msg)
raise
speech_to_srt_logger.info(f'The sound "{audio_fname}" is converted to the "{tmp_wav_name}".')
try:
input_sound = load_sound(tmp_wav_name)
except BaseException as ex:
err_msg = str(ex)
speech_to_srt_logger.error(err_msg)
raise
speech_to_srt_logger.info(f'The sound is "{tmp_wav_name}" is loaded.')
finally:
if os.path.isfile(tmp_wav_name):
os.remove(tmp_wav_name)
speech_to_srt_logger.info(f'The sound is "{tmp_wav_name}" is removed.')
if input_sound is None:
speech_to_srt_logger.info(f'The sound "{audio_fname}" is empty.')
texts_with_timestamps = []
else:
if not isinstance(input_sound, np.ndarray):
speech_to_srt_logger.info(f'The sound "{audio_fname}" is stereo.')
input_sound = (input_sound[0] + input_sound[1]) / 2.0
speech_to_srt_logger.info(f'The total duration of the sound "{audio_fname}" is '
f'{time_to_str(input_sound.shape[0] / TARGET_SAMPLING_FREQUENCY)}.')
try:
segmenter = initialize_model_for_speech_segmentation(language_name, model_info=wav2vec2_path)
except BaseException as ex:
err_msg = str(ex)
speech_to_srt_logger.error(err_msg)
raise
speech_to_srt_logger.info('The Wav2Vec2-based segmenter is loaded.')
try:
vad = initialize_model_for_speech_classification(model_info=audiotransformer_path)
except BaseException as ex:
err_msg = str(ex)
speech_to_srt_logger.error(err_msg)
raise
speech_to_srt_logger.info('The AST-based voice activity detector is loaded.')
try:
asr = initialize_model_for_speech_recognition(language_name, model_info=whisper_path)
except BaseException as ex:
err_msg = str(ex)
speech_to_srt_logger.error(err_msg)
raise
speech_to_srt_logger.info('The Whisper-based ASR is initialized.')
texts_with_timestamps = transcribe(input_sound, segmenter, vad, asr, min_segment_size=1, max_segment_size=20)
doc = Document()
for start_time, end_time, sentence_text in texts_with_timestamps:
line = f'{time_to_str(start_time)} - {time_to_str(end_time)} - {sentence_text}'
doc.add_paragraph(line)
doc.add_paragraph('')
doc.save(output_docx_fname)
if __name__ == '__main__':
speech_to_srt_logger.setLevel(logging.INFO)
asr_logger.setLevel(logging.INFO)
fmt_str = '%(filename)s[LINE:%(lineno)d]# %(levelname)-8s ' \
'[%(asctime)s] %(message)s'
formatter = logging.Formatter(fmt_str)
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(formatter)
speech_to_srt_logger.addHandler(stdout_handler)
asr_logger.addHandler(stdout_handler)
file_handler = logging.FileHandler('speech_to_docx.log')
file_handler.setFormatter(formatter)
speech_to_srt_logger.addHandler(file_handler)
asr_logger.addHandler(file_handler)
main()