Skip to content

Commit

Permalink
feat: 实现新的基于http协议API服务接口
Browse files Browse the repository at this point in the history
  • Loading branch information
nl8590687 committed Feb 9, 2022
1 parent deb0a57 commit 6b70ec7
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 0 deletions.
176 changes: 176 additions & 0 deletions asrserver_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================

"""
@author: nl8590687
ASRT语音识别基于HTTP协议的API服务器程序
"""

import base64
import json
from flask import Flask, Response, request

from speech_model import ModelSpeech
from speech_model_zoo import SpeechModel251
from speech_features import Spectrogram
from LanguageModel2 import ModelLanguage
from utils.ops import decode_wav_bytes

API_STATUS_CODE_OK = 200000 # OK
API_STATUS_CODE_CLIENT_ERROR = 400000
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400002 # 请求数据配置不支持
API_STATUS_CODE_SERVER_ERROR = 500000
API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错

app = Flask("ASRT API Service")

AUDIO_LENGTH = 1600
AUDIO_FEATURE_LENGTH = 200
CHANNELS = 1
# 默认输出的拼音的表示大小是1428,即1427个拼音+1个空白块
OUTPUT_SIZE = 1428
sm251 = SpeechModel251(
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
output_size=OUTPUT_SIZE
)
feat = Spectrogram()
ms = ModelSpeech(sm251, feat, max_label_length=64)
ms.load_model('save_models/' + sm251.get_model_name() + '.model.h5')

ml = ModelLanguage('model_language')
ml.LoadModel()


class AsrtApiResponse:
'''
ASRT语音识别基于HTTP协议的API接口响应类
'''
def __init__(self, status_code, status_message='', result=''):
self.status_code = status_code
self.status_message = status_message
self.result = result
def to_json(self):
'''
类转json
'''
return json.dumps(self, default=lambda o: o.__dict__,
sort_keys=True)

# api接口根url:GET
@app.route('/', methods=["GET"])
def index_get():
'''
根路径handle GET方法
'''
buffer = ''
with open('assets/default.html', 'r', encoding='utf-8') as file_handle:
buffer = file_handle.read()
return Response(buffer, mimetype='text/html; charset=utf-8')

# api接口根url:POST
@app.route('/', methods=["POST"])
def index_post():
'''
根路径handle POST方法
'''
json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'ok')
buffer = json_data.to_json()
return Response(buffer, mimetype='application/json')

# 获取分类列表
@app.route('/<level>', methods=["POST"])
def recognition_post(level):
'''
其他路径 POST方法
'''
#读取json文件内容
try:
if level == 'speech':
request_data = request.get_json()
samples = request_data['samples']
wavdata_bytes = base64.urlsafe_b64decode(bytes(samples,encoding='utf-8'))
sample_rate = request_data['sample_rate']
channels = request_data['channels']
byte_width = request_data['byte_width']

wavdata = decode_wav_bytes(samples_data=wavdata_bytes,
channels=channels, byte_width=byte_width)
result = ms.recognize_speech(wavdata, sample_rate)

json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'speech level')
json_data.result = result
buffer = json_data.to_json()
print('output:', buffer)
return Response(buffer, mimetype='application/json')
elif level == 'language':
request_data = request.get_json()

seq_pinyin = request_data['sequence_pinyin']

result = ml.SpeechToText(seq_pinyin)

json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'language level')
json_data.result = result
buffer = json_data.to_json()
print('output:', buffer)
return Response(buffer, mimetype='application/json')
elif level == 'all':
request_data = request.get_json()

samples = request_data['samples']
wavdata_bytes = base64.urlsafe_b64decode(samples)
sample_rate = request_data['sample_rate']
channels = request_data['channels']
byte_width = request_data['byte_width']

wavdata = decode_wav_bytes(samples_data=wavdata_bytes,
channels=channels, byte_width=byte_width)
result_speech = ms.recognize_speech(wavdata, sample_rate)
result = ml.SpeechToText(result_speech)

json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'all level')
json_data.result = result
buffer = json_data.to_json()
print('output:', buffer)
return Response(buffer, mimetype='application/json')
else:
request_data = request.get_json()
print('input:', request_data)
json_data = AsrtApiResponse(API_STATUS_CODE_CLIENT_ERROR, '')
buffer = json_data.to_json()
print('output:', buffer)
return Response(buffer, mimetype='application/json')
except Exception as except_general:
request_data = request.get_json()
#print(request_data['sample_rate'], request_data['channels'],
# request_data['byte_width'], len(request_data['samples']),
# request_data['samples'][-100:])
json_data = AsrtApiResponse(API_STATUS_CODE_SERVER_ERROR, str(except_general))
buffer = json_data.to_json()
return Response(buffer, mimetype='application/json')


if __name__ == '__main__':
# for development env
#app.run(host='0.0.0.0', port=20001)
# for production env
import waitress
waitress.serve(app, host='0.0.0.0', port=20001)
25 changes: 25 additions & 0 deletions assets/default.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>ASRT Speech Recognition API</title>
<style>
body {
width: 35em;
margin: 0 auto;
font-family: Tahoma, Verdana, Arial, sans-serif;
}
</style>
</head>
<body>
<h1>ASRT Speech Recognition API</h1>
<h3>framework version: 1.0</h3>
<p>If you see this page, the ASRT api server is successfully installed and working. </p>
<p>For online documentation and support please refer to <a href="https://asrt.ailemon.net">ASRT Project Page</a>.</p>
<p>Please call this web api by post menthod. </p>
<em>Thank you for using ASRT.</em>
<footer>
<p style="text-align: center;">Copyright © <a href="https://www.ailemon.net">ailemon.net</a></p>
</footer>
</body>
</html>
49 changes: 49 additions & 0 deletions client_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================

'''
@author: nl8590687
ASRT语音识别asrserver http协议测试专用客户端
'''
import base64
import json
import time
import requests
from utils.ops import read_wav_bytes

URL = 'http://127.0.0.1:20001/all'

wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('Y:\\SpeechData\\语音数据集\\data_thchs30\\train\\A11_0.wav')
datas = {
'channels': channels,
'sample_rate': sample_rate,
'byte_width': sample_width,
'samples': str(base64.urlsafe_b64encode(wav_bytes), encoding='utf-8')
}
headers = {'Content-Type': 'application/json'}

t0=time.time()
r = requests.post(URL, headers=headers, data=json.dumps(datas))
t1=time.time()
r.encoding='utf-8'

result = json.loads(r.text)
print(result)
print('time:', t1-t0, 's')
68 changes: 68 additions & 0 deletions utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def read_wav_data(filename: str) -> tuple:
wave_data = wave_data.T # 将矩阵转置
return wave_data, framerate, num_channel, num_sample_width

def read_wav_bytes(filename: str) -> tuple:
'''
读取一个wav文件,返回声音信号的时域谱矩阵和播放时间
'''
wav = wave.open(filename,"rb") # 打开一个wav格式的声音文件流
num_frame = wav.getnframes() # 获取帧数
num_channel=wav.getnchannels() # 获取声道数
framerate=wav.getframerate() # 获取帧速率
num_sample_width=wav.getsampwidth() # 获取实例的比特宽度,即每一帧的字节数
str_data = wav.readframes(num_frame) # 读取全部的帧
wav.close() # 关闭流
return str_data, framerate, num_channel, num_sample_width

def get_edit_distance(str1, str2) -> int:
'''
Expand Down Expand Up @@ -89,3 +101,59 @@ def visual_2D(img):
plt.imshow(img)
plt.colorbar(cax=None, ax=None, shrink=0.5)
plt.show()

def decode_wav_bytes(samples_data: bytes, channels: int = 1, byte_width: int = 2) -> list:
'''
解码wav格式样本点字节流,得到numpy数组
'''
numpy_type = np.short
if byte_width == 4:
numpy_type = np.int
elif byte_width != 2:
raise Exception('error: unsurpport byte width `' + str(byte_width) + '`')
wave_data = np.fromstring(samples_data, dtype = numpy_type) # 将声音文件数据转换为数组矩阵形式
wave_data.shape = -1, channels # 按照声道数将数组整形,单声道时候是一列数组,双声道时候是两列的矩阵
wave_data = wave_data.T # 将矩阵转置
return wave_data

def get_symbol_dict(dict_filename):
'''
读取拼音汉字的字典文件
返回读取后的字典
'''
txt_obj = open(dict_filename, 'r', encoding='UTF-8') # 打开文件并读入
txt_text = txt_obj.read()
txt_obj.close()
txt_lines = txt_text.split('\n') # 文本分割

dic_symbol = {} # 初始化符号字典
for i in txt_lines:
list_symbol=[] # 初始化符号列表
if i!='':
txt_l=i.split('\t')
pinyin = txt_l[0]
for word in txt_l[1]:
list_symbol.append(word)
dic_symbol[pinyin] = list_symbol

return dic_symbol

def get_language_model(model_language_filename):
'''
读取语言模型的文件
返回读取后的模型
'''
txt_obj = open(model_language_filename, 'r', encoding='UTF-8') # 打开文件并读入
txt_text = txt_obj.read()
txt_obj.close()
txt_lines = txt_text.split('\n') # 文本分割

dic_model = {} # 初始化符号字典
for i in txt_lines:
if i!='':
txt_l=i.split('\t')
if len(txt_l) == 1:
continue
dic_model[txt_l[0]] = txt_l[1]

return dic_model

0 comments on commit 6b70ec7

Please sign in to comment.