diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..5eaed6b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[report] +exclude_lines = + if __name__ == .__main__.: \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..45655fa --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..e846fb5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,27 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots/Logs** +If applicable, add screenshots or logs to help explain your problem. + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..d9baa67 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: 'request' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is, e.g. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..be73c8e --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,33 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Publish + +on: + release: + types: [released] + +jobs: + deploy: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine + + - name: Build and publish to PyPI + env: + TWINE_USERNAME: '__token__' + TWINE_PASSWORD: ${{ secrets.PYPI_ORG_TOKEN }} + run: | + make build + twine upload dist/* diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..7cd7c16 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,39 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Tests + +on: [ push, pull_request ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [ "3.10", "3.11" ] + + steps: + - uses: actions/checkout@v2 + + - name: Install portaudio + run: | + sudo apt-get update + sudo apt-get install portaudio19-dev + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: pip3 install -r requirements.txt -r requirements-dev.txt + + - name: Install package + run: python3 setup.py install + + - name: Lint + run: make lint + + - name: Unittest + run: make unittest diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..41158d0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,11 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). + +## [0.0.1] - 2024-10-14 + +### Added + +- Add speechmatics-flow client diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..f0db3b1 --- /dev/null +++ b/Makefile @@ -0,0 +1,25 @@ +SOURCES := speechmatics_flow/ tests/ setup.py +VERSION ?= $(shell cat VERSION) + +.PHONY: all +all: lint test + +.PHONY: lint +lint: + black --check --diff $(SOURCES) + ruff $(SOURCES) + +.PHONY: format +format: + black $(SOURCES) + +.PHONY: test +test: unittest + +.PHONY: unittest +unittest: + pytest -v tests + +.PHONY: build +build: + VERSION=$(VERSION) python setup.py sdist bdist_wheel diff --git a/README.md b/README.md index 72eed7f..7a881c9 100644 --- a/README.md +++ b/README.md @@ -1 +1,39 @@ -# speechmatics-flow \ No newline at end of file +# speechmatics-flow + +Python client library and CLI for Speechmatics' Flow Service API. + +## Getting started + +To install from PyPI: + +```bash +pip install speechmatics-flow +``` + +To install from source: + +```bash +git clone https://github.com/speechmatics/speechmatics-flow +cd speechmatics-flow && python setup.py install +``` + +Windows users may need to run the install command with an extra flag: + +```bash +python setup.py install --user +``` + +## Example command-line usage + +- Setting URLs for connecting to flow service. These values can be used in places of the --url flag: + +*Note: Requires access to microphone + + ```bash + speechmatics-flow --url $URL --auth-token $TOKEN - + ``` + +## Support + +If you have any issues with this library or encounter any bugs then please get in touch with us at +support@speechmatics.com. diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..8acdd82 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.0.1 diff --git a/examples/stream_from_microphone.py b/examples/stream_from_microphone.py new file mode 100644 index 0000000..ade0a07 --- /dev/null +++ b/examples/stream_from_microphone.py @@ -0,0 +1,85 @@ +import asyncio +import io +import ssl +import sys + +import pyaudio + +from speechmatics_flow.client import WebsocketClient +from speechmatics_flow.models import ( + ConnectionSettings, + Interaction, + AudioSettings, + ConversationConfig, + ServerMessageType, +) + +AUTH_TOKEN = "YOUR TOKEN HERE" + + +# Create a websocket client +ssl_context = ssl.create_default_context() +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE +client = WebsocketClient( + ConnectionSettings( + url="wss://flow.api.speechmatics.com/v1/flow", + auth_token=AUTH_TOKEN, + ssl_context=None, + ) +) + +# Create a buffer to store binary messages sent from the server +audio_buffer = io.BytesIO() + + +# Create callback function which adds binary messages to audio buffer +def binary_msg_handler(msg: bytes): + if isinstance(msg, (bytes, bytearray)): + audio_buffer.write(msg) + + +# Register the callback to be called when the client receives an audio message from the server +client.add_event_handler(ServerMessageType.audio, binary_msg_handler) + + +async def audio_playback(): + """Read from buffer and play audio back to the user""" + p = pyaudio.PyAudio() + stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, output=True) + try: + while True: + # Get the current value from the buffer + audio_to_play = audio_buffer.getvalue() + # Only proceed if there is audio data to play + if audio_to_play: + # Write the audio to the stream + stream.write(audio_to_play) + audio_buffer.seek(0) + audio_buffer.truncate(0) + # Pause briefly before checking the buffer again + await asyncio.sleep(0.05) + finally: + stream.close() + stream.stop_stream() + p.terminate() + + +async def main(): + tasks = [ + # Use the websocket to connect to Flow Service and start a conversation + asyncio.create_task( + client.run( + interactions=[Interaction(sys.stdin.buffer)], + audio_settings=AudioSettings(), + conversation_config=ConversationConfig(), + ) + ), + # Run audio playback handler which streams audio from audio buffer + asyncio.create_task(audio_playback()), + ] + + await asyncio.gather(*tasks) + + +asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0e7055b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.ruff] +# Allow lines to be as long as 120 characters. +line-length = 120 diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..1a60281 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = -ra --full-trace --cov=speechmatics_flow --cov-branch -o asyncio_mode=auto \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..dc8df89 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,6 @@ +pytest==7.1.1 +pytest-mock==3.7.0 +black==22.3.0 +ruff==0.0.280 +pre-commit==2.21.0 +pytest-cov==3.0.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e87dadc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +httpx==0.27.1 +pyaudio==0.2.14 +websockets>=10 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7cd8bfd --- /dev/null +++ b/setup.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +Setuptools configuration for the Speechmatics Flow client +""" + +import logging +import os + +from setuptools import setup, find_packages + + +def read(fname): + """ + Load content of the file with path relative to where setup.py is. + + Args: + fname (str): file name (or path relative to the project root) + + Returns: + str: file content + """ + fpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), fname) + with open(fpath, encoding="utf-8") as path: + return path.read() + + +def read_list(fname): + """ + Load the content of the file and split it into a list of lines. + + Args: + fname (str): file name (or path relative to the project root) + + Returns: + List[str]: file content (one string per line) with end-of-lines + characters stripped off and empty lines filtered out + """ + content = read(fname) + retval = list(filter(None, content.split("\n"))) + + return retval + + +def get_version(fname): + """ + Retrieve the version from the VERSION file. + + Args: + fname (str): file containing only the version + + Returns: + str: version with whitespace characters stripped off + """ + return read(fname).strip() + + +logging.basicConfig(level=logging.INFO) +setup( + name="speechmatics-flow", + version=os.getenv("VERSION", get_version("VERSION")), + packages=find_packages(exclude=["tests"]), + url="https://github.com/speechmatics/speechmatics-flow/", + license="MIT", + author="Speechmatics", + author_email="support@speechmatics.com", + description="Python library and CLI for Speechmatics Flow API", + long_description=read("README.md"), + long_description_content_type="text/markdown", + install_requires=read_list("requirements.txt"), + tests_require=read_list("requirements-dev.txt"), + entry_points={ + "console_scripts": ["speechmatics-flow = speechmatics_flow.cli:main"] + }, + project_urls={ + "Documentation": "https://speechmatics.github.io/speechmatics-flow/", + "Source Code": "https://github.com/speechmatics/speechmatics-flow/", + }, + classifiers=[ + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + include_package_data=True, + python_requires=">=3.7", +) diff --git a/speechmatics_flow/__init__.py b/speechmatics_flow/__init__.py new file mode 100644 index 0000000..49b82d3 --- /dev/null +++ b/speechmatics_flow/__init__.py @@ -0,0 +1,3 @@ +from speechmatics_flow.client import * # noqa +from speechmatics_flow.exceptions import * # noqa +from speechmatics_flow.models import * # noqa diff --git a/speechmatics_flow/cli.py b/speechmatics_flow/cli.py new file mode 100755 index 0000000..7548a20 --- /dev/null +++ b/speechmatics_flow/cli.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# (c) 2024, Cantab Research Ltd. +""" +Command-line interface +""" + +import json +import logging +import ssl +import sys +from dataclasses import dataclass +from socket import gaierror +from typing import Any, Dict + +import httpx +from websockets.exceptions import WebSocketException + +from speechmatics_flow.cli_parser import parse_args +from speechmatics_flow.client import WebsocketClient +from speechmatics_flow.exceptions import TranscriptionError +from speechmatics_flow.models import ( + AudioSettings, + ConversationConfig, + ServerMessageType, + Interaction, + ConnectionSettings, +) + +LOGGER = logging.getLogger(__name__) + + +@dataclass +class Transcripts: + user_transcript: str = "" + agent_transcript: str = "" + + +def get_log_level(verbosity): + """ + Returns the appropriate log level given a verbosity level. + + :param verbosity: Verbosity level. + :type verbosity: int + + :return: The logging level (eg. logging.INFO). + :rtype: int + + :raises SystemExit: If the given verbosity level is invalid. + """ + try: + log_level = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}[verbosity] + + return log_level + except KeyError as error: + key = int(str(error)) + raise SystemExit( + f"Only supports 2 log levels eg. -vv, you are asking for " f"-{'v' * key}" + ) from error + + +def get_connection_settings(args): + """ + Helper function which returns a ConnectionSettings object based on the + command line options given to the program. + + :param args: Keyword arguments, typically from the command line. + :type args: dict + + :return: Settings for the WebSocket connection. + :rtype: speechmatics_flow.models.ConnectionSettings + """ + auth_token = args.get("auth_token") + url = args.get("url") + generate_temp_token = args.get("generate_temp_token") + settings = ConnectionSettings( + url=url, + auth_token=auth_token, + generate_temp_token=generate_temp_token, + ) + + if args.get("buffer_size") is not None: + settings.message_buffer_size = args["buffer_size"] + + if args.get("ssl_mode") == "insecure": + settings.ssl_context.check_hostname = False + settings.ssl_context.verify_mode = ssl.CERT_NONE + elif args.get("ssl_mode") == "none": + settings.ssl_context = None + + return settings + + +def get_conversation_config( + args, +): + """ + Helper function which returns a ConversationConfig object based on the + command line options given to the program. + + :param args: Keyword arguments probably from the command line. + :type args: Dict + + :return: Settings for the ASR engine. + :rtype: flow.models.ConversationConfig + """ + + config: Dict[str, Any] = {} + if args.get("config_file"): + with open(args["config_file"], encoding="utf-8") as config_file: + config = json.load(config_file) + + if config.get("conversation_config"): + config.update(config.pop("conversation_config")) + + return ConversationConfig(**config) + + +def get_audio_settings(args): + """ + Helper function which returns an AudioSettings object based on the command + line options given to the program. + + Args: + args (dict): Keyword arguments, typically from the command line. + + Returns: + flow.models.AudioSettings: Settings for the audio stream + in the connection. + """ + settings = AudioSettings( + sample_rate=args.get("sample_rate"), + chunk_size=args.get("chunk_size"), + encoding=args.get("raw"), + ) + return settings + + +# pylint: disable=too-many-arguments,too-many-statements +def add_printing_handlers( + api, + transcripts, + print_json=False, +): + """ + Adds a set of handlers to the websocket client which print out transcripts + as they are received. This includes partials if they are enabled. + + Args: + api (client.WebsocketClient): Client instance. + transcripts (Transcripts): Allows the transcripts to be concatenated to + produce a final result. + print_json (bool, optional): Whether to print json transcript messages. + """ + escape_seq = "\33[2K" if sys.stdout.isatty() else "" + + def convert_to_txt(message): + if print_json: + print(json.dumps(message)) + return + transcript = message["metadata"]["transcript"] + return transcript.replace(" ", "").replace("", "").replace("", "") + + def transcript_handler(message): + plaintext = convert_to_txt(message) + if plaintext: + sys.stdout.write(f"{escape_seq}{plaintext}\n") + transcripts.user_transcript += plaintext + + def partial_transcript_handler(message): + plaintext = convert_to_txt(message) + if plaintext: + sys.stderr.write(f"{escape_seq}{plaintext}\r") + + def prompt_handler(message): + if print_json: + print(json.dumps(message)) + return + new_response = message["prompt"]["response"] + new_plaintext_response = new_response.replace(" ", "").replace(" ", "") + if new_plaintext_response: + sys.stdout.write(f"{escape_seq}{new_plaintext_response}\n") + transcripts.user_transcript += new_plaintext_response + + def end_of_transcript_handler(_): + print("\n", file=sys.stderr) + + api.add_event_handler(ServerMessageType.prompt, prompt_handler) + api.add_event_handler(ServerMessageType.AddTranscript, transcript_handler) + api.add_event_handler( + ServerMessageType.AddPartialTranscript, partial_transcript_handler + ) + api.add_event_handler(ServerMessageType.EndOfTranscript, end_of_transcript_handler) + + +# pylint: disable=too-many-branches +# pylint: disable=too-many-statements +def main(args=None): + """ + Main entrypoint. + + :param args: command-line arguments; defaults to None in which + case arguments will be retrieved from `sys.argv` (this is useful + mainly for unit tests). + :type args: List[str] + """ + if not args: + args = vars(parse_args()) + + logging.basicConfig(level=get_log_level(args["verbose"])) + LOGGER.info("Args: %s", args) + + try: + flow_main(args) + except (KeyboardInterrupt, ValueError, TranscriptionError, KeyError) as error: + LOGGER.info(error, exc_info=True) + sys.exit(f"{type(error).__name__}: {error}") + except FileNotFoundError as error: + LOGGER.info(error, exc_info=True) + sys.exit( + f"FileNotFoundError: {error.strerror}: '{error.filename}'." + + " Check to make sure the filename is spelled correctly, and that the file exists." + ) + except httpx.HTTPStatusError as error: + LOGGER.info(error, exc_info=True) + sys.exit(error.response.text) + except httpx.HTTPError as error: + LOGGER.info(error, exc_info=True) + sys.exit(f"httpx.HTTPError: An unexpected http error occurred. {error}") + except ConnectionResetError as error: + LOGGER.info(error, exc_info=True) + sys.exit( + f"ConnectionResetError: {error}.\n\nThe most likely reason for this is that the client " + + "has been configured to use SSL but the server does not support SSL. " + + "If this is the case then try using --ssl-mode=none" + ) + except (WebSocketException, gaierror) as error: + LOGGER.info(error, exc_info=True) + sys.exit( + f"WebSocketError: An unexpected error occurred in the websocket: {error}.\n\n" + + "Check that the url and config provided is valid, " + + "and that the language in the url matches the config.\n" + ) + + +def flow_main(args): + """Main dispatch for "flow" mode commands. + + :param args: arguments from parse_args() + :type args: argparse.Namespace + """ + conversation_config = get_conversation_config(args) + settings = get_connection_settings(args) + api = WebsocketClient(settings) + transcripts = Transcripts() + add_printing_handlers( + api, + transcripts, + print_json=args["print_json"], + ) + + def run(stream): + try: + api.run_synchronously( + [Interaction(stream)], + get_audio_settings(args), + conversation_config, + from_cli=True, + ) + except KeyboardInterrupt: + # Gracefully handle Ctrl-C, else we get a huge stack-trace. + LOGGER.warning("Keyboard interrupt received.") + + run(sys.stdin.buffer) + + +if __name__ == "__main__": + main() diff --git a/speechmatics_flow/cli_parser.py b/speechmatics_flow/cli_parser.py new file mode 100644 index 0000000..ecd9e68 --- /dev/null +++ b/speechmatics_flow/cli_parser.py @@ -0,0 +1,129 @@ +# (c) 2024, Cantab Research Ltd. +""" +Parsers used by the CLI to handle CLI arguments +""" +import argparse +import logging + +LOGGER = logging.getLogger(__name__) + + +# pylint: disable=too-many-locals +# pylint: disable=too-many-statements +def get_arg_parser(): + """ + Creates a command-line argument parser objct + + :return: The argparser object with all commands and subcommands. + :rtype: argparse.ArgumentParser + """ + parser = argparse.ArgumentParser(description="CLI for Speechmatics Flow API.") + parser.add_argument( + "-v", + dest="verbose", + action="count", + default=0, + help=( + "Set the log level for verbose logs. " + "The number of flags indicate the level, eg. " + "-v is INFO and -vv is DEBUG." + ), + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help=( + "Prints useful symbols to represent the messages on the wire. " + "Symbols are printed to STDERR, use only when STDOUT is " + "redirected to a file." + ), + ) + parser.add_argument( + "--url", + type=str, + help="Websocket for Flow API URL (e.g. wss://flow.api.speechmatics.com/v1/flow", + ) + parser.add_argument( + "--auth-token", + type=str, + help="Authentication token to authorize the client.", + ) + parser.add_argument( + "--generate-temp-token", + default=True, + action="store_true", + help="Automatically generate a temporary token for authentication.", + ) + parser.add_argument( + "--ssl-mode", + default="regular", + choices=["regular", "insecure", "none"], + help=( + "Use a preset configuration for the SSL context. With `regular` " + "mode a valid certificate is expected. With `insecure` mode" + " a self signed certificate is allowed." + " With `none` then SSL is not used." + ), + ) + parser.add_argument( + "--raw", + metavar="ENCODING", + type=str, + default="pcm_s16le", + help=( + "Indicate that the input audio is raw, provide the encoding" + "of this raw audio, eg. pcm_f32le." + ), + ) + parser.add_argument( + "--sample-rate", + type=int, + default=16_000, + help="The sample rate in Hz of the input audio, if in raw format.", + ) + parser.add_argument( + "--chunk-size", + type=int, + default=256, + help=( + "How much audio data, in bytes, to send to the server in each " + "websocket message. Larger values can increase latency, but " + "values which are too small create unnecessary overhead." + ), + ) + parser.add_argument( + "--buffer-size", + default=512, + type=int, + help=( + "Maximum number of messages to send before waiting for" + "acknowledgements from the server." + ), + ) + parser.add_argument( + "--print-json", + default=False, + action="store_true", + help=( + "Print the JSON partial & final transcripts received rather than " + "plaintext messages." + ), + ) + + return parser + + +def parse_args(args=None): + """ + Parses command-line arguments. + + :param args: List of arguments to parse. + :type args: (List[str], optional) + + :return: The set of arguments provided along with their values. + :rtype: Namespace + """ + parser = get_arg_parser() + parsed_args = parser.parse_args(args=args) + return parsed_args diff --git a/speechmatics_flow/client.py b/speechmatics_flow/client.py new file mode 100644 index 0000000..8173523 --- /dev/null +++ b/speechmatics_flow/client.py @@ -0,0 +1,506 @@ +# (c) 2024, Cantab Research Ltd. +""" +Wrapper library to interface with Flow Service API. +""" + +import asyncio +import copy +import json +import logging +import os +from typing import List + +import httpx +import pyaudio +import websockets + +from speechmatics_flow.exceptions import ( + ConversationEndedException, + EndOfTranscriptException, + ForceEndSession, + TranscriptionError, +) +from speechmatics_flow.models import ( + ClientMessageType, + ServerMessageType, + AudioSettings, + ConversationConfig, + Interaction, + ConnectionSettings, +) +from speechmatics_flow.utils import read_in_chunks, json_utf8 + +LOGGER = logging.getLogger(__name__) + +# If the logging level is set to DEBUG websockets logs very verbosely, +# including a hex dump of every message being sent. Setting the websockets +# logger at INFO level specifically prevents this spam. +logging.getLogger("websockets.protocol").setLevel(logging.INFO) + + +class WebsocketClient: + """ + Manage a conversation session with the agent. + + The best way to interact with this library is to instantiate this client + and then add a set of handlers to it. Handlers respond to particular types + of messages received from the server. + + :param connection_settings: Settings for the WebSocket connection, + including the URL of the server. + :type connection_settings: models.ConnectionSettings + """ + + # pylint: disable=too-many-instance-attributes + + def __init__( + self, + connection_settings: ConnectionSettings = None, + ): + self.connection_settings = connection_settings + self.websocket = None + self.conversation_config = None + self.audio_settings = None + + self.event_handlers = {x: [] for x in ServerMessageType} + self.middlewares = {x: [] for x in ClientMessageType} + + self.seq_no = 0 + self.session_running = False + self._language_pack_info = None + self._transcription_config_needs_update = False + self._session_needs_closing = False + self._audio_buffer = None + + # The following asyncio fields are fully instantiated in + # _init_synchronization_primitives + self._conversation_started = asyncio.Event + # Semaphore used to ensure that we don't send too much audio data to + # the server too quickly and burst any buffers downstream. + self._buffer_semaphore = asyncio.BoundedSemaphore + + async def _init_synchronization_primitives(self): + """ + Used to initialise synchronization primitives that require + an event loop + """ + self._conversation_started = asyncio.Event() + self._buffer_semaphore = asyncio.BoundedSemaphore( + self.connection_settings.message_buffer_size + ) + + def _flag_conversation_started(self): + """ + Handle a + :py:attr:`models.ClientMessageType.ConversationStarted` + message from the server. + This updates an internal flag to mark the session started + as started meaning, AddAudio is now allowed. + """ + self._conversation_started.set() + + @json_utf8 + def _start_conversation(self): + """ + Constructs a + :py:attr:`models.ClientMessageType.StartConversation` + message. + This initiates the conversation session. + """ + assert self.conversation_config is not None + msg = { + "message": ClientMessageType.StartConversation, + "audio_format": self.audio_settings.asdict(), + "conversation_config": self.conversation_config.asdict(), + } + self.session_running = True + self._call_middleware(ClientMessageType.StartConversation, msg, False) + LOGGER.debug(msg) + return msg + + @json_utf8 + def _end_of_audio(self): + """ + Constructs an + :py:attr:`models.ClientMessageType.AudioEnded` + message. + """ + msg = {"message": ClientMessageType.AudioEnded, "last_seq_no": self.seq_no} + self._call_middleware(ClientMessageType.AudioEnded, msg, False) + LOGGER.debug(msg) + return msg + + async def _consumer(self, message, from_cli: False): + """ + Consumes messages and acts on them. + + :param message: Message received from the server. + :type message: str + + :raises TranscriptionError: on an error message received from the + server after the Session started. + :raises EndOfTranscriptException: on EndOfTranscription message. + :raises ForceEndSession: If this was raised by the user's event + handler. + """ + LOGGER.debug(message) + if isinstance(message, (bytes, bytearray)): + # add an audio message to local buffer only when running from cli + if from_cli: + await self._audio_buffer.put(message) + # Flow service does not send message_type with binary data, + # so we need to set it here for event_handler to work + message_type = ServerMessageType.audio + else: + message = json.loads(message) + message_type = message.get("message") + + if message_type is None: + return + + for handler in self.event_handlers[message_type]: + try: + handler(copy.deepcopy(message)) + except ForceEndSession: + LOGGER.warning("Session was ended forcefully by an event handler") + raise + + if message_type == ServerMessageType.ConversationStarted: + self._flag_conversation_started() + elif message_type == ServerMessageType.AudioAdded: + self._buffer_semaphore.release() + elif message_type == ServerMessageType.ConversationEnded: + raise ConversationEndedException() + elif message_type == ServerMessageType.EndOfTranscript: + raise EndOfTranscriptException() + elif message_type == ServerMessageType.Warning: + LOGGER.warning(message["reason"]) + elif message_type == ServerMessageType.Error: + raise TranscriptionError(message["reason"]) + + async def _read_from_microphone(self): + p = pyaudio.PyAudio() + print(f"Default input device: {p.get_default_input_device_info()['name']}") + print(f"Default output device: {p.get_default_output_device_info()['name']}") + print("Start speaking...") + stream = p.open( + format=pyaudio.paInt16, + channels=1, + rate=self.audio_settings.sample_rate, + input=True, + ) + try: + while True: + if self._session_needs_closing: + break + + await asyncio.wait_for( + self._buffer_semaphore.acquire(), + timeout=self.connection_settings.semaphore_timeout_seconds, + ) + + # audio_chunk size is 128 * 2 = 256 bytes which is about 8ms + audio_chunk = stream.read(num_frames=128, exception_on_overflow=False) + + self.seq_no += 1 + self._call_middleware(ClientMessageType.AddAudio, audio_chunk, True) + await self.websocket.send(audio_chunk) + finally: + await self.websocket.send(self._end_of_audio()) + stream.stop_stream() + stream.close() + p.terminate() + + async def _consumer_handler(self, from_cli: False): + """ + Controls the consumer loop for handling messages from the server. + + raises: ConnectionClosedError when the upstream closes unexpectedly + """ + while self.session_running: + try: + message = await self.websocket.recv() + except websockets.exceptions.ConnectionClosedOK: + # Can occur if a timeout has closed the connection. + LOGGER.info("Cannot receive from closed websocket.") + return + except websockets.exceptions.ConnectionClosedError as ex: + LOGGER.info("Disconnected while waiting for recv().") + raise ex + await self._consumer(message, from_cli) + + async def _stream_producer(self, stream, audio_chunk_size): + async for audio_chunk in read_in_chunks(stream, audio_chunk_size): + if self._session_needs_closing: + break + + await asyncio.wait_for( + self._buffer_semaphore.acquire(), + timeout=self.connection_settings.semaphore_timeout_seconds, + ) + + self.seq_no += 1 + self._call_middleware(ClientMessageType.AddAudio, audio_chunk, True) + yield audio_chunk + + async def _producer_handler(self, interactions: List[Interaction]): + """ + Controls the producer loop for sending messages to the server. + """ + await self._conversation_started.wait() + + if interactions[0].stream.name == "": + return await self._read_from_microphone() + + for interaction in interactions: + async for message in self._stream_producer( + interaction.stream, self.audio_settings.chunk_size + ): + try: + await self.websocket.send(message) + except Exception as e: + LOGGER.error(f"error sending message: {e}") + return + if interaction.callback: + interaction.callback(self) + + await self.websocket.send(self._end_of_audio()) + + async def _playback_handler(self): + """ + Reads audio binary messages from the playback buffer and plays them to the user. + """ + p = pyaudio.PyAudio() + stream = p.open( + format=pyaudio.paInt16, + channels=1, + rate=self.audio_settings.sample_rate, + output=True, + ) + try: + while True: + if self._session_needs_closing: + break + try: + audio_message = await self._audio_buffer.get() + stream.write(audio_message) + self._audio_buffer.task_done() + except Exception as e: + LOGGER.error(f"Error during audio playback: {e}") + raise e + finally: + stream.close() + stream.stop_stream() + p.terminate() + LOGGER.debug("Exiting playback handler") + + def _call_middleware(self, event_name, *args): + """ + Call the middlewares attached to the client for the given event name. + + :raises ForceEndSession: If this was raised by the user's middleware. + """ + for middleware in self.middlewares[event_name]: + try: + middleware(*args) + except ForceEndSession: + LOGGER.warning("Session was ended forcefully by a middleware") + raise + + def add_event_handler(self, event_name, event_handler): + """ + Add an event handler (callback function) to handle an incoming + message from the server. Event handlers are passed a copy of the + incoming message from the server. If `event_name` is set to 'all' then + the handler will be added for every event. + + For example, a simple handler that just LOGGER.debugs out the + :py:attr:`models.ServerMessageType.audio` + messages received: + + >>> client = WebsocketClient( + ConnectionSettings(url="wss://localhost:9000")) + >>> handler = lambda msg: LOGGER.debug(msg) + >>> client.add_event_handler(ServerMessageType.audio, handler) + + :param event_name: The name of the message for which a handler is + being added. Refer to + :py:class:`models.ServerMessageType` for a list + of the possible message types. + :type event_name: str + + :param event_handler: A function to be called when a message of the + given type is received. + :type event_handler: Callable[[dict], None] + + :raises ValueError: If the given event name is not valid. + """ + if event_name == "all": + for name in self.event_handlers.keys(): + self.event_handlers[name].append(event_handler) + elif event_name not in self.event_handlers: + raise ValueError( + f"Unknown event name: {event_name!r}, expected to be " + f"'all' or one of {list(self.event_handlers.keys())}." + ) + else: + self.event_handlers[event_name].append(event_handler) + + def add_middleware(self, event_name, middleware): + """ + Add middleware to handle outgoing messages sent to the server. + Middlewares are passed a reference to the outgoing message, which + they may alter. + If `event_name` is set to 'all' then the handler will be added for + every event. + + :param event_name: The name of the message for which middleware is + being added. Refer to the V2 API docs for a list of the possible + message types. + :type event_name: str + + :param middleware: A function to be called to process an outgoing + message of the given type. The function receives the message as + the first argument and a second, boolean argument indicating + whether the message is binary data (which implies it is an + AddAudio message). + :type middleware: Callable[[dict, bool], None] + + :raises ValueError: If the given event name is not valid. + """ + if event_name == "all": + for name in self.middlewares.keys(): + self.middlewares[name].append(middleware) + elif event_name not in self.middlewares: + raise ValueError( + ( + f"Unknown event name: {event_name}, expected to be 'all'" + f"or one of {list(self.middlewares.keys())}." + ) + ) + else: + self.middlewares[event_name].append(middleware) + + async def _communicate(self, interactions: List[Interaction], from_cli=False): + """ + Create a producer/consumer for transcription messages and + communicate with the server. + Internal method called from _run. + """ + try: + start_conversation_msg = self._start_conversation() + except ForceEndSession: + return + await self.websocket.send(start_conversation_msg) + + tasks = [ + asyncio.create_task(self._consumer_handler(from_cli)), + asyncio.create_task(self._producer_handler(interactions)), + ] + + # Run the playback task that plays audio messages to the user when started from cli + if from_cli: + self._audio_buffer = asyncio.Queue() + tasks.append(asyncio.create_task(self._playback_handler())) + + (done, pending) = await asyncio.wait( + tasks, + return_when=asyncio.FIRST_EXCEPTION, + ) + + # If a task is pending, the other one threw an exception, so tidy up + for task in pending: + task.cancel() + + for task in done: + exc = task.exception() + if exc and not isinstance( + exc, + ( + EndOfTranscriptException, + ForceEndSession, + ConversationEndedException, + ), + ): + raise exc + + async def run( + self, + interactions: List[Interaction], + audio_settings: AudioSettings = AudioSettings(), + conversation_config: ConversationConfig = None, + from_cli: bool = False, + ): + """ + Begin a new recognition session. + This will run asynchronously. Most callers may prefer to use + :py:meth:`run_synchronously` which will block until the session is + finished. + + :param interactions: A list of interactions with FlowService API. + :type interactions: List[Interaction] + + :param audio_settings: Configuration for the audio stream. + :type audio_settings: models.AudioSettings + + :param conversation_config: Configuration for the conversation. + :type conversation_config: models.ConversationConfig + + :raises Exception: Can raise any exception returned by the + consumer/producer tasks. + """ + self.seq_no = 0 + self._language_pack_info = None + self.conversation_config = conversation_config + self.audio_settings = audio_settings + + await self._init_synchronization_primitives() + + extra_headers = {} + auth_token = await get_temp_token(self.connection_settings.auth_token) + extra_headers["Authorization"] = f"Bearer {auth_token}" + try: + async with websockets.connect( # pylint: disable=no-member + self.connection_settings.url, + ssl=self.connection_settings.ssl_context, + ping_timeout=self.connection_settings.ping_timeout_seconds, + # Don't limit the max. size of incoming messages + max_size=None, + extra_headers=extra_headers, + ) as self.websocket: + await self._communicate(interactions, from_cli) + finally: + self.session_running = False + self._session_needs_closing = False + self.websocket = None + + def stop(self): + """ + Indicates that the recognition session should be forcefully stopped. + Only used in conjunction with `run`. + You probably don't need to call this if you're running the client via + :py:meth:`run_synchronously`. + """ + self._session_needs_closing = True + + def run_synchronously(self, *args, timeout=None, **kwargs): + """ + Run the transcription synchronously. + :raises asyncio.TimeoutError: If the given timeout is exceeded. + """ + # pylint: disable=no-value-for-parameter + asyncio.run(asyncio.wait_for(self.run(*args, **kwargs), timeout=timeout)) + + +async def get_temp_token(api_key): + """ + Used to get a temporary token from management platform api + """ + mp_api_url = os.getenv("SM_MANAGEMENT_PLATFORM_URL", "https://mp.speechmatics.com") + endpoint = f"{mp_api_url}/v1/api_keys?type=flow" + body = {"ttl": 300, "client_ref": "speechmatics-flow-python-client"} + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + response = httpx.post(endpoint, json=body, headers=headers) + response.raise_for_status() + data = response.json() + return data["key_value"] diff --git a/speechmatics_flow/exceptions.py b/speechmatics_flow/exceptions.py new file mode 100644 index 0000000..d4d9e37 --- /dev/null +++ b/speechmatics_flow/exceptions.py @@ -0,0 +1,29 @@ +# (c) 2024, Cantab Research Ltd. +""" +Exceptions and errors used by the library. +""" + + +class TranscriptionError(Exception): + """ + Indicates an error in transcription. + """ + + +class EndOfTranscriptException(Exception): + """ + Indicates that the transcription session has finished. + """ + + +class ForceEndSession(Exception): + """ + Can be raised by the user from a middleware or event handler + to force the transcription session to end early. + """ + + +class ConversationEndedException(Exception): + """ + Indicates the session ended. + """ diff --git a/speechmatics_flow/models.py b/speechmatics_flow/models.py new file mode 100644 index 0000000..ac2a629 --- /dev/null +++ b/speechmatics_flow/models.py @@ -0,0 +1,143 @@ +# (c) 2024, Cantab Research Ltd. +""" +Data models and message types used by the library. +""" + +import io +import ssl +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Callable, Dict, Optional, Literal + + +@dataclass +class AudioSettings: + """Defines audio parameters.""" + + encoding: str = "pcm_s16le" + """Encoding format when raw audio is used. Allowed values are + `pcm_f32le` and `pcm_s16le`""" + + sample_rate: int = 16000 + """Sampling rate in hertz.""" + + chunk_size: int = 256 + """Chunk size in bytes.""" + + def asdict(self): + return { + "type": "raw", + "encoding": self.encoding, + "sample_rate": self.sample_rate, + } + + +@dataclass +class ConnectionSettings: + """Defines connection parameters.""" + + url: str + """Websocket server endpoint.""" + + message_buffer_size: int = 512 + """Message buffer size in bytes.""" + + ssl_context: ssl.SSLContext = field(default_factory=ssl.create_default_context) + """SSL context.""" + + semaphore_timeout_seconds: float = 120 + """Semaphore timeout in seconds.""" + + ping_timeout_seconds: float = 60 + """Ping-pong timeout in seconds.""" + + auth_token: Optional[str] = None + """auth token to authenticate a customer.""" + + generate_temp_token: Optional[bool] = True + """Automatically generate a temporary token for authentication.""" + + +@dataclass +class ConversationConfig: + """Defines configuration parameters for conversation requests.""" + + template_id: Literal[ + "default", "flow-service-assistant-amelia", "flow-service-assistant-humphrey" + ] = "default" + """Name of a predefined template.""" + + template_variables: Optional[Dict[str, str]] = None + """Optional parameter to allow overriding the default values of variables defined in the template.""" + + def asdict(self): + return asdict( + self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None} + ) + + +class ClientMessageType(str, Enum): + # pylint: disable=invalid-name + """Defines various messages sent from client to server.""" + + StartConversation = "StartConversation" + """Initiates a conversation job based on configuration set previously.""" + + AddAudio = "AddAudio" + """Adds more audio data to the recognition job. The server confirms + receipt by sending an :py:attr:`ServerMessageType.AudioAdded` message.""" + + AudioEnded = "AudioEnded" + """Indicates audio input has finished.""" + + +class ServerMessageType(str, Enum): + # pylint: disable=invalid-name + """Defines various message types sent from server to client.""" + + ConversationStarted = "ConversationStarted" + """Server response to :py:attr:`ClientMessageType.StartConversation`, + acknowledging that a conversation session has started.""" + + AddPartialTranscript = "AddPartialTranscript" + """Indicates a partial transcript, which is an incomplete transcript that + is immediately produced and may change as more context becomes available. + """ + + AudioAdded = "AudioAdded" + """Server response to :py:attr:`ClientMessageType.AddAudio`, indicating + that audio has been added successfully.""" + + AddTranscript = "AddTranscript" + """Indicates the final transcript of a part of the audio.""" + + audio = "audio" + """Message contains binary data""" + + prompt = "prompt" + """Message contains text data""" + + ConversationEnded = "ConversationEnded" + """Message indicates the session ended.""" + + EndOfTranscript = "EndOfTranscript" + """Server response to :py:attr:`ClientMessageType.EndOfStream`, + after the server has finished sending all :py:attr:`AddTranscript` + messages.""" + + Info = "Info" + """Indicates a generic info message.""" + + Warning = "Warning" + """Indicates a generic warning message.""" + + Error = "Error" + """Indicates n generic error message.""" + + +@dataclass +class Interaction: + """Defines various interactions between client and server.""" + + stream: io.BufferedReader + callback: Optional[Callable] = None diff --git a/speechmatics_flow/utils.py b/speechmatics_flow/utils.py new file mode 100644 index 0000000..cf78315 --- /dev/null +++ b/speechmatics_flow/utils.py @@ -0,0 +1,52 @@ +# (c) 2024, Cantab Research Ltd. +""" +Helper functions used by the library. +""" + +import asyncio +import concurrent.futures +import inspect +import json + + +def json_utf8(func): + """A decorator to turn a function's return value into JSON""" + + def wrapper(*args, **kwargs): + """wrapper""" + return json.dumps(func(*args, **kwargs)) + + return wrapper + + +async def read_in_chunks(stream, chunk_size): + """ + Utility method for reading in and yielding chunks + + :param stream: file-like object to read audio from + :type stream: io.IOBase + + :param chunk_size: maximum chunk size in bytes + :type chunk_size: int + + :raises ValueError: if no data was read from the stream + + :return: a sequence of chunks of data where the length in bytes of each + chunk is <= max_sample_size and a multiple of max_sample_size + :rtype: collections.AsyncIterable + + """ + while True: + # Work with both async and synchronous file readers. + if inspect.iscoroutinefunction(stream.read): + audio_chunk = await stream.read(chunk_size) + else: + # Run the read() operation in a separate thread to avoid blocking the event loop. + with concurrent.futures.ThreadPoolExecutor() as executor: + audio_chunk = await asyncio.get_event_loop().run_in_executor( + executor, stream.read, chunk_size + ) + + if not audio_chunk: + break + yield audio_chunk diff --git a/tests/data/example.wav b/tests/data/example.wav new file mode 100644 index 0000000..0956fab --- /dev/null +++ b/tests/data/example.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd31c48cac8f24d3c18469438f64dc555c911eb2a0d915d3c0b477744d83e1aa +size 1533868 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..e11060d --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,41 @@ +import pytest + +from speechmatics_flow import models + +TEMPLATE_VARS = { + "persona": "You are an aging English butler named Humphrey.", + "style": "Be charming but unpredictable.", + "context": "You are taking a customer's order for fast food.", +} + + +@pytest.mark.parametrize( + "config, want", + [ + ({}, {"type": "raw", "encoding": "pcm_s16le", "sample_rate": 16000}), + ( + {"encoding": "pcm_f32le", "sample_rate": 44100}, + {"type": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}, + ), + ], +) +def test_audio_settings(config, want): + audio_settings = models.AudioSettings(**config) + got = audio_settings.asdict() + assert got == want + + +@pytest.mark.parametrize( + "config, want", + [ + ({}, {"template_id": "default"}), + ( + {"template_id": "test", "template_variables": TEMPLATE_VARS}, + {"template_id": "test", "template_variables": TEMPLATE_VARS}, + ), + ], +) +def test_conversation_config(config, want): + conversation_config = models.ConversationConfig(**config) + got = conversation_config.asdict() + assert got == want