From 751908cc1b55e856c1494a5cdcfaa1c4ad9866d5 Mon Sep 17 00:00:00 2001 From: boocmp Date: Sun, 29 Sep 2024 16:29:55 +0700 Subject: [PATCH] bento replacement. --- .github/workflows/generalized-deploy.yaml | 1 + Dockerfile | 81 +---- Makefile | 6 - apis/openapi.yaml | 144 --------- bento.yaml | 43 --- bentofile.yaml | 20 -- env/docker/entrypoint.sh | 56 ---- env/python/install.sh | 41 --- env/python/version.txt | 1 - pyproject.toml | 14 + .../requirements.txt => requirements.txt | 15 +- sampuru/__init__.py | 5 + sampuru/batch.py | 111 +++++++ sampuru/batch_test.py | 52 ++++ sampuru/job.py | 21 ++ sampuru/local.py | 98 ++++++ sampuru/redis.py | 193 ++++++++++++ sampuru/runnable.py | 22 ++ sampuru/runner.py | 255 ++++++++++++++++ sampuru/runner_test.py | 196 ++++++++++++ src/__init__.py | 1 + src/ipc/__init__.py | 1 + src/{utils => }/ipc/client.py | 4 +- src/{utils => }/ipc/messages.py | 0 .../ipc_server.py => ipc/server.py} | 4 +- .../__main__.py => ipc/server_test.py} | 21 +- src/ipc_server/__init__.py | 1 - src/requirements.txt | 3 +- src/runners/audio_transcriber.py | 285 +++--------------- src/service.py | 52 ++-- src/stream_transcriber.py | 64 ++-- src/stt_api.py | 86 +++--- src/utils/ipc/__init__.py | 2 - 33 files changed, 1155 insertions(+), 744 deletions(-) delete mode 100644 apis/openapi.yaml delete mode 100644 bento.yaml delete mode 100644 bentofile.yaml delete mode 100644 env/docker/entrypoint.sh delete mode 100644 env/python/install.sh delete mode 100644 env/python/version.txt create mode 100644 pyproject.toml rename env/python/requirements.txt => requirements.txt (57%) create mode 100644 sampuru/__init__.py create mode 100644 sampuru/batch.py create mode 100644 sampuru/batch_test.py create mode 100644 sampuru/job.py create mode 100644 sampuru/local.py create mode 100644 sampuru/redis.py create mode 100644 sampuru/runnable.py create mode 100644 sampuru/runner.py create mode 100644 sampuru/runner_test.py create mode 100644 src/__init__.py create mode 100644 src/ipc/__init__.py rename src/{utils => }/ipc/client.py (95%) rename src/{utils => }/ipc/messages.py (100%) rename src/{ipc_server/ipc_server.py => ipc/server.py} (97%) rename src/{tests/ipc_test/__main__.py => ipc/server_test.py} (76%) delete mode 100644 src/ipc_server/__init__.py delete mode 100644 src/utils/ipc/__init__.py diff --git a/.github/workflows/generalized-deploy.yaml b/.github/workflows/generalized-deploy.yaml index 22d1e2d..c72d30a 100644 --- a/.github/workflows/generalized-deploy.yaml +++ b/.github/workflows/generalized-deploy.yaml @@ -4,6 +4,7 @@ on: - main - dev - bentoml + - replace-bentoml name: stt deployments jobs: diff --git a/Dockerfile b/Dockerfile index 498d217..b57857e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,81 +1,20 @@ -# =========================================== -# -# THIS IS A GENERATED DOCKERFILE. DO NOT EDIT -# -# =========================================== - -# Block SETUP_BENTO_BASE_IMAGE -FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04 as base-container - -ENV LANG=C.UTF-8 - -ENV LC_ALL=C.UTF-8 - -ENV PYTHONIOENCODING=UTF-8 - -ENV PYTHONUNBUFFERED=1 +FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 ENV NVIDIA_VISIBLE_DEVICES=all +RUN apt update && apt install -y python3-pip -USER root - -ENV DEBIAN_FRONTEND=noninteractive -RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache -RUN set -eux && \ - apt-get update -y && \ - apt-get install -q -y --no-install-recommends --allow-remove-essential \ - ca-certificates gnupg2 bash build-essential libsndfile1 ffmpeg - -RUN \ - set -eux && \ - apt-get install -y --no-install-recommends --allow-remove-essential software-properties-common && \ - # add deadsnakes ppa to install python - add-apt-repository ppa:deadsnakes/ppa && \ - apt-get update -y && \ - apt-get install -y --no-install-recommends --allow-remove-essential curl python3.11 python3.11-dev python3.11-distutils +RUN pip install --upgrade pip +RUN pip install --upgrade setuptools -RUN ln -sf /usr/bin/python3.11 /usr/bin/python3 && \ - ln -sf /usr/bin/pip3.11 /usr/bin/pip3 +COPY ./requirements.txt ./requirements.txt -RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ - python3 get-pip.py && \ - rm -rf get-pip.py +RUN pip install -r requirements.txt -# Block SETUP_BENTO_USER -ARG BENTO_USER=bentoml -ARG BENTO_USER_UID=1034 -ARG BENTO_USER_GID=1034 -RUN groupadd -g $BENTO_USER_GID -o $BENTO_USER && useradd -m -u $BENTO_USER_UID -g $BENTO_USER_GID -o -r $BENTO_USER +COPY . /app +WORKDIR /app +RUN pip install . -ENV BENTOML_CONFIG=src/configuration.yaml -ARG BENTO_PATH=/home/bentoml/bento -ENV BENTO_PATH=$BENTO_PATH -ENV BENTOML_HOME=/home/bentoml/ -ENV BENTOML_DO_NOT_TRACK=True - -RUN mkdir $BENTO_PATH && chown bentoml:bentoml $BENTO_PATH -R -WORKDIR $BENTO_PATH - -# Block SETUP_BENTO_COMPONENTS -COPY --chown=bentoml:bentoml ./env/python ./env/python/ -# install python packages with install.sh -RUN bash -euxo pipefail /home/bentoml/bento/env/python/install.sh -COPY --chown=bentoml:bentoml . ./ - -# Block SETUP_BENTO_ENTRYPOINT -RUN rm -rf /var/lib/{apt,cache,log} -# Default port for BentoServer EXPOSE 3000 -# Expose Prometheus port -EXPOSE 3001 - -RUN chmod +x /home/bentoml/bento/env/docker/entrypoint.sh - -USER bentoml - -RUN mkdir /home/bentoml/.cache -RUN mkdir /home/bentoml/.cache/torch - -ENTRYPOINT [ "/home/bentoml/bento/env/docker/entrypoint.sh" ] +CMD [ "python3", "-m", "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "stt:app", "--workers", "1", "-b", "0.0.0.0:3000"] diff --git a/Makefile b/Makefile index 09edce1..f668639 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,8 @@ .PHONY: all build docker serve-docker all: serve-docker -build: - python -m bentoml delete -y stt:git || true - python -m bentoml build --version git src - cp -R ~/bentoml/bentos/stt/git/* . - docker: build docker build -t stt:latest . serve-docker: docker - sudo chown -R 1034:1043 models docker run -it --rm -p 3000:3000 -v ./models/huggingface:/home/bentoml/.cache/huggingface stt:latest diff --git a/apis/openapi.yaml b/apis/openapi.yaml deleted file mode 100644 index 3461303..0000000 --- a/apis/openapi.yaml +++ /dev/null @@ -1,144 +0,0 @@ -components: - schemas: - InternalServerError: - description: Internal Server Error - properties: - msg: - title: Message - type: string - type: - title: Error Type - type: string - required: - - msg - - type - title: InternalServerError - type: object - InvalidArgument: - description: Bad Request - properties: - msg: - title: Message - type: string - type: - title: Error Type - type: string - required: - - msg - - type - title: InvalidArgument - type: object - NotFound: - description: Not Found - properties: - msg: - title: Message - type: string - type: - title: Error Type - type: string - required: - - msg - - type - title: NotFound - type: object -info: - contact: - email: contact@bentoml.com - name: BentoML Team - description: "# stt:None\n\n[![pypi_status](https://img.shields.io/badge/BentoML-1.1.10-informational)](https://pypi.org/project/BentoML)\n - title: stt - version: None -openapi: 3.0.2 -paths: - /livez: - get: - description: Health check endpoint for Kubernetes. Healthy endpoint responses - with a 200 OK status. - responses: - '200': - description: Successful Response - tags: - - Infrastructure - /metrics: - get: - description: Prometheus metrics endpoint. The /metrics responses - with a 200. The output can then be used by a Prometheus sidecar - to scrape the metrics of the service. - responses: - '200': - description: Successful Response - tags: - - Infrastructure - /process_audio: - post: - consumes: - - null - description: '' - operationId: stt__process_audio - produces: - - application/json - requestBody: - content: - '*/*': - schema: - format: binary - type: string - required: true - x-bentoml-io-descriptor: - args: - kind: binaryio - mime_type: null - id: bentoml.io.File - responses: - 200: - content: - application/json: - schema: - type: object - description: Successful Response - x-bentoml-io-descriptor: - args: - has_json_encoder: false - has_pydantic_model: false - id: bentoml.io.JSON - 400: - content: - application/json: - schema: - $ref: '#/components/schemas/InvalidArgument' - description: Bad Request - 404: - content: - application/json: - schema: - $ref: '#/components/schemas/NotFound' - description: Not Found - 500: - content: - application/json: - schema: - $ref: '#/components/schemas/InternalServerError' - description: Internal Server Error - summary: "InferenceAPI(BytesIOFile \u2192 JSON)" - tags: - - Service APIs - x-bentoml-name: process_audio - /readyz: - get: - description: A 200 OK status from /readyz endpoint - indicated the service is ready to accept traffic. From that point and onward, - Kubernetes will use /livez endpoint to perform periodic health - checks. - responses: - '200': - description: Successful Response - tags: - - Infrastructure -servers: -- url: . -tags: -- description: BentoML Service API endpoints for inference. - name: Service APIs -- description: Common infrastructure endpoints for observability. - name: Infrastructure diff --git a/bento.yaml b/bento.yaml deleted file mode 100644 index 5534cbc..0000000 --- a/bento.yaml +++ /dev/null @@ -1,43 +0,0 @@ -service: service:svc -name: stt -version: git -bentoml_version: 1.1.10 -creation_time: '2023-12-15T21:32:03.847558+00:00' -labels: {} -models: [] -runners: -- name: audio_transcriber - runnable_type: AudioTranscriber - embedded: false - models: [] - resource_config: null -apis: -- name: process_audio - input_type: BytesIOFile - output_type: JSON -docker: - distro: debian - python_version: '3.11' - cuda_version: 12.1.1 - env: - BENTOML_CONFIG: src/configuration.yaml - system_packages: null - setup_script: null - base_image: null - dockerfile_template: null -python: - requirements_txt: requirements.txt - packages: null - lock_packages: null - index_url: null - no_index: null - trusted_host: null - find_links: null - extra_index_url: null - pip_args: null - wheels: null -conda: - environment_yml: null - channels: null - dependencies: null - pip: null diff --git a/bentofile.yaml b/bentofile.yaml deleted file mode 100644 index 1771af3..0000000 --- a/bentofile.yaml +++ /dev/null @@ -1,20 +0,0 @@ -service: "service:svc" -include: - - "service.py" - - "stt_api.py" - - "runners/__init__.py" - - "runners/audio_transcriber.py" - - "utils/google_streaming/google_streaming_api_pb2.py" - - "utils/ipc/__init__.py" - - "utils/ipc/client.py" - - "utils/ipc/messages.py" - - "utils/ipc/server.py" - - "utils/config/config.py" - - "utils/service_key/brave_service_key.py" - - "configuration.yaml" -python: - requirements_txt: "requirements.txt" -docker: - cuda_version: "12.1.1" - env: - BENTOML_CONFIG: "src/configuration.yaml" diff --git a/env/docker/entrypoint.sh b/env/docker/entrypoint.sh deleted file mode 100644 index df1892d..0000000 --- a/env/docker/entrypoint.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env bash -set -Eeuo pipefail - -# check to see if this file is being run or sourced from another script -_is_sourced() { - # https://unix.stackexchange.com/a/215279 - [ "${#FUNCNAME[@]}" -ge 2 ] && - [ "${FUNCNAME[0]}" = '_is_sourced' ] && - [ "${FUNCNAME[1]}" = 'source' ] -} - -_main() { - # For backwards compatibility with the yatai<1.0.0, adapting the old "yatai" command to the new "start" command. - if [ "${#}" -gt 0 ] && [ "${1}" = 'python' ] && [ "${2}" = '-m' ] && { [ "${3}" = 'bentoml._internal.server.cli.runner' ] || [ "${3}" = "bentoml._internal.server.cli.api_server" ]; }; then # SC2235, use { } to avoid subshell overhead - if [ "${3}" = 'bentoml._internal.server.cli.runner' ]; then - set -- bentoml start-runner-server "${@:4}" - elif [ "${3}" = 'bentoml._internal.server.cli.api_server' ]; then - set -- bentoml start-http-server "${@:4}" - fi - # If no arg or first arg looks like a flag. - elif [[ "$#" -eq 0 ]] || [[ "${1:0:1}" =~ '-' ]]; then - # This is provided for backwards compatibility with places where user may have - # discover this easter egg and use it in their scripts to run the container. - if [[ -v BENTOML_SERVE_COMPONENT ]]; then - echo "\$BENTOML_SERVE_COMPONENT is set! Calling 'bentoml start-*' instead" - if [ "${BENTOML_SERVE_COMPONENT}" = 'http_server' ]; then - set -- bentoml start-http-server "$@" "$BENTO_PATH" - elif [ "${BENTOML_SERVE_COMPONENT}" = 'grpc_server' ]; then - set -- bentoml start-grpc-server "$@" "$BENTO_PATH" - elif [ "${BENTOML_SERVE_COMPONENT}" = 'runner' ]; then - set -- bentoml start-runner-server "$@" "$BENTO_PATH" - fi - else - set -- bentoml serve "$@" "$BENTO_PATH" - fi - fi - # Overide the BENTOML_PORT if PORT env var is present. Used for Heroku and Yatai. - if [[ -v PORT ]]; then - echo "\$PORT is set! Overiding \$BENTOML_PORT with \$PORT ($PORT)" - export BENTOML_PORT=$PORT - fi - # Handle serve and start commands that is passed to the container. - # Assuming that serve and start commands are the first arguments - # Note that this is the recommended way going forward to run all bentoml containers. - if [ "${#}" -gt 0 ] && { [ "${1}" = 'serve' ] || [ "${1}" = 'serve-http' ] || [ "${1}" = 'serve-grpc' ] || [ "${1}" = 'start-http-server' ] || [ "${1}" = 'start-grpc-server' ] || [ "${1}" = 'start-runner-server' ]; }; then - exec bentoml "$@" "$BENTO_PATH" - else - # otherwise default to run whatever the command is - # This should allow running bash, sh, python, etc - exec "$@" - fi -} - -if ! _is_sourced; then - _main "$@" -fi diff --git a/env/python/install.sh b/env/python/install.sh deleted file mode 100644 index b2b512f..0000000 --- a/env/python/install.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env bash -set -exuo pipefail - -# Parent directory https://stackoverflow.com/a/246128/8643197 -BASEDIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]:-$0}"; )" &> /dev/null && pwd 2> /dev/null; )" - -PIP_ARGS=(--no-warn-script-location) - -# BentoML by default generates two requirement files: -# - ./env/python/requirements.lock.txt: all dependencies locked to its version presented during `build` -# - ./env/python/requirements.txt: all dependencies as user specified in code or requirements.txt file -REQUIREMENTS_TXT="$BASEDIR/requirements.txt" -REQUIREMENTS_LOCK="$BASEDIR/requirements.lock.txt" -WHEELS_DIR="$BASEDIR/wheels" -BENTOML_VERSION=${BENTOML_VERSION:-1.1.10} -# Install python packages, prefer installing the requirements.lock.txt file if it exist -if [ -f "$REQUIREMENTS_LOCK" ]; then - echo "Installing pip packages from 'requirements.lock.txt'.." - pip install -r "$REQUIREMENTS_LOCK" "${PIP_ARGS[@]}" -else - if [ -f "$REQUIREMENTS_TXT" ]; then - echo "Installing pip packages from 'requirements.txt'.." - pip install -r "$REQUIREMENTS_TXT" "${PIP_ARGS[@]}" - fi -fi - -# Install user-provided wheels -if [ -d "$WHEELS_DIR" ]; then - echo "Installing wheels packaged in Bento.." - pip install "$WHEELS_DIR"/*.whl "${PIP_ARGS[@]}" -fi - -# Install the BentoML from PyPI if it's not already installed -if python3 -c "import bentoml" &> /dev/null; then - existing_bentoml_version=$(python3 -c "import bentoml; print(bentoml.__version__)") - if [ "$existing_bentoml_version" != "$BENTOML_VERSION" ]; then - echo "WARNING: using BentoML version ${existing_bentoml_version}" - fi -else - pip install bentoml=="$BENTOML_VERSION" -fi diff --git a/env/python/version.txt b/env/python/version.txt deleted file mode 100644 index c8d5014..0000000 --- a/env/python/version.txt +++ /dev/null @@ -1 +0,0 @@ -3.11.5 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5ebad58 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "stt" +version = "0.1.0" +description = "brave stt service" +readme = "README.md" + + +[project.optional-dependencies] +dev = ["black", "pylint"] +test = ["pytest", "pytest-asyncio"] + +[tool.setuptools.package-dir] +stt = "src" +sampuru = "sampuru" diff --git a/env/python/requirements.txt b/requirements.txt similarity index 57% rename from env/python/requirements.txt rename to requirements.txt index 4eea27a..1285fd0 100644 --- a/env/python/requirements.txt +++ b/requirements.txt @@ -1,11 +1,16 @@ -ctranslate2 -faster_whisper -fastapi +accelerate aiofiles asyncio +fastapi +faster_whisper +gunicorn +huggingface_hub +msgspec pydantic pydantic-settings six -msgspec +torch transformers -torch \ No newline at end of file +uvicorn +scipy +silero_vad diff --git a/sampuru/__init__.py b/sampuru/__init__.py new file mode 100644 index 0000000..9fb208f --- /dev/null +++ b/sampuru/__init__.py @@ -0,0 +1,5 @@ +from .batch import BatchParameters +from .runnable import Runnable +from .local import LocalRunner + +# from .redis import RedisRunner, RemoteRedisRunner diff --git a/sampuru/batch.py b/sampuru/batch.py new file mode 100644 index 0000000..28d3e53 --- /dev/null +++ b/sampuru/batch.py @@ -0,0 +1,111 @@ +""" +module for defining the parameters for forming batches +""" + +import logging +import time + +from dataclasses import dataclass +from typing import List, Optional + +from scipy import stats + + +logger = logging.getLogger(__name__) + + +@dataclass +class BatchParameters: + """ + configuration for batching of requests + """ + + current_size: int = 1 + max_latency_ms: int = 1000 + max_observations: int = 50 + max_size: int = 10 + min_size: int = 1 + predicted_time_ms: int = 0 + update_interval_s: int = 10 + last_updated_at: float = 0 + + def update_current_size( + self, + queue_length: int, + jobs_per_second: float, + previous_batch_sizes: List[int], + previous_batch_times_ms: List[float], + ): + """ + calculate the optimal batch size given the configuration, current + observed incoming jobs per second and observed runtime of past batches + """ + + logger.info( + "update_current_size: before update - queue length %d, " + + "incoming jobs per second: %f, batch size %d, predicted batch time %f ms", + queue_length, + jobs_per_second, + self.current_size, + self.predicted_time_ms, + ) + batch_size = self.min_size + slope, intercept = self.regress_observations( + previous_batch_sizes, previous_batch_times_ms + ) + while batch_size < self.max_size: + predicted_batch_time_ms = self.predict(slope, intercept, batch_size) + predicted_throughput = (batch_size / predicted_batch_time_ms) * 1000 + if jobs_per_second < predicted_throughput: + self.current_size = batch_size + break + batch_size += 1 + else: + self.current_size = self.max_size + self.predicted_time_ms = self.predict(slope, intercept, self.current_size) + logger.info( + "update_current_size: after update - batch size %d, " + "predicted batch time %f ms", + self.current_size, + self.predicted_time_ms, + ) + self.last_updated_at = time.time() + + def get_timeout(self): + """ + timeout for getting a batch, based on max latency and predicted batch time + """ + return max(self.max_latency_ms - self.predicted_time_ms, 0.01) + + @staticmethod + def regress_observations( + previous_batch_sizes: List[int], previous_batch_times_ms: List[float] + ) -> (float, float): + """ + perform a linear regression of on observed worker performance + """ + res = stats.theilslopes(previous_batch_times_ms, previous_batch_sizes) + return res.slope, res.intercept + + @staticmethod + def predict(slope: float, intercept: float, size: int) -> float: + """ + based on the slope and intercept determined by a previous regression, + predict the time needed to process a batch of size + """ + return slope * size + intercept + + def should_update(self, jobs_per_second: float, batches_processed: int): + """ + should we attempt to update the batch size based on the number of batches we + have processed since the last update + """ + if jobs_per_second > 0: + elapsed = time.time() - self.last_updated_at + logger.debug( + "should_update: batches processed %d, time since last update %fs", + batches_processed, + elapsed, + ) + return self.last_updated_at == 0 or (elapsed >= self.update_interval_s) + return False diff --git a/sampuru/batch_test.py b/sampuru/batch_test.py new file mode 100644 index 0000000..2b1021d --- /dev/null +++ b/sampuru/batch_test.py @@ -0,0 +1,52 @@ +import pytest + +from sampuru.batch import BatchParameters + + +def test_update_current_size(): + params = BatchParameters( + max_latency_ms=1000, + max_observations=50, + max_size=10, + min_size=1, + update_interval_s=10, + ) + params.update_current_size( + previous_batch_sizes=[1, 2, 3], + previous_batch_times_ms=[100, 150, 200], + # 50*(i+1) = 50*6 = 300ms per batch for batch size 6 + # 5 * 1000ms / 300ms = 16.6 per second + jobs_per_second=16, + ) + assert params.current_size == 5 + assert params.predicted_time_ms == 300 + + +def test_get_timeout(): + params = BatchParameters(max_latency_ms=1000, predicted_time_ms=500) + assert params.get_timeout() == 500 + + params.predicted_time_ms = 1500 + assert params.get_timeout() == 0.01 + + +def test_regress_observations(): + batch_sizes = [1, 2, 3] + batch_times_ms = [10, 20, 30] + slope, intercept = BatchParameters.regress_observations(batch_sizes, batch_times_ms) + assert slope == pytest.approx(10.0) + assert intercept == pytest.approx(0.0) + + +def test_predict(): + slope = 10.0 + intercept = 0.0 + assert BatchParameters.predict(slope, intercept, 5) == 50.0 + + +def test_should_update(): + params = BatchParameters( + current_size=5, predicted_time_ms=100, update_interval_s=10 + ) + assert params.should_update(jobs_per_second=10, batches_processed=17) + assert not params.should_update(jobs_per_second=10, batches_processed=8) diff --git a/sampuru/job.py b/sampuru/job.py new file mode 100644 index 0000000..38d1763 --- /dev/null +++ b/sampuru/job.py @@ -0,0 +1,21 @@ +""" +job definition +""" + +from asyncio import Future +from typing import Any, List, Optional +from dataclasses import dataclass + + +@dataclass +class Job: + """ + a single unit of work to be sent to a remote runner + """ + + runnable_name: str + submitter_name: str + data: List[Any] + future: Future + job_id: Optional[str] = None + metadata: Optional[dict] = None diff --git a/sampuru/local.py b/sampuru/local.py new file mode 100644 index 0000000..5d240d5 --- /dev/null +++ b/sampuru/local.py @@ -0,0 +1,98 @@ +import asyncio +import collections +import logging +import time + +from contextlib import asynccontextmanager +from typing import Any, List + +from sampuru.runner import Runner, Job + +logger = logging.getLogger(__name__) + + +class LocalRunner(Runner): + """ + a local runner which keeps jobs in a collections.deque + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.queue = asyncio.Queue() + self.waker = asyncio.Condition() + + self.x = collections.deque(maxlen=self.batch_params.max_observations) + self.y = collections.deque(maxlen=self.batch_params.max_observations) + + self.request_count = 0 + self.request_count_history = collections.deque(maxlen=30) + + async def run(self, *args, loop, **kwargs): + self.loop = loop + self.loop.create_task(self.tock()) + await super().run(*args, **kwargs) + + async def get_observations(self) -> (List[int], List[float]): + return list(self.x), list(self.y) + + async def observe(self, num_jobs: int, elapsed_ms: float): + self.x.append(num_jobs) + self.y.append(elapsed_ms) + + async def create_response_future(self) -> asyncio.Future: + return self.loop.create_future() + + async def enqueue(self, job: Job): + self.request_count += 1 + await self.queue.put(job) + + async def get_queue_length(self) -> int: + return self.queue.qsize() + + async def get_from_queue(self) -> Job: + try: + return await self.queue.get() + except RuntimeError as e: + if str(e) != "Event loop is closed": + raise e + + @asynccontextmanager + async def get_batch_with_timeout(self, timeout_ms: int = None) -> List[Job]: + batch = [] + start_time = time.time() + try: + while True: + remaining_timeout_s = None + if timeout_ms: + elapsed_time_s = time.time() - start_time + remaining_timeout_s = max( + (timeout_ms / 1000) - elapsed_time_s, 0.001 + ) + job = await asyncio.wait_for( + self.get_from_queue(), timeout=remaining_timeout_s + ) + if job: + batch.append(job) + if len(batch) >= self.batch_params.max_size: + break + if ( + await self.get_queue_length() + > 2 * self.batch_params.current_size + ): + continue + if len(batch) >= self.batch_params.current_size: + break + except asyncio.TimeoutError: + pass + yield batch + for job in batch: + self.queue.task_done() + + async def tock(self): + while True: + self.request_count_history.append(self.request_count) + self.request_count = 0 + await asyncio.sleep(1) + + async def get_jobs_per_second(self): + return sum(self.request_count_history) / len(self.request_count_history) diff --git a/sampuru/redis.py b/sampuru/redis.py new file mode 100644 index 0000000..29ef375 --- /dev/null +++ b/sampuru/redis.py @@ -0,0 +1,193 @@ +import asyncio +import collections +import json +import logging +import time + +from contextlib import asynccontextmanager +from typing import List + +import redis + +from sampuru.runner import Runner, Job + +logger = logging.getLogger(__name__) + +redis_host = "localhost" +stream_key = "skey" +stream2_key = "s2key" +group1 = "grp1" + +list_key = "lke" + +# TODO +# - truncate stream ( XTRIM ~ rps * 5 * 60 ) +# - delete old streams +# - truncate list +# - rename keys / groups +# - handle multiple workers + + +class RedisRunner(Runner): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.request_count = 0 + + self.pool = redis.asyncio.ConnectionPool.from_url("redis://localhost") + # self.r.ping() + + self.futures = {} + + self.request_count_history = collections.deque(maxlen=30) + + async def get_observations(self) -> (List[int], List[float]): + client = redis.asyncio.Redis(connection_pool=self.pool) + result = await client.lrange(list_key, 0, -1) + + x = [] + y = [] + for elem in result: + # get string value + tmp = elem.decode("utf-8") + + _x, _y = tmp.split(":") + x.append(int(_x)) + y.append(float(_y)) + return x, y + + async def observe(self, batch_size: int, elapsed_s: float): + # FIXME may want to ensure not all one x value + client = redis.asyncio.Redis(connection_pool=self.pool) + await client.rpush(list_key, f"{batch_size}:{elapsed_s}") + + async def create_response_future(self): + return self.loop.create_future() + + def get_current_stream_key(self): + return self.get_stream_key(time.time()) + + def get_previous_stream_key(self): + return self.get_stream_key(time.time() - 60) + + def get_stream_key(self, timestamp: int = None): + return f"{stream_key}:{int(timestamp / 60)}" + + async def enqueue(self, job: Job): + client = redis.asyncio.Redis(connection_pool=self.pool) + self.request_count += 1 + jid = await client.xadd( + self.get_current_stream_key(), + {"data": json.dumps(job.data).encode("utf-8"), "submitter": job.submitter}, + ) + self.futures[jid] = job.future + + @asynccontextmanager + async def get_batch_with_timeout(self, timeout_s): + # FIXME + client = redis.asyncio.Redis(connection_pool=self.pool) + streams = [self.get_previous_stream_key(), self.get_current_stream_key()] + for stream in streams: + try: + await client.xgroup_create(stream, group1, mkstream=True) + except Exception as e: + pass + logger.debug( + "get_batch_with_timeout: waiting up to %fs for %d items", + timeout_s, + self.batch_size, + ) + reply = await client.xreadgroup( + groupname=group1, + consumername=self.name, + block=int(timeout_s * 1000), + count=self.batch_size, + streams={stream: ">" for stream in streams}, + ) + batch = [] + if len(reply) > 0: + for d_stream in reply: + for element in d_stream[1]: + future = await self.create_response_future() + job = Job( + data=json.loads(element[1][b"data"].decode("utf-8")), + future=future, + jid=element[0], + submitter=element[1][b"submitter"], + metadata={"stream_name": d_stream[0]}, + ) + batch.append(job) + + try: + yield batch + finally: + for job in batch: + await client.xadd( + f"{stream2_key}:{job.submitter}", + { + "orig_jid": job.jid, + "data": json.dumps(await job.future).encode("utf-8"), + }, + ) + await client.xack(job.metadata["stream_name"], group1, job.jid) + + async def handle_responses(self): + client = redis.asyncio.Redis(connection_pool=self.pool) + try: + await client.xgroup_create( + f"{stream2_key}:{self.name}", group1, mkstream=True + ) + except Exception as e: + pass + + while True: + reply = await client.xreadgroup( + groupname=group1, + consumername=self.name, + block=1000, + streams={f"{stream2_key}:{self.name}": ">"}, + ) + if len(reply) > 0: + d_stream = reply[0] + for element in d_stream[1]: + jid = element[1][b"orig_jid"] + response = json.loads(element[1][b"data"].decode("utf-8")) + if jid in self.futures and not self.futures[jid].done(): + self.futures[jid].set_result(response) + await client.xack(d_stream[0], group1, jid) + + async def run(self, warmup: bool = False): + self.loop.create_task(self.tock()) + await super().run(warmup) + + async def tock(self): + client = redis.asyncio.Redis(connection_pool=self.pool) + last_entries_added = {} + while True: + try: + stream_key = self.get_current_stream_key() + result = await client.xinfo_stream(stream_key) + entries_added = result["entries-added"] + active_consumers = await client.xinfo_consumers(stream_key, group1) + # TODO cleanup last entries and consumers + if stream_key in last_entries_added: + # given request_count_history averaging, per worker rps will slowly go up/down + self.request_count_history.append( + (entries_added - last_entries_added[stream_key]) + / len(active_consumers) + ) + last_entries_added[stream_key] = entries_added + except Exception as e: + # print("tock exception", e) + pass + await asyncio.sleep(1) + + async def get_rps(self): + return sum(self.request_count_history) / len(self.request_count_history) + + +class RemoteRedisRunner(RedisRunner): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, remote=True) + + async def run(self, warmup: bool = False): + await self.handle_responses() diff --git a/sampuru/runnable.py b/sampuru/runnable.py new file mode 100644 index 0000000..9d9bab2 --- /dev/null +++ b/sampuru/runnable.py @@ -0,0 +1,22 @@ +import abc + +from typing import Any, List + + +class Runnable(abc.ABC): + # whether a runner should flatten the incoming data into separate jobs when batching + # + # considerations: + # - if data from the same request is split into multiple jobs it may be split over + # multiple workers or multiple forward passes. this can increase the median + # latency + # - if the length of data from different requests is highly irregular, the quality + # of the performance prediction can degrade + flatten = False + + @abc.abstractmethod + def forward(self, data: List[Any]) -> List[Any]: + """ + perform a forward pass over the data, producing one output for each input + """ + raise NotImplementedError diff --git a/sampuru/runner.py b/sampuru/runner.py new file mode 100644 index 0000000..22e7881 --- /dev/null +++ b/sampuru/runner.py @@ -0,0 +1,255 @@ +""" +module for defining runners +""" + +import abc +import asyncio +import logging +import random +import time + +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Optional + +from sampuru.batch import BatchParameters +from sampuru.job import Job +from sampuru.runnable import Runnable + +logger = logging.getLogger(__name__) + + +class Runner(abc.ABC): + """ + a runner provides an efficient interface to perform work using a particular runnable + """ + + def __init__( + self, + cls, + runnable_init_params: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + batch_params: Optional[BatchParameters] = None, + ): + self.name = name or f"{cls.__name__}:{random.randint(1000, 9999)}" + self.last_observations = None + + self.batch_params = batch_params or BatchParameters() + self.batches_processed_since_last_update = 0 + + self.runnable = None + self.runnable_cls = cls + self.runnable_init_params = runnable_init_params or {} + + self.loop = asyncio.get_event_loop() + + @abc.abstractmethod + async def get_observations(self) -> (List[int], List[float]): + """ + get past records of batch processing returning a list containing the number of + jobs processed and a list of the elapsed time to process those jobs + """ + raise NotImplementedError + + @abc.abstractmethod + async def observe(self, num_jobs: int, elapsed_ms: float): + """ + record an observed batch processing event noting the number of jobs + processed and the elapsed time + """ + raise NotImplementedError + + @abc.abstractmethod + async def create_response_future(self) -> asyncio.Future: + """ + create a future to be be resolved when the processing for a request is + complete + """ + raise NotImplementedError + + @abc.abstractmethod + async def enqueue(self, job: Job): + """ + enqueue a job to be processed + """ + raise NotImplementedError + + @abc.abstractmethod + async def get_queue_length(self) -> int: + """ + get the current length of the queue + """ + raise NotImplementedError + + @abc.abstractmethod + @asynccontextmanager + async def get_batch_with_timeout(self, timeout_ms: int) -> List[Job]: + """ + get a batch with timeout + """ + raise NotImplementedError + + @abc.abstractmethod + async def get_jobs_per_second(self) -> float: + """ + get the current incoming requests per second + """ + raise NotImplementedError + + async def forward( + self, data: List[Any], timeout_ms: Optional[int] = None + ) -> List[Any]: + """ + asynchronously performs a forward pass of the runnable over the data + by job(s) submitting to the work queue. jobs in the queue will be + processed by batching + """ + futures = [] + try: + if self.runnable_cls.flatten: + # we should flatten each element of data into it's own job + for d in data: + future = await self.create_response_future() + job = Job( + runnable_name=str(self.runnable_cls.__name__), + submitter_name=self.name, + data=[d], + future=future, + ) + await self.enqueue(job) + futures.append(future) + + outputs = await asyncio.wait_for( + asyncio.gather(*futures), + timeout=timeout_ms / 1000 if timeout_ms else None, + ) + return [out[0] for out in outputs] + + future = await self.create_response_future() + futures.append(future) + job = Job( + runnable_name=str(self.runnable_cls.__name__), + submitter_name=self.name, + data=data, + future=future, + ) + await self.enqueue(job) + return await asyncio.wait_for( + future, timeout=timeout_ms / 1000 if timeout_ms else None + ) + except asyncio.TimeoutError: + for future in futures: + if not future.done(): + future.cancel("cancelling due to timeout") + + async def update_batch_size(self): + """ + update the batch size based on current observed performance + """ + queue_length = await self.get_queue_length() + jobs_per_second = await self.get_jobs_per_second() + observations = await self.get_observations() + + # need at least 3 unique observations, + if len(set(observations[0])) > 2: + self.last_observations = observations + + # fallback to last observations if insufficient + previous_batch_sizes, previous_batch_times_ms = self.last_observations + + self.batch_params.update_current_size( + queue_length, jobs_per_second, previous_batch_sizes, previous_batch_times_ms + ) + self.batches_processed_since_last_update = 0 + + async def run_warmup(self): + """ + run in warmup mode, slowly stepping up from the minimum batch size + to create observations for performance prediction + """ + i = self.batch_params.min_size + batch_warmup_increment = int( + 0.1 * (self.batch_params.max_size - self.batch_params.min_size) + ) + while i < self.batch_params.min_size + 3 * batch_warmup_increment: + logger.debug("run_warmup: waiting for batch size %d", i) + self.batch_params.current_size = i + async with self.get_batch_with_timeout( + self.batch_params.max_latency_ms + ) as batch: + if not batch: + continue + await self.run_batch(batch) + if len(batch) >= self.batch_params.current_size: + i += batch_warmup_increment + + async def run(self, warmup: bool = False, run_forever: bool = True): + """ + run continuously, first performing a warmup if desired + """ + if not self.runnable: + self.runnable = self.runnable_cls(**self.runnable_init_params) + if warmup: + await self.run_warmup() + await self.update_batch_size() + + while True: + async with self.get_batch_with_timeout( + self.batch_params.get_timeout() + ) as batch: + if not batch: + continue + await self.run_batch(batch) + self.batches_processed_since_last_update += 1 + + if self.batch_params.should_update( + await self.get_jobs_per_second(), + self.batches_processed_since_last_update, + ): + await self.update_batch_size() + + if not run_forever: + break + + async def run_batch(self, jobs: List[Job]): + """ + run a batch of jobs, flattening their data to perform a single forward pass of + the runnable + """ + + if not self.runnable: + self.runnable = self.runnable_cls(**self.runnable_init_params) + + # only attempt to run jobs that haven't already been completed + # e.g. through cancellation + jobs = [job for job in jobs if not job.future.done()] + + logger.debug("run_batch: running %d jobs", len(jobs)) + if len(jobs) > 0: + start_time = time.time() + try: + outputs = await self.loop.run_in_executor( + None, self.runnable.forward, sum([job.data for job in jobs], []) + ) + elapsed_ms = (time.time() - start_time) * 1000 + logger.debug("run_batch: finished execution") + + for job in jobs: + future = job.future + job_dim = len(job.data) + if not future.done(): + future.set_result(outputs[:job_dim]) + # advance the outputs + outputs = outputs[job_dim:] + logger.debug("run_batch: set results") + + await self.observe(len(jobs), elapsed_ms) + logger.debug( + "run_batch: recorded observation, ran %d jobs in %fms", + len(jobs), + elapsed_ms, + ) + except Exception as e: + logger.exception("run_batch: execution failed") + for job in jobs: + if not job.future.done(): + job.future.set_exception(e) diff --git a/sampuru/runner_test.py b/sampuru/runner_test.py new file mode 100644 index 0000000..d899eed --- /dev/null +++ b/sampuru/runner_test.py @@ -0,0 +1,196 @@ +""" +runner tests +""" + +import asyncio +import logging +import time + +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Optional + +import pytest + +from sampuru.batch import BatchParameters +from sampuru.job import Job +from sampuru.runnable import Runnable +from sampuru.runner import Runner + +logger = logging.getLogger(__name__) + + +class MockRunner(Runner): + def __init__( + self, + cls, + runnable_init_params: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + batch_params: Optional[BatchParameters] = None, + ): + super().__init__(cls, runnable_init_params, name, batch_params) + self.observations = [] + self.queue = asyncio.Queue() + + async def get_observations(self) -> (List[int], List[float]): + jobs_processed = [obs[0] for obs in self.observations] + elapsed_times = [obs[1] for obs in self.observations] + return jobs_processed, elapsed_times + + async def observe(self, num_jobs: int, elapsed_ms: float): + self.observations.append((num_jobs, elapsed_ms)) + + async def create_response_future(self) -> asyncio.Future: + return asyncio.Future() + + async def enqueue(self, job: Job): + await self.queue.put(job) + + async def get_from_queue(self) -> Job: + try: + return await self.queue.get() + except RuntimeError as e: + if str(e) != "Event loop is closed": + raise e + + @asynccontextmanager + async def get_batch_with_timeout(self, timeout_ms: int = None) -> List[Job]: + batch = [] + start_time = time.time() + try: + while True: + remaining_timeout_s = None + if timeout_ms: + elapsed_time_s = time.time() - start_time + remaining_timeout_s = max( + (timeout_ms / 1000) - elapsed_time_s, 0.001 + ) + job = await asyncio.wait_for( + self.get_from_queue(), timeout=remaining_timeout_s + ) + if job: + batch.append(job) + if len(batch) >= self.batch_params.current_size: + break + except asyncio.TimeoutError: + pass + if batch: + yield batch + for job in batch: + self.queue.task_done() + + async def get_jobs_per_second(self) -> float: + return 100.0 + + +class MockRunnable(Runnable): + flatten = False + + def __init__(self, *args, **kwargs): + pass + + def forward(self, data): + return data + + +class FlatMockRunnable(MockRunnable): + flatten = True + + +class ExceptionMockRunnable(MockRunnable): + def forward(self, data): + raise Exception("error in forward") + + +@pytest.mark.asyncio +async def test_runner_init(): + runner = MockRunner(MockRunnable, name="test_runner") + assert runner.name == "test_runner" + assert isinstance(runner.batch_params, BatchParameters) + assert runner.runnable is None + assert runner.runnable_cls == MockRunnable + assert runner.runnable_init_params == {} + + +@pytest.mark.asyncio +async def test_runner_forward(): + runner = MockRunner(MockRunnable) + data = [1, 2, 3] + out = asyncio.create_task(runner.forward(data)) + # forward should enqueue a single job given flatten = False + job = await asyncio.wait_for(runner.queue.get(), timeout=0.1) + # the job data should be as expected + assert job.data == data + # if we resolve the future we should get the expected output from forward + job.future.set_result(data) + assert await out == data + + runner = MockRunner(FlatMockRunnable) + out = asyncio.create_task(runner.forward(data)) + # forward should enqueue a three jobs given flatten = False + job1 = await asyncio.wait_for(runner.queue.get(), timeout=0.1) + job2 = await asyncio.wait_for(runner.queue.get(), timeout=0.1) + job3 = await asyncio.wait_for(runner.queue.get(), timeout=0.1) + # if we resolve the future we should get the combined output from forward + job1.future.set_result([data[0]]) + job2.future.set_result([data[1]]) + job3.future.set_result([data[2]]) + assert await out == data + + +@pytest.mark.asyncio +async def test_runner_run_batch(): + runner = MockRunner(MockRunnable, batch_params=BatchParameters(current_size=2)) + jobs = [ + Job( + runnable_name="MockRunnable", + submitter_name="test_runner", + data=[1, 2], + future=asyncio.Future(), + ), + Job( + runnable_name="MockRunnable", + submitter_name="test_runner", + data=[3, 4], + future=asyncio.Future(), + ), + Job( + runnable_name="MockRunnable", + submitter_name="test_runner", + data=[5], + future=asyncio.Future(), + ), + ] + for job in jobs: + await runner.enqueue(job) + + async with runner.get_batch_with_timeout(timeout_ms=1000) as batch: + await runner.run_batch(batch) + + assert len(runner.observations) == 1 + assert runner.observations[0][0] == 2 # Number of jobs processed + assert runner.observations[0][1] > 0 # Elapsed time + + for job in jobs[:2]: + assert job.future.done() + assert job.future.result() == job.data + + assert not jobs[2].future.done() + assert not jobs[2].future.set_result([]) + + runner = MockRunner(ExceptionMockRunnable) + jobs[0].future = asyncio.Future() + await runner.enqueue(jobs[0]) + async with runner.get_batch_with_timeout(timeout_ms=1000) as batch: + await runner.run_batch(batch) + # an exception within the runnable should be propogated to the job future + with pytest.raises(Exception): + await jobs[0].future + + +@pytest.mark.asyncio +async def test_runner_run(): + runner = MockRunner(MockRunnable, batch_params=BatchParameters(current_size=1)) + data = [1, 2, 3] + asyncio.create_task(runner.run(run_forever=False)) + result = await runner.forward(data) + assert result == data diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..ba6d90a --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +from .service import app diff --git a/src/ipc/__init__.py b/src/ipc/__init__.py new file mode 100644 index 0000000..fe0338a --- /dev/null +++ b/src/ipc/__init__.py @@ -0,0 +1 @@ +from .server import * diff --git a/src/utils/ipc/client.py b/src/ipc/client.py similarity index 95% rename from src/utils/ipc/client.py rename to src/ipc/client.py index 26d554a..af70878 100644 --- a/src/utils/ipc/client.py +++ b/src/ipc/client.py @@ -3,7 +3,7 @@ class Publisher: - def __init__(self, pair, host="localhost", port=3015): + def __init__(self, pair, host="127.0.0.1", port=3015): self._pair: str = pair self._host = host self._port = port @@ -39,7 +39,7 @@ async def op(): class Subscriber: - def __init__(self, pair: str, host="localhost", port=3015): + def __init__(self, pair: str, host="127.0.0.1", port=3015): self._pair: str = pair self._host = host self._port = port diff --git a/src/utils/ipc/messages.py b/src/ipc/messages.py similarity index 100% rename from src/utils/ipc/messages.py rename to src/ipc/messages.py diff --git a/src/ipc_server/ipc_server.py b/src/ipc/server.py similarity index 97% rename from src/ipc_server/ipc_server.py rename to src/ipc/server.py index cba73cf..53c178e 100644 --- a/src/ipc_server/ipc_server.py +++ b/src/ipc/server.py @@ -1,5 +1,5 @@ import asyncio -from utils.ipc import messages +from . import messages Publishers: dict[str, asyncio.StreamReader] = {} Subscribers: dict[str, bool] = {} @@ -108,7 +108,7 @@ async def run_ipc_server(host, port): await server.serve_forever() -def start_ipc_server(host="localhost", port=3015): +def start_ipc_server(host="127.0.0.1", port=3015): asyncio.run(run_ipc_server(host, port)) diff --git a/src/tests/ipc_test/__main__.py b/src/ipc/server_test.py similarity index 76% rename from src/tests/ipc_test/__main__.py rename to src/ipc/server_test.py index a51893c..c9e79df 100644 --- a/src/tests/ipc_test/__main__.py +++ b/src/ipc/server_test.py @@ -1,8 +1,10 @@ import asyncio -from utils.ipc.client import Publisher, Subscriber -from utils.ipc import messages -from ipc_server import run_ipc_server +import pytest + +from stt.ipc.client import Publisher, Subscriber +from stt.ipc import messages +from stt.ipc.server import run_ipc_server async def publisher(pair): @@ -41,15 +43,16 @@ async def batch(pair): except Exception as e: print(e) pass + print("done") + +@pytest.mark.asyncio +async def test_server(): + srv = asyncio.create_task(run_ipc_server("localhost", 3015)) -async def main(): - tasks = [ asyncio.create_task(run_ipc_server("localhost", 3015))] + tasks = [] for i in range(20): tasks.append(asyncio.create_task(batch(str(i)))) - for t in tasks: await t - - -asyncio.run(main()) + srv.cancel() diff --git a/src/ipc_server/__init__.py b/src/ipc_server/__init__.py deleted file mode 100644 index ee4027c..0000000 --- a/src/ipc_server/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .ipc_server import * diff --git a/src/requirements.txt b/src/requirements.txt index 36b5cbd..3c073bb 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,4 +1,3 @@ -ctranslate2 faster_whisper fastapi aiofiles @@ -7,4 +6,4 @@ pydantic pydantic-settings six msgspec -whisperx \ No newline at end of file +silero_vad diff --git a/src/runners/audio_transcriber.py b/src/runners/audio_transcriber.py index aabd21f..0310a6e 100644 --- a/src/runners/audio_transcriber.py +++ b/src/runners/audio_transcriber.py @@ -1,56 +1,18 @@ -import bentoml -import ctranslate2 -from faster_whisper import WhisperModel, decode_audio +import time +from typing import List -class AudioTranscriber(bentoml.Runnable): - SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") - SUPPORTS_CPU_MULTI_THREADING = True - - def __init__(self): - device = "cuda" if ctranslate2.get_cuda_device_count() > 0 else "cpu" - compute_type = ( - "int8_float16" if ctranslate2.get_cuda_device_count() > 0 else "int8" - ) - - print(device, " ", compute_type) - - model = "medium" - self.model = WhisperModel(model, device=device, compute_type=compute_type) - - @bentoml.Runnable.method(batchable=False) - def transcribe_audio(self, audio, lang): - if len(lang) < 2: - lang = "en" - else: - lang = lang[0:2] - - segments, info = self.model.transcribe( - audio, - vad_filter=True, - vad_parameters=dict(min_silence_duration_ms=500), - language=lang, - ) - - text = "" - for segment in segments: - text += segment.text - - return {"text": text} - +import torch -import numpy as np -import io -from datetime import datetime -from faster_whisper.vad import get_speech_timestamps, collect_chunks +import string from pydantic import BaseModel +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline +from sampuru import Runnable -class BatchInput(BaseModel): - audio: bytes - pair: str - lang: str = "en" +import logging +logger = logging.getLogger(__name__) class BatchOutput(BaseModel): text: str @@ -60,220 +22,43 @@ class BatchOutput(BaseModel): restore_time: float -""" -class BatchableAudioTranscriber(bentoml.Runnable): - SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") - SUPPORTS_CPU_MULTI_THREADING = True - - def __init__(self): - device = "cuda" if ctranslate2.get_cuda_device_count() > 0 else "cpu" - compute_type = "float16" if ctranslate2.get_cuda_device_count() > 0 else "int8" +class WhisperHFRunnable(Runnable): + def __init__(self, model_id: str = "openai/whisper-tiny"): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 - print(device, " ", compute_type) - - model = "base.en" - self.model = WhisperModel(model, device=device, compute_type=compute_type) - - def transcribe(self, audios): - segments, info = self.model.transcribe( - audios, - vad_filter=False, - vad_parameters=dict(min_silence_duration_ms=250), - language="en", - condition_on_previous_text=False, - word_timestamps=True, - no_speech_threshold=10, + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + use_safetensors=True, ) - return segments - - @bentoml.Runnable.method(batchable=True) - def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: - result = [] - - # merging audio - ts = datetime.now() - - batch_list = [] - audio_batch = np.ndarray(1, dtype=np.float32) - for input in inputs: - wav = decode_audio(io.BytesIO(input.audio)) - chunks = get_speech_timestamps(wav) - if len(chunks) == 0: - batch_list.append( - BatchItem( - start_time=len(audio_batch) / 16000.0, - end_time=len(audio_batch) / 16000.0, - chunks_count=0, - ) - ) - else: - wav = collect_chunks(wav, chunks=chunks) - wav = np.append(wav, np.zeros(16000, dtype=np.float32)) - batch_list.append( - BatchItem( - start_time=len(audio_batch) / 16000.0, - end_time=(len(audio_batch) + len(wav)) / 16000.0, - chunks_count=len(chunks), - ) - ) - audio_batch = np.append(audio_batch, wav) - - for item in batch_list: - print(item) - - merge_time = (datetime.now() - ts).total_seconds() - - ts = datetime.now() - segments = self.transcribe(audio_batch) - transcribe_time = (datetime.now() - ts).total_seconds() - - ts = datetime.now() - output = [segment for segment in segments] - - for segment in output: - for word in segment.words: - for item in batch_list: - item.add(word) - - restore_time = (datetime.now() - ts).total_seconds() - - for item in batch_list: - result.append( - BatchOutput( - text=item.transcription, - batched_count=len(inputs), - merge_audio_time=merge_time, - transcribe_time=transcribe_time, - restore_time=restore_time, - ) - ) - - return result - -""" - - -from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor -import torch -from itertools import groupby + model.to(device) + processor = AutoProcessor.from_pretrained(model_id) -class BatchableAudioTranscriber(bentoml.Runnable): - SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") - SUPPORTS_CPU_MULTI_THREADING = True - - def __init__(self): - pass - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.processor = Wav2Vec2Processor.from_pretrained( - # "facebook/wav2vec2-base-960h" - "facebook/wav2vec2-large-960h-lv60-self" + self.pipe = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + torch_dtype=torch_dtype, + device=device, ) - self.model = Wav2Vec2ForCTC.from_pretrained( - # "facebook/wav2vec2-base-960h" - "facebook/wav2vec2-large-960h-lv60-self" - ).cuda() - - def transcribe(self, audios): - input_values = self.processor( - audios, return_tensors="pt", sampling_rate=16000, padding=True - ).input_values.cuda() - - with torch.no_grad(): - logits = self.model(input_values).logits - predicted_ids = torch.argmax(logits, dim=-1) - transcriptions = self.processor.batch_decode(predicted_ids) - - return transcriptions - - @bentoml.Runnable.method(batchable=True) - def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: - # merging audio - audio_batch = [] - for input in inputs: - audio_batch.append(np.frombuffer(input.audio, dtype=np.float32)) - - ts = datetime.now() - segments = self.transcribe(audio_batch) - transcribe_time = (datetime.now() - ts).total_seconds() + def forward(self, data: List[bytes]) -> List[BatchOutput]: + start = time.time() + result = self.pipe(data, batch_size=len(data)) + logger.debug(result) + transcribe_time = time.time() - start + no_punctuation = str.maketrans('', '', string.punctuation) return [ BatchOutput( - text=text, - batched_count=len(inputs), + text=r["text"].translate(no_punctuation), + batched_count=len(data), merge_audio_time=0, transcribe_time=transcribe_time, restore_time=0, ) - for text in segments - ] - - -""" -from transformers import WhisperProcessor, WhisperForConditionalGeneration -import torch - - -class BatchableAudioTranscriber(bentoml.Runnable): - SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") - SUPPORTS_CPU_MULTI_THREADING = True - - def __init__(self): - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.processor = WhisperProcessor.from_pretrained("openai/whisper-base.en") - self.model = WhisperForConditionalGeneration.from_pretrained( - "openai/whisper-base.en", attn_implementation="sdpa" - ).cuda() - - self.model.generation_config.cache_implementation = "static" - self.model.forward = torch.compile( - self.model.forward, mode="reduce-overhead", fullgraph=True - ) - - def transcribe(self, audios): - input_features = self.processor( - audios, return_tensors="pt", sampling_rate=16000, padding=True - ).input_features.cuda() - - for _ in range(2): - self.model.generate(input_features) - - predicted_ids = self.model.generate(input_features) - transcriptions = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) - print(transcriptions) - - return transcriptions - - @bentoml.Runnable.method(batchable=True) - def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: - result = [] - - # merging audio - ts = datetime.now() - audio_batch = [] - for input in inputs: - wav = decode_audio(io.BytesIO(input.audio)) - chunks = get_speech_timestamps(wav) - if len(chunks) == 0: - audio_batch.append(np.zeros(16000, dtype=np.float32)) - else: - wav = collect_chunks(wav, chunks=chunks) - audio_batch.append(wav) - - merge_time = (datetime.now() - ts).total_seconds() - - ts = datetime.now() - segments = self.transcribe(audio_batch) - transcribe_time = (datetime.now() - ts).total_seconds() - - return [ - BatchOutput( - text=text, - batched_count=len(inputs), - merge_audio_time=merge_time, - transcribe_time=transcribe_time, - restore_time=0, - ) - for text in segments + for r in result ] -""" diff --git a/src/service.py b/src/service.py index cfea8e7..c822e97 100644 --- a/src/service.py +++ b/src/service.py @@ -1,28 +1,44 @@ -import io +import asyncio +import logging import os -import bentoml -from bentoml.io import JSON, File +from threading import Thread -from stt_api import app, runner_audio_transcriber +from .stt_api import app, runner_audio_transcriber +from .ipc import run_ipc_server -import ipc_server - -svc = bentoml.Service( - "stt", - runners=[runner_audio_transcriber], +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s", + level=logging.DEBUG, + datefmt="%Y-%m-%d %H:%M:%S", ) -svc.mount_asgi_app(app) +def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + +def multiprocessing_startup(): + loop = asyncio.new_event_loop() + loop.create_task(run_ipc_server("127.0.0.1", 3015)) + t = Thread(target=start_background_loop, args=(loop,), daemon=True) + t.start() + + +@app.on_event("startup") +async def app_startup(): + loop = asyncio.get_event_loop() + loop.create_task(runner_audio_transcriber.run(loop=loop, warmup=True)) -@svc.on_deployment -def on_deployment(): - if not os.fork(): - ipc_server.start_ipc_server() + logger = logging.getLogger("uvicorn.access") + handler = logging.StreamHandler() + handler.setFormatter( + logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s") + ) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) -@svc.api(input=File(), output=JSON()) -async def process_audio(input_file: io.BytesIO): - transcript = await runner_audio_transcriber.transcribe_audio.async_run(input_file) - return transcript +multiprocessing_startup() diff --git a/src/stream_transcriber.py b/src/stream_transcriber.py index 5751291..9a1b764 100644 --- a/src/stream_transcriber.py +++ b/src/stream_transcriber.py @@ -1,5 +1,8 @@ +import torch + from faster_whisper import decode_audio -from faster_whisper.vad import get_speech_timestamps, collect_chunks, VadOptions + +from silero_vad import (get_speech_timestamps, load_silero_vad, collect_chunks) import numpy as np import io @@ -24,7 +27,8 @@ def split_speech_timestamps(speech_timestamps, buffered, split_time): timestamps = [] while speech_timestamps: timestamps.append([]) - while speech_timestamps and speech_timestamps[0]["end"] < max_offset - buffered: + while speech_timestamps and speech_timestamps[0][ + "end"] < max_offset - buffered: timestamps[-1].append(speech_timestamps.pop(0)) max_offset += secs2len(split_time) @@ -34,7 +38,8 @@ def split_speech_timestamps(speech_timestamps, buffered, split_time): class StreamTranscriber: - def __init__(self): + + def __init__(self, loop, pool): self._raw_stream_data = bytes() self._raw_stream_data_duration = 0 @@ -44,36 +49,42 @@ def __init__(self): self._speech_timestamps = [] self._last_chunk_received = False - self._vad_options = VadOptions( - min_speech_duration_ms=125, min_silence_duration_ms=125, speech_pad_ms=125 - ) + self._vad_model = load_silero_vad() + + self._min_speech_duration_ms = 125 + self._min_silence_duration_ms = 125 + self._speech_pad_ms = 125 + + self.loop = loop + self.pool = pool - def consume(self, stream_data: bytes): + async def consume(self, stream_data: bytes): self._last_chunk_received = len(stream_data) == 0 self._raw_stream_data += stream_data try: raw_audio_buffer = decode_audio(io.BytesIO(self._raw_stream_data)) - raw_audio_buffer = raw_audio_buffer[self._vad_detected_offset :] - except: + raw_audio_buffer = raw_audio_buffer[self._vad_detected_offset:] + except Exception as e: return self._raw_stream_data_duration = buf2secs(raw_audio_buffer) speech_timestamps = get_speech_timestamps( - raw_audio_buffer, vad_options=self._vad_options - ) + raw_audio_buffer, + self._vad_model, + min_speech_duration_ms=self._min_speech_duration_ms, + min_silence_duration_ms=self._min_silence_duration_ms, + speech_pad_ms=self._speech_pad_ms) if not speech_timestamps: return if not self._last_chunk_received: # remove the speech chunks which probably are not ended - while ( - speech_timestamps - and speech_timestamps[-1]["end"] - > len(raw_audio_buffer) - self._vad_options.min_silence_duration_ms * 16 - ): + while (speech_timestamps + and speech_timestamps[-1]["end"] > len(raw_audio_buffer) - + self._min_silence_duration_ms * 16): del speech_timestamps[-1] if not speech_timestamps: @@ -85,31 +96,22 @@ def consume(self, stream_data: bytes): if self._speech_audio_buffers: buffered = buf2secs(self._speech_audio_buffers[-1]) - print(speech_timestamps) - speech_timestamps = split_speech_timestamps( speech_timestamps, buffered, 5, ) - print(speech_timestamps) - for chunks in speech_timestamps: - speech = collect_chunks(raw_audio_buffer, chunks) - if ( - not self._speech_audio_buffers - or buf2secs(self._speech_audio_buffers[-1]) > 5 - ): + speech = collect_chunks( + chunks, torch.tensor(raw_audio_buffer, + dtype=torch.float32)).numpy() + if (not self._speech_audio_buffers + or buf2secs(self._speech_audio_buffers[-1]) > 5): self._speech_audio_buffers.append(speech) else: self._speech_audio_buffers[-1] = np.append( - self._speech_audio_buffers[-1], speech - ) - - [print(buf2secs(x)) for x in self._speech_audio_buffers] - - print(self._raw_stream_data_duration, len2secs(self._vad_detected_offset)) + self._speech_audio_buffers[-1], speech) def should_transcribe(self): if not self._speech_audio_buffers: diff --git a/src/stt_api.py b/src/stt_api.py index 9a6f1e0..34945dd 100644 --- a/src/stt_api.py +++ b/src/stt_api.py @@ -1,31 +1,42 @@ +import asyncio import json -import io from datetime import datetime -import bentoml -from runners.audio_transcriber import ( - BatchableAudioTranscriber, - BatchInput, -) +from concurrent.futures import ProcessPoolExecutor + +import numpy as np from fastapi import FastAPI, Request, Depends from fastapi.responses import StreamingResponse, JSONResponse from fastapi.encoders import jsonable_encoder -import utils.google_streaming.google_streaming_api_pb2 as speech -from utils.service_key.brave_service_key import check_stt_request +from sampuru import LocalRunner, BatchParameters -import utils.ipc as ipc -from stream_transcriber import StreamTranscriber +from .utils.google_streaming import google_streaming_api_pb2 as speech +from .utils.service_key.brave_service_key import check_stt_request -runner_audio_transcriber = bentoml.Runner( - BatchableAudioTranscriber, +from .ipc import ( + client as ipc_client, + messages as ipc_messages, +) +from .stream_transcriber import StreamTranscriber + +from .runners.audio_transcriber import ( + WhisperHFRunnable, +) + +runner_audio_transcriber = LocalRunner( + WhisperHFRunnable, name="audio_transcriber", - max_batch_size=32, + runnable_init_params={ + "model_id": "openai/whisper-base", + }, + batch_params=BatchParameters(max_size=32), ) +vad_executor_pool = ProcessPoolExecutor() -def TextToProtoMessage(text: ipc.messages.Text): +def TextToProtoMessage(text: ipc_messages.Text): event = speech.SpeechRecognitionEvent() rr = speech.SpeechRecognitionResult() rr.stability = 1.0 @@ -34,7 +45,7 @@ def TextToProtoMessage(text: ipc.messages.Text): event.result.append(rr) proto = event.SerializeToString() - return len(proto).to_bytes(4, signed=False) + proto + return len(proto).to_bytes(4, "big", signed=False) + proto app = FastAPI() @@ -50,7 +61,8 @@ async def handleUpstream( pair: str, request: Request, lang: str = "en", - is_valid_brave_key=Depends(check_stt_request), + # is_valid_brave_key=Depends(check_stt_request), + is_valid_brave_key=True, ): if not is_valid_brave_key: return JSONResponse( @@ -60,42 +72,34 @@ async def handleUpstream( try: mic_data = bytes() text = "" - stream = StreamTranscriber() - async with ipc.client.Publisher(pair) as pipe: + stream = StreamTranscriber(asyncio.get_event_loop(), vad_executor_pool) + async with ipc_client.Publisher(pair) as pipe: try: async for chunk in request.stream(): - stream.consume(chunk) + await stream.consume(chunk) while stream.should_transcribe(): process_time = datetime.now() - transciption = await runner_audio_transcriber.async_run( + transciption = await runner_audio_transcriber.forward( [ - BatchInput( - audio=stream.get_speech_audio(), - lang=lang, - pair=pair, - ) + { + "raw": np.frombuffer( + stream.get_speech_audio(), dtype=np.float32 + ), + "sampling_rate": 16000, + } ] ) process_time = (datetime.now() - process_time).total_seconds() out = transciption[0] - print( - pair, - " : ", - out.batched_count, - "", - out.merge_audio_time, - " ", - out.transcribe_time, - " ", - out.restore_time, - ) + # print( pair, " : ", out.batched_count, "", out.merge_audio_time, " ", out.transcribe_time, " ", out.restore_time,) if out.text: + print(out.text.lower()) text += out.text.lower() + " " await pipe.push( - ipc.messages.Text( + ipc_messages.Text( text, False, len(mic_data), @@ -108,7 +112,7 @@ async def handleUpstream( finally: if text: - await pipe.push(ipc.messages.Text(text, True)) + await pipe.push(ipc_messages.Text(text, True)) except Exception as e: raise @@ -121,7 +125,9 @@ async def handleUpstream( @app.get("/down") async def handleDownstream( - pair: str, output: str = "pb", is_valid_brave_key=Depends(check_stt_request) + pair: str, + output: str = "pb", # is_valid_brave_key=Depends(check_stt_request) + is_valid_brave_key=True, ): if not is_valid_brave_key: return JSONResponse( @@ -130,7 +136,7 @@ async def handleDownstream( async def handleStream(pair): try: - async with ipc.client.Subscriber(pair) as pipe: + async with ipc_client.Subscriber(pair) as pipe: while True: text = await pipe.pull() if not text: diff --git a/src/utils/ipc/__init__.py b/src/utils/ipc/__init__.py deleted file mode 100644 index 54001d8..0000000 --- a/src/utils/ipc/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import client -from . import messages