From 6b70ec7836655903c19e7fb8825bbc95e76b5c1f Mon Sep 17 00:00:00 2001 From: nl <3210346136@qq.com> Date: Wed, 9 Feb 2022 18:52:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E6=96=B0=E7=9A=84?= =?UTF-8?q?=E5=9F=BA=E4=BA=8Ehttp=E5=8D=8F=E8=AE=AEAPI=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- asrserver_http.py | 176 ++++++++++++++++++++++++++++++++++++++++++++ assets/default.html | 25 +++++++ client_http.py | 49 ++++++++++++ utils/ops.py | 68 +++++++++++++++++ 4 files changed, 318 insertions(+) create mode 100644 asrserver_http.py create mode 100644 assets/default.html create mode 100644 client_http.py diff --git a/asrserver_http.py b/asrserver_http.py new file mode 100644 index 0000000..a5761f5 --- /dev/null +++ b/asrserver_http.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2016-2099 Ailemon.net +# +# This file is part of ASRT Speech Recognition Tool. +# +# ASRT is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# ASRT is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with ASRT. If not, see . +# ============================================================================ + +""" +@author: nl8590687 +ASRT语音识别基于HTTP协议的API服务器程序 +""" + +import base64 +import json +from flask import Flask, Response, request + +from speech_model import ModelSpeech +from speech_model_zoo import SpeechModel251 +from speech_features import Spectrogram +from LanguageModel2 import ModelLanguage +from utils.ops import decode_wav_bytes + +API_STATUS_CODE_OK = 200000 # OK +API_STATUS_CODE_CLIENT_ERROR = 400000 +API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误 +API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400002 # 请求数据配置不支持 +API_STATUS_CODE_SERVER_ERROR = 500000 +API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错 + +app = Flask("ASRT API Service") + +AUDIO_LENGTH = 1600 +AUDIO_FEATURE_LENGTH = 200 +CHANNELS = 1 +# 默认输出的拼音的表示大小是1428,即1427个拼音+1个空白块 +OUTPUT_SIZE = 1428 +sm251 = SpeechModel251( + input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS), + output_size=OUTPUT_SIZE + ) +feat = Spectrogram() +ms = ModelSpeech(sm251, feat, max_label_length=64) +ms.load_model('save_models/' + sm251.get_model_name() + '.model.h5') + +ml = ModelLanguage('model_language') +ml.LoadModel() + + +class AsrtApiResponse: + ''' + ASRT语音识别基于HTTP协议的API接口响应类 + ''' + def __init__(self, status_code, status_message='', result=''): + self.status_code = status_code + self.status_message = status_message + self.result = result + def to_json(self): + ''' + 类转json + ''' + return json.dumps(self, default=lambda o: o.__dict__, + sort_keys=True) + +# api接口根url:GET +@app.route('/', methods=["GET"]) +def index_get(): + ''' + 根路径handle GET方法 + ''' + buffer = '' + with open('assets/default.html', 'r', encoding='utf-8') as file_handle: + buffer = file_handle.read() + return Response(buffer, mimetype='text/html; charset=utf-8') + +# api接口根url:POST +@app.route('/', methods=["POST"]) +def index_post(): + ''' + 根路径handle POST方法 + ''' + json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'ok') + buffer = json_data.to_json() + return Response(buffer, mimetype='application/json') + +# 获取分类列表 +@app.route('/', methods=["POST"]) +def recognition_post(level): + ''' + 其他路径 POST方法 + ''' + #读取json文件内容 + try: + if level == 'speech': + request_data = request.get_json() + samples = request_data['samples'] + wavdata_bytes = base64.urlsafe_b64decode(bytes(samples,encoding='utf-8')) + sample_rate = request_data['sample_rate'] + channels = request_data['channels'] + byte_width = request_data['byte_width'] + + wavdata = decode_wav_bytes(samples_data=wavdata_bytes, + channels=channels, byte_width=byte_width) + result = ms.recognize_speech(wavdata, sample_rate) + + json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'speech level') + json_data.result = result + buffer = json_data.to_json() + print('output:', buffer) + return Response(buffer, mimetype='application/json') + elif level == 'language': + request_data = request.get_json() + + seq_pinyin = request_data['sequence_pinyin'] + + result = ml.SpeechToText(seq_pinyin) + + json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'language level') + json_data.result = result + buffer = json_data.to_json() + print('output:', buffer) + return Response(buffer, mimetype='application/json') + elif level == 'all': + request_data = request.get_json() + + samples = request_data['samples'] + wavdata_bytes = base64.urlsafe_b64decode(samples) + sample_rate = request_data['sample_rate'] + channels = request_data['channels'] + byte_width = request_data['byte_width'] + + wavdata = decode_wav_bytes(samples_data=wavdata_bytes, + channels=channels, byte_width=byte_width) + result_speech = ms.recognize_speech(wavdata, sample_rate) + result = ml.SpeechToText(result_speech) + + json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'all level') + json_data.result = result + buffer = json_data.to_json() + print('output:', buffer) + return Response(buffer, mimetype='application/json') + else: + request_data = request.get_json() + print('input:', request_data) + json_data = AsrtApiResponse(API_STATUS_CODE_CLIENT_ERROR, '') + buffer = json_data.to_json() + print('output:', buffer) + return Response(buffer, mimetype='application/json') + except Exception as except_general: + request_data = request.get_json() + #print(request_data['sample_rate'], request_data['channels'], + # request_data['byte_width'], len(request_data['samples']), + # request_data['samples'][-100:]) + json_data = AsrtApiResponse(API_STATUS_CODE_SERVER_ERROR, str(except_general)) + buffer = json_data.to_json() + return Response(buffer, mimetype='application/json') + + +if __name__ == '__main__': + # for development env + #app.run(host='0.0.0.0', port=20001) + # for production env + import waitress + waitress.serve(app, host='0.0.0.0', port=20001) diff --git a/assets/default.html b/assets/default.html new file mode 100644 index 0000000..a13d213 --- /dev/null +++ b/assets/default.html @@ -0,0 +1,25 @@ + + + + + ASRT Speech Recognition API + + + +

ASRT Speech Recognition API

+

framework version: 1.0

+

If you see this page, the ASRT api server is successfully installed and working.

+

For online documentation and support please refer to ASRT Project Page.

+

Please call this web api by post menthod.

+ Thank you for using ASRT. + + + \ No newline at end of file diff --git a/client_http.py b/client_http.py new file mode 100644 index 0000000..2396d35 --- /dev/null +++ b/client_http.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2016-2099 Ailemon.net +# +# This file is part of ASRT Speech Recognition Tool. +# +# ASRT is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# ASRT is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with ASRT. If not, see . +# ============================================================================ + +''' +@author: nl8590687 +ASRT语音识别asrserver http协议测试专用客户端 +''' +import base64 +import json +import time +import requests +from utils.ops import read_wav_bytes + +URL = 'http://127.0.0.1:20001/all' + +wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('Y:\\SpeechData\\语音数据集\\data_thchs30\\train\\A11_0.wav') +datas = { + 'channels': channels, + 'sample_rate': sample_rate, + 'byte_width': sample_width, + 'samples': str(base64.urlsafe_b64encode(wav_bytes), encoding='utf-8') +} +headers = {'Content-Type': 'application/json'} + +t0=time.time() +r = requests.post(URL, headers=headers, data=json.dumps(datas)) +t1=time.time() +r.encoding='utf-8' + +result = json.loads(r.text) +print(result) +print('time:', t1-t0, 's') diff --git a/utils/ops.py b/utils/ops.py index 23ea543..b579f19 100644 --- a/utils/ops.py +++ b/utils/ops.py @@ -44,6 +44,18 @@ def read_wav_data(filename: str) -> tuple: wave_data = wave_data.T # 将矩阵转置 return wave_data, framerate, num_channel, num_sample_width +def read_wav_bytes(filename: str) -> tuple: + ''' + 读取一个wav文件,返回声音信号的时域谱矩阵和播放时间 + ''' + wav = wave.open(filename,"rb") # 打开一个wav格式的声音文件流 + num_frame = wav.getnframes() # 获取帧数 + num_channel=wav.getnchannels() # 获取声道数 + framerate=wav.getframerate() # 获取帧速率 + num_sample_width=wav.getsampwidth() # 获取实例的比特宽度,即每一帧的字节数 + str_data = wav.readframes(num_frame) # 读取全部的帧 + wav.close() # 关闭流 + return str_data, framerate, num_channel, num_sample_width def get_edit_distance(str1, str2) -> int: ''' @@ -89,3 +101,59 @@ def visual_2D(img): plt.imshow(img) plt.colorbar(cax=None, ax=None, shrink=0.5) plt.show() + +def decode_wav_bytes(samples_data: bytes, channels: int = 1, byte_width: int = 2) -> list: + ''' + 解码wav格式样本点字节流,得到numpy数组 + ''' + numpy_type = np.short + if byte_width == 4: + numpy_type = np.int + elif byte_width != 2: + raise Exception('error: unsurpport byte width `' + str(byte_width) + '`') + wave_data = np.fromstring(samples_data, dtype = numpy_type) # 将声音文件数据转换为数组矩阵形式 + wave_data.shape = -1, channels # 按照声道数将数组整形,单声道时候是一列数组,双声道时候是两列的矩阵 + wave_data = wave_data.T # 将矩阵转置 + return wave_data + +def get_symbol_dict(dict_filename): + ''' + 读取拼音汉字的字典文件 + 返回读取后的字典 + ''' + txt_obj = open(dict_filename, 'r', encoding='UTF-8') # 打开文件并读入 + txt_text = txt_obj.read() + txt_obj.close() + txt_lines = txt_text.split('\n') # 文本分割 + + dic_symbol = {} # 初始化符号字典 + for i in txt_lines: + list_symbol=[] # 初始化符号列表 + if i!='': + txt_l=i.split('\t') + pinyin = txt_l[0] + for word in txt_l[1]: + list_symbol.append(word) + dic_symbol[pinyin] = list_symbol + + return dic_symbol + +def get_language_model(model_language_filename): + ''' + 读取语言模型的文件 + 返回读取后的模型 + ''' + txt_obj = open(model_language_filename, 'r', encoding='UTF-8') # 打开文件并读入 + txt_text = txt_obj.read() + txt_obj.close() + txt_lines = txt_text.split('\n') # 文本分割 + + dic_model = {} # 初始化符号字典 + for i in txt_lines: + if i!='': + txt_l=i.split('\t') + if len(txt_l) == 1: + continue + dic_model[txt_l[0]] = txt_l[1] + + return dic_model