From 50ee5f9918fa1dc3f2b6a896eda149ffdb5e8a7d Mon Sep 17 00:00:00 2001 From: Anartz Nuin Date: Wed, 11 Oct 2023 15:02:12 +0100 Subject: [PATCH] first commit --- .github/workflows/lint.yml | 21 ++ .github/workflows/publish.yml | 27 ++ .github/workflows/unittests.yml | 21 ++ .gitignore | 33 +++ Dockerfile | 36 +++ Makefile | 40 +++ README.md | 48 ++++ profiling/README.md | 48 ++++ profiling/client-load.js | 34 +++ pytest.ini | 4 + requirements-dev.txt | 6 + requirements.txt | 6 + setup.cfg | 2 + stream_transcriber/__init__.py | 0 stream_transcriber/server.py | 196 +++++++++++++++ stream_transcriber/streams.py | 432 ++++++++++++++++++++++++++++++++ unittests/__init__.py | 0 unittests/conftest.py | 54 ++++ unittests/test_server.py | 99 ++++++++ unittests/test_streams.py | 195 ++++++++++++++ 20 files changed, 1302 insertions(+) create mode 100644 .github/workflows/lint.yml create mode 100644 .github/workflows/publish.yml create mode 100644 .github/workflows/unittests.yml create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 README.md create mode 100644 profiling/README.md create mode 100644 profiling/client-load.js create mode 100644 pytest.ini create mode 100644 requirements-dev.txt create mode 100644 requirements.txt create mode 100644 setup.cfg create mode 100644 stream_transcriber/__init__.py create mode 100644 stream_transcriber/server.py create mode 100644 stream_transcriber/streams.py create mode 100644 unittests/__init__.py create mode 100644 unittests/conftest.py create mode 100644 unittests/test_server.py create mode 100644 unittests/test_streams.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..134e256 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,21 @@ +name: Lint +on: + push: + branches: "*" +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-dev.txt + - name: Run lint tools + run: | + make lint-local diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..aba5b17 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,27 @@ +name: Build and Publish Docker Image +on: + release: + types: [released] +jobs: + build-and-publish: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Build Docker image + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + target: production + push: false + tags: | + "${{ vars.ACR_ADDRESS }}/${{ vars.IMAGE_NAME }}:${{ github.ref_name }}" + - name: Login to Azure ACR + uses: azure/docker-login@v1 + with: + login-server: ${{ vars.ACR_ADDRESS }} + username: ${{ secrets.ACR_USERNAME }} + password: ${{ secrets.ACR_PASSWORD }} + - name: Push Docker image to Azure ACR + run: docker push ${{ vars.ACR_ADDRESS }}/${{ vars.IMAGE_NAME }}:${{ github.ref_name }} diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml new file mode 100644 index 0000000..54fbd37 --- /dev/null +++ b/.github/workflows/unittests.yml @@ -0,0 +1,21 @@ +name: Unit tests +on: + push: + branches: "*" +jobs: + unittests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-dev.txt + - name: Run unit tests + run: | + make unittest-local diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a5daf72 --- /dev/null +++ b/.gitignore @@ -0,0 +1,33 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +dist/ +build/ +*.egg-info/ + +# Virtual environments +venv/ +env/ +ENV/ +.venv/ +.ENV/ + +# IDEs and editors +.idea/ +.vscode/ +*.sublime-project +*.sublime-workspace + +# Logs and databases +*.log +*.sqlite3 +*.db + +# Other +*.pyc +.DS_Store +.env +.coverage \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..1077678 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,36 @@ +FROM python:3.11-alpine as builder +RUN mkdir /install +WORKDIR /install +COPY requirements.txt /requirements.txt +RUN pip install --prefix="/install" -r /requirements.txt + + +FROM python:3.11-alpine as production +RUN apk upgrade -U && apk add ffmpeg +COPY --from=builder /install /usr/local +COPY ./stream_transcriber /stream_transcriber +WORKDIR / +EXPOSE 8765 +EXPOSE 8000 +ENTRYPOINT ["python"] +CMD ["-m", "stream_transcriber.server"] + +FROM production as base-dev +RUN apk add --no-cache make +COPY ./requirements-dev.txt /requirements-dev.txt +RUN pip install -r /requirements-dev.txt + +FROM base-dev as lint +WORKDIR / +COPY ./Makefile /Makefile +COPY ./setup.cfg /setup.cfg +COPY ./stream_transcriber /stream_transcriber +ENTRYPOINT ["make", "lint-local"] + +FROM base-dev as unittest +WORKDIR / +COPY ./stream_transcriber /stream_transcriber +COPY ./unittests /unittests +COPY ./Makefile /Makefile +COPY ./pytest.ini /pytest.ini +ENTRYPOINT [ "make", "unittest-local" ] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4d05213 --- /dev/null +++ b/Makefile @@ -0,0 +1,40 @@ +.DEFAULT_GOAL := all +IMAGE_NAME ?= stream-demo-server +TAG ?= manual +DOCKER := DOCKER_BUILDKIT=1 docker +SOURCES := stream_transcriber/ +ACR_ADDRESS := speechmatics.azurecr.io +ACR_IMAGE_NAME := ${ACR_ADDRESS}/${IMAGE_NAME} + +.PHONY: all lint build publish format build-linux-amd64 lint-local unittest unittest-local + +all: lint build + +lint: + ${DOCKER} build -t ${IMAGE_NAME}:${TAG}-lint --target lint . + ${DOCKER} run --rm --name ${IMAGE_NAME}-lint ${IMAGE_NAME}:${TAG}-lint +lint-local: + black --check --diff ${SOURCES} + pylint ${SOURCES} + pycodestyle ${SOURCES} + +format: + black ${SOURCES} + +unittest: + ${DOCKER} build -t ${IMAGE_NAME}:${TAG}-unittest --target unittest . + ${DOCKER} run --rm --name ${IMAGE_NAME}-unittest ${IMAGE_NAME}:${TAG}-unittest +unittest-local: + AUTH_TOKEN=token pytest -v unittests + +build: + ${DOCKER} build -t ${IMAGE_NAME}:${TAG} --target production . + +# Build locally an image for linux/amd64 +build-linux-amd64: + ${DOCKER} build --platform linux/amd64 -t ${IMAGE_NAME}:${TAG} --target production . + +publish: + docker tag ${IMAGE_NAME}:${TAG} ${ACR_IMAGE_NAME}:${TAG} + docker image inspect ${ACR_IMAGE_NAME}:${TAG} + docker push ${ACR_IMAGE_NAME}:${TAG} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..22fe28f --- /dev/null +++ b/README.md @@ -0,0 +1,48 @@ +# Stream Radio Server + +A Python Websocket Server for transcribing/translating multiple radio streams and allowing clients to subscribe to the results. + +## Getting Started + +Install all the required dependencies with: + +``` +brew install ffmpeg +pip3 install -r requirements.txt +``` + +## Running + +Start the server with + +```bash +python3 -m stream_transcriber.server --port 8765 +``` + +Connect with your client to e.g. `ws://localhost:8765`, +with https://github.com/vi/websocat this can be done with: +```bash +websocat ws://127.0.0.1:8765 +``` +> {"message": "Initialised", "info": "Waiting for message specyifing desired stream url"} + +The server expects an initial JSON message with the desired language to start streaming: +```json +{"name": "english"} +``` + +Now the client will receive audio chunks and messages in JSON format until the stream ends or the client disconnects. + +## Running tests + +Run the following command + +```bash +make unittest +``` + +The above command runs the tests in a docker container with the intended Python version and all dependencies installed. For running the tests directly on your computer run the following command + +```bash +make unittest-local +``` \ No newline at end of file diff --git a/profiling/README.md b/profiling/README.md new file mode 100644 index 0000000..ff2a0db --- /dev/null +++ b/profiling/README.md @@ -0,0 +1,48 @@ +# Profiling the server under load + +## Dependencies + +In addition to the dependencies needed to run the server, you'll need the following: + + +- cli tools: + - k6 + - ffmpeg +- Python packages: + - memory_profiler + - matplotlib + +## Run profiling + +We can collect some statistics while the server is under load: + +1. Start the server with mprofile to get an evolution of memory consumption over time. It'll track also memory of child processes (ffmpeg) + +```bash +SM_MANAGEMENT_PLATFORM_URL='' AUTH_TOKEN='' mprof run --multiprocess python3 -m stream_transcriber.server --port 8765 +``` + +2. A simple way to keep an eye of cpu usage while the server is running. In another terminal: + +```bash +# 1. Find the pid of the server +ps | grep server.py + +# 2. Watch snapshots every 1s +watch -n 1 'ps -p -o %cpu,%mem,cmd' +``` + +1. Generate some load using [k6](https://k6.io) + +```bash +k6 run profiling/client-load.js +``` +NOTE: for really high numbers of clients you might hit the max number of file descriptors allowed to be open. Find how to change it for your OS. In MacOS the number can be retrieved with `ulimit -n`. It can be changed with `ulimit -n ` + +4. The snapshots every 1 second of cpu and mem will be showing in the separate terminal. + +5. To visualize the graph of memory consumption over time, Ctrl + C in the terminal in which the server is running to stop it from running. Now use: + +```bash +mprof plot +``` diff --git a/profiling/client-load.js b/profiling/client-load.js new file mode 100644 index 0000000..77071aa --- /dev/null +++ b/profiling/client-load.js @@ -0,0 +1,34 @@ +import ws from 'k6/ws'; +import { check } from 'k6'; + +export const options = { + discardresponsebodies: true, + scenarios: { + users: { + executor: "ramping-vus", + startvus: 1, + stages: [ + { duration: '1m', target: 1 }, + { duration: '2m', target: 200 }, + { duration: '5m', target: 200 }, + { duration: '2m', target: 1 }, + { duration: '1m', target: 1 }, + ], + }, + }, +}; + +export default function () { + const url = 'ws://127.0.0.1:8765'; + const res = ws.connect(url, function (socket) { + socket.on('open', function open() { + console.log('connected') + const streams = ["english", "german", "french", "spanish"]; + const random = Math.floor(Math.random() * streams.length); + socket.send(`{"name": "${streams[random]}"}`) + }); + // socket.on('message', (data) => console.log('Message received: ', data)); + socket.on('close', () => console.log('disconnected')); + }); + check(res, { 'status is 101': (r) => r && r.status === 101 }); +} diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..a1471d0 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +addopts = -ra --full-trace --cov=stream_transcriber --cov-branch -o asyncio_mode=auto +pythonpath = stream_transcriber +testpaths = unittests \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..a73a194 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,6 @@ +pycodestyle==2.11.0 +pylint==3.0.1 +black==23.9.1 +pytest==7.4.2 +pytest-asyncio==0.21.1 +pytest-cov==4.1.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3b58c62 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +websockets~=11.0.3 +httpx[http2]~=0.23 +polling2~=0.5 +toml~=0.10.2 +prometheus-client~=0.16.0 +speechmatics-python~=1.9.0 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..15fbabe --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[pycodestyle] +max-line-length = 120 \ No newline at end of file diff --git a/stream_transcriber/__init__.py b/stream_transcriber/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stream_transcriber/server.py b/stream_transcriber/server.py new file mode 100644 index 0000000..4fb41af --- /dev/null +++ b/stream_transcriber/server.py @@ -0,0 +1,196 @@ +""" +Main server entry point +""" + +import asyncio +import datetime +import logging +import json +import os +import signal +from argparse import ArgumentParser +import traceback +from prometheus_client import start_http_server, Gauge + +import websockets +from websockets.exceptions import ConnectionClosed + +from stream_transcriber.streams import ( + SUPPORTED_STREAMS, + STREAMS, + StreamState, + load_stream, + streams_gauge, +) + +LOGGER = logging.getLogger("server") +SM_STREAM_TIMEOUT = os.getenv("SM_STREAM_TIMEOUT", "10") +clients_gauge = Gauge( + "connected_clients", "Amount of clients connected to the ws server" +) + + +async def close_stream_with_delay(key): + """ + Function to close the transcriber stream when the number of connections drops to 0 + """ + try: + stream_state = STREAMS[key] + except KeyError: + LOGGER.warning("Stream %s not found in the streams dictionary", key) + return + + if len(stream_state.connections) > 0: + return + + await asyncio.sleep(int(SM_STREAM_TIMEOUT)) + if len(stream_state.connections) > 0: + return + + try: + LOGGER.info("No connections left. Closing transcription", extra={"stream": key}) + stream_state.internal_task.cancel() + await stream_state.internal_task + finally: + LOGGER.info("Closed stream %s as no more clients are connected", key) + + +async def ws_handler(websocket): + """ + Websocket handler - receives client connections and attaches them to the correct audio stream + """ + try: + stream_name = None + LOGGER.info("Client connected") + await websocket.send( + json.dumps( + { + "message": "Initialised", + "info": "Waiting for message specifing desired stream url", + } + ) + ) + stream_data = json.loads(await websocket.recv()) + LOGGER.info("Received stream connection data %s", stream_data) + if "name" not in stream_data: + raise ValueError("Stream name not specified") + + stream_name = stream_data["name"] + if stream_name not in SUPPORTED_STREAMS: + raise ValueError(f"stream {stream_name} is not supported") + + if stream_name in STREAMS: + LOGGER.info("already started", extra={"stream": stream_name}) + STREAMS[stream_name].connections.append(websocket) + for old_message in STREAMS[stream_name].previous_messages: + await websocket.send(json.dumps(old_message)) + else: + LOGGER.info( + "Creating a new Transcription session", extra={"stream": stream_name} + ) + STREAMS[stream_name] = StreamState( + internal_task=asyncio.create_task(load_stream(stream_name)), + connections=[websocket], + ) + streams_gauge.labels(stream_name).inc() + + with clients_gauge.track_inprogress(): + await websocket.wait_closed() + except json.JSONDecodeError as error: + LOGGER.warning( + "Error decoding incoming JSON message with stream name: %s", error + ) + except ValueError as error: + LOGGER.warning( + "Non recognized stream in incoming select stream message: %s", error + ) + except ( + asyncio.InvalidStateError, + asyncio.CancelledError, + asyncio.IncompleteReadError, + ConnectionClosed, + ) as error: + LOGGER.error("Error in websocket connection handler: %s", error) + except Exception: # pylint: disable=broad-except + LOGGER.error( + "Unexpected exception in websocket connection handler:\n %s", + traceback.format_exc(), + ) + finally: + LOGGER.info("Connection closed, cleaning up") + if stream_name and stream_name in STREAMS: + stream_state = STREAMS[stream_name] + stream_state.connections.remove(websocket) + await close_stream_with_delay(stream_name) + + +async def main(port): + """ + Main entry point for the websocket server + """ + LOGGER.info("Starting WebSocket Server") + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + # pylint: disable=locally-disabled, no-member + async with websockets.serve(ws_handler, "0.0.0.0", port): + await stop # Wait for SIGTERM + + +class ExtraFormatter(logging.Formatter): + """ + Extra formatter for logging more context like stream name + """ + + def format(self, record: logging.LogRecord) -> str: + default_attrs = logging.LogRecord( + None, None, None, None, None, None, None + ).__dict__.keys() + extras = set(record.__dict__.keys()) - default_attrs + log_items = [ + ( + '"msg": "%(message)s", "time": "%(asctime)s",' + '"level": "%(levelname)s", "source": "%(name)s"' + ) + ] + for attr in extras: + log_items.append(f'"{attr}": "%({attr})s"') + format_str = f'{{{", ".join(log_items)}}}' + # pylint: disable=locally-disabled, protected-access + self._style._fmt = format_str + record.levelname = record.levelname.lower() + record.msg = record.msg.replace('"', r"\"") + return super().format(record) + + def formatTime(self, record, datefmt=None): + return datetime.datetime.fromtimestamp(record.created).isoformat() + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--port", default=8765, type=int, help="Port for the Websocket server" + ) + args = parser.parse_args() + FORMAT_STRING = ( + '{"msg": "%(message)s","time": "%(asctime)s", ' + '"source": "%(name)s", "level": "%(levelname)s"}' + ) + logging.basicConfig( + level=os.getenv("LOG_LEVEL", "debug").upper(), format=FORMAT_STRING + ) + logging.Formatter.formatTime = ( + lambda self, record, datefmt=None: datetime.datetime.fromtimestamp( + record.created + ).isoformat() + ) + handler = logging.StreamHandler() + formatter = ExtraFormatter() + handler.setFormatter(formatter) + LOGGER.setLevel(os.getenv("LOG_LEVEL", "debug").upper()) + + # Start server for exposing metrics + start_http_server(8000) + + asyncio.run(main(args.port)) diff --git a/stream_transcriber/streams.py b/stream_transcriber/streams.py new file mode 100644 index 0000000..1c2a5b3 --- /dev/null +++ b/stream_transcriber/streams.py @@ -0,0 +1,432 @@ +""" +Functions, variables and classes for handling the transcription streams +""" +import asyncio +import logging +import subprocess +import json +import time +from dataclasses import dataclass, field +from typing import Dict, List, Tuple +from collections import deque +import os +from functools import partial + +from prometheus_client import Gauge +import websockets + +from speechmatics.models import ( + ConnectionSettings, + TranscriptionConfig, + AudioSettings, + ServerMessageType, + RTTranslationConfig, +) +from speechmatics.client import WebsocketClient + +streams_gauge = Gauge( + "open_streams", "Amount of opened transcription/translation streams", ["language"] +) + + +LOGGER = logging.getLogger("server") +AUTH_TOKEN = os.environ["AUTH_TOKEN"] + +CONNECTION_URL = os.getenv("SM_RT_RUNTIME_URL", "wss://neu.rt.speechmatics.com/v2") + +ENABLE_TRANSCRIPTION_PARTIALS = os.getenv( + "SM_ENABLE_TRANSCRIPTION_PARTIALS", "False" +).lower() in ("true", "1", "t") +ENABLE_TRANSLATION_PARTIALS = os.getenv( + "SM_ENABLE_TRANSLATION_PARTIALS", "False" +).lower() in ("true", "1", "t") + +MAX_DELAY = int(os.getenv("SM_MAX_DELAY", "3")) + +FRAME_RATE = 16000 +FFMPEG_OUTPUT_FORMAT = "f32le" +ENCODING = f"pcm_{FFMPEG_OUTPUT_FORMAT}" +settings = AudioSettings(encoding=ENCODING, sample_rate=FRAME_RATE) + + +@dataclass +class SupportedStream: + """ + SupportedStream holds all the metadata about supported streams, + including language, translations, and endpoint + """ + + url: str + language: str + translation_languages: list[str] = field(default_factory=list) + + +SUPPORTED_STREAMS: Dict[str, SupportedStream] = { + "english": SupportedStream( + url="https://stream.live.vc.bbcmedia.co.uk/bbc_world_service", + # pylint: disable=locally-disabled, line-too-long + # url="https://a.files.bbci.co.uk/media/live/manifesto/audio/simulcast/hls/uk/high/cfs/bbc_world_service.m3u8", + language="en", + translation_languages=["fr", "de", "ja", "ko", "es"], + ), + "german": SupportedStream( + # pylint: disable=locally-disabled, line-too-long + url="https://icecast.ndr.de/ndr/ndrinfo/niedersachsen/mp3/128/stream.mp3?1680704013057&aggregator=web", + language="de", + translation_languages=["en"], + ), + "french": SupportedStream( + url="https://icecast.radiofrance.fr/franceinter-midfi.mp3", + language="fr", + translation_languages=["en"], + ), + "spanish": SupportedStream( + url="https://22333.live.streamtheworld.com/CADENASERAAC_SC", + language="es", + translation_languages=["en"], + ), + "korean": SupportedStream( + url="http://fmt01.egihosting.com:9468/", + language="ko", + translation_languages=["en"], + ), +} + + +@dataclass +class StreamState: + """ + StreamState is used to keep track of clients connected to the stream + Also holds the current transcription process to allow process cancellation + """ + + internal_task: asyncio.Task + connections: List = field(default_factory=list) + previous_messages: deque = field(default_factory=partial(deque, maxlen=20)) + + +# The Tuple key has a structure (name, language) e.g. (Radio 4, en) +STREAMS: Dict[Tuple[str, str], StreamState] = {} + + +# pylint: disable=locally-disabled, too-many-locals, too-many-statements +async def load_stream(stream_name: str): + """ + Function used to initialise a supported stream + + Sets ffmpeg to read from the stream URL and opens a transcriber session to transcribe the stream + """ + LOGGER.info("loading stream", extra={"stream": stream_name}) + stream_meta = SUPPORTED_STREAMS[stream_name] + stream_url = stream_meta.url + language = stream_meta.language + + conf = TranscriptionConfig( + language=language, + operating_point="enhanced", + max_delay=MAX_DELAY, + enable_partials=True, + translation_config=RTTranslationConfig( + target_languages=stream_meta.translation_languages, enable_partials=True + ), + ) + + stream_with_arg = ( + f"{stream_url}?rcvbuf=15000000" + if stream_url.startswith("srt://") + else stream_url + ) + ffmpegs_args = [ + *("-re",), + # *("-v", "48"), + *("-i", stream_with_arg), + *("-f", FFMPEG_OUTPUT_FORMAT), + *("-ar", FRAME_RATE), + *("-ac", 1), + *("-acodec", ENCODING), + "-", + ] + LOGGER.info( + "Running ffmpeg with args: %s", ffmpegs_args, extra={"stream": stream_name} + ) + + process = await asyncio.create_subprocess_exec( + "ffmpeg", + *map(str, ffmpegs_args), + limit=1024 * 1024, # 1 MiB reduces errors with long outputs (default is 64 KiB) + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + LOGGER.info("ffmpeg started", extra={"stream": stream_name}) + + url = f"{CONNECTION_URL}/{language}?sm-app=radio-stream-translation-demo" + + LOGGER.info("Starting SM websocket client", extra={"url": url}) + + start_time = time.time() + sm_client = WebsocketClient( + ConnectionSettings( + url=url, + auth_token=AUTH_TOKEN, + generate_temp_token=True, + ) + ) + sm_client.add_event_handler( + event_name=ServerMessageType.AddTranscript, + event_handler=partial( + send_transcript, + stream_name=stream_name, + start_time=start_time, + ), + ) + sm_client.add_event_handler( + event_name=ServerMessageType.AddPartialTranscript, + event_handler=partial( + send_transcript, + stream_name=stream_name, + start_time=start_time, + ), + ) + sm_client.add_event_handler( + event_name=ServerMessageType.AddTranslation, + event_handler=partial( + send_translation, + stream_name=stream_name, + start_time=start_time, + ), + ) + sm_client.add_event_handler( + event_name=ServerMessageType.AddPartialTranslation, + event_handler=partial( + send_translation, + stream_name=stream_name, + start_time=start_time, + ), + ) + sm_client.add_event_handler( + event_name=ServerMessageType.EndOfTranscript, + event_handler=partial(finish_session, stream_name=stream_name), + ) + sm_client.add_event_handler( + event_name=ServerMessageType.RecognitionStarted, + event_handler=partial(receive_message, stream_name=stream_name), + ) + sm_client.add_event_handler( + event_name=ServerMessageType.Error, + event_handler=partial(receive_message, stream_name=stream_name, level="error"), + ) + sm_client.add_event_handler( + event_name=ServerMessageType.Warning, + event_handler=partial( + receive_message, stream_name=stream_name, level="warning" + ), + ) + sm_client.add_event_handler( + event_name=ServerMessageType.Info, + event_handler=partial(receive_message, stream_name=stream_name), + ) + + try: + runtime_stream = asyncio.StreamReader() + broadcast_stream = asyncio.StreamReader() + + stream_clone_task = asyncio.create_task( + stream_tee( + process.stdout, runtime_stream, broadcast_stream, settings.chunk_size + ) + ) + + LOGGER.info( + "Starting transcription", + extra={"stream": stream_name}, + ) + asr_task = asyncio.create_task(sm_client.run(runtime_stream, conf, settings)) + send_audio_task = asyncio.create_task(send_audio(broadcast_stream, stream_name)) + log_task = asyncio.create_task(log_ffmpeg(process)) + + done, pending = await asyncio.wait( + [log_task, send_audio_task, asr_task, stream_clone_task], + return_when=asyncio.FIRST_EXCEPTION, + ) + async for done_routine in done: + if done_routine.exception() is not None: + LOGGER.error( + "Exception in return %s", + done_routine.exception(), + extra={"stream": stream_name}, + ) + async for pending_routine in pending: + pending_routine.cancel() + if pending_routine.exception() is not None: + LOGGER.error( + "Exception in return %s", + pending_routine.exception(), + extra={"stream": stream_name}, + ) + + except asyncio.CancelledError: + LOGGER.warning("Task Cancelled", extra={"stream": stream_name}) + finally: + LOGGER.warning("Stream %s exited, cleaning up tasks", stream_name) + if stream_name in STREAMS: + stream = STREAMS.pop(stream_name) + streams_gauge.labels(stream_name).dec() + LOGGER.info( + "Popped closed transcription stream, closing all connections", + extra={"stream": stream_name}, + ) + await force_close_connections(stream) + sm_client.stop() + await process.kill() + await asr_task.cancel() + await stream_clone_task.cancel() + await send_audio_task.cancel() + + LOGGER.info("Finished transcription", extra={"stream": stream_name}) + + +async def stream_tee(source, target, target_two, chunk_size): + """ + This function splits stream data from a source stream into two target streams. + """ + while True: + data = await source.read(chunk_size) + if not data: # EOF + break + target.feed_data(data) + target_two.feed_data(data) + + +async def send_audio(stream, stream_name): + """ + broadcast audio to connected clients + """ + while True: + data = await stream.read(settings.chunk_size) + if not data: + break + stream_state = STREAMS[stream_name] + # pylint: disable=locally-disabled, no-member + websockets.broadcast(stream_state.connections, data) + + +def send_transcript(message, stream_name, start_time): + """ + Event handler function to send transcript data to the client + """ + LOGGER.debug( + "Received message from transcriber", + extra={"stream": stream_name}, + ) + if stream_name not in STREAMS: + # no clients to serve + LOGGER.warning( + "Tried to send transcript message to closed stream", + extra={"stream": stream_name}, + ) + return + + LOGGER.debug("Received message from transcriber", extra={"stream": stream_name}) + + message["current_timestamp"] = time.time() + message["metadata"]["session_start_time"] = start_time + stream_state = STREAMS[stream_name] + LOGGER.debug("Received %s", message, extra={"stream": stream_name}) + LOGGER.debug( + "Broadcasting message for %s clients", + len(stream_state.connections), + extra={"stream": stream_name}, + ) + # pylint: disable=locally-disabled, no-member + websockets.broadcast(stream_state.connections, json.dumps(message)) + + +def send_translation(message, stream_name, start_time): + """ + Event handler function to send translation data to the client + """ + if stream_name not in STREAMS: + # no clients to serve + LOGGER.warning( + "Tried to send translation message to closed stream", + extra={"stream": stream_name}, + ) + return + + message["current_timestamp"] = time.time() + + message["metadata"] = {"session_start_time": start_time} + + stream_state = STREAMS[stream_name] + + LOGGER.debug("Received %s", message, extra={"stream": stream_name}) + + LOGGER.debug( + "Broadcasting message to %s clients", + len(stream_state.connections), + extra={"stream": stream_name}, + ) + # pylint: disable=locally-disabled, no-member + websockets.broadcast(stream_state.connections, json.dumps(message)) + + +def finish_session(message, stream_name): + """ + Handles finishing the session when end of transcript is reached + """ + LOGGER.info( + "Received end of transcript: %s", message, extra={"stream": stream_name} + ) + if stream_name not in STREAMS: + # no clients to serve + return + stream_state = STREAMS[stream_name] + loop = asyncio.get_event_loop() + for connection in stream_state.connections: + loop.create_task(connection.close()) + + +async def log_ffmpeg(process): + """ + Log stderr from ffmpeg - ffmpeg writes all logs to stderr as + """ + while True: + line = (await process.stderr.readline()).decode("utf-8").strip() + if len(line) == 0: + break + LOGGER.info(line, extra={"source": "ffmpeg"}) + + +def receive_message(message, stream_name, level="debug"): + """ + Receive messages + """ + if stream_name not in STREAMS: + # no clients to serve + return + + stream_state = STREAMS[stream_name] + + if level == "info": + LOGGER.info("Received %s", message, extra={"stream": stream_name}) + elif level == "warning": + LOGGER.warning("Received %s", message, extra={"stream": stream_name}) + stream_state.previous_messages.append(message) + elif level == "error": + LOGGER.error("Received %s", message, extra={"stream": stream_name}) + stream_state.previous_messages.append(message) + elif level == "debug": + LOGGER.debug("Received %s", message, extra={"stream": stream_name}) + + # pylint: disable=locally-disabled, no-member + websockets.broadcast(stream_state.connections, json.dumps(message)) + + +async def force_close_connections(stream): + """ + Force closes all websocket connections for a particular stream. + Used when the stream dies unexpectedly. + """ + LOGGER.warning("Force closing all connections") + for conn in stream.connections: + await conn.close(1011, "An unexpected error occurred on the server") diff --git a/unittests/__init__.py b/unittests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unittests/conftest.py b/unittests/conftest.py new file mode 100644 index 0000000..00e09f9 --- /dev/null +++ b/unittests/conftest.py @@ -0,0 +1,54 @@ +from unittest.mock import MagicMock, patch +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def set_sm_stream_timeout(): + with patch("stream_transcriber.server.SM_STREAM_TIMEOUT", 1): + yield + + +@pytest.fixture(scope="function") +def stream(): + stream_state = MagicMock() + + stream_state.internal_task = CustomAsyncMock() + stream_state.internal_task.cancel = MagicMock() + + stream_key = "english" + patcher = patch.dict( + "stream_transcriber.server.STREAMS", {stream_key: stream_state} + ) + patcher.start() + + yield stream_key, stream_state + patcher.stop() + + +@pytest.fixture() +def ws_server_protocol(recv_content, send_content): + mock_websocket = MagicMock() + + async def recv(): + return recv_content + + async def send(content): + return send_content + + mock_websocket.recv = MagicMock(side_effect=recv) + mock_websocket.send = MagicMock(side_effect=send) + + return mock_websocket + + +class CustomAsyncMock(MagicMock): + def __call__(self, *args, **kwargs): + sup = super(CustomAsyncMock, self) + + async def coro(): + return sup.__call__(*args, **kwargs) + + return coro() + + def __await__(self): + return self().__await__() diff --git a/unittests/test_server.py b/unittests/test_server.py new file mode 100644 index 0000000..3ca093f --- /dev/null +++ b/unittests/test_server.py @@ -0,0 +1,99 @@ +import asyncio +import pytest +from unittest.mock import MagicMock, patch + +from stream_transcriber.server import close_stream_with_delay, ws_handler, STREAMS + + +@pytest.mark.asyncio +async def test_closes_stream_when_stream_does_not_exist(): + with patch("stream_transcriber.server.LOGGER") as mock_logger: + key = "non_existing_stream_key" + await close_stream_with_delay(key) + mock_logger.warning.assert_called_with( + "Stream %s not found in the streams dictionary", key + ) + + +@pytest.mark.asyncio +async def test_does_not_close_stream_when_connections_exist(stream): + stream_key, stream_state = stream + stream_state.connections = [MagicMock()] + + with patch("stream_transcriber.server.LOGGER") as mock_logger: + await close_stream_with_delay(stream_key) + stream_state.internal_task.cancel.assert_not_called() + mock_logger.info.assert_not_called() + + +@pytest.mark.asyncio +async def test_does_not_close_stream_when_connections_reappear(stream): + stream_key, stream_state = stream + stream_state.connections = [] + + async def add_connection(): + await asyncio.sleep(0.5) + stream_state.connections.append(MagicMock()) + + with patch("stream_transcriber.server.LOGGER") as mock_logger: + await asyncio.gather(close_stream_with_delay(stream_key), add_connection()) + stream_state.internal_task.cancel.assert_not_called() + mock_logger.info.assert_not_called() + + +@pytest.mark.asyncio +async def test_close_stream_when_no_connections_left(stream): + stream_key, stream_state = stream + stream_state.connections = [] + + await close_stream_with_delay(stream_key) + stream_state.internal_task.cancel.assert_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("recv_content,send_content", [("not-json", "")]) +async def test_ws_handler_non_json_stream_select_message(ws_server_protocol): + with patch("stream_transcriber.server.LOGGER") as mock_logger: + await ws_handler(ws_server_protocol) + assert "Error decoding incoming JSON message with stream name" in str( + mock_logger.warning.call_args_list[0] + ) + assert STREAMS == {} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("recv_content,send_content", [('{"foo":"bar"}', "")]) +async def test_ws_handler_bad_format_in_stream_select_message(ws_server_protocol): + with patch("stream_transcriber.server.LOGGER") as mock_logger: + await ws_handler(ws_server_protocol) + assert "Non recognized stream in incoming select stream message" in str( + mock_logger.warning.call_args_list[0] + ) + assert STREAMS == {} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("recv_content,send_content", [('{"name":"english"}', "")]) +async def test_ws_handler_first_connection_for_stream(ws_server_protocol): + with patch("stream_transcriber.server.load_stream") as mock_load_stream: + with patch("stream_transcriber.server.LOGGER") as mock_logger: + await ws_handler(ws_server_protocol) + assert any( + "Creating a new Transcription session" in str(args) + for args, _ in mock_logger.info.call_args_list + ) + assert "english" in STREAMS + assert STREAMS["english"].connections is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("recv_content,send_content", [('{"name":"english"}', "")]) +async def test_ws_handler_connection_for_existing_stream(ws_server_protocol, stream): + with patch("stream_transcriber.server.LOGGER") as mock_logger: + await ws_handler(ws_server_protocol) + assert any( + "already started" in str(args) + for args, _ in mock_logger.info.call_args_list + ) + assert "english" in STREAMS + assert STREAMS["english"].connections.append.called diff --git a/unittests/test_streams.py b/unittests/test_streams.py new file mode 100644 index 0000000..eb6b117 --- /dev/null +++ b/unittests/test_streams.py @@ -0,0 +1,195 @@ +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from stream_transcriber.streams import ( + SUPPORTED_STREAMS, + finish_session, + force_close_connections, + load_stream, + receive_message, + send_audio, + send_transcript, + send_translation, + stream_tee, +) +from unittests.conftest import CustomAsyncMock + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "source_content", + [ + b"test", + b"", + b"hello world", + b"12345", + ], +) +async def test_stream_tee(source_content): + source = asyncio.StreamReader() + source.feed_data(source_content) + source.feed_eof() + + target_one = asyncio.StreamReader() + target_two = asyncio.StreamReader() + + await stream_tee(source, target_one, target_two, 4) + + assert await target_one.read(len(source_content)) == source_content + assert await target_two.read(len(source_content)) == source_content + + +@pytest.mark.asyncio +@pytest.mark.parametrize("content", [b"audio-content", b""]) +async def test_send_audio(stream, content): + stream_key, stream_state = stream + audio_stream = asyncio.StreamReader() + audio_stream.feed_data(content) + audio_stream.feed_eof() + + stream_state.connections = [MagicMock()] + + with patch("websockets.broadcast") as mock_broadcast: + await send_audio(audio_stream, stream_key) + if content: + mock_broadcast.assert_called_with(stream_state.connections, content) + else: + mock_broadcast.assert_not_called() + + +def test_send_transcript_for_closed_stream(): + with patch("stream_transcriber.streams.LOGGER") as mock_logger: + with patch("websockets.broadcast") as mock_broadcast: + send_transcript({}, "not-opened", "2023-01-01T00:00:00.000Z") + mock_broadcast.assert_not_called() + assert "Tried to send transcript message to closed stream" in str( + mock_logger.warning.call_args_list[0] + ) + + +def test_send_transcript(stream): + stream_key, stream_state = stream + stream_state.connections = [MagicMock()] + + session_start_time = "2023-01-01T00:00:00.000Z" + transcript = {"metadata": {}} + + expected_sent_transcript = transcript.copy() + mocked_time = 1696780852 + with patch("time.time", return_value=mocked_time): + expected_sent_transcript["current_timestamp"] = mocked_time + expected_sent_transcript["metadata"]["session_start_time"] = session_start_time + with patch("websockets.broadcast") as mock_broadcast: + send_transcript(transcript, stream_key, session_start_time) + mock_broadcast.assert_called_with( + stream_state.connections, json.dumps(expected_sent_transcript) + ) + + +def test_send_translation_for_closed_stream(): + with patch("stream_transcriber.streams.LOGGER") as mock_logger: + with patch("websockets.broadcast") as mock_broadcast: + send_translation({}, "not-opened", "2023-01-01T00:00:00.000Z") + mock_broadcast.assert_not_called() + assert "Tried to send translation message to closed stream" in str( + mock_logger.warning.call_args_list[0] + ) + + +def test_send_translation(stream): + stream_key, stream_state = stream + stream_state.connections = [MagicMock()] + + session_start_time = "2023-01-01T00:00:00.000Z" + translation = {"metadata": {}} + + expected_sent_translation = translation.copy() + mocked_time = 1696780852 + with patch("time.time", return_value=mocked_time): + expected_sent_translation["current_timestamp"] = mocked_time + expected_sent_translation["metadata"]["session_start_time"] = session_start_time + with patch("websockets.broadcast") as mock_broadcast: + send_translation(translation, stream_key, session_start_time) + mock_broadcast.assert_called_with( + stream_state.connections, json.dumps(expected_sent_translation) + ) + + +def test_finish_session_with_closed_stream(): + with patch("asyncio.get_event_loop") as mock_loop: + mock_loop.return_value = MagicMock() + finish_session({}, "not-opened") + mock_loop.return_value.create_task.assert_not_called() + + +@pytest.mark.parametrize("amount_connections", [0, 1, 2]) +def test_finish_session(stream, amount_connections): + stream_key, stream_state = stream + stream_state.connections = [] + for _ in range(amount_connections): + stream_state.connections.append(MagicMock()) + + with patch("asyncio.get_event_loop") as mock_loop: + mock_loop.return_value = MagicMock() + finish_session({}, stream_key) + assert mock_loop.return_value.create_task.call_count == amount_connections + + +def test_receive_message_for_closed_stream(): + with patch("websockets.broadcast") as mock_broadcast: + receive_message({}, "not-opened") + mock_broadcast.assert_not_called() + + +@pytest.mark.parametrize("level", ["info", "warning", "error", "debug"]) +def test_receive_message(stream, level): + stream_key, stream_state = stream + stream_state.previous_messages = [] + message = {"metadata": {}} + with patch("websockets.broadcast") as mock_broadcast: + receive_message(message, stream_key, level) + mock_broadcast.assert_called_with(stream_state.connections, json.dumps(message)) + # messages received with level warning or error are also appended to the stream previous_messages + if level in ("warning", "error"): + assert len(stream_state.previous_messages) == 1 + assert stream_state.previous_messages[0] == message + + +@pytest.mark.asyncio +@pytest.mark.parametrize("amount_connections", [0, 1, 2]) +async def test_force_close_connections(stream, amount_connections): + _, stream_state = stream + stream_state.connections = [] + for _ in range(amount_connections): + stream_state.connections.append(AsyncMock()) + + await force_close_connections(stream_state) + + for connection in stream_state.connections: + connection.close.assert_called_once() + + +@pytest.mark.asyncio +@patch("asyncio.create_subprocess_exec") +@patch("asyncio.wait", new=AsyncMock(return_value=(AsyncMock(), AsyncMock()))) +async def test_load_stream(mock_wait, stream): + stream_key, stream_state = stream + with patch("asyncio.create_task") as mock_create_task: + mock_create_task.return_value = CustomAsyncMock() + with patch( + "stream_transcriber.streams.force_close_connections" + ) as mock_force_close: + await load_stream(stream_key) + # Active stream should call force_close_connections + mock_force_close.assert_called_once_with(stream_state) + + # Inactive stream should not call force_close_connections + mock_force_close.reset_mock() + closed_stream = next( + key for key in SUPPORTED_STREAMS.keys() if key != stream_key + ) + await load_stream(closed_stream) + mock_force_close.assert_not_called()