Skip to content

Commit

Permalink
feat: 添加基于http协议的新调用接口
Browse files Browse the repository at this point in the history
  • Loading branch information
nl8590687 committed Mar 16, 2022
1 parent 9c2e369 commit 797bb53
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 8 deletions.
38 changes: 35 additions & 3 deletions asrt_sdk/Recorder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool Python SDK.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================

"""
@author: nl8590687
ASRT语音识别Python SDK 录音功能库
"""

import wave
from pyaudio import PyAudio, paInt16
import struct
import threading

import numpy as np


class AudioRecorder():
Expand Down Expand Up @@ -70,11 +94,19 @@ def SaveAudioToFile(self, filename):
pass

def GetAudioStream(self):
return b"".join(self.__audio_buffers__)
bytesStream = b"".join(self.__audio_buffers__)
#print(bytesStream[-1000:])
#f=open('test0.bin','wb')
#f.write(bytesStream)
#f.close()
return bytesStream

def GetAudioSamples(self):
audio_bin_serials = self.GetAudioStream()
return self.__audio_stream_to_short__(audio_bin_serials)
wave_data = np.fromstring(audio_bin_serials, dtype = np.short) # 将声音文件数据转换为数组矩阵形式
#print(wave_data[-1000:])
return wave_data
#return self.__audio_stream_to_short__(audio_bin_serials)
pass

def __audio_stream_to_short__(self, audio_stream):
Expand Down
113 changes: 112 additions & 1 deletion asrt_sdk/SpeechRecognizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,122 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool Python SDK.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================

"""
@author: nl8590687
ASRT语音识别Python SDK 语音识别接口调用类库
"""

from .Recorder import AudioRecorder
import threading
import time
import wave
import requests
import numpy as np

from .Recorder import AudioRecorder
from .utils import *

def get_speech_recognizer(host:str, port:str, protocol:str):
'''
获取一个ASRT语音识别SDK接口调用实例对象 \\
参数:\\
host: 主机域名或IP.
port: 主机端口号.
protocol: 网络协议类型.
'''
if protocol.lower() == 'http' or protocol.lower() == 'https':
return HttpSpeechRecognizer(host, port, protocol)
return None


class BaseSpeechRecognizer():
'''
ASRT语音识别SDK接口调用类基类
'''
def __init__(self, host:str, port:str, protocol:str):
self.host = host
self.port = port
self.protocol = protocol

def recognite(self, wav_data, frame_rate, channels, byte_width):
raise Exception("Method Unimpletment")

def recognite_speech(self, wav_data, frame_rate, channels, byte_width):
raise Exception("Method Unimpletment")

def recognite_language(self, sequence_pinyin):
raise Exception("Method Unimpletment")

def recognite_file(self, filename):
wave_data = read_wav_datas(filename)
str_data = wave_data.str_data
frame_rate = wave_data.sample_rate
channels = wave_data.channels
byte_width = wave_data.byte_width
return self.recognite(wav_data=str_data,
frame_rate=frame_rate,
channels=channels,
byte_width=byte_width
)

class HttpSpeechRecognizer(BaseSpeechRecognizer):
'''
ASRT语音识别SDK基于HTTP协议接口调用类 \\
参数: \\
host: 主机域名或IP.
port: 主机端口号.
protocol: 网络协议类型.
sub_path: 接口所在URL的子路径, 默认为""
'''
def __init__(self, host:str, port:str, protocol:str, sub_path:str=''):
super().__init__(host, port, protocol)
if protocol != 'http' and protocol != 'https':
raise Exception('Unsupport netword protocol `' + protocol +'`')
self._url_ = protocol + '://' + host + ':' + port
self.sub_path = sub_path

def recognite(self, wav_data, frame_rate:int, channels:int, byte_width:int) -> AsrtApiResponse:
request_body = AsrtApiSpeechRequest(wav_data, frame_rate, channels, byte_width)
headers = {'Content-Type': 'application/json'}
response_object = requests.post(self._url_ + self.sub_path + '/all', headers=headers, data=request_body.to_json())
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body

def recognite_speech(self, wav_data, frame_rate, channels, byte_width):
request_body = AsrtApiSpeechRequest(wav_data, frame_rate, channels, byte_width)
headers = {'Content-Type': 'application/json'}
response_object = requests.post(self._url_ + self.sub_path + '/speech', headers=headers, data=request_body.to_json())
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body

def recognite_language(self, sequence_pinyin):
request_body = AsrtApiLanguageRequest(sequence_pinyin)
headers = {'Content-Type': 'application/json'}
response_object = requests.post(self._url_ + self.sub_path + '/language', headers=headers, data=request_body.to_json())
response_body_dict = json.loads(response_object.text)
response_body = AsrtApiResponse()
response_body.from_json(**response_body_dict)
return response_body

class SpeechRecognizer():
def __init__(self, url_server = 'http://127.0.0.1:20000/', token_client = 'qwertasd'):
Expand Down
2 changes: 2 additions & 0 deletions asrt_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@
from . import Recorder, SpeechRecognizer
from .Recorder import AudioRecorder
from .SpeechRecognizer import SpeechRecognizer
from .SpeechRecognizer import get_speech_recognizer, HttpSpeechRecognizer
from .utils import read_wav_datas

__version__ = '1.0.0'
160 changes: 160 additions & 0 deletions asrt_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool Python SDK.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================

"""
@author: nl8590687
ASRT语音识别Python SDK 基础库模块
"""

import base64
import json
import wave
import numpy as np

API_STATUS_CODE_OK = 200000 # OK
API_STATUS_CODE_CLIENT_ERROR = 400000
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400002 # 请求数据配置不支持
API_STATUS_CODE_SERVER_ERROR = 500000
API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错

class AsrtApiSpeechRequest:
'''
ASRT语音识别基于HTTP协议的API接口请求类(声学模型)
'''
def __init__(self, samples, sample_rate, channels, byte_width):
self.samples = str(base64.urlsafe_b64encode(samples), encoding='utf-8')
self.sample_rate = sample_rate
self.channels = channels
self.byte_width = byte_width

def to_json(self):
'''
类转json
'''
return json.dumps(self, default=lambda o: o.__dict__,
sort_keys=True)

def from_json(self, **entries):
'''
json转AsrtApiSpeechRequest
'''
self.__dict__.update(entries)

def __str__(self):
'''
AsrtApiSpeechRequest转为字符串
'''
return self.to_json()

class AsrtApiLanguageRequest:
'''
ASRT语音识别基于HTTP协议的API接口请求类(声学模型)
'''
def __init__(self, sequence_pinyin):
self.sequence_pinyin = sequence_pinyin

def to_json(self):
'''
类转json
'''
return json.dumps(self, default=lambda o: o.__dict__,
sort_keys=True)

def from_json(self, **entries):
'''
json转AsrtApiLanguageRequest
'''
self.__dict__.update(entries)

def __str__(self):
'''
AsrtApiLanguageRequest转为字符串
'''
return self.to_json()

class AsrtApiResponse:
'''
ASRT语音识别基于HTTP协议的API接口响应类
'''
def __init__(self, status_code=0, status_message='', result=''):
self.status_code = status_code
self.status_message = status_message
self.result = result
def to_json(self):
'''
类转json
'''
return json.dumps(self, default=lambda o: o.__dict__,
sort_keys=True)

def from_json(self, **entries):
'''
json转AsrtApiResponse
'''
self.__dict__.update(entries)

def __str__(self):
'''
AsrtApiResponse转为字符串
'''
return self.to_json()

class WaveData:
'''
WAVE格式音频数据类
'''
def __init__(self, str_data, frame_rate, channels, byte_width) -> None:
self.str_data = str_data
self.sample_rate = frame_rate
self.channels = channels
self.byte_width = byte_width
self.filename = ''

def get_samples(self):
'''
str_data转short数组
'''
# 将声音文件数据转换为数组矩阵形式
wave_data = np.fromstring(self.str_data, dtype = np.short)
# 按照声道数将数组整形,单声道时候是一列数组,双声道时候是两列的矩阵
wave_data.shape = -1, self.channels
# 将矩阵转置
wave_data = wave_data.T
return wave_data

def set_filename(self, filename):
'''
记录该wave文件名
'''
self.filename = filename

def read_wav_datas(filename):
'''
读取wave格式文件数据
'''
wav_file = wave.open(filename, 'rb')
num_frame = wav_file.getnframes()
str_data = wav_file.readframes(num_frame)
frame_rate = wav_file.getframerate()
channels = wav_file.getnchannels()
byte_width = wav_file.getsampwidth()
wav_file.close()
return WaveData(str_data, frame_rate, channels, byte_width)
Loading

0 comments on commit 797bb53

Please sign in to comment.