From 71d3b623873552497253665c079df69e93ce1b41 Mon Sep 17 00:00:00 2001 From: Kwangsoo Yeo Date: Wed, 22 Nov 2023 16:04:11 -0800 Subject: [PATCH] v2.0 python (#256) --- .github/workflows/python-codestyle.yml | 4 +- .github/workflows/python-demos.yml | 16 ++- .github/workflows/python-perf.yml | 36 +++--- .github/workflows/python.yml | 4 +- binding/python/_cheetah.py | 81 +++++++++++-- binding/python/setup.py | 39 +++--- binding/python/test_cheetah.py | 161 +++++++++++++++++++++---- binding/python/test_util.py | 72 +++++++++++ demo/python/requirements.txt | 2 +- demo/python/setup.py | 4 +- 10 files changed, 338 insertions(+), 81 deletions(-) create mode 100644 binding/python/test_util.py diff --git a/.github/workflows/python-codestyle.yml b/.github/workflows/python-codestyle.yml index 09a056f6..56068ef0 100644 --- a/.github/workflows/python-codestyle.yml +++ b/.github/workflows/python-codestyle.yml @@ -3,12 +3,12 @@ name: Python Codestyle on: workflow_dispatch: push: - branches: [master] + branches: [ master ] paths: - 'binding/python/*.py' - 'demo/python/*.py' pull_request: - branches: [master] + branches: [ master, 'v[0-9]+.[0-9]+' ] paths: - 'binding/python/*.py' - 'demo/python/*.py' diff --git a/.github/workflows/python-demos.yml b/.github/workflows/python-demos.yml index a3969050..7d69c409 100644 --- a/.github/workflows/python-demos.yml +++ b/.github/workflows/python-demos.yml @@ -3,13 +3,13 @@ name: Python Demos on: workflow_dispatch: push: - branches: [master] + branches: [ master ] paths: - '.github/workflows/python-demos.yml' - 'demo/python/**' - '!demo/python/README.md' pull_request: - branches: [master] + branches: [ master, 'v[0-9]+.[0-9]+' ] paths: - '.github/workflows/python-demos.yml' - 'demo/python/**' @@ -39,6 +39,12 @@ jobs: - name: Pre-build dependencies run: python -m pip install --upgrade pip + # ************** REMOVE AFTER RELEASE ******************** + - name: Build binding + run: | + pip install wheel && cd ../../binding/python && python setup.py sdist bdist_wheel && pip install dist/pvcheetah-2.0.0-py3-none-any.whl + # ******************************************************** + - name: Install dependencies run: pip install -r requirements.txt @@ -55,6 +61,12 @@ jobs: steps: - uses: actions/checkout@v3 + # ************** REMOVE AFTER RELEASE ******************** + - name: Build binding + run: | + pip3 install wheel && cd ../../binding/python && python3 setup.py sdist bdist_wheel && pip3 install dist/pvcheetah-2.0.0-py3-none-any.whl + # ******************************************************** + - name: Install dependencies run: pip3 install -r requirements.txt diff --git a/.github/workflows/python-perf.yml b/.github/workflows/python-perf.yml index 0013ef12..154bd74f 100644 --- a/.github/workflows/python-perf.yml +++ b/.github/workflows/python-perf.yml @@ -3,7 +3,7 @@ name: Python performance on: workflow_dispatch: push: - branches: [master] + branches: [ master ] paths: - '.github/workflows/python-perf.yml' - 'binding/python/test_cheetah_perf.py' @@ -14,7 +14,7 @@ on: - 'lib/raspberry-pi/**' - 'lib/windows/**' pull_request: - branches: [master] + branches: [ master, 'v[0-9]+.[0-9]+' ] paths: - '.github/workflows/python-perf.yml' - 'binding/python/test_cheetah_perf.py' @@ -39,14 +39,14 @@ jobs: os: [ubuntu-latest, windows-latest, macos-latest] include: - os: ubuntu-latest - init_performance_threshold_sec: 2.1 - proc_performance_threshold_sec: 0.5 + init_performance_threshold_sec: 4.0 + proc_performance_threshold_sec: 0.8 - os: windows-latest - init_performance_threshold_sec: 2.4 - proc_performance_threshold_sec: 0.6 + init_performance_threshold_sec: 4.0 + proc_performance_threshold_sec: 0.7 - os: macos-latest - init_performance_threshold_sec: 3.0 - proc_performance_threshold_sec: 0.8 + init_performance_threshold_sec: 4.5 + proc_performance_threshold_sec: 2.5 steps: - uses: actions/checkout@v3 @@ -74,20 +74,20 @@ jobs: machine: [rpi3-32, rpi3-64, rpi4-32, rpi4-64, jetson] include: - machine: rpi3-32 - init_performance_threshold_sec: 7.5 - proc_performance_threshold_sec: 3.6 + init_performance_threshold_sec: 9.0 + proc_performance_threshold_sec: 9.0 - machine: rpi3-64 - init_performance_threshold_sec: 8.3 - proc_performance_threshold_sec: 3.5 + init_performance_threshold_sec: 9.0 + proc_performance_threshold_sec: 7.5 - machine: rpi4-32 - init_performance_threshold_sec: 5.7 - proc_performance_threshold_sec: 2.0 + init_performance_threshold_sec: 7.5 + proc_performance_threshold_sec: 4.0 - machine: rpi4-64 - init_performance_threshold_sec: 5.1 - proc_performance_threshold_sec: 1.9 + init_performance_threshold_sec: 7.5 + proc_performance_threshold_sec: 4.0 - machine: jetson - init_performance_threshold_sec: 5.2 - proc_performance_threshold_sec: 1.9 + init_performance_threshold_sec: 7.5 + proc_performance_threshold_sec: 4.0 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 5bc8fcbc..b9742bc2 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -3,7 +3,7 @@ name: Python on: workflow_dispatch: push: - branches: [master] + branches: [ master ] paths: - '.github/workflows/python.yml' - 'binding/python/**' @@ -15,7 +15,7 @@ on: - 'lib/raspberry-pi/**' - 'lib/windows/**' pull_request: - branches: [master] + branches: [ master, 'v[0-9]+.[0-9]+' ] paths: - '.github/workflows/python.yml' - 'binding/python/**' diff --git a/binding/python/_cheetah.py b/binding/python/_cheetah.py index 219b5f5f..c37911c4 100644 --- a/binding/python/_cheetah.py +++ b/binding/python/_cheetah.py @@ -16,7 +16,27 @@ class CheetahError(Exception): - pass + def __init__(self, message: str = '', message_stack: Sequence[str] = None): + super().__init__(message) + + self._message = message + self._message_stack = list() if message_stack is None else message_stack + + def __str__(self): + message = self._message + if len(self._message_stack) > 0: + message += ':' + for i in range(len(self._message_stack)): + message += '\n [%d] %s' % (i, self._message_stack[i]) + return message + + @property + def message(self) -> str: + return self._message + + @property + def message_stack(self) -> Sequence[str]: + return self._message_stack class CheetahMemoryError(CheetahError): @@ -122,14 +142,28 @@ def __init__( if not os.path.exists(library_path): raise CheetahIOError("Could not find Cheetah's dynamic library at `%s`." % library_path) - library = cdll.LoadLibrary(library_path) - if not os.path.exists(model_path): raise CheetahIOError("Could not find model file at `%s`." % model_path) if endpoint_duration_sec is not None and not endpoint_duration_sec > 0.: raise CheetahInvalidArgumentError("`endpoint_duration_sec` must be either `None` or a positive number") + library = cdll.LoadLibrary(library_path) + + set_sdk_func = library.pv_set_sdk + set_sdk_func.argtypes = [c_char_p] + set_sdk_func.restype = None + + set_sdk_func('python'.encode('utf-8')) + + self._get_error_stack_func = library.pv_get_error_stack + self._get_error_stack_func.argtypes = [POINTER(POINTER(c_char_p)), POINTER(c_int)] + self._get_error_stack_func.restype = self.PicovoiceStatuses + + self._free_error_stack_func = library.pv_free_error_stack + self._free_error_stack_func.argtypes = [POINTER(c_char_p)] + self._free_error_stack_func.restype = None + init_func = library.pv_cheetah_init init_func.argtypes = [c_char_p, c_char_p, c_float, c_bool, POINTER(POINTER(self.CCheetah))] init_func.restype = self.PicovoiceStatuses @@ -143,7 +177,9 @@ def __init__( enable_automatic_punctuation, byref(self._handle)) if status is not self.PicovoiceStatuses.SUCCESS: - raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]() + raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]( + message='Initialization failed', + message_stack=self._get_error_stack()) self._delete_func = library.pv_cheetah_delete self._delete_func.argtypes = [POINTER(self.CCheetah)] @@ -158,6 +194,10 @@ def __init__( self._flush_func.argtypes = [POINTER(self.CCheetah), POINTER(c_char_p)] self._flush_func.restype = self.PicovoiceStatuses + self._transcript_delete_func = library.pv_cheetah_transcript_delete + self._transcript_delete_func.argtypes = [c_char_p] + self._transcript_delete_func.restype = None + version_func = library.pv_cheetah_version version_func.argtypes = [] version_func.restype = c_char_p @@ -192,9 +232,14 @@ def process(self, pcm: Sequence[int]) -> Tuple[str, bool]: byref(c_partial_transcript), byref(is_endpoint)) if status is not self.PicovoiceStatuses.SUCCESS: - raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]() + raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]( + message='Process failed', + message_stack=self._get_error_stack()) + + partial_transcript = c_partial_transcript.value.decode('utf-8') + self._transcript_delete_func(c_partial_transcript) - return c_partial_transcript.value.decode('utf-8'), is_endpoint.value + return partial_transcript, is_endpoint.value def flush(self) -> str: """ @@ -207,9 +252,14 @@ def flush(self) -> str: c_final_transcript = c_char_p() status = self._flush_func(self._handle, byref(c_final_transcript)) if status is not self.PicovoiceStatuses.SUCCESS: - raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]() + raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]( + message='Flush failed', + message_stack=self._get_error_stack()) + + final_transcript = c_final_transcript.value.decode('utf-8') + self._transcript_delete_func(c_final_transcript) - return c_final_transcript.value.decode('utf-8') + return final_transcript def delete(self) -> None: """Releases resources acquired by Cheetah.""" @@ -234,6 +284,21 @@ def frame_length(self) -> int: return self._frame_length + def _get_error_stack(self) -> Sequence[str]: + message_stack_ref = POINTER(c_char_p)() + message_stack_depth = c_int() + status = self._get_error_stack_func(byref(message_stack_ref), byref(message_stack_depth)) + if status is not self.PicovoiceStatuses.SUCCESS: + raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](message='Unable to get Porcupine error state') + + message_stack = list() + for i in range(message_stack_depth.value): + message_stack.append(message_stack_ref[i].decode('utf-8')) + + self._free_error_stack_func(message_stack_ref) + + return message_stack + __all__ = [ 'Cheetah', diff --git a/binding/python/setup.py b/binding/python/setup.py index 7a6ce9b7..af254377 100644 --- a/binding/python/setup.py +++ b/binding/python/setup.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Picovoice Inc. +# Copyright 2022-2023 Picovoice Inc. # # You may not use this file except in compliance with the license. A copy of the license is located in the "LICENSE" # file accompanying this source. @@ -14,44 +14,41 @@ import setuptools +INCLUDE_FILES = ('../../LICENSE', '__init__.py', '_factory.py', '_cheetah.py', '_util.py') +INCLUDE_LIBS = ('linux', 'mac', 'windows', 'jetson', 'raspberry-pi') + os.system('git clean -dfx') package_folder = os.path.join(os.path.dirname(__file__), 'pvcheetah') os.mkdir(package_folder) +manifest_in = "" -shutil.copy(os.path.join(os.path.dirname(__file__), '../../LICENSE'), package_folder) - -shutil.copy(os.path.join(os.path.dirname(__file__), '__init__.py'), os.path.join(package_folder, '__init__.py')) -shutil.copy(os.path.join(os.path.dirname(__file__), '_cheetah.py'), os.path.join(package_folder, '_cheetah.py')) -shutil.copy(os.path.join(os.path.dirname(__file__), '_factory.py'), os.path.join(package_folder, '_factory.py')) -shutil.copy(os.path.join(os.path.dirname(__file__), '_util.py'), os.path.join(package_folder, '_util.py')) +for rel_path in INCLUDE_FILES: + shutil.copy(os.path.join(os.path.dirname(__file__), rel_path), package_folder) + manifest_in += "include pvcheetah/%s\n" % os.path.basename(rel_path) -platforms = ('jetson', 'linux', 'mac', 'raspberry-pi', 'windows') +model_file = 'lib/common/cheetah_params.pv' +os.makedirs(os.path.join(package_folder, os.path.split(model_file)[0])) +shutil.copy( + os.path.join(os.path.dirname(__file__), '../..', model_file), + os.path.join(package_folder, model_file)) +manifest_in += "include pvcheetah/%s\n" % model_file -os.mkdir(os.path.join(package_folder, 'lib')) -for platform in ('common',) + platforms: +for platform in INCLUDE_LIBS: shutil.copytree( os.path.join(os.path.dirname(__file__), '../../lib', platform), os.path.join(package_folder, 'lib', platform)) - -MANIFEST_IN = """ -include pvcheetah/LICENSE -include pvcheetah/__init__.py -include pvcheetah/_cheetah.py -include pvcheetah/_factory.py -include pvcheetah/_util.py -recursive-include pvcheetah/lib/ * -""" + manifest_in += "recursive-include pvcheetah/lib/%s *\n" % platform with open(os.path.join(os.path.dirname(__file__), 'MANIFEST.in'), 'w') as f: - f.write(MANIFEST_IN.strip('\n ')) + f.write(manifest_in) with open(os.path.join(os.path.dirname(__file__), 'README.md'), 'r') as f: long_description = f.read() setuptools.setup( name="pvcheetah", - version="1.1.3", + version="2.0.0", author="Picovoice", author_email="hello@picovoice.ai", description="Cheetah Speech-to-Text Engine.", diff --git a/binding/python/test_cheetah.py b/binding/python/test_cheetah.py index 9ec2869f..40b42fef 100644 --- a/binding/python/test_cheetah.py +++ b/binding/python/test_cheetah.py @@ -1,5 +1,5 @@ # -# Copyright 2018-2022 Picovoice Inc. +# Copyright 2018-2023 Picovoice Inc. # # You may not use this file except in compliance with the license. A copy of the license is located in the "LICENSE" # file accompanying this source. @@ -17,51 +17,162 @@ from parameterized import parameterized -from _cheetah import Cheetah +from _cheetah import Cheetah, CheetahError from _util import * +from test_util import * -TEST_PARAMS = [ - [False, "Mr quilter is the apostle of the middle classes and we are glad to welcome his gospel"], - [True, "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."], -] + +parameters = load_test_data() class CheetahTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - with wave.open(os.path.join(os.path.dirname(__file__), '../../resources/audio_samples/test.wav'), 'rb') as f: - buffer = f.readframes(f.getnframes()) - cls.pcm = struct.unpack('%dh' % (len(buffer) / struct.calcsize('h')), buffer) + cls._access_key = sys.argv[1] + cls._audio_directory = os.path.join('..', '..', 'resources', 'audio_samples') - @staticmethod - def _create_cheetah(enable_automatic_punctuation: bool) -> Cheetah: + @classmethod + def _create_cheetah(cls, enable_automatic_punctuation: bool) -> Cheetah: return Cheetah( - access_key=sys.argv[1], + access_key=cls._access_key, model_path=default_model_path('../..'), library_path=default_library_path('../..'), enable_automatic_punctuation=enable_automatic_punctuation) - @parameterized.expand(TEST_PARAMS) - def test_transcribe(self, enable_automatic_punctuation: bool, ref: str): - o = self._create_cheetah(enable_automatic_punctuation) + @parameterized.expand(parameters) + def test_process( + self, + _: str, + audio_file: str, + expected_transcript: str, + punctuations: List[str], + error_rate: float): + o = None + + try: + o = self._create_cheetah(False) + + pcm = read_wav_file( + file_name=os.path.join(self._audio_directory, audio_file), + sample_rate=o.sample_rate) + + transcript = '' + num_frames = len(pcm) // o.frame_length + for i in range(num_frames): + frame = pcm[i * o.frame_length:(i + 1) * o.frame_length] + partial_transcript, _ = o.process(frame) + transcript += partial_transcript + + final_transcript = o.flush() + transcript += final_transcript + + normalized_transcript = expected_transcript + for punctuation in punctuations: + normalized_transcript = normalized_transcript.replace(punctuation, "") + + self.assertLessEqual( + get_word_error_rate(transcript, normalized_transcript), + error_rate) + finally: + if o is not None: + o.delete() + + @parameterized.expand(parameters) + def test_process_with_punctuation( + self, + _: str, + audio_file: str, + expected_transcript: str, + punctuations: List[str], + error_rate: float): + o = None + + try: + o = self._create_cheetah(True) + + pcm = read_wav_file( + file_name=os.path.join(self._audio_directory, audio_file), + sample_rate=o.sample_rate) - transcript = '' - num_frames = len(self.pcm) // o.frame_length - for i in range(num_frames): - frame = self.pcm[i * o.frame_length:(i + 1) * o.frame_length] - partial_transcript, _ = o.process(frame) - transcript += partial_transcript + transcript = '' + num_frames = len(pcm) // o.frame_length + for i in range(num_frames): + frame = pcm[i * o.frame_length:(i + 1) * o.frame_length] + partial_transcript, _ = o.process(frame) + transcript += partial_transcript - final_transcript = o.flush() - transcript += final_transcript - print(transcript) - self.assertEqual(transcript, ref) + final_transcript = o.flush() + transcript += final_transcript + + self.assertLessEqual( + get_word_error_rate(transcript, expected_transcript), + error_rate) + finally: + if o is not None: + o.delete() def test_version(self): o = self._create_cheetah(False) self.assertIsInstance(o.version, str) self.assertGreater(len(o.version), 0) + def test_message_stack(self): + relative = '../../' + + error = None + try: + c = Cheetah( + access_key='invalid', + library_path=default_library_path(relative), + model_path=default_model_path(relative), + enable_automatic_punctuation=True) + self.assertIsNone(c) + except CheetahError as e: + error = e.message_stack + + self.assertIsNotNone(error) + self.assertGreater(len(error), 0) + + try: + c = Cheetah( + access_key='invalid', + library_path=default_library_path(relative), + model_path=default_model_path(relative), + enable_automatic_punctuation=True) + self.assertIsNone(c) + except CheetahError as e: + self.assertEqual(len(error), len(e.message_stack)) + self.assertListEqual(list(error), list(e.message_stack)) + + def test_process_flush_message_stack(self): + relative = '../../' + + c = Cheetah( + access_key=sys.argv[1], + library_path=default_library_path(relative), + model_path=default_model_path(relative), + enable_automatic_punctuation=True) + test_pcm = [0] * c._frame_length + + address = c._handle + c._handle = None + + try: + res = c.process(test_pcm) + self.assertIsNone(res) + except CheetahError as e: + self.assertGreater(len(e.message_stack), 0) + self.assertLess(len(e.message_stack), 8) + + try: + res = c.flush() + self.assertIsNone(res) + except CheetahError as e: + self.assertGreater(len(e.message_stack), 0) + self.assertLess(len(e.message_stack), 8) + + c._handle = address + if __name__ == '__main__': if len(sys.argv) != 2: diff --git a/binding/python/test_util.py b/binding/python/test_util.py new file mode 100644 index 00000000..4ba1815e --- /dev/null +++ b/binding/python/test_util.py @@ -0,0 +1,72 @@ +# +# Copyright 2023 Picovoice Inc. +# +# You may not use this file except in compliance with the license. A copy of the license is located in the "LICENSE" +# file accompanying this source. +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# + +import struct +import wave +from typing import * + + +def load_test_data() -> List[Tuple[str, str, str, List[str], float]]: + parameters = [ + ( + "en", + "test.wav", + "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + ["."], + 0.025 + ) + ] + + return parameters + + +def read_wav_file(file_name: str, sample_rate: int) -> Tuple: + wav_file = wave.open(file_name, mode="rb") + channels = wav_file.getnchannels() + num_frames = wav_file.getnframes() + + if wav_file.getframerate() != sample_rate: + raise ValueError( + "Audio file should have a sample rate of %d, got %d" % (sample_rate, wav_file.getframerate())) + + samples = wav_file.readframes(num_frames) + wav_file.close() + + frames = struct.unpack('h' * num_frames * channels, samples) + + if channels == 2: + print("Picovoice processes single-channel audio but stereo file is provided. Processing left channel only.") + + return frames[::channels] + + +def get_word_error_rate(transcript: str, expected_transcript: str, use_cer: bool = False) -> float: + transcript_split = list(transcript) if use_cer else transcript.split() + expected_split = list(expected_transcript) if use_cer else expected_transcript.split() + return _levenshtein_distance(transcript_split, expected_split) / len(transcript) + + +def _levenshtein_distance(words1: Sequence[str], words2: Sequence[str]) -> int: + res = [[0] * (len(words1) + 2) for _ in range(len(words2) + 1)] + for i in range(len(words1) + 1): + res[i][0] = i + for j in range(len(words2) + 1): + res[0][j] = j + + for i in range(1, len(words1) + 1): + for j in range(1, len(words2) + 1): + res[i][j] = min( + res[i - 1][j] + 1, + res[i][j - 1] + 1, + res[i - 1][j - 1] + (0 if words1[i - 1].upper() == words2[j - 1].upper() else 1) + ) + + return res[len(words1)][len(words2)] diff --git a/demo/python/requirements.txt b/demo/python/requirements.txt index 5c874acf..280aec39 100644 --- a/demo/python/requirements.txt +++ b/demo/python/requirements.txt @@ -1,2 +1,2 @@ -pvcheetah==1.1.3 +pvcheetah==2.0.0 pvrecorder==1.2.1 diff --git a/demo/python/setup.py b/demo/python/setup.py index ed130af0..aa5c1f68 100644 --- a/demo/python/setup.py +++ b/demo/python/setup.py @@ -28,7 +28,7 @@ setuptools.setup( name="pvcheetahdemo", - version="1.1.6", + version="2.0.0", author="Picovoice", author_email="hello@picovoice.ai", description="Cheetah speech-to-text engine demos", @@ -36,7 +36,7 @@ long_description_content_type="text/markdown", url="https://github.com/Picovoice/cheetah", packages=["pvcheetahdemo"], - install_requires=["pvcheetah==1.1.3", "pvrecorder==1.2.1"], + install_requires=["pvcheetah==2.0.0", "pvrecorder==1.2.1"], include_package_data=True, classifiers=[ "Development Status :: 5 - Production/Stable",