Skip to content

Commit

Permalink
initial working api
Browse files Browse the repository at this point in the history
  • Loading branch information
rlellep committed Oct 26, 2022
1 parent 666d3c2 commit ed09e9a
Show file tree
Hide file tree
Showing 13 changed files with 194 additions and 93 deletions.
4 changes: 2 additions & 2 deletions app/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .auth import get_username
from .enums import Language, State
from .schemas import JobInfo, Result, WorkerResponse, ErrorMessage
from .enums import Speaker, State
from .schemas import JobInfo, Result, ErrorMessage #, WorkerResponse
from .routers import router
12 changes: 10 additions & 2 deletions app/api/enums.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from enum import Enum
from socket import VM_SOCKETS_INVALID_VERSION


class Language(str, Enum):
ESTONIAN = "et"
class Speaker(str, Enum):
Albert = "albert"
Mari = "mari"
Kalev = "kalev"
Vesta = "vesta"
Kylli = "kylli"
Külli = "kylli"
Meelis = "meelis"



class State(str, Enum):
Expand Down
107 changes: 63 additions & 44 deletions app/api/routers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from curses import meta
from email.policy import default
import os
import re
import json
from tokenize import String
import uuid
import logging

import aiofiles
import aiofiles.os
from app.api.enums import Speaker
from fastapi import APIRouter, Depends, File, UploadFile, Form, HTTPException, Header, Response
from starlette.responses import FileResponse
from sqlalchemy.ext.asyncio import AsyncSession

from app import database, api_settings
from app.api import JobInfo, Result, WorkerResponse, Language, State, get_username, ErrorMessage
from app.api import JobInfo, State, get_username, ErrorMessage #, Result, WorkerResponse
from app.rabbitmq import publish

FILENAME_RE = re.compile(r"[\w\- ]+\.wav|blob")
FILENAME_RE = re.compile(r"[\w\- ]+\.epub|blob")

LOGGER = logging.getLogger(__name__)

Expand All @@ -33,90 +38,104 @@ def check_uuid(job_id: str):


@router.post('/', response_model=JobInfo, response_model_exclude_none=True,
description="Submit a new ASR job.", status_code=202,
description="Submit a new ebook job.", status_code=202,
responses={400: {"model": ErrorMessage}})
async def create_job(response: Response,
file: UploadFile = File(..., media_type="audio/wav"),
language: Language = Form(default=Language.ESTONIAN, # todo from config
description="Input language ISO 2-letter code."),
file: UploadFile = File(..., media_type="application/epub+zip"),
speaker: Speaker = Form(default=Speaker.Mari),
speed: float = Form(default=1.0),
session: AsyncSession = Depends(database.get_session)):
if file.content_type != "audio/wav":
if file.content_type != "application/epub+zip":
raise HTTPException(400, "Unsupported file type")

if not FILENAME_RE.fullmatch(file.filename):
LOGGER.debug(f"Bad filename: {file.filename}")
raise HTTPException(400, "Filename contains unsuitable characters "
"(allowed: letters, numbers, spaces, undescores) "
"or does not end with '.wav'")
"or does not end with '.epub'")

job_id = str(uuid4())
filename = file.filename

async with aiofiles.open(os.path.join(api_settings.storage_path, f"{job_id}.wav"), 'wb') as out_file:
async with aiofiles.open(os.path.join(api_settings.storage_path, f"{job_id}.epub"), 'wb') as out_file:
content = await file.read()
await out_file.write(content)

async with aiofiles.open(os.path.join(api_settings.storage_path, f"{job_id}.txt"), 'w') as out_file:
await out_file.write('')

job_info = await database.create_job(session, job_id, filename, language)
await publish(job_id, file_extension="wav", language=language)
job_info = await database.create_job(session, job_id, filename, speaker, speed)
await publish(job_id, file_extension="epub")

response.headers['Content-Disposition'] = 'attachment; filename="api.json"'
return job_info


@router.get('/{job_id}', response_model=Result, response_model_exclude_none=True,
@router.get('/{job_id}', response_model=JobInfo, response_model_exclude_none=True,
responses={404: {"model": ErrorMessage},
400: {"model": ErrorMessage}},
dependencies=[Depends(check_uuid)])
async def get_job_info(response: Response, job_id: str, session: AsyncSession = Depends(database.get_session)):
async def get_job_info(job_id: str, session: AsyncSession = Depends(database.get_session)):
return await database.read_job(session, job_id)


@router.get('/{job_id}/audiobook', response_class=FileResponse,
responses={404: {"model": ErrorMessage},
400: {"model": ErrorMessage},
200: {"content": {"application/zip": {}}, "description": "Returns the original audio file."}
},
dependencies=[Depends(check_uuid)])
async def get_audiobook(job_id: str, session: AsyncSession = Depends(database.get_session)):
job_info = await database.read_job(session, job_id)
if job_info.state in [State.IN_PROGRESS, State.COMPLETED]:
async with aiofiles.open(os.path.join(api_settings.storage_path, f"{job_id}.txt"), 'r') as file:
content = await file.read()
job_info.transcription = content.strip()
response.headers['Content-Disposition'] = 'attachment; filename="api.json"'
return job_info
file_path = os.path.join(api_settings.storage_path, f"{job_id}.zip")
if job_info.state == State.COMPLETED and os.path.exists(file_path):
await database.update_job(session, job_id, State.EXPIRED)
return FileResponse(file_path, filename=f"{job_id}.zip")


@router.get('/{job_id}/audio', response_class=FileResponse,
@router.get('/{job_id}/epub', response_class=FileResponse,
responses={
404: {"model": ErrorMessage},
200: {"content": {"audio/wav": {}}, "description": "Returns the original audio file."}
200: {"content": {"application/epub+zip": {}}, "description": "Returns the original audio file."}
},
dependencies=[Depends(check_uuid)])
async def get_audio(job_id: str, _: str = Depends(get_username),
async def get_epub(job_id: str, _: str = Depends(get_username),
session: AsyncSession = Depends(database.get_session)):
job_info = await database.read_job(session, job_id)
if job_info.state in [State.QUEUED, State.IN_PROGRESS]:
await database.update_job(session, job_id, State.IN_PROGRESS)
return FileResponse(os.path.join(api_settings.storage_path, f"{job_id}.wav"), filename=f"{job_id}.wav")
return FileResponse(os.path.join(api_settings.storage_path, f"{job_id}.epub"), filename=f"{job_id}.epub")


@router.post('/{job_id}/transcription',
responses={404: {"model": ErrorMessage},
409: {"model": ErrorMessage},
422: {"model": ErrorMessage}},
@router.post('/{job_id}/failed', response_model=JobInfo, response_model_exclude_none=True,
description="Post error message and fail job.", status_code=202,
responses={400: {"model": ErrorMessage}},
dependencies=[Depends(check_uuid)])
async def submit_transcription(job_id: str,
result: WorkerResponse,
content_type: str = Header(...),
async def submit_audiobook(job_id: str,
error: str = Form(default="Failed to synthesize audiobook."),
_: str = Depends(get_username),
session: AsyncSession = Depends(database.get_session)):
if content_type != "application/json":
raise HTTPException(422, "Unsupported content type.")
job_info = await database.read_job(session, job_id)
if job_info.state == State.IN_PROGRESS:
await database.update_job(session, job_id, State.ERROR, error_message=error)
else: # HTTP 409 - conflict
raise HTTPException(409, f"Job '{job_id}' is not in progress. Current state: {job_info.state}")
return await database.read_job(session, job_id)


@router.post('/{job_id}/audiobook', response_model=JobInfo, response_model_exclude_none=True,
description="Post audiobook and complete job.", status_code=202,
responses={400: {"model": ErrorMessage}},
dependencies=[Depends(check_uuid)])
async def submit_audiobook(job_id: str,
file: UploadFile = File(..., media_type="application/zip"),
_: str = Depends(get_username),
session: AsyncSession = Depends(database.get_session)):
if file.content_type != "application/zip":
raise HTTPException(422, "Unsupported content type.")
job_info = await database.read_job(session, job_id)
if job_info.state == State.IN_PROGRESS:
if result.success:
async with aiofiles.open(os.path.join(api_settings.storage_path, f"{job_id}.txt"), 'a') as out_file:
await out_file.write(result.result)
if result.final:
await database.update_job(session, job_id, State.COMPLETED)
else: # update timestamp only
await database.update_job(session, job_id, State.IN_PROGRESS)
else:
await database.update_job(session, job_id, State.ERROR, result.result)
async with aiofiles.open(os.path.join(api_settings.storage_path, f"{job_id}.zip"), 'wb') as out_file:
content = await file.read()
await out_file.write(content)
await database.update_job(session, job_id, State.COMPLETED)
else: # HTTP 409 - conflict
raise HTTPException(409, f"Job '{job_id}' is not in progress. Current state: {job_info.state}")
return await database.read_job(session, job_id)
48 changes: 25 additions & 23 deletions app/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from doctest import Example
import imp
from pydoc import describe
from typing import Optional
from datetime import datetime

from pydantic import BaseModel, Field

from app.api import Language, State
#from fastapi import UploadFile, Form, File

from app.api import State, Speaker


class ErrorMessage(BaseModel):
Expand All @@ -15,35 +20,32 @@ class JobInfo(BaseModel):
description="Randomly generated job UUID.",
example="08d99935-6ffd-4780-870a-d6f0cc863d77")
created_at: datetime = Field(...,
description="Job creation time.")
description="Job creation time.")
updated_at: datetime = Field(...,
description="Last state change time.")
language: Language = Field(Language.ESTONIAN,
description="Input language ISO 2-letter code.")
description="Last state change time.")
speaker: str = Field(...,
description="Speaker voice to synthesize with",
example=Speaker.Mari)
speed: float = Field(...,
description="Speed to synthesize with.",
example=Speaker.Mari)
file_name: str = Field(...,
description="Original name of the uploaded file",
example="audio.wav")
description="Original name of the uploaded file",
example="book.epub")
state: str = Field(...,
description="Job state.",
example=State.QUEUED)
description="Job state.",
example=State.QUEUED)
error_message: Optional[str] = Field(None,
description="Error message for failed job.",
example="Parsing error.")

class Config:
orm_mode = True


class Result(JobInfo):
error_message: Optional[str] = Field(None,
description="An optional human-readable error message.")
transcription: str = Field(None,
description="Transcribed text.",
example="Tere!")


class WorkerResponse(BaseModel):
result: str = Field(...,
description="Transcribed text segment or an error message in case ASR was not successful. "
"In case the transcription is sent in multiple parts, "
"only new segments should be sent.",
example="Tere!")
success: bool = Field(True, description="Boolean value to show whether ASR was successful")
final: bool = Field(True, description="Value to show whether the final part was sent.")
description="An optional human-readable error message.")
audiobook: str = Field(None,
description="Synthesized audiobook zip file name.",
example="book.zip")
7 changes: 4 additions & 3 deletions app/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ipaddress import ip_address
import logging
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -9,9 +10,9 @@
from app.cleanup import cleanup

app = FastAPI(
title="ASR Service",
version=api_settings.version,
description="A service that performs automatic speech recognition (ASR) on uploaded audio files."
title="epub-api",
version='2.1.0',#api_settings.version,
description="A service that performs text-to-speech on uploaded epub audio book."
)

app.add_middleware(
Expand Down
5 changes: 3 additions & 2 deletions app/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ async def _run():
LOGGER.debug(f"Removing files for: {expired_jobs}")

for job_id in expired_jobs:
await aiofiles.os.remove(os.path.join(api_settings.storage_path, f"{job_id}.txt"))
await aiofiles.os.remove(os.path.join(api_settings.storage_path, f"{job_id}.wav"))
await aiofiles.os.remove(os.path.join(api_settings.storage_path, f"{job_id}.zip"))
await aiofiles.os.remove(os.path.join(api_settings.storage_path, f"{job_id}.json"))
await aiofiles.os.remove(os.path.join(api_settings.storage_path, f"{job_id}.epub"))

LOGGER.info("Cleanup finished.")

Expand Down
16 changes: 8 additions & 8 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@


class APISettings(BaseSettings):
version: str
username: str = 'user'
password: str = 'pass'
cleanup_interval: int = 60 # 1 minute - run db & file cleanup
version: str = '2.1.0'
username: str = 'guest'
password: str = 'guest'
cleanup_interval: int = 600 # 1 minute - run db & file cleanup
expiration_threshold: int = 1200 # 20 minutes - expire / cancel jobs without updates
removal_threshold: int = 86400 # 24 h - remove db records after expiration / cancellation

Expand All @@ -20,7 +20,7 @@ class MQSettings(BaseSettings):
port: int = 5672
username: str = 'guest'
password: str = 'guest'
exchange: str = 'speech-to-text'
exchange: str = 'epub_to_audiobook'
timeout: int = 1200 # 20 minutes

class Config:
Expand All @@ -30,9 +30,9 @@ class Config:
class DBSettings(BaseSettings):
host: str = 'localhost'
port: int = 3306
username: str = 'user'
password: str = 'pass'
database: str = 'speech_to_text'
username: str = 'guest'
password: str = 'guest'
database: str = 'epub_to_audiobook'

class Config:
env_prefix = 'mysql_'
Expand Down
5 changes: 3 additions & 2 deletions app/database/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
LOGGER = logging.getLogger(__name__)


async def create_job(session: AsyncSession, job_id: str, file_name: str, language: str) -> schemas.JobInfo:
async def create_job(session: AsyncSession, job_id: str, file_name: str, speaker: str, speed: float) -> schemas.JobInfo:
job_info = Job(
job_id=job_id,
file_name=file_name,
language=language,
speaker=speaker,
speed=speed,
state=State.QUEUED
)

Expand Down
6 changes: 4 additions & 2 deletions app/database/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlalchemy import Column, String, DateTime, Enum
from email.policy import default
from sqlalchemy import Column, String, Float, DateTime, Enum
from sqlalchemy.sql import func, text

from app.database import Base
Expand All @@ -11,8 +12,9 @@ class Job(Base):
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
updated_at = Column(DateTime(timezone=True), nullable=False,
server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'))
language = Column(String(8), default="et", nullable=False)
file_name = Column(String(255), nullable=False)
speaker = Column(String(255), default="mari", nullable=False)
speed = Column(Float(), default=1.0, nullable=False)
state = Column(Enum(State), nullable=False)
error_message = Column(String(255), default=None, nullable=True)

Expand Down
8 changes: 4 additions & 4 deletions app/rabbitmq/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
LOGGER = logging.getLogger(__name__)


async def publish(correlation_id: str, file_extension: str, language: str):
async def publish(correlation_id: str, file_extension: str):
body = json.dumps({"correlation_id": correlation_id,
"file_extension": file_extension}).encode()
message = Message(
Expand All @@ -20,10 +20,10 @@ async def publish(correlation_id: str, file_extension: str, language: str):
)

try:
await mq_session.exchange.publish(message, routing_key=f"{mq_settings.exchange}.{language}")
await mq_session.exchange.publish(message, routing_key=f"{mq_settings.exchange}")
except Exception as e:
LOGGER.exception(e)
LOGGER.info("Attempting to restore the channel.")
await mq_session.channel.reopen()
await mq_session.exchange.publish(message, routing_key=f"{mq_settings.exchange}.{language}")
LOGGER.info(f"Sent request: {{id: {correlation_id}, routing_key: {mq_settings.exchange}.{language}}}")
await mq_session.exchange.publish(message, routing_key=f"{mq_settings.exchange}")
LOGGER.info(f"Sent request: {{id: {correlation_id}, routing_key: {mq_settings.exchange}}}")
Loading

0 comments on commit ed09e9a

Please sign in to comment.