From 33b5cbcb9216869c69e1dae229b9dd66b767ff9e Mon Sep 17 00:00:00 2001 From: Robin Ole Heinemann Date: Thu, 28 Dec 2023 19:36:28 +0100 Subject: [PATCH] add endpoint to generate webvtt export --- backend/openapi-schema.yml | 176 ++++++++++++++++++ backend/pdm.lock | 25 ++- backend/pyproject.toml | 1 + backend/transcribee_backend/config.py | 3 + backend/transcribee_backend/db/__init__.py | 10 + backend/transcribee_backend/models/task.py | 11 +- .../transcribee_backend/routers/document.py | 53 +++++- .../util/redis_task_channel.py | 21 +++ frontend/src/openapi-schema.ts | 119 +++++++++++- packaging/Procfile | 1 + proto/transcribee_proto/api.py | 22 ++- shell.nix | 4 + .../webvtt/export_webvtt.py | 140 ++++++++++++++ .../webvtt/webvtt_writer.py | 171 +++++++++++++++++ worker/transcribee_worker/worker.py | 37 ++++ 15 files changed, 787 insertions(+), 7 deletions(-) create mode 100644 backend/transcribee_backend/util/redis_task_channel.py create mode 100644 worker/transcribee_worker/webvtt/export_webvtt.py create mode 100644 worker/transcribee_worker/webvtt/webvtt_writer.py diff --git a/backend/openapi-schema.yml b/backend/openapi-schema.yml index 7730d02c..14687465 100644 --- a/backend/openapi-schema.yml +++ b/backend/openapi-schema.yml @@ -327,6 +327,69 @@ components: - has_full_access title: DocumentWithAccessInfo type: object + ExportError: + properties: + error: + title: Error + type: string + required: + - error + title: ExportError + type: object + ExportFormat: + description: An enumeration. + enum: + - VTT + - SRT + title: ExportFormat + type: string + ExportResult: + properties: + result: + title: Result + type: string + required: + - result + title: ExportResult + type: object + ExportTask: + properties: + document_id: + format: uuid + title: Document Id + type: string + task_parameters: + $ref: '#/components/schemas/ExportTaskParameters' + task_type: + default: EXPORT + enum: + - EXPORT + title: Task Type + type: string + required: + - task_parameters + - document_id + title: ExportTask + type: object + ExportTaskParameters: + properties: + format: + $ref: '#/components/schemas/ExportFormat' + include_speaker_names: + title: Include Speaker Names + type: boolean + include_word_timing: + title: Include Word Timing + type: boolean + max_line_length: + title: Max Line Length + type: integer + required: + - format + - include_speaker_names + - include_word_timing + title: ExportTaskParameters + type: object HTTPValidationError: properties: detail: @@ -529,6 +592,7 @@ components: - TRANSCRIBE - ALIGN - REENCODE + - EXPORT title: TaskType type: string TranscribeTask: @@ -887,6 +951,57 @@ paths: $ref: '#/components/schemas/HTTPValidationError' description: Validation Error summary: Update Document + /api/v1/documents/{document_id}/add_export_result/: + post: + operationId: add_export_result_api_v1_documents__document_id__add_export_result__post + parameters: + - in: path + name: document_id + required: true + schema: + format: uuid + title: Document Id + type: string + - in: query + name: task_id + required: true + schema: + title: Task Id + type: string + - in: header + name: authorization + required: false + schema: + title: Authorization + type: string + - in: header + name: Share-Token + required: false + schema: + title: Share-Token + type: string + requestBody: + content: + application/json: + schema: + anyOf: + - $ref: '#/components/schemas/ExportResult' + - $ref: '#/components/schemas/ExportError' + title: Result + required: true + responses: + '200': + content: + application/json: + schema: {} + description: Successful Response + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + description: Validation Error + summary: Add Export Result /api/v1/documents/{document_id}/add_media_file/: post: operationId: add_media_file_api_v1_documents__document_id__add_media_file__post @@ -924,6 +1039,66 @@ paths: $ref: '#/components/schemas/HTTPValidationError' description: Validation Error summary: Add Media File + /api/v1/documents/{document_id}/export/: + get: + operationId: export_api_v1_documents__document_id__export__get + parameters: + - in: path + name: document_id + required: true + schema: + format: uuid + title: Document Id + type: string + - in: query + name: format + required: true + schema: + $ref: '#/components/schemas/ExportFormat' + - in: query + name: include_speaker_names + required: true + schema: + title: Include Speaker Names + type: boolean + - in: query + name: include_word_timing + required: true + schema: + title: Include Word Timing + type: boolean + - in: query + name: max_line_length + required: false + schema: + title: Max Line Length + type: integer + - in: header + name: authorization + required: false + schema: + title: Authorization + type: string + - in: header + name: Share-Token + required: false + schema: + title: Share-Token + type: string + responses: + '200': + content: + text/plain: + schema: + type: string + description: Successful Response + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + description: Validation Error + summary: Export /api/v1/documents/{document_id}/set_duration/: post: operationId: set_duration_api_v1_documents__document_id__set_duration__post @@ -1211,6 +1386,7 @@ paths: - $ref: '#/components/schemas/SpeakerIdentificationTask' - $ref: '#/components/schemas/TranscribeTask' - $ref: '#/components/schemas/AlignTask' + - $ref: '#/components/schemas/ExportTask' - $ref: '#/components/schemas/UnknownTask' title: Task required: true diff --git a/backend/pdm.lock b/backend/pdm.lock index 0baa4114..bb43cb65 100644 --- a/backend/pdm.lock +++ b/backend/pdm.lock @@ -73,6 +73,12 @@ dependencies = [ "typing-extensions>=4.0.0; python_version < \"3.11\"", ] +[[package]] +name = "async-timeout" +version = "4.0.3" +requires_python = ">=3.7" +summary = "Timeout context manager for asyncio programs" + [[package]] name = "attrs" version = "23.1.0" @@ -987,6 +993,15 @@ dependencies = [ "packaging", ] +[[package]] +name = "redis" +version = "5.0.1" +requires_python = ">=3.7" +summary = "Python client for Redis database and key-value store" +dependencies = [ + "async-timeout>=4.0.2; python_full_version <= \"3.11.2\"", +] + [[package]] name = "referencing" version = "0.31.0" @@ -1284,7 +1299,7 @@ summary = "Jupyter interactive widgets for Jupyter Notebook" [metadata] lock_version = "4.1" -content_hash = "sha256:b9412fb0704ebdac6028e09e8e76d0e12d606505a34f92fd9a701fc9cc1a892f" +content_hash = "sha256:15954986dfbaa2cc507675379441f8b6682f69561b69cc87cb8dcd97489c8303" [metadata.files] "alembic 1.12.1" = [ @@ -1338,6 +1353,10 @@ content_hash = "sha256:b9412fb0704ebdac6028e09e8e76d0e12d606505a34f92fd9a701fc9c {url = "https://files.pythonhosted.org/packages/80/e2/2b4651eff771f6fd900d233e175ddc5e2be502c7eb62c0c42f975c6d36cd/async-lru-2.0.4.tar.gz", hash = "sha256:b8a59a5df60805ff63220b2a0c5b5393da5521b113cd5465a44eb037d81a5627"}, {url = "https://files.pythonhosted.org/packages/fa/9f/3c3503693386c4b0f245eaf5ca6198e3b28879ca0a40bde6b0e319793453/async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224"}, ] +"async-timeout 4.0.3" = [ + {url = "https://files.pythonhosted.org/packages/87/d6/21b30a550dafea84b1b8eee21b5e23fa16d010ae006011221f33dcd8d7f8/async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {url = "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] "attrs 23.1.0" = [ {url = "https://files.pythonhosted.org/packages/97/90/81f95d5f705be17872843536b1868f351805acf6971251ff07c1b8334dbb/attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, {url = "https://files.pythonhosted.org/packages/f0/eb/fcb708c7bf5056045e9e98f62b93bd7467eb718b0202e7698eb11d66416c/attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, @@ -2577,6 +2596,10 @@ content_hash = "sha256:b9412fb0704ebdac6028e09e8e76d0e12d606505a34f92fd9a701fc9c {url = "https://files.pythonhosted.org/packages/7e/a9/2146d5117ad8a81185331e0809a6b48933c10171f5bac253c6df9fce991c/QtPy-2.4.1-py3-none-any.whl", hash = "sha256:1c1d8c4fa2c884ae742b069151b0abe15b3f70491f3972698c683b8e38de839b"}, {url = "https://files.pythonhosted.org/packages/eb/9a/7ce646daefb2f85bf5b9c8ac461508b58fa5dcad6d40db476187fafd0148/QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"}, ] +"redis 5.0.1" = [ + {url = "https://files.pythonhosted.org/packages/0b/34/a01250ac1fc9bf9161e07956d2d580413106ce02d5591470130a25c599e3/redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"}, + {url = "https://files.pythonhosted.org/packages/4a/4c/3c3b766f4ecbb3f0bec91ef342ee98d179e040c25b6ecc99e510c2570f2a/redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"}, +] "referencing 0.31.0" = [ {url = "https://files.pythonhosted.org/packages/29/c1/69342fbc8efd1aac5cda853cea771763b95d92325c4f8f83b499c07bc698/referencing-0.31.0-py3-none-any.whl", hash = "sha256:381b11e53dd93babb55696c71cf42aef2d36b8a150c49bf0bc301e36d536c882"}, {url = "https://files.pythonhosted.org/packages/61/11/5e947c3f2a73e7fb77fd1c3370aa04e107f3c10ceef4880c2e25ef19679c/referencing-0.31.0.tar.gz", hash = "sha256:cc28f2c88fbe7b961a7817a0abc034c09a1e36358f82fedb4ffdf29a25398863"}, diff --git a/backend/pyproject.toml b/backend/pyproject.toml index b8805275..93884c98 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -10,6 +10,7 @@ authors = [ ] dependencies = [ + "redis>=5.0.1", "fastapi>=0.92.0", "uvicorn[standard]>=0.20.0", "sqlmodel @ git+https://github.com/transcribee/sqlmodel.git@transcribee_main", diff --git a/backend/transcribee_backend/config.py b/backend/transcribee_backend/config.py index 1c38a480..7bc482b8 100644 --- a/backend/transcribee_backend/config.py +++ b/backend/transcribee_backend/config.py @@ -23,6 +23,9 @@ class Settings(BaseSettings): metrics_username = "transcribee" metrics_password = "transcribee" + redis_host = "localhost" + redis_port = 6379 + class ModelConfig(BaseModel): id: str diff --git a/backend/transcribee_backend/db/__init__.py b/backend/transcribee_backend/db/__init__.py index 49b6f4fa..1fab73d7 100644 --- a/backend/transcribee_backend/db/__init__.py +++ b/backend/transcribee_backend/db/__init__.py @@ -6,10 +6,14 @@ from fastapi import Request from prometheus_client import Histogram from prometheus_fastapi_instrumentator import routing +from redis.asyncio import Redis from sqlalchemy import event from sqlmodel import Session, create_engine from starlette.websockets import WebSocket +from transcribee_backend.config import settings +from transcribee_backend.util.redis_task_channel import RedisTaskChannel + DEFAULT_SOCKET_PATH = Path(__file__).parent.parent.parent / "db" / "sockets" DATABASE_URL = os.environ.get( @@ -22,6 +26,8 @@ pool_size=32, max_overflow=1024, # we keep open a database connection for every worker ) +redis = Redis(host=settings.redis_host, port=settings.redis_port) +redis_task_channel = RedisTaskChannel(redis) query_histogram = Histogram( "sql_queries", @@ -31,6 +37,10 @@ ) +def get_redis_task_channel(): + return redis_task_channel + + def get_session(request: Request): handler = routing.get_route_name(request) with Session(engine) as session, query_counter(session, path=handler): diff --git a/backend/transcribee_backend/models/task.py b/backend/transcribee_backend/models/task.py index 6cc0a1ee..1ce3af0e 100644 --- a/backend/transcribee_backend/models/task.py +++ b/backend/transcribee_backend/models/task.py @@ -7,7 +7,7 @@ from sqlmodel import JSON, Column, Field, ForeignKey, Relationship, SQLModel, col from sqlmodel.sql.sqltypes import GUID from transcribee_proto.api import Document as ApiDocument -from transcribee_proto.api import TaskType +from transcribee_proto.api import ExportTaskParameters, TaskType from typing_extensions import Self from transcribee_backend.config import settings @@ -291,9 +291,16 @@ class AlignTask(TaskBase): task_parameters: Dict[str, Any] +class ExportTask(TaskBase): + task_type: Literal[TaskType.EXPORT] = TaskType.EXPORT + task_parameters: ExportTaskParameters + + class UnknownTask(TaskBase): task_type: str task_parameters: Dict[str, Any] -CreateTask = SpeakerIdentificationTask | TranscribeTask | AlignTask | UnknownTask +CreateTask = ( + SpeakerIdentificationTask | TranscribeTask | AlignTask | ExportTask | UnknownTask +) diff --git a/backend/transcribee_backend/routers/document.py b/backend/transcribee_backend/routers/document.py index 15033f07..0abf65d6 100644 --- a/backend/transcribee_backend/routers/document.py +++ b/backend/transcribee_backend/routers/document.py @@ -1,5 +1,6 @@ import datetime import enum +import json import pathlib import uuid from dataclasses import dataclass @@ -21,13 +22,15 @@ status, ) from fastapi.exceptions import RequestValidationError -from pydantic import BaseModel +from fastapi.responses import PlainTextResponse +from pydantic import BaseModel, parse_obj_as from pydantic.error_wrappers import ErrorWrapper from sqlalchemy.orm import selectinload from sqlalchemy.sql.expression import desc from sqlmodel import Session, col, select from transcribee_proto.api import Document as ApiDocument from transcribee_proto.api import DocumentWithAccessInfo as ApiDocumentWithAccessInfo +from transcribee_proto.api import ExportTaskParameters from transcribee_backend.auth import ( generate_share_token, @@ -38,6 +41,7 @@ ) from transcribee_backend.config import get_model_config, settings from transcribee_backend.db import ( + get_redis_task_channel, get_session, get_session_ws, ) @@ -48,6 +52,7 @@ DocumentShareTokenBase, ) from transcribee_backend.models.task import TaskAttempt, TaskResponse +from transcribee_backend.util.redis_task_channel import RedisTaskChannel from .. import media_storage from ..models import ( @@ -601,3 +606,49 @@ def delete_share_tokens( session.delete(token) session.commit() return + + +class ExportResult(BaseModel): + result: str + + +class ExportError(BaseModel): + error: str + + +ExportRes = ExportResult | ExportError + + +@document_router.get("/{document_id}/export/", response_class=PlainTextResponse) +async def export( + export_parameters: ExportTaskParameters = Depends(), + auth: AuthInfo = Depends(get_doc_min_readonly_auth), + redis_task_channel: RedisTaskChannel = Depends(get_redis_task_channel), + session: Session = Depends(get_session), +): + export_task = Task( + task_type=TaskType.EXPORT, + task_parameters=export_parameters.dict(), + document_id=auth.document.id, + ) + session.add(export_task) + session.commit() + + result = parse_obj_as( + ExportRes, + json.loads(await redis_task_channel.wait_for_result(str(export_task.id))), + ) + if isinstance(result, ExportError): + raise Exception(result.error) + else: + return result.result + + +@document_router.post("/{document_id}/add_export_result/") +async def add_export_result( + result: ExportRes, + task_id: str, + auth: AuthInfo = Depends(get_doc_worker_auth), + redis_task_channel: RedisTaskChannel = Depends(get_redis_task_channel), +) -> None: + await redis_task_channel.put_result(task_id, result.json()) diff --git a/backend/transcribee_backend/util/redis_task_channel.py b/backend/transcribee_backend/util/redis_task_channel.py new file mode 100644 index 00000000..9882165c --- /dev/null +++ b/backend/transcribee_backend/util/redis_task_channel.py @@ -0,0 +1,21 @@ +from redis.asyncio import Redis + + +class RedisTaskChannel: + redis: Redis + prefix: str + + def __init__(self, redis, prefix="task-channel:"): + self.redis = redis + self.prefix = prefix + + async def put_result(self, id: str, value: str): + # https://github.com/redis/redis-py/issues/2897 + return await self.redis.rpush(self._redis_key(id), value) # type: ignore + + async def wait_for_result(self, id) -> str: + # https://github.com/redis/redis-py/issues/2897 + return (await self.redis.blpop(self._redis_key(id)))[1] # type: ignore + + def _redis_key(self, id): + return self.prefix + id diff --git a/frontend/src/openapi-schema.ts b/frontend/src/openapi-schema.ts index 80c750fc..cb3cb59e 100644 --- a/frontend/src/openapi-schema.ts +++ b/frontend/src/openapi-schema.ts @@ -31,10 +31,18 @@ export interface paths { /** Update Document */ patch: operations["update_document_api_v1_documents__document_id___patch"]; }; + "/api/v1/documents/{document_id}/add_export_result/": { + /** Add Export Result */ + post: operations["add_export_result_api_v1_documents__document_id__add_export_result__post"]; + }; "/api/v1/documents/{document_id}/add_media_file/": { /** Add Media File */ post: operations["add_media_file_api_v1_documents__document_id__add_media_file__post"]; }; + "/api/v1/documents/{document_id}/export/": { + /** Export */ + get: operations["export_api_v1_documents__document_id__export__get"]; + }; "/api/v1/documents/{document_id}/set_duration/": { /** Set Duration */ post: operations["set_duration_api_v1_documents__document_id__set_duration__post"]; @@ -331,6 +339,47 @@ export interface components { /** Name */ name: string; }; + /** ExportError */ + ExportError: { + /** Error */ + error: string; + }; + /** + * ExportFormat + * @description An enumeration. + * @enum {string} + */ + ExportFormat: "VTT" | "SRT"; + /** ExportResult */ + ExportResult: { + /** Result */ + result: string; + }; + /** ExportTask */ + ExportTask: { + /** + * Document Id + * Format: uuid + */ + document_id: string; + task_parameters: components["schemas"]["ExportTaskParameters"]; + /** + * Task Type + * @default EXPORT + * @enum {string} + */ + task_type?: "EXPORT"; + }; + /** ExportTaskParameters */ + ExportTaskParameters: { + format: components["schemas"]["ExportFormat"]; + /** Include Speaker Names */ + include_speaker_names: boolean; + /** Include Word Timing */ + include_word_timing: boolean; + /** Max Line Length */ + max_line_length?: number; + }; /** HTTPValidationError */ HTTPValidationError: { /** Detail */ @@ -454,7 +503,7 @@ export interface components { * @description An enumeration. * @enum {string} */ - TaskType: "IDENTIFY_SPEAKERS" | "TRANSCRIBE" | "ALIGN" | "REENCODE"; + TaskType: "IDENTIFY_SPEAKERS" | "TRANSCRIBE" | "ALIGN" | "REENCODE" | "EXPORT"; /** TranscribeTask */ TranscribeTask: { /** @@ -740,6 +789,40 @@ export interface operations { }; }; }; + /** Add Export Result */ + add_export_result_api_v1_documents__document_id__add_export_result__post: { + parameters: { + query: { + task_id: string; + }; + header?: { + authorization?: string; + "Share-Token"?: string; + }; + path: { + document_id: string; + }; + }; + requestBody: { + content: { + "application/json": components["schemas"]["ExportResult"] | components["schemas"]["ExportError"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": unknown; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** Add Media File */ add_media_file_api_v1_documents__document_id__add_media_file__post: { parameters: { @@ -770,6 +853,38 @@ export interface operations { }; }; }; + /** Export */ + export_api_v1_documents__document_id__export__get: { + parameters: { + query: { + format: components["schemas"]["ExportFormat"]; + include_speaker_names: boolean; + include_word_timing: boolean; + max_line_length?: number; + }; + header?: { + authorization?: string; + "Share-Token"?: string; + }; + path: { + document_id: string; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "text/plain": string; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** Set Duration */ set_duration_api_v1_documents__document_id__set_duration__post: { parameters: { @@ -980,7 +1095,7 @@ export interface operations { }; requestBody: { content: { - "application/json": components["schemas"]["SpeakerIdentificationTask"] | components["schemas"]["TranscribeTask"] | components["schemas"]["AlignTask"] | components["schemas"]["UnknownTask"]; + "application/json": components["schemas"]["SpeakerIdentificationTask"] | components["schemas"]["TranscribeTask"] | components["schemas"]["AlignTask"] | components["schemas"]["ExportTask"] | components["schemas"]["UnknownTask"]; }; }; responses: { diff --git a/packaging/Procfile b/packaging/Procfile index 1d7d5190..f1ef3817 100644 --- a/packaging/Procfile +++ b/packaging/Procfile @@ -2,3 +2,4 @@ backend: ./start_backend.sh worker: pdm run -p ../worker/ start --coordinator http://127.0.0.1:8000 --token dev_worker --reload frontend: pnpm --prefix ../frontend/ dev --clearScreen false db: ./start_db.sh +redis: redis-server diff --git a/proto/transcribee_proto/api.py b/proto/transcribee_proto/api.py index 75e6ec1d..515114bf 100644 --- a/proto/transcribee_proto/api.py +++ b/proto/transcribee_proto/api.py @@ -11,6 +11,7 @@ class TaskType(str, enum.Enum): TRANSCRIBE = "TRANSCRIBE" ALIGN = "ALIGN" REENCODE = "REENCODE" + EXPORT = "EXPORT" class DocumentMedia(BaseModel): @@ -52,6 +53,18 @@ class TranscribeTaskParameters(BaseModel): model: str +class ExportFormat(str, enum.Enum): + VTT = "VTT" + SRT = "SRT" + + +class ExportTaskParameters(BaseModel): + format: ExportFormat + include_speaker_names: bool + include_word_timing: bool + max_line_length: int | None + + class TranscribeTask(TaskBase): task_type: Literal[TaskType.TRANSCRIBE] = TaskType.TRANSCRIBE task_parameters: TranscribeTaskParameters @@ -67,7 +80,14 @@ class ReencodeTask(TaskBase): task_parameters: Dict[str, Any] -AssignedTask = SpeakerIdentificationTask | TranscribeTask | AlignTask | ReencodeTask +class ExportTask(TaskBase): + task_type: Literal[TaskType.EXPORT] = TaskType.EXPORT + task_parameters: ExportTaskParameters + + +AssignedTask = ( + SpeakerIdentificationTask | TranscribeTask | AlignTask | ReencodeTask | ExportTask +) class LoginResponse(BaseModel): diff --git a/shell.nix b/shell.nix index fd11cefe..28b836c0 100644 --- a/shell.nix +++ b/shell.nix @@ -61,6 +61,10 @@ pkgs.mkShell { # Our database postgresql + + # Our database2 ? + redis + openssl # needed for psycopg2 ] ++ diff --git a/worker/transcribee_worker/webvtt/export_webvtt.py b/worker/transcribee_worker/webvtt/export_webvtt.py new file mode 100644 index 00000000..c51f9a89 --- /dev/null +++ b/worker/transcribee_worker/webvtt/export_webvtt.py @@ -0,0 +1,140 @@ +import logging + +from transcribee_proto.document import Atom, Document, Paragraph + +from .webvtt_writer import VttCue, WebVtt, escape_vtt_string, formatted_time + + +def get_speaker_name( + speaker: str | None, + speaker_names: dict[str, str], +) -> str: + if speaker is None: + return "Unknown Speaker" + else: + try: + return speaker_names[speaker] + except KeyError: + return f"Unnamed Speaker {speaker}" + + +def atom_to_string(item: Atom, include_word_timings: bool): + if include_word_timings and isinstance(item.start, float): + return ( + f"<{formatted_time(item.start)}>{escape_vtt_string(str(item.text))}" + ) + else: + return escape_vtt_string(str(item.text)) + + +def can_generate_vtt(paras: list[Paragraph] | None): + if paras is None: + return (False, "No document content") + + for para in paras: + for atom in para.children: + if not isinstance(atom.end, float) or not isinstance(atom.start, float): + return (False, "Missing timings for at least one atom") + + return (True, "") + + +def paragraph_to_cues( + paragraph: Paragraph, + include_word_timings: bool, + include_speaker_names: bool, + max_line_length: int | None, + speaker_names, +): + cues = [] + cue_payload = "" + cue_length = 0 + cue_start = None + cue_end = None + + def push_payload(payload): + nonlocal cue_start, cue_end + if include_speaker_names and paragraph.speaker: + payload = ( + f"" + + payload + ) + + assert cue_start is not None + assert cue_end is not None + + if cue_start >= cue_end: + logging.debug( + f"found {cue_start=} that is not before {cue_end=}" + ", fixing the end to be behind cue_start" + ) + cue_end = cue_start + 0.02 + + cues.append( + VttCue( + start_time=cue_start, + end_time=cue_end, + payload=payload, + payload_escaped=True, + ) + ) + + for atom in paragraph.children: + atom_text = str(atom.text) + if ( + max_line_length is not None + and cue_start is not None + and cue_end is not None + and cue_length + len(atom_text) > max_line_length + ): + push_payload(cue_payload) + + cue_payload = "" + cue_length = 0 + cue_start = None + cue_end = None + + if atom.start and (cue_start is None or atom.start < cue_start): + cue_start = atom.start + + if atom.end and (cue_end is None or atom.end > cue_end): + cue_end = atom.end + + cue_payload += atom_to_string(atom, include_word_timings) + cue_length += len(atom_text) + + if len(cue_payload) > 0: + if cue_start is None or cue_end is None: + raise ValueError( + "Paragraph contains no timings, cannot generate cue(s)." + " Make sure to only call this function if `canGenerateVtt` returns true", + ) + push_payload(cue_payload) + + return cues + + +def generate_web_vtt( + doc: Document, + include_speaker_names: bool, + include_word_timing: bool, + max_line_length: int | None, +) -> WebVtt: + vtt = WebVtt( + "This file was generated using transcribee." + " Find out more at https://github.com/bugbakery/transcribee" + ) + for par in doc.children: + if len(par.children) == 0: + continue + + for cue in paragraph_to_cues( + par, + include_word_timing, + include_speaker_names, + max_line_length, + doc.speaker_names, + ): + vtt.add(cue) + + return vtt diff --git a/worker/transcribee_worker/webvtt/webvtt_writer.py b/worker/transcribee_worker/webvtt/webvtt_writer.py new file mode 100644 index 00000000..fff33a6c --- /dev/null +++ b/worker/transcribee_worker/webvtt/webvtt_writer.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +import dataclasses +import enum +import re +from abc import ABC, abstractmethod + + +class SubtitleFormat(str, enum.Enum): + VTT = "vtt" + SRT = "srt" + + +class VttElement(ABC): + @abstractmethod + def to_string(self, format: SubtitleFormat) -> str: + ... + + +class Vertical(str, enum.Enum): + RL = "rl" + LR = "lr" + + +class Align(str, enum.Enum): + START = "start" + CENTER = "center" + END = "end" + + +@dataclasses.dataclass +class VttCueSettings(VttElement): + vertical: None | Vertical + line: None | int | str + position: None | str + size: None | str + align: None | Align + + def to_string(self, format: SubtitleFormat): + if format == SubtitleFormat.SRT: + return "" + + def format_elem(name, elem): + if elem is None: + return [] + else: + return [f"{name}:{elem}"] + + return " ".join( + format_elem("vertical", self.vertical) + + format_elem("line", self.line) + + format_elem("position", self.position) + + format_elem("size", self.size) + + format_elem("align", self.align) + ) + + +def formatted_time(secs: float): + subseconds = int((secs % 1.0) * 1000) + seconds = int(secs % 60) + minutes = int((secs / 60) % 60) + hours = int((secs / 60 / 60) % 60) + + return f"{hours:02}:{minutes:02}:{seconds:02}.{subseconds:03}" + + +class VttCue(VttElement): + identifier: str | None + start_time: float + end_time: float + settings: VttCueSettings | None + payload: str + + def __init__( + self, + start_time: float, + end_time: float, + payload: str, + payload_escaped: bool | None = None, + identifier: str | None = None, + identifier_escaped: bool | None = None, + settings: VttCueSettings | None = None, + ): + if start_time >= end_time: + raise ValueError("Cue end time must be greater than cue start time") + + self.start_time = start_time + self.end_time = end_time + + if not payload_escaped: + payload = escape_vtt_string(payload) + + self.payload = payload + + if identifier and not identifier_escaped: + identifier = escape_vtt_string(identifier) + + if identifier and "\n" in identifier: + raise ValueError("WebVTT cue identifiers MUST NOT contain a newline") + + if identifier and "-->" in identifier: + raise ValueError("WebVTT cue identifiers MUST NOT contain -->") + + self.identifier = identifier + self.settings = settings + + def to_string(self, format: SubtitleFormat): + ident = f"{self.identifier}\n" if self.identifier else "" + time = f"{formatted_time(self.start_time)} --> {formatted_time(self.end_time)}" + settings = f"{self.settings.to_string(format)}".strip() if self.settings else "" + + return ident + time + settings + "\n" + self.payload + + +class VttComment(VttElement): + comment_text: str + + def __init__(self, text: str, escaped: bool = False): + if not escaped: + text = escape_vtt_string(text) + if "-->" in text: + raise ValueError("WebVTT comments MUST NOT contain -->") + + self.comment_text = text + + def to_string(self, format: SubtitleFormat): + if format != SubtitleFormat.VTT: + return "" + return f"NOTE {self.comment_text}" + + +def escape_vtt_string(text: str) -> str: + escape_map = {"&": "&", "<": "<", ">": ">"} + re.sub(r"[&<>]", lambda obj: escape_map[obj.group(0)], text) + return text + + +class VttHeader(VttElement): + header_text: str + + def __init__(self, text: str, escaped: bool = False): + if not escaped: + text = escape_vtt_string(text) + + if "-->" in text: + raise ValueError("WebVTT text header MUST NOT contain -->") + + if "\n" in text: + raise ValueError("WebVTT text header MUST NOT contain newlines") + + self.header_text = text + + def to_string(self, format: SubtitleFormat): + if format != SubtitleFormat.VTT: + return "" + + return f"WEBVTT {self.header_text}" + + +class WebVtt: + elements: list[VttElement] + + def __init__(self, header=""): + self.elements = [VttHeader(header)] + + def add(self, element: VttElement): + self.elements.append(element) + + def to_string(self, format: SubtitleFormat = SubtitleFormat.VTT): + as_strings = [elem.to_string(format) for elem in self.elements] + return "\n\n".join([elem for elem in as_strings if len(elem) > 0]) + "\n" diff --git a/worker/transcribee_worker/worker.py b/worker/transcribee_worker/worker.py index 4572b658..66b144c5 100644 --- a/worker/transcribee_worker/worker.py +++ b/worker/transcribee_worker/worker.py @@ -15,6 +15,8 @@ from transcribee_proto.api import ( AlignTask, AssignedTask, + ExportFormat, + ExportTask, ReencodeTask, SpeakerIdentificationTask, TaskType, @@ -29,6 +31,8 @@ from transcribee_worker.torchaudio_align import align from transcribee_worker.types import ProgressCallbackType from transcribee_worker.util import aenumerate, load_audio +from transcribee_worker.webvtt.export_webvtt import generate_web_vtt +from transcribee_worker.webvtt.webvtt_writer import SubtitleFormat from transcribee_worker.whisper_transcribe import ( transcribe_clean_async, ) @@ -116,6 +120,7 @@ def __init__( TaskType.ALIGN, TaskType.TRANSCRIBE, TaskType.REENCODE, + TaskType.EXPORT, ] def claim_task(self) -> Optional[AssignedTask]: @@ -185,6 +190,8 @@ def progress_callback(*, progress, step: Optional[str] = "", extra_data=None): await self.align(task, progress_callback) elif task.task_type == TaskType.REENCODE: await self.reencode(task, progress_callback) + elif task.task_type == TaskType.EXPORT: + await self.export(task, progress_callback) else: raise ValueError(f"Invalid task type: '{task.task_type}'") @@ -308,6 +315,36 @@ async def reencode( None, self.add_document_media_file, task, output_path, tags ) + async def export(self, task: ExportTask, progress_callback: ProgressCallbackType): + async with self.api_client.document(task.document.id) as doc: + params = task.task_parameters + res = None + try: + vtt = generate_web_vtt( + EditorDocument.parse_obj(automerge.dump(doc.doc)), + params.include_speaker_names, + params.include_word_timing, + params.max_line_length, + ) + + if params.format == ExportFormat.VTT: + res = vtt.to_string(SubtitleFormat.VTT) + elif params.format == ExportFormat.SRT: + res = vtt.to_string(SubtitleFormat.SRT) + + res = {"result": res} + except ValueError as e: + res = {"error": str(e)} + + if res is not None: + logging.info( + f"Uploading document export for {task.document.id=} {res=}" + ) + self.api_client.post( + f"documents/{task.document.id}/add_export_result/?task_id={task.id}", + json=res, + ) + def set_duration(self, task: AssignedTask, duration: float): logging.debug( f"Setting audio duration for document {task.document.id=} {duration=}"