Skip to content

Commit

Permalink
Formatting. Folders.
Browse files Browse the repository at this point in the history
  • Loading branch information
boocmp committed Jun 7, 2024
1 parent 1b5a219 commit a0b45a4
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 135 deletions.
2 changes: 1 addition & 1 deletion env/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloa
RUN set -eux && \
apt-get update -y && \
apt-get install -q -y --no-install-recommends --allow-remove-essential \
ca-certificates gnupg2 bash build-essential
ca-certificates gnupg2 bash build-essential

RUN \
set -eux && \
Expand Down
1 change: 1 addition & 0 deletions src/ipc_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ipc_server import *
57 changes: 21 additions & 36 deletions src/utils/ipc/server.py → src/ipc_server/ipc_server.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import asyncio

if __name__ == '__main__':
import messages
else:
from . import messages
from utils.ipc import messages

Publishers: dict[str, asyncio.StreamReader] = {}
Subscribers: dict[str, bool] = {}

PublisherAppear = asyncio.Condition()
SubscriberAppear = asyncio.Condition()

async def AddPublisher(
pair: str,
reader: asyncio.StreamReader
):

async def AddPublisher(pair: str, reader: asyncio.StreamReader):
async with PublisherAppear:
if pair in Publishers:
Publishers[pair] = None
Expand All @@ -23,6 +17,7 @@ async def AddPublisher(

PublisherAppear.notify_all()


async def RemovePublisher(pair: str):
async with PublisherAppear:
if pair in Publishers:
Expand All @@ -34,44 +29,39 @@ async def AddSubscriber(pair: str):
Subscribers[pair] = pair not in Subscribers
SubscriberAppear.notify_all()


async def RemoveSubscriber(pair: str):
async with SubscriberAppear:
if pair in Subscribers:
del Subscribers[pair]


async def wait_for_publisher(
pair: str,
timeout: float = 10.0
):
async def wait_for_publisher(pair: str, timeout: float = 10.0):
async def waiter():
async with PublisherAppear:
await PublisherAppear.wait_for(lambda: pair in Publishers)
return Publishers[pair]

return await asyncio.wait_for(waiter(), timeout)


async def wait_for_subscriber(pair: str, timeout: float = 10.0):
async def waiter():
async with SubscriberAppear:
await SubscriberAppear.wait_for(lambda: pair in Subscribers)
return Subscribers[pair]

return await asyncio.wait_for(waiter(), timeout)


async def pipe(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter
):
async def pipe(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
while not reader.at_eof():
writer.write(await reader.read(1024))
await writer.drain()


async def handle_publish(
pair: str,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter
pair: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
await AddPublisher(pair, reader)
try:
Expand All @@ -81,12 +71,9 @@ async def handle_publish(
except Exception as e:
await RemovePublisher(pair)
writer.close()


async def handle_subscribe(
pair: str,
writer: asyncio.StreamWriter
):

async def handle_subscribe(pair: str, writer: asyncio.StreamWriter):
await AddSubscriber(pair)

try:
Expand All @@ -101,10 +88,7 @@ async def handle_subscribe(
writer.close()


async def handle_connection(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter
):
async def handle_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
try:
req = await messages.receive_request(reader)
match req:
Expand All @@ -117,15 +101,16 @@ async def handle_connection(
except Exception as e:
writer.close()

HOST, PORT = "localhost", 3015

async def run_ipc_server():
server = await asyncio.start_server(handle_connection, HOST, PORT)
async def run_ipc_server(host, port):
server = await asyncio.start_server(handle_connection, host, port)
async with server:
await server.serve_forever()

def start_ipc_server():
asyncio.run(run_ipc_server())

if __name__ == '__main__':
start_ipc_server()
def start_ipc_server(host="localhost", port=3015):
asyncio.run(run_ipc_server(host, port))


if __name__ == "__main__":
start_ipc_server()
13 changes: 8 additions & 5 deletions src/runners/audio_transcriber.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import bentoml
import ctranslate2
from faster_whisper import WhisperModel
import io
import numpy as np


class AudioTranscriber(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True

def __init__(self):
device = "cuda" if ctranslate2.get_cuda_device_count() > 0 else "cpu"
compute_type = "int8_float16" if ctranslate2.get_cuda_device_count() > 0 else "int8"
compute_type = (
"int8_float16" if ctranslate2.get_cuda_device_count() > 0 else "int8"
)

print(device, " ", compute_type)

Expand All @@ -19,10 +20,12 @@ def __init__(self):

@bentoml.Runnable.method(batchable=False)
def transcribe_audio(self, audio):
segments, info = self.model.transcribe(audio, vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500))
segments, info = self.model.transcribe(
audio, vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500)
)

text = ""
for segment in segments:
text += segment.text

return { "text" : text }
return {"text": text}
5 changes: 3 additions & 2 deletions src/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

from stt_api import app, runner_audio_transcriber

from utils.ipc import server
from ipc_server import server

svc = bentoml.Service(
"stt",
runners=[ runner_audio_transcriber ],
runners=[runner_audio_transcriber],
)

svc.mount_asgi_app(app)


@svc.on_deployment
def on_deployment():
if not os.fork():
Expand Down
30 changes: 17 additions & 13 deletions src/stt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import io

import bentoml
from runners.audio_transcriber import AudioTranscriber
from runners.audio_transcriber import AudioTranscriber

from fastapi import FastAPI, Request, Depends, Cookie
from fastapi.responses import StreamingResponse, JSONResponse, Response
Expand Down Expand Up @@ -34,6 +34,7 @@ def to_bytes(self):
message = self._event.SerializeToString()
return len(message).to_bytes(4, signed=False) + message


app = FastAPI()


Expand All @@ -44,12 +45,12 @@ async def handleSticky():

@app.post("/up")
async def handleUpstream(
pair: str,
request: Request,
is_valid_brave_key = Depends(check_stt_request)
pair: str, request: Request, is_valid_brave_key=Depends(check_stt_request)
):
if not is_valid_brave_key:
return JSONResponse(content = jsonable_encoder({ "status" : "Invalid Brave Service Key" }))
return JSONResponse(
content=jsonable_encoder({"status": "Invalid Brave Service Key"})
)

try:
mic_data = bytes()
Expand All @@ -62,18 +63,21 @@ async def handleUpstream(
await pipe.push(ipc.messages.Text(text["text"], False))

except Exception as e:
return JSONResponse(content = jsonable_encoder({ "status" : "exception", "exception" : str(e) }) )
return JSONResponse(
content=jsonable_encoder({"status": "exception", "exception": str(e)})
)

return JSONResponse(content=jsonable_encoder({"status": "ok"}))

return JSONResponse(content = jsonable_encoder({ "status" : "ok" }))

@app.get("/down")
async def handleDownstream(
pair: str,
output: str = "pb",
is_valid_brave_key = Depends(check_stt_request)
pair: str, output: str = "pb", is_valid_brave_key=Depends(check_stt_request)
):
if not is_valid_brave_key:
return JSONResponse(content = jsonable_encoder({ "status" : "Invalid Brave Service Key" }))
return JSONResponse(
content=jsonable_encoder({"status": "Invalid Brave Service Key"})
)

async def handleStream(pair):
try:
Expand All @@ -88,8 +92,8 @@ async def handleStream(pair):

yield event.to_bytes()
else:
yield json.dumps({ "text" : r.text })
yield json.dumps({"text": r.text})
except Exception as e:
yield json.dumps({ "exception" : str(e)})
yield json.dumps({"exception": str(e)})

return StreamingResponse(handleStream(pair))
55 changes: 55 additions & 0 deletions src/tests/ipc_test/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio

from utils.ipc.client import Publisher, Subscriber
from utils.ipc import messages
from ipc_server import run_ipc_server


async def publisher(pair):
async def op():
async with Publisher(pair) as pipe:
for i in range(0, 30):
await pipe.push(messages.Text(f"{pair} -> {i}", False))
await asyncio.sleep(1)

try:
await asyncio.wait_for(op(), 3)
except Exception as e:
print(e)
pass


async def subscriber(pair):
try:
async with Subscriber(pair) as pipe:
while True:
r = await pipe.pull()
if r is None:
break
print(r)

except asyncio.IncompleteReadError:
pass
except Exception as e:
print(e)
pass


async def batch(pair):
try:
await asyncio.gather(subscriber(pair), publisher(pair))
except Exception as e:
print(e)
pass


async def main():
tasks = [ asyncio.create_task(run_ipc_server("localhost", 3015))]
for i in range(20):
tasks.append(asyncio.create_task(batch(str(i))))

for t in tasks:
await t


asyncio.run(main())
Empty file added src/utils/__init__.py
Empty file.
6 changes: 4 additions & 2 deletions src/utils/config/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pydantic import Field
from pydantic_settings import BaseSettings


class AppSettings(BaseSettings):
master_services_key_seed: str = Field('dummy')
master_services_key_seed: str = Field("dummy")


app_settings = AppSettings()
app_settings = AppSettings()
3 changes: 0 additions & 3 deletions src/utils/ipc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from . import messages
from . import client
from . import server
Loading

0 comments on commit a0b45a4

Please sign in to comment.