diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..5eaed6b
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,3 @@
+[report]
+exclude_lines =
+ if __name__ == .__main__.:
\ No newline at end of file
diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..45655fa
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+*.mp3 filter=lfs diff=lfs merge=lfs -text
+*.wav filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 0000000..e846fb5
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,27 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: bug
+assignees: ''
+
+---
+
+**Describe the bug**
+A clear and concise description of what the bug is.
+
+**To Reproduce**
+Steps to reproduce the behavior:
+1. Go to '...'
+2. Click on '....'
+3. Scroll down to '....'
+4. See error
+
+**Expected behavior**
+A clear and concise description of what you expected to happen.
+
+**Screenshots/Logs**
+If applicable, add screenshots or logs to help explain your problem.
+
+**Additional context**
+Add any other context about the problem here.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000..d9baa67
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,20 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: ''
+labels: 'request'
+assignees: ''
+
+---
+
+**Is your feature request related to a problem? Please describe.**
+A clear and concise description of what the problem is, e.g. I'm always frustrated when [...]
+
+**Describe the solution you'd like**
+A clear and concise description of what you want to happen.
+
+**Describe alternatives you've considered**
+A clear and concise description of any alternative solutions or features you've considered.
+
+**Additional context**
+Add any other context or screenshots about the feature request here.
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 0000000..be73c8e
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,33 @@
+# This workflow will upload a Python Package using Twine when a release is created
+# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
+
+name: Publish
+
+on:
+ release:
+ types: [released]
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.x'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install setuptools wheel twine
+
+ - name: Build and publish to PyPI
+ env:
+ TWINE_USERNAME: '__token__'
+ TWINE_PASSWORD: ${{ secrets.PYPI_ORG_TOKEN }}
+ run: |
+ make build
+ twine upload dist/*
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
new file mode 100644
index 0000000..7cd7c16
--- /dev/null
+++ b/.github/workflows/tests.yml
@@ -0,0 +1,39 @@
+# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
+# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
+
+name: Tests
+
+on: [ push, pull_request ]
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [ "3.10", "3.11" ]
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Install portaudio
+ run: |
+ sudo apt-get update
+ sudo apt-get install portaudio19-dev
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v3
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install dependencies
+ run: pip3 install -r requirements.txt -r requirements-dev.txt
+
+ - name: Install package
+ run: python3 setup.py install
+
+ - name: Lint
+ run: make lint
+
+ - name: Unittest
+ run: make unittest
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..41158d0
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,11 @@
+# Changelog
+
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
+
+## [0.0.1] - 2024-10-14
+
+### Added
+
+- Add speechmatics-flow client
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..f0db3b1
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,25 @@
+SOURCES := speechmatics_flow/ tests/ setup.py
+VERSION ?= $(shell cat VERSION)
+
+.PHONY: all
+all: lint test
+
+.PHONY: lint
+lint:
+ black --check --diff $(SOURCES)
+ ruff $(SOURCES)
+
+.PHONY: format
+format:
+ black $(SOURCES)
+
+.PHONY: test
+test: unittest
+
+.PHONY: unittest
+unittest:
+ pytest -v tests
+
+.PHONY: build
+build:
+ VERSION=$(VERSION) python setup.py sdist bdist_wheel
diff --git a/README.md b/README.md
index 72eed7f..7a881c9 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,39 @@
-# speechmatics-flow
\ No newline at end of file
+# speechmatics-flow
+
+Python client library and CLI for Speechmatics' Flow Service API.
+
+## Getting started
+
+To install from PyPI:
+
+```bash
+pip install speechmatics-flow
+```
+
+To install from source:
+
+```bash
+git clone https://github.com/speechmatics/speechmatics-flow
+cd speechmatics-flow && python setup.py install
+```
+
+Windows users may need to run the install command with an extra flag:
+
+```bash
+python setup.py install --user
+```
+
+## Example command-line usage
+
+- Setting URLs for connecting to flow service. These values can be used in places of the --url flag:
+
+*Note: Requires access to microphone
+
+ ```bash
+ speechmatics-flow --url $URL --auth-token $TOKEN -
+ ```
+
+## Support
+
+If you have any issues with this library or encounter any bugs then please get in touch with us at
+support@speechmatics.com.
diff --git a/VERSION b/VERSION
new file mode 100644
index 0000000..8acdd82
--- /dev/null
+++ b/VERSION
@@ -0,0 +1 @@
+0.0.1
diff --git a/examples/stream_from_microphone.py b/examples/stream_from_microphone.py
new file mode 100644
index 0000000..ade0a07
--- /dev/null
+++ b/examples/stream_from_microphone.py
@@ -0,0 +1,85 @@
+import asyncio
+import io
+import ssl
+import sys
+
+import pyaudio
+
+from speechmatics_flow.client import WebsocketClient
+from speechmatics_flow.models import (
+ ConnectionSettings,
+ Interaction,
+ AudioSettings,
+ ConversationConfig,
+ ServerMessageType,
+)
+
+AUTH_TOKEN = "YOUR TOKEN HERE"
+
+
+# Create a websocket client
+ssl_context = ssl.create_default_context()
+ssl_context.check_hostname = False
+ssl_context.verify_mode = ssl.CERT_NONE
+client = WebsocketClient(
+ ConnectionSettings(
+ url="wss://flow.api.speechmatics.com/v1/flow",
+ auth_token=AUTH_TOKEN,
+ ssl_context=None,
+ )
+)
+
+# Create a buffer to store binary messages sent from the server
+audio_buffer = io.BytesIO()
+
+
+# Create callback function which adds binary messages to audio buffer
+def binary_msg_handler(msg: bytes):
+ if isinstance(msg, (bytes, bytearray)):
+ audio_buffer.write(msg)
+
+
+# Register the callback to be called when the client receives an audio message from the server
+client.add_event_handler(ServerMessageType.audio, binary_msg_handler)
+
+
+async def audio_playback():
+ """Read from buffer and play audio back to the user"""
+ p = pyaudio.PyAudio()
+ stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, output=True)
+ try:
+ while True:
+ # Get the current value from the buffer
+ audio_to_play = audio_buffer.getvalue()
+ # Only proceed if there is audio data to play
+ if audio_to_play:
+ # Write the audio to the stream
+ stream.write(audio_to_play)
+ audio_buffer.seek(0)
+ audio_buffer.truncate(0)
+ # Pause briefly before checking the buffer again
+ await asyncio.sleep(0.05)
+ finally:
+ stream.close()
+ stream.stop_stream()
+ p.terminate()
+
+
+async def main():
+ tasks = [
+ # Use the websocket to connect to Flow Service and start a conversation
+ asyncio.create_task(
+ client.run(
+ interactions=[Interaction(sys.stdin.buffer)],
+ audio_settings=AudioSettings(),
+ conversation_config=ConversationConfig(),
+ )
+ ),
+ # Run audio playback handler which streams audio from audio buffer
+ asyncio.create_task(audio_playback()),
+ ]
+
+ await asyncio.gather(*tasks)
+
+
+asyncio.run(main())
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..0e7055b
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,3 @@
+[tool.ruff]
+# Allow lines to be as long as 120 characters.
+line-length = 120
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..1a60281
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts = -ra --full-trace --cov=speechmatics_flow --cov-branch -o asyncio_mode=auto
\ No newline at end of file
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 0000000..dc8df89
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,6 @@
+pytest==7.1.1
+pytest-mock==3.7.0
+black==22.3.0
+ruff==0.0.280
+pre-commit==2.21.0
+pytest-cov==3.0.0
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e87dadc
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+httpx==0.27.1
+pyaudio==0.2.14
+websockets>=10
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..7cd8bfd
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+"""
+Setuptools configuration for the Speechmatics Flow client
+"""
+
+import logging
+import os
+
+from setuptools import setup, find_packages
+
+
+def read(fname):
+ """
+ Load content of the file with path relative to where setup.py is.
+
+ Args:
+ fname (str): file name (or path relative to the project root)
+
+ Returns:
+ str: file content
+ """
+ fpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), fname)
+ with open(fpath, encoding="utf-8") as path:
+ return path.read()
+
+
+def read_list(fname):
+ """
+ Load the content of the file and split it into a list of lines.
+
+ Args:
+ fname (str): file name (or path relative to the project root)
+
+ Returns:
+ List[str]: file content (one string per line) with end-of-lines
+ characters stripped off and empty lines filtered out
+ """
+ content = read(fname)
+ retval = list(filter(None, content.split("\n")))
+
+ return retval
+
+
+def get_version(fname):
+ """
+ Retrieve the version from the VERSION file.
+
+ Args:
+ fname (str): file containing only the version
+
+ Returns:
+ str: version with whitespace characters stripped off
+ """
+ return read(fname).strip()
+
+
+logging.basicConfig(level=logging.INFO)
+setup(
+ name="speechmatics-flow",
+ version=os.getenv("VERSION", get_version("VERSION")),
+ packages=find_packages(exclude=["tests"]),
+ url="https://github.com/speechmatics/speechmatics-flow/",
+ license="MIT",
+ author="Speechmatics",
+ author_email="support@speechmatics.com",
+ description="Python library and CLI for Speechmatics Flow API",
+ long_description=read("README.md"),
+ long_description_content_type="text/markdown",
+ install_requires=read_list("requirements.txt"),
+ tests_require=read_list("requirements-dev.txt"),
+ entry_points={
+ "console_scripts": ["speechmatics-flow = speechmatics_flow.cli:main"]
+ },
+ project_urls={
+ "Documentation": "https://speechmatics.github.io/speechmatics-flow/",
+ "Source Code": "https://github.com/speechmatics/speechmatics-flow/",
+ },
+ classifiers=[
+ "Environment :: Console",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.10",
+ "Topic :: Multimedia :: Sound/Audio :: Speech",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ ],
+ include_package_data=True,
+ python_requires=">=3.7",
+)
diff --git a/speechmatics_flow/__init__.py b/speechmatics_flow/__init__.py
new file mode 100644
index 0000000..49b82d3
--- /dev/null
+++ b/speechmatics_flow/__init__.py
@@ -0,0 +1,3 @@
+from speechmatics_flow.client import * # noqa
+from speechmatics_flow.exceptions import * # noqa
+from speechmatics_flow.models import * # noqa
diff --git a/speechmatics_flow/cli.py b/speechmatics_flow/cli.py
new file mode 100755
index 0000000..7548a20
--- /dev/null
+++ b/speechmatics_flow/cli.py
@@ -0,0 +1,277 @@
+#!/usr/bin/env python3
+# (c) 2024, Cantab Research Ltd.
+"""
+Command-line interface
+"""
+
+import json
+import logging
+import ssl
+import sys
+from dataclasses import dataclass
+from socket import gaierror
+from typing import Any, Dict
+
+import httpx
+from websockets.exceptions import WebSocketException
+
+from speechmatics_flow.cli_parser import parse_args
+from speechmatics_flow.client import WebsocketClient
+from speechmatics_flow.exceptions import TranscriptionError
+from speechmatics_flow.models import (
+ AudioSettings,
+ ConversationConfig,
+ ServerMessageType,
+ Interaction,
+ ConnectionSettings,
+)
+
+LOGGER = logging.getLogger(__name__)
+
+
+@dataclass
+class Transcripts:
+ user_transcript: str = ""
+ agent_transcript: str = ""
+
+
+def get_log_level(verbosity):
+ """
+ Returns the appropriate log level given a verbosity level.
+
+ :param verbosity: Verbosity level.
+ :type verbosity: int
+
+ :return: The logging level (eg. logging.INFO).
+ :rtype: int
+
+ :raises SystemExit: If the given verbosity level is invalid.
+ """
+ try:
+ log_level = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}[verbosity]
+
+ return log_level
+ except KeyError as error:
+ key = int(str(error))
+ raise SystemExit(
+ f"Only supports 2 log levels eg. -vv, you are asking for " f"-{'v' * key}"
+ ) from error
+
+
+def get_connection_settings(args):
+ """
+ Helper function which returns a ConnectionSettings object based on the
+ command line options given to the program.
+
+ :param args: Keyword arguments, typically from the command line.
+ :type args: dict
+
+ :return: Settings for the WebSocket connection.
+ :rtype: speechmatics_flow.models.ConnectionSettings
+ """
+ auth_token = args.get("auth_token")
+ url = args.get("url")
+ generate_temp_token = args.get("generate_temp_token")
+ settings = ConnectionSettings(
+ url=url,
+ auth_token=auth_token,
+ generate_temp_token=generate_temp_token,
+ )
+
+ if args.get("buffer_size") is not None:
+ settings.message_buffer_size = args["buffer_size"]
+
+ if args.get("ssl_mode") == "insecure":
+ settings.ssl_context.check_hostname = False
+ settings.ssl_context.verify_mode = ssl.CERT_NONE
+ elif args.get("ssl_mode") == "none":
+ settings.ssl_context = None
+
+ return settings
+
+
+def get_conversation_config(
+ args,
+):
+ """
+ Helper function which returns a ConversationConfig object based on the
+ command line options given to the program.
+
+ :param args: Keyword arguments probably from the command line.
+ :type args: Dict
+
+ :return: Settings for the ASR engine.
+ :rtype: flow.models.ConversationConfig
+ """
+
+ config: Dict[str, Any] = {}
+ if args.get("config_file"):
+ with open(args["config_file"], encoding="utf-8") as config_file:
+ config = json.load(config_file)
+
+ if config.get("conversation_config"):
+ config.update(config.pop("conversation_config"))
+
+ return ConversationConfig(**config)
+
+
+def get_audio_settings(args):
+ """
+ Helper function which returns an AudioSettings object based on the command
+ line options given to the program.
+
+ Args:
+ args (dict): Keyword arguments, typically from the command line.
+
+ Returns:
+ flow.models.AudioSettings: Settings for the audio stream
+ in the connection.
+ """
+ settings = AudioSettings(
+ sample_rate=args.get("sample_rate"),
+ chunk_size=args.get("chunk_size"),
+ encoding=args.get("raw"),
+ )
+ return settings
+
+
+# pylint: disable=too-many-arguments,too-many-statements
+def add_printing_handlers(
+ api,
+ transcripts,
+ print_json=False,
+):
+ """
+ Adds a set of handlers to the websocket client which print out transcripts
+ as they are received. This includes partials if they are enabled.
+
+ Args:
+ api (client.WebsocketClient): Client instance.
+ transcripts (Transcripts): Allows the transcripts to be concatenated to
+ produce a final result.
+ print_json (bool, optional): Whether to print json transcript messages.
+ """
+ escape_seq = "\33[2K" if sys.stdout.isatty() else ""
+
+ def convert_to_txt(message):
+ if print_json:
+ print(json.dumps(message))
+ return
+ transcript = message["metadata"]["transcript"]
+ return transcript.replace(" ", "").replace("", "").replace("", "")
+
+ def transcript_handler(message):
+ plaintext = convert_to_txt(message)
+ if plaintext:
+ sys.stdout.write(f"{escape_seq}{plaintext}\n")
+ transcripts.user_transcript += plaintext
+
+ def partial_transcript_handler(message):
+ plaintext = convert_to_txt(message)
+ if plaintext:
+ sys.stderr.write(f"{escape_seq}{plaintext}\r")
+
+ def prompt_handler(message):
+ if print_json:
+ print(json.dumps(message))
+ return
+ new_response = message["prompt"]["response"]
+ new_plaintext_response = new_response.replace(" ", "").replace(" ", "")
+ if new_plaintext_response:
+ sys.stdout.write(f"{escape_seq}{new_plaintext_response}\n")
+ transcripts.user_transcript += new_plaintext_response
+
+ def end_of_transcript_handler(_):
+ print("\n", file=sys.stderr)
+
+ api.add_event_handler(ServerMessageType.prompt, prompt_handler)
+ api.add_event_handler(ServerMessageType.AddTranscript, transcript_handler)
+ api.add_event_handler(
+ ServerMessageType.AddPartialTranscript, partial_transcript_handler
+ )
+ api.add_event_handler(ServerMessageType.EndOfTranscript, end_of_transcript_handler)
+
+
+# pylint: disable=too-many-branches
+# pylint: disable=too-many-statements
+def main(args=None):
+ """
+ Main entrypoint.
+
+ :param args: command-line arguments; defaults to None in which
+ case arguments will be retrieved from `sys.argv` (this is useful
+ mainly for unit tests).
+ :type args: List[str]
+ """
+ if not args:
+ args = vars(parse_args())
+
+ logging.basicConfig(level=get_log_level(args["verbose"]))
+ LOGGER.info("Args: %s", args)
+
+ try:
+ flow_main(args)
+ except (KeyboardInterrupt, ValueError, TranscriptionError, KeyError) as error:
+ LOGGER.info(error, exc_info=True)
+ sys.exit(f"{type(error).__name__}: {error}")
+ except FileNotFoundError as error:
+ LOGGER.info(error, exc_info=True)
+ sys.exit(
+ f"FileNotFoundError: {error.strerror}: '{error.filename}'."
+ + " Check to make sure the filename is spelled correctly, and that the file exists."
+ )
+ except httpx.HTTPStatusError as error:
+ LOGGER.info(error, exc_info=True)
+ sys.exit(error.response.text)
+ except httpx.HTTPError as error:
+ LOGGER.info(error, exc_info=True)
+ sys.exit(f"httpx.HTTPError: An unexpected http error occurred. {error}")
+ except ConnectionResetError as error:
+ LOGGER.info(error, exc_info=True)
+ sys.exit(
+ f"ConnectionResetError: {error}.\n\nThe most likely reason for this is that the client "
+ + "has been configured to use SSL but the server does not support SSL. "
+ + "If this is the case then try using --ssl-mode=none"
+ )
+ except (WebSocketException, gaierror) as error:
+ LOGGER.info(error, exc_info=True)
+ sys.exit(
+ f"WebSocketError: An unexpected error occurred in the websocket: {error}.\n\n"
+ + "Check that the url and config provided is valid, "
+ + "and that the language in the url matches the config.\n"
+ )
+
+
+def flow_main(args):
+ """Main dispatch for "flow" mode commands.
+
+ :param args: arguments from parse_args()
+ :type args: argparse.Namespace
+ """
+ conversation_config = get_conversation_config(args)
+ settings = get_connection_settings(args)
+ api = WebsocketClient(settings)
+ transcripts = Transcripts()
+ add_printing_handlers(
+ api,
+ transcripts,
+ print_json=args["print_json"],
+ )
+
+ def run(stream):
+ try:
+ api.run_synchronously(
+ [Interaction(stream)],
+ get_audio_settings(args),
+ conversation_config,
+ from_cli=True,
+ )
+ except KeyboardInterrupt:
+ # Gracefully handle Ctrl-C, else we get a huge stack-trace.
+ LOGGER.warning("Keyboard interrupt received.")
+
+ run(sys.stdin.buffer)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/speechmatics_flow/cli_parser.py b/speechmatics_flow/cli_parser.py
new file mode 100644
index 0000000..ecd9e68
--- /dev/null
+++ b/speechmatics_flow/cli_parser.py
@@ -0,0 +1,129 @@
+# (c) 2024, Cantab Research Ltd.
+"""
+Parsers used by the CLI to handle CLI arguments
+"""
+import argparse
+import logging
+
+LOGGER = logging.getLogger(__name__)
+
+
+# pylint: disable=too-many-locals
+# pylint: disable=too-many-statements
+def get_arg_parser():
+ """
+ Creates a command-line argument parser objct
+
+ :return: The argparser object with all commands and subcommands.
+ :rtype: argparse.ArgumentParser
+ """
+ parser = argparse.ArgumentParser(description="CLI for Speechmatics Flow API.")
+ parser.add_argument(
+ "-v",
+ dest="verbose",
+ action="count",
+ default=0,
+ help=(
+ "Set the log level for verbose logs. "
+ "The number of flags indicate the level, eg. "
+ "-v is INFO and -vv is DEBUG."
+ ),
+ )
+ parser.add_argument(
+ "--debug",
+ default=False,
+ action="store_true",
+ help=(
+ "Prints useful symbols to represent the messages on the wire. "
+ "Symbols are printed to STDERR, use only when STDOUT is "
+ "redirected to a file."
+ ),
+ )
+ parser.add_argument(
+ "--url",
+ type=str,
+ help="Websocket for Flow API URL (e.g. wss://flow.api.speechmatics.com/v1/flow",
+ )
+ parser.add_argument(
+ "--auth-token",
+ type=str,
+ help="Authentication token to authorize the client.",
+ )
+ parser.add_argument(
+ "--generate-temp-token",
+ default=True,
+ action="store_true",
+ help="Automatically generate a temporary token for authentication.",
+ )
+ parser.add_argument(
+ "--ssl-mode",
+ default="regular",
+ choices=["regular", "insecure", "none"],
+ help=(
+ "Use a preset configuration for the SSL context. With `regular` "
+ "mode a valid certificate is expected. With `insecure` mode"
+ " a self signed certificate is allowed."
+ " With `none` then SSL is not used."
+ ),
+ )
+ parser.add_argument(
+ "--raw",
+ metavar="ENCODING",
+ type=str,
+ default="pcm_s16le",
+ help=(
+ "Indicate that the input audio is raw, provide the encoding"
+ "of this raw audio, eg. pcm_f32le."
+ ),
+ )
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16_000,
+ help="The sample rate in Hz of the input audio, if in raw format.",
+ )
+ parser.add_argument(
+ "--chunk-size",
+ type=int,
+ default=256,
+ help=(
+ "How much audio data, in bytes, to send to the server in each "
+ "websocket message. Larger values can increase latency, but "
+ "values which are too small create unnecessary overhead."
+ ),
+ )
+ parser.add_argument(
+ "--buffer-size",
+ default=512,
+ type=int,
+ help=(
+ "Maximum number of messages to send before waiting for"
+ "acknowledgements from the server."
+ ),
+ )
+ parser.add_argument(
+ "--print-json",
+ default=False,
+ action="store_true",
+ help=(
+ "Print the JSON partial & final transcripts received rather than "
+ "plaintext messages."
+ ),
+ )
+
+ return parser
+
+
+def parse_args(args=None):
+ """
+ Parses command-line arguments.
+
+ :param args: List of arguments to parse.
+ :type args: (List[str], optional)
+
+ :return: The set of arguments provided along with their values.
+ :rtype: Namespace
+ """
+ parser = get_arg_parser()
+ parsed_args = parser.parse_args(args=args)
+ return parsed_args
diff --git a/speechmatics_flow/client.py b/speechmatics_flow/client.py
new file mode 100644
index 0000000..8173523
--- /dev/null
+++ b/speechmatics_flow/client.py
@@ -0,0 +1,506 @@
+# (c) 2024, Cantab Research Ltd.
+"""
+Wrapper library to interface with Flow Service API.
+"""
+
+import asyncio
+import copy
+import json
+import logging
+import os
+from typing import List
+
+import httpx
+import pyaudio
+import websockets
+
+from speechmatics_flow.exceptions import (
+ ConversationEndedException,
+ EndOfTranscriptException,
+ ForceEndSession,
+ TranscriptionError,
+)
+from speechmatics_flow.models import (
+ ClientMessageType,
+ ServerMessageType,
+ AudioSettings,
+ ConversationConfig,
+ Interaction,
+ ConnectionSettings,
+)
+from speechmatics_flow.utils import read_in_chunks, json_utf8
+
+LOGGER = logging.getLogger(__name__)
+
+# If the logging level is set to DEBUG websockets logs very verbosely,
+# including a hex dump of every message being sent. Setting the websockets
+# logger at INFO level specifically prevents this spam.
+logging.getLogger("websockets.protocol").setLevel(logging.INFO)
+
+
+class WebsocketClient:
+ """
+ Manage a conversation session with the agent.
+
+ The best way to interact with this library is to instantiate this client
+ and then add a set of handlers to it. Handlers respond to particular types
+ of messages received from the server.
+
+ :param connection_settings: Settings for the WebSocket connection,
+ including the URL of the server.
+ :type connection_settings: models.ConnectionSettings
+ """
+
+ # pylint: disable=too-many-instance-attributes
+
+ def __init__(
+ self,
+ connection_settings: ConnectionSettings = None,
+ ):
+ self.connection_settings = connection_settings
+ self.websocket = None
+ self.conversation_config = None
+ self.audio_settings = None
+
+ self.event_handlers = {x: [] for x in ServerMessageType}
+ self.middlewares = {x: [] for x in ClientMessageType}
+
+ self.seq_no = 0
+ self.session_running = False
+ self._language_pack_info = None
+ self._transcription_config_needs_update = False
+ self._session_needs_closing = False
+ self._audio_buffer = None
+
+ # The following asyncio fields are fully instantiated in
+ # _init_synchronization_primitives
+ self._conversation_started = asyncio.Event
+ # Semaphore used to ensure that we don't send too much audio data to
+ # the server too quickly and burst any buffers downstream.
+ self._buffer_semaphore = asyncio.BoundedSemaphore
+
+ async def _init_synchronization_primitives(self):
+ """
+ Used to initialise synchronization primitives that require
+ an event loop
+ """
+ self._conversation_started = asyncio.Event()
+ self._buffer_semaphore = asyncio.BoundedSemaphore(
+ self.connection_settings.message_buffer_size
+ )
+
+ def _flag_conversation_started(self):
+ """
+ Handle a
+ :py:attr:`models.ClientMessageType.ConversationStarted`
+ message from the server.
+ This updates an internal flag to mark the session started
+ as started meaning, AddAudio is now allowed.
+ """
+ self._conversation_started.set()
+
+ @json_utf8
+ def _start_conversation(self):
+ """
+ Constructs a
+ :py:attr:`models.ClientMessageType.StartConversation`
+ message.
+ This initiates the conversation session.
+ """
+ assert self.conversation_config is not None
+ msg = {
+ "message": ClientMessageType.StartConversation,
+ "audio_format": self.audio_settings.asdict(),
+ "conversation_config": self.conversation_config.asdict(),
+ }
+ self.session_running = True
+ self._call_middleware(ClientMessageType.StartConversation, msg, False)
+ LOGGER.debug(msg)
+ return msg
+
+ @json_utf8
+ def _end_of_audio(self):
+ """
+ Constructs an
+ :py:attr:`models.ClientMessageType.AudioEnded`
+ message.
+ """
+ msg = {"message": ClientMessageType.AudioEnded, "last_seq_no": self.seq_no}
+ self._call_middleware(ClientMessageType.AudioEnded, msg, False)
+ LOGGER.debug(msg)
+ return msg
+
+ async def _consumer(self, message, from_cli: False):
+ """
+ Consumes messages and acts on them.
+
+ :param message: Message received from the server.
+ :type message: str
+
+ :raises TranscriptionError: on an error message received from the
+ server after the Session started.
+ :raises EndOfTranscriptException: on EndOfTranscription message.
+ :raises ForceEndSession: If this was raised by the user's event
+ handler.
+ """
+ LOGGER.debug(message)
+ if isinstance(message, (bytes, bytearray)):
+ # add an audio message to local buffer only when running from cli
+ if from_cli:
+ await self._audio_buffer.put(message)
+ # Flow service does not send message_type with binary data,
+ # so we need to set it here for event_handler to work
+ message_type = ServerMessageType.audio
+ else:
+ message = json.loads(message)
+ message_type = message.get("message")
+
+ if message_type is None:
+ return
+
+ for handler in self.event_handlers[message_type]:
+ try:
+ handler(copy.deepcopy(message))
+ except ForceEndSession:
+ LOGGER.warning("Session was ended forcefully by an event handler")
+ raise
+
+ if message_type == ServerMessageType.ConversationStarted:
+ self._flag_conversation_started()
+ elif message_type == ServerMessageType.AudioAdded:
+ self._buffer_semaphore.release()
+ elif message_type == ServerMessageType.ConversationEnded:
+ raise ConversationEndedException()
+ elif message_type == ServerMessageType.EndOfTranscript:
+ raise EndOfTranscriptException()
+ elif message_type == ServerMessageType.Warning:
+ LOGGER.warning(message["reason"])
+ elif message_type == ServerMessageType.Error:
+ raise TranscriptionError(message["reason"])
+
+ async def _read_from_microphone(self):
+ p = pyaudio.PyAudio()
+ print(f"Default input device: {p.get_default_input_device_info()['name']}")
+ print(f"Default output device: {p.get_default_output_device_info()['name']}")
+ print("Start speaking...")
+ stream = p.open(
+ format=pyaudio.paInt16,
+ channels=1,
+ rate=self.audio_settings.sample_rate,
+ input=True,
+ )
+ try:
+ while True:
+ if self._session_needs_closing:
+ break
+
+ await asyncio.wait_for(
+ self._buffer_semaphore.acquire(),
+ timeout=self.connection_settings.semaphore_timeout_seconds,
+ )
+
+ # audio_chunk size is 128 * 2 = 256 bytes which is about 8ms
+ audio_chunk = stream.read(num_frames=128, exception_on_overflow=False)
+
+ self.seq_no += 1
+ self._call_middleware(ClientMessageType.AddAudio, audio_chunk, True)
+ await self.websocket.send(audio_chunk)
+ finally:
+ await self.websocket.send(self._end_of_audio())
+ stream.stop_stream()
+ stream.close()
+ p.terminate()
+
+ async def _consumer_handler(self, from_cli: False):
+ """
+ Controls the consumer loop for handling messages from the server.
+
+ raises: ConnectionClosedError when the upstream closes unexpectedly
+ """
+ while self.session_running:
+ try:
+ message = await self.websocket.recv()
+ except websockets.exceptions.ConnectionClosedOK:
+ # Can occur if a timeout has closed the connection.
+ LOGGER.info("Cannot receive from closed websocket.")
+ return
+ except websockets.exceptions.ConnectionClosedError as ex:
+ LOGGER.info("Disconnected while waiting for recv().")
+ raise ex
+ await self._consumer(message, from_cli)
+
+ async def _stream_producer(self, stream, audio_chunk_size):
+ async for audio_chunk in read_in_chunks(stream, audio_chunk_size):
+ if self._session_needs_closing:
+ break
+
+ await asyncio.wait_for(
+ self._buffer_semaphore.acquire(),
+ timeout=self.connection_settings.semaphore_timeout_seconds,
+ )
+
+ self.seq_no += 1
+ self._call_middleware(ClientMessageType.AddAudio, audio_chunk, True)
+ yield audio_chunk
+
+ async def _producer_handler(self, interactions: List[Interaction]):
+ """
+ Controls the producer loop for sending messages to the server.
+ """
+ await self._conversation_started.wait()
+
+ if interactions[0].stream.name == "":
+ return await self._read_from_microphone()
+
+ for interaction in interactions:
+ async for message in self._stream_producer(
+ interaction.stream, self.audio_settings.chunk_size
+ ):
+ try:
+ await self.websocket.send(message)
+ except Exception as e:
+ LOGGER.error(f"error sending message: {e}")
+ return
+ if interaction.callback:
+ interaction.callback(self)
+
+ await self.websocket.send(self._end_of_audio())
+
+ async def _playback_handler(self):
+ """
+ Reads audio binary messages from the playback buffer and plays them to the user.
+ """
+ p = pyaudio.PyAudio()
+ stream = p.open(
+ format=pyaudio.paInt16,
+ channels=1,
+ rate=self.audio_settings.sample_rate,
+ output=True,
+ )
+ try:
+ while True:
+ if self._session_needs_closing:
+ break
+ try:
+ audio_message = await self._audio_buffer.get()
+ stream.write(audio_message)
+ self._audio_buffer.task_done()
+ except Exception as e:
+ LOGGER.error(f"Error during audio playback: {e}")
+ raise e
+ finally:
+ stream.close()
+ stream.stop_stream()
+ p.terminate()
+ LOGGER.debug("Exiting playback handler")
+
+ def _call_middleware(self, event_name, *args):
+ """
+ Call the middlewares attached to the client for the given event name.
+
+ :raises ForceEndSession: If this was raised by the user's middleware.
+ """
+ for middleware in self.middlewares[event_name]:
+ try:
+ middleware(*args)
+ except ForceEndSession:
+ LOGGER.warning("Session was ended forcefully by a middleware")
+ raise
+
+ def add_event_handler(self, event_name, event_handler):
+ """
+ Add an event handler (callback function) to handle an incoming
+ message from the server. Event handlers are passed a copy of the
+ incoming message from the server. If `event_name` is set to 'all' then
+ the handler will be added for every event.
+
+ For example, a simple handler that just LOGGER.debugs out the
+ :py:attr:`models.ServerMessageType.audio`
+ messages received:
+
+ >>> client = WebsocketClient(
+ ConnectionSettings(url="wss://localhost:9000"))
+ >>> handler = lambda msg: LOGGER.debug(msg)
+ >>> client.add_event_handler(ServerMessageType.audio, handler)
+
+ :param event_name: The name of the message for which a handler is
+ being added. Refer to
+ :py:class:`models.ServerMessageType` for a list
+ of the possible message types.
+ :type event_name: str
+
+ :param event_handler: A function to be called when a message of the
+ given type is received.
+ :type event_handler: Callable[[dict], None]
+
+ :raises ValueError: If the given event name is not valid.
+ """
+ if event_name == "all":
+ for name in self.event_handlers.keys():
+ self.event_handlers[name].append(event_handler)
+ elif event_name not in self.event_handlers:
+ raise ValueError(
+ f"Unknown event name: {event_name!r}, expected to be "
+ f"'all' or one of {list(self.event_handlers.keys())}."
+ )
+ else:
+ self.event_handlers[event_name].append(event_handler)
+
+ def add_middleware(self, event_name, middleware):
+ """
+ Add middleware to handle outgoing messages sent to the server.
+ Middlewares are passed a reference to the outgoing message, which
+ they may alter.
+ If `event_name` is set to 'all' then the handler will be added for
+ every event.
+
+ :param event_name: The name of the message for which middleware is
+ being added. Refer to the V2 API docs for a list of the possible
+ message types.
+ :type event_name: str
+
+ :param middleware: A function to be called to process an outgoing
+ message of the given type. The function receives the message as
+ the first argument and a second, boolean argument indicating
+ whether the message is binary data (which implies it is an
+ AddAudio message).
+ :type middleware: Callable[[dict, bool], None]
+
+ :raises ValueError: If the given event name is not valid.
+ """
+ if event_name == "all":
+ for name in self.middlewares.keys():
+ self.middlewares[name].append(middleware)
+ elif event_name not in self.middlewares:
+ raise ValueError(
+ (
+ f"Unknown event name: {event_name}, expected to be 'all'"
+ f"or one of {list(self.middlewares.keys())}."
+ )
+ )
+ else:
+ self.middlewares[event_name].append(middleware)
+
+ async def _communicate(self, interactions: List[Interaction], from_cli=False):
+ """
+ Create a producer/consumer for transcription messages and
+ communicate with the server.
+ Internal method called from _run.
+ """
+ try:
+ start_conversation_msg = self._start_conversation()
+ except ForceEndSession:
+ return
+ await self.websocket.send(start_conversation_msg)
+
+ tasks = [
+ asyncio.create_task(self._consumer_handler(from_cli)),
+ asyncio.create_task(self._producer_handler(interactions)),
+ ]
+
+ # Run the playback task that plays audio messages to the user when started from cli
+ if from_cli:
+ self._audio_buffer = asyncio.Queue()
+ tasks.append(asyncio.create_task(self._playback_handler()))
+
+ (done, pending) = await asyncio.wait(
+ tasks,
+ return_when=asyncio.FIRST_EXCEPTION,
+ )
+
+ # If a task is pending, the other one threw an exception, so tidy up
+ for task in pending:
+ task.cancel()
+
+ for task in done:
+ exc = task.exception()
+ if exc and not isinstance(
+ exc,
+ (
+ EndOfTranscriptException,
+ ForceEndSession,
+ ConversationEndedException,
+ ),
+ ):
+ raise exc
+
+ async def run(
+ self,
+ interactions: List[Interaction],
+ audio_settings: AudioSettings = AudioSettings(),
+ conversation_config: ConversationConfig = None,
+ from_cli: bool = False,
+ ):
+ """
+ Begin a new recognition session.
+ This will run asynchronously. Most callers may prefer to use
+ :py:meth:`run_synchronously` which will block until the session is
+ finished.
+
+ :param interactions: A list of interactions with FlowService API.
+ :type interactions: List[Interaction]
+
+ :param audio_settings: Configuration for the audio stream.
+ :type audio_settings: models.AudioSettings
+
+ :param conversation_config: Configuration for the conversation.
+ :type conversation_config: models.ConversationConfig
+
+ :raises Exception: Can raise any exception returned by the
+ consumer/producer tasks.
+ """
+ self.seq_no = 0
+ self._language_pack_info = None
+ self.conversation_config = conversation_config
+ self.audio_settings = audio_settings
+
+ await self._init_synchronization_primitives()
+
+ extra_headers = {}
+ auth_token = await get_temp_token(self.connection_settings.auth_token)
+ extra_headers["Authorization"] = f"Bearer {auth_token}"
+ try:
+ async with websockets.connect( # pylint: disable=no-member
+ self.connection_settings.url,
+ ssl=self.connection_settings.ssl_context,
+ ping_timeout=self.connection_settings.ping_timeout_seconds,
+ # Don't limit the max. size of incoming messages
+ max_size=None,
+ extra_headers=extra_headers,
+ ) as self.websocket:
+ await self._communicate(interactions, from_cli)
+ finally:
+ self.session_running = False
+ self._session_needs_closing = False
+ self.websocket = None
+
+ def stop(self):
+ """
+ Indicates that the recognition session should be forcefully stopped.
+ Only used in conjunction with `run`.
+ You probably don't need to call this if you're running the client via
+ :py:meth:`run_synchronously`.
+ """
+ self._session_needs_closing = True
+
+ def run_synchronously(self, *args, timeout=None, **kwargs):
+ """
+ Run the transcription synchronously.
+ :raises asyncio.TimeoutError: If the given timeout is exceeded.
+ """
+ # pylint: disable=no-value-for-parameter
+ asyncio.run(asyncio.wait_for(self.run(*args, **kwargs), timeout=timeout))
+
+
+async def get_temp_token(api_key):
+ """
+ Used to get a temporary token from management platform api
+ """
+ mp_api_url = os.getenv("SM_MANAGEMENT_PLATFORM_URL", "https://mp.speechmatics.com")
+ endpoint = f"{mp_api_url}/v1/api_keys?type=flow"
+ body = {"ttl": 300, "client_ref": "speechmatics-flow-python-client"}
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
+ response = httpx.post(endpoint, json=body, headers=headers)
+ response.raise_for_status()
+ data = response.json()
+ return data["key_value"]
diff --git a/speechmatics_flow/exceptions.py b/speechmatics_flow/exceptions.py
new file mode 100644
index 0000000..d4d9e37
--- /dev/null
+++ b/speechmatics_flow/exceptions.py
@@ -0,0 +1,29 @@
+# (c) 2024, Cantab Research Ltd.
+"""
+Exceptions and errors used by the library.
+"""
+
+
+class TranscriptionError(Exception):
+ """
+ Indicates an error in transcription.
+ """
+
+
+class EndOfTranscriptException(Exception):
+ """
+ Indicates that the transcription session has finished.
+ """
+
+
+class ForceEndSession(Exception):
+ """
+ Can be raised by the user from a middleware or event handler
+ to force the transcription session to end early.
+ """
+
+
+class ConversationEndedException(Exception):
+ """
+ Indicates the session ended.
+ """
diff --git a/speechmatics_flow/models.py b/speechmatics_flow/models.py
new file mode 100644
index 0000000..ac2a629
--- /dev/null
+++ b/speechmatics_flow/models.py
@@ -0,0 +1,143 @@
+# (c) 2024, Cantab Research Ltd.
+"""
+Data models and message types used by the library.
+"""
+
+import io
+import ssl
+from dataclasses import asdict, dataclass, field
+from enum import Enum
+from typing import Callable, Dict, Optional, Literal
+
+
+@dataclass
+class AudioSettings:
+ """Defines audio parameters."""
+
+ encoding: str = "pcm_s16le"
+ """Encoding format when raw audio is used. Allowed values are
+ `pcm_f32le` and `pcm_s16le`"""
+
+ sample_rate: int = 16000
+ """Sampling rate in hertz."""
+
+ chunk_size: int = 256
+ """Chunk size in bytes."""
+
+ def asdict(self):
+ return {
+ "type": "raw",
+ "encoding": self.encoding,
+ "sample_rate": self.sample_rate,
+ }
+
+
+@dataclass
+class ConnectionSettings:
+ """Defines connection parameters."""
+
+ url: str
+ """Websocket server endpoint."""
+
+ message_buffer_size: int = 512
+ """Message buffer size in bytes."""
+
+ ssl_context: ssl.SSLContext = field(default_factory=ssl.create_default_context)
+ """SSL context."""
+
+ semaphore_timeout_seconds: float = 120
+ """Semaphore timeout in seconds."""
+
+ ping_timeout_seconds: float = 60
+ """Ping-pong timeout in seconds."""
+
+ auth_token: Optional[str] = None
+ """auth token to authenticate a customer."""
+
+ generate_temp_token: Optional[bool] = True
+ """Automatically generate a temporary token for authentication."""
+
+
+@dataclass
+class ConversationConfig:
+ """Defines configuration parameters for conversation requests."""
+
+ template_id: Literal[
+ "default", "flow-service-assistant-amelia", "flow-service-assistant-humphrey"
+ ] = "default"
+ """Name of a predefined template."""
+
+ template_variables: Optional[Dict[str, str]] = None
+ """Optional parameter to allow overriding the default values of variables defined in the template."""
+
+ def asdict(self):
+ return asdict(
+ self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None}
+ )
+
+
+class ClientMessageType(str, Enum):
+ # pylint: disable=invalid-name
+ """Defines various messages sent from client to server."""
+
+ StartConversation = "StartConversation"
+ """Initiates a conversation job based on configuration set previously."""
+
+ AddAudio = "AddAudio"
+ """Adds more audio data to the recognition job. The server confirms
+ receipt by sending an :py:attr:`ServerMessageType.AudioAdded` message."""
+
+ AudioEnded = "AudioEnded"
+ """Indicates audio input has finished."""
+
+
+class ServerMessageType(str, Enum):
+ # pylint: disable=invalid-name
+ """Defines various message types sent from server to client."""
+
+ ConversationStarted = "ConversationStarted"
+ """Server response to :py:attr:`ClientMessageType.StartConversation`,
+ acknowledging that a conversation session has started."""
+
+ AddPartialTranscript = "AddPartialTranscript"
+ """Indicates a partial transcript, which is an incomplete transcript that
+ is immediately produced and may change as more context becomes available.
+ """
+
+ AudioAdded = "AudioAdded"
+ """Server response to :py:attr:`ClientMessageType.AddAudio`, indicating
+ that audio has been added successfully."""
+
+ AddTranscript = "AddTranscript"
+ """Indicates the final transcript of a part of the audio."""
+
+ audio = "audio"
+ """Message contains binary data"""
+
+ prompt = "prompt"
+ """Message contains text data"""
+
+ ConversationEnded = "ConversationEnded"
+ """Message indicates the session ended."""
+
+ EndOfTranscript = "EndOfTranscript"
+ """Server response to :py:attr:`ClientMessageType.EndOfStream`,
+ after the server has finished sending all :py:attr:`AddTranscript`
+ messages."""
+
+ Info = "Info"
+ """Indicates a generic info message."""
+
+ Warning = "Warning"
+ """Indicates a generic warning message."""
+
+ Error = "Error"
+ """Indicates n generic error message."""
+
+
+@dataclass
+class Interaction:
+ """Defines various interactions between client and server."""
+
+ stream: io.BufferedReader
+ callback: Optional[Callable] = None
diff --git a/speechmatics_flow/utils.py b/speechmatics_flow/utils.py
new file mode 100644
index 0000000..cf78315
--- /dev/null
+++ b/speechmatics_flow/utils.py
@@ -0,0 +1,52 @@
+# (c) 2024, Cantab Research Ltd.
+"""
+Helper functions used by the library.
+"""
+
+import asyncio
+import concurrent.futures
+import inspect
+import json
+
+
+def json_utf8(func):
+ """A decorator to turn a function's return value into JSON"""
+
+ def wrapper(*args, **kwargs):
+ """wrapper"""
+ return json.dumps(func(*args, **kwargs))
+
+ return wrapper
+
+
+async def read_in_chunks(stream, chunk_size):
+ """
+ Utility method for reading in and yielding chunks
+
+ :param stream: file-like object to read audio from
+ :type stream: io.IOBase
+
+ :param chunk_size: maximum chunk size in bytes
+ :type chunk_size: int
+
+ :raises ValueError: if no data was read from the stream
+
+ :return: a sequence of chunks of data where the length in bytes of each
+ chunk is <= max_sample_size and a multiple of max_sample_size
+ :rtype: collections.AsyncIterable
+
+ """
+ while True:
+ # Work with both async and synchronous file readers.
+ if inspect.iscoroutinefunction(stream.read):
+ audio_chunk = await stream.read(chunk_size)
+ else:
+ # Run the read() operation in a separate thread to avoid blocking the event loop.
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ audio_chunk = await asyncio.get_event_loop().run_in_executor(
+ executor, stream.read, chunk_size
+ )
+
+ if not audio_chunk:
+ break
+ yield audio_chunk
diff --git a/tests/data/example.wav b/tests/data/example.wav
new file mode 100644
index 0000000..0956fab
--- /dev/null
+++ b/tests/data/example.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd31c48cac8f24d3c18469438f64dc555c911eb2a0d915d3c0b477744d83e1aa
+size 1533868
diff --git a/tests/test_models.py b/tests/test_models.py
new file mode 100644
index 0000000..e11060d
--- /dev/null
+++ b/tests/test_models.py
@@ -0,0 +1,41 @@
+import pytest
+
+from speechmatics_flow import models
+
+TEMPLATE_VARS = {
+ "persona": "You are an aging English butler named Humphrey.",
+ "style": "Be charming but unpredictable.",
+ "context": "You are taking a customer's order for fast food.",
+}
+
+
+@pytest.mark.parametrize(
+ "config, want",
+ [
+ ({}, {"type": "raw", "encoding": "pcm_s16le", "sample_rate": 16000}),
+ (
+ {"encoding": "pcm_f32le", "sample_rate": 44100},
+ {"type": "raw", "encoding": "pcm_f32le", "sample_rate": 44100},
+ ),
+ ],
+)
+def test_audio_settings(config, want):
+ audio_settings = models.AudioSettings(**config)
+ got = audio_settings.asdict()
+ assert got == want
+
+
+@pytest.mark.parametrize(
+ "config, want",
+ [
+ ({}, {"template_id": "default"}),
+ (
+ {"template_id": "test", "template_variables": TEMPLATE_VARS},
+ {"template_id": "test", "template_variables": TEMPLATE_VARS},
+ ),
+ ],
+)
+def test_conversation_config(config, want):
+ conversation_config = models.ConversationConfig(**config)
+ got = conversation_config.asdict()
+ assert got == want