diff --git a/backend/openapi-schema.yml b/backend/openapi-schema.yml
index 7730d02c..14687465 100644
--- a/backend/openapi-schema.yml
+++ b/backend/openapi-schema.yml
@@ -327,6 +327,69 @@ components:
- has_full_access
title: DocumentWithAccessInfo
type: object
+ ExportError:
+ properties:
+ error:
+ title: Error
+ type: string
+ required:
+ - error
+ title: ExportError
+ type: object
+ ExportFormat:
+ description: An enumeration.
+ enum:
+ - VTT
+ - SRT
+ title: ExportFormat
+ type: string
+ ExportResult:
+ properties:
+ result:
+ title: Result
+ type: string
+ required:
+ - result
+ title: ExportResult
+ type: object
+ ExportTask:
+ properties:
+ document_id:
+ format: uuid
+ title: Document Id
+ type: string
+ task_parameters:
+ $ref: '#/components/schemas/ExportTaskParameters'
+ task_type:
+ default: EXPORT
+ enum:
+ - EXPORT
+ title: Task Type
+ type: string
+ required:
+ - task_parameters
+ - document_id
+ title: ExportTask
+ type: object
+ ExportTaskParameters:
+ properties:
+ format:
+ $ref: '#/components/schemas/ExportFormat'
+ include_speaker_names:
+ title: Include Speaker Names
+ type: boolean
+ include_word_timing:
+ title: Include Word Timing
+ type: boolean
+ max_line_length:
+ title: Max Line Length
+ type: integer
+ required:
+ - format
+ - include_speaker_names
+ - include_word_timing
+ title: ExportTaskParameters
+ type: object
HTTPValidationError:
properties:
detail:
@@ -529,6 +592,7 @@ components:
- TRANSCRIBE
- ALIGN
- REENCODE
+ - EXPORT
title: TaskType
type: string
TranscribeTask:
@@ -887,6 +951,57 @@ paths:
$ref: '#/components/schemas/HTTPValidationError'
description: Validation Error
summary: Update Document
+ /api/v1/documents/{document_id}/add_export_result/:
+ post:
+ operationId: add_export_result_api_v1_documents__document_id__add_export_result__post
+ parameters:
+ - in: path
+ name: document_id
+ required: true
+ schema:
+ format: uuid
+ title: Document Id
+ type: string
+ - in: query
+ name: task_id
+ required: true
+ schema:
+ title: Task Id
+ type: string
+ - in: header
+ name: authorization
+ required: false
+ schema:
+ title: Authorization
+ type: string
+ - in: header
+ name: Share-Token
+ required: false
+ schema:
+ title: Share-Token
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ anyOf:
+ - $ref: '#/components/schemas/ExportResult'
+ - $ref: '#/components/schemas/ExportError'
+ title: Result
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema: {}
+ description: Successful Response
+ '422':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ description: Validation Error
+ summary: Add Export Result
/api/v1/documents/{document_id}/add_media_file/:
post:
operationId: add_media_file_api_v1_documents__document_id__add_media_file__post
@@ -924,6 +1039,66 @@ paths:
$ref: '#/components/schemas/HTTPValidationError'
description: Validation Error
summary: Add Media File
+ /api/v1/documents/{document_id}/export/:
+ get:
+ operationId: export_api_v1_documents__document_id__export__get
+ parameters:
+ - in: path
+ name: document_id
+ required: true
+ schema:
+ format: uuid
+ title: Document Id
+ type: string
+ - in: query
+ name: format
+ required: true
+ schema:
+ $ref: '#/components/schemas/ExportFormat'
+ - in: query
+ name: include_speaker_names
+ required: true
+ schema:
+ title: Include Speaker Names
+ type: boolean
+ - in: query
+ name: include_word_timing
+ required: true
+ schema:
+ title: Include Word Timing
+ type: boolean
+ - in: query
+ name: max_line_length
+ required: false
+ schema:
+ title: Max Line Length
+ type: integer
+ - in: header
+ name: authorization
+ required: false
+ schema:
+ title: Authorization
+ type: string
+ - in: header
+ name: Share-Token
+ required: false
+ schema:
+ title: Share-Token
+ type: string
+ responses:
+ '200':
+ content:
+ text/plain:
+ schema:
+ type: string
+ description: Successful Response
+ '422':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ description: Validation Error
+ summary: Export
/api/v1/documents/{document_id}/set_duration/:
post:
operationId: set_duration_api_v1_documents__document_id__set_duration__post
@@ -1211,6 +1386,7 @@ paths:
- $ref: '#/components/schemas/SpeakerIdentificationTask'
- $ref: '#/components/schemas/TranscribeTask'
- $ref: '#/components/schemas/AlignTask'
+ - $ref: '#/components/schemas/ExportTask'
- $ref: '#/components/schemas/UnknownTask'
title: Task
required: true
diff --git a/backend/pdm.lock b/backend/pdm.lock
index 0baa4114..bb43cb65 100644
--- a/backend/pdm.lock
+++ b/backend/pdm.lock
@@ -73,6 +73,12 @@ dependencies = [
"typing-extensions>=4.0.0; python_version < \"3.11\"",
]
+[[package]]
+name = "async-timeout"
+version = "4.0.3"
+requires_python = ">=3.7"
+summary = "Timeout context manager for asyncio programs"
+
[[package]]
name = "attrs"
version = "23.1.0"
@@ -987,6 +993,15 @@ dependencies = [
"packaging",
]
+[[package]]
+name = "redis"
+version = "5.0.1"
+requires_python = ">=3.7"
+summary = "Python client for Redis database and key-value store"
+dependencies = [
+ "async-timeout>=4.0.2; python_full_version <= \"3.11.2\"",
+]
+
[[package]]
name = "referencing"
version = "0.31.0"
@@ -1284,7 +1299,7 @@ summary = "Jupyter interactive widgets for Jupyter Notebook"
[metadata]
lock_version = "4.1"
-content_hash = "sha256:b9412fb0704ebdac6028e09e8e76d0e12d606505a34f92fd9a701fc9cc1a892f"
+content_hash = "sha256:15954986dfbaa2cc507675379441f8b6682f69561b69cc87cb8dcd97489c8303"
[metadata.files]
"alembic 1.12.1" = [
@@ -1338,6 +1353,10 @@ content_hash = "sha256:b9412fb0704ebdac6028e09e8e76d0e12d606505a34f92fd9a701fc9c
{url = "https://files.pythonhosted.org/packages/80/e2/2b4651eff771f6fd900d233e175ddc5e2be502c7eb62c0c42f975c6d36cd/async-lru-2.0.4.tar.gz", hash = "sha256:b8a59a5df60805ff63220b2a0c5b5393da5521b113cd5465a44eb037d81a5627"},
{url = "https://files.pythonhosted.org/packages/fa/9f/3c3503693386c4b0f245eaf5ca6198e3b28879ca0a40bde6b0e319793453/async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224"},
]
+"async-timeout 4.0.3" = [
+ {url = "https://files.pythonhosted.org/packages/87/d6/21b30a550dafea84b1b8eee21b5e23fa16d010ae006011221f33dcd8d7f8/async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
+ {url = "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
+]
"attrs 23.1.0" = [
{url = "https://files.pythonhosted.org/packages/97/90/81f95d5f705be17872843536b1868f351805acf6971251ff07c1b8334dbb/attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"},
{url = "https://files.pythonhosted.org/packages/f0/eb/fcb708c7bf5056045e9e98f62b93bd7467eb718b0202e7698eb11d66416c/attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"},
@@ -2577,6 +2596,10 @@ content_hash = "sha256:b9412fb0704ebdac6028e09e8e76d0e12d606505a34f92fd9a701fc9c
{url = "https://files.pythonhosted.org/packages/7e/a9/2146d5117ad8a81185331e0809a6b48933c10171f5bac253c6df9fce991c/QtPy-2.4.1-py3-none-any.whl", hash = "sha256:1c1d8c4fa2c884ae742b069151b0abe15b3f70491f3972698c683b8e38de839b"},
{url = "https://files.pythonhosted.org/packages/eb/9a/7ce646daefb2f85bf5b9c8ac461508b58fa5dcad6d40db476187fafd0148/QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"},
]
+"redis 5.0.1" = [
+ {url = "https://files.pythonhosted.org/packages/0b/34/a01250ac1fc9bf9161e07956d2d580413106ce02d5591470130a25c599e3/redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"},
+ {url = "https://files.pythonhosted.org/packages/4a/4c/3c3b766f4ecbb3f0bec91ef342ee98d179e040c25b6ecc99e510c2570f2a/redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"},
+]
"referencing 0.31.0" = [
{url = "https://files.pythonhosted.org/packages/29/c1/69342fbc8efd1aac5cda853cea771763b95d92325c4f8f83b499c07bc698/referencing-0.31.0-py3-none-any.whl", hash = "sha256:381b11e53dd93babb55696c71cf42aef2d36b8a150c49bf0bc301e36d536c882"},
{url = "https://files.pythonhosted.org/packages/61/11/5e947c3f2a73e7fb77fd1c3370aa04e107f3c10ceef4880c2e25ef19679c/referencing-0.31.0.tar.gz", hash = "sha256:cc28f2c88fbe7b961a7817a0abc034c09a1e36358f82fedb4ffdf29a25398863"},
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index b8805275..93884c98 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -10,6 +10,7 @@ authors = [
]
dependencies = [
+ "redis>=5.0.1",
"fastapi>=0.92.0",
"uvicorn[standard]>=0.20.0",
"sqlmodel @ git+https://github.com/transcribee/sqlmodel.git@transcribee_main",
diff --git a/backend/transcribee_backend/config.py b/backend/transcribee_backend/config.py
index 1c38a480..7bc482b8 100644
--- a/backend/transcribee_backend/config.py
+++ b/backend/transcribee_backend/config.py
@@ -23,6 +23,9 @@ class Settings(BaseSettings):
metrics_username = "transcribee"
metrics_password = "transcribee"
+ redis_host = "localhost"
+ redis_port = 6379
+
class ModelConfig(BaseModel):
id: str
diff --git a/backend/transcribee_backend/db/__init__.py b/backend/transcribee_backend/db/__init__.py
index 49b6f4fa..1fab73d7 100644
--- a/backend/transcribee_backend/db/__init__.py
+++ b/backend/transcribee_backend/db/__init__.py
@@ -6,10 +6,14 @@
from fastapi import Request
from prometheus_client import Histogram
from prometheus_fastapi_instrumentator import routing
+from redis.asyncio import Redis
from sqlalchemy import event
from sqlmodel import Session, create_engine
from starlette.websockets import WebSocket
+from transcribee_backend.config import settings
+from transcribee_backend.util.redis_task_channel import RedisTaskChannel
+
DEFAULT_SOCKET_PATH = Path(__file__).parent.parent.parent / "db" / "sockets"
DATABASE_URL = os.environ.get(
@@ -22,6 +26,8 @@
pool_size=32,
max_overflow=1024, # we keep open a database connection for every worker
)
+redis = Redis(host=settings.redis_host, port=settings.redis_port)
+redis_task_channel = RedisTaskChannel(redis)
query_histogram = Histogram(
"sql_queries",
@@ -31,6 +37,10 @@
)
+def get_redis_task_channel():
+ return redis_task_channel
+
+
def get_session(request: Request):
handler = routing.get_route_name(request)
with Session(engine) as session, query_counter(session, path=handler):
diff --git a/backend/transcribee_backend/models/task.py b/backend/transcribee_backend/models/task.py
index 6cc0a1ee..1ce3af0e 100644
--- a/backend/transcribee_backend/models/task.py
+++ b/backend/transcribee_backend/models/task.py
@@ -7,7 +7,7 @@
from sqlmodel import JSON, Column, Field, ForeignKey, Relationship, SQLModel, col
from sqlmodel.sql.sqltypes import GUID
from transcribee_proto.api import Document as ApiDocument
-from transcribee_proto.api import TaskType
+from transcribee_proto.api import ExportTaskParameters, TaskType
from typing_extensions import Self
from transcribee_backend.config import settings
@@ -291,9 +291,16 @@ class AlignTask(TaskBase):
task_parameters: Dict[str, Any]
+class ExportTask(TaskBase):
+ task_type: Literal[TaskType.EXPORT] = TaskType.EXPORT
+ task_parameters: ExportTaskParameters
+
+
class UnknownTask(TaskBase):
task_type: str
task_parameters: Dict[str, Any]
-CreateTask = SpeakerIdentificationTask | TranscribeTask | AlignTask | UnknownTask
+CreateTask = (
+ SpeakerIdentificationTask | TranscribeTask | AlignTask | ExportTask | UnknownTask
+)
diff --git a/backend/transcribee_backend/routers/document.py b/backend/transcribee_backend/routers/document.py
index 15033f07..0abf65d6 100644
--- a/backend/transcribee_backend/routers/document.py
+++ b/backend/transcribee_backend/routers/document.py
@@ -1,5 +1,6 @@
import datetime
import enum
+import json
import pathlib
import uuid
from dataclasses import dataclass
@@ -21,13 +22,15 @@
status,
)
from fastapi.exceptions import RequestValidationError
-from pydantic import BaseModel
+from fastapi.responses import PlainTextResponse
+from pydantic import BaseModel, parse_obj_as
from pydantic.error_wrappers import ErrorWrapper
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import desc
from sqlmodel import Session, col, select
from transcribee_proto.api import Document as ApiDocument
from transcribee_proto.api import DocumentWithAccessInfo as ApiDocumentWithAccessInfo
+from transcribee_proto.api import ExportTaskParameters
from transcribee_backend.auth import (
generate_share_token,
@@ -38,6 +41,7 @@
)
from transcribee_backend.config import get_model_config, settings
from transcribee_backend.db import (
+ get_redis_task_channel,
get_session,
get_session_ws,
)
@@ -48,6 +52,7 @@
DocumentShareTokenBase,
)
from transcribee_backend.models.task import TaskAttempt, TaskResponse
+from transcribee_backend.util.redis_task_channel import RedisTaskChannel
from .. import media_storage
from ..models import (
@@ -601,3 +606,49 @@ def delete_share_tokens(
session.delete(token)
session.commit()
return
+
+
+class ExportResult(BaseModel):
+ result: str
+
+
+class ExportError(BaseModel):
+ error: str
+
+
+ExportRes = ExportResult | ExportError
+
+
+@document_router.get("/{document_id}/export/", response_class=PlainTextResponse)
+async def export(
+ export_parameters: ExportTaskParameters = Depends(),
+ auth: AuthInfo = Depends(get_doc_min_readonly_auth),
+ redis_task_channel: RedisTaskChannel = Depends(get_redis_task_channel),
+ session: Session = Depends(get_session),
+):
+ export_task = Task(
+ task_type=TaskType.EXPORT,
+ task_parameters=export_parameters.dict(),
+ document_id=auth.document.id,
+ )
+ session.add(export_task)
+ session.commit()
+
+ result = parse_obj_as(
+ ExportRes,
+ json.loads(await redis_task_channel.wait_for_result(str(export_task.id))),
+ )
+ if isinstance(result, ExportError):
+ raise Exception(result.error)
+ else:
+ return result.result
+
+
+@document_router.post("/{document_id}/add_export_result/")
+async def add_export_result(
+ result: ExportRes,
+ task_id: str,
+ auth: AuthInfo = Depends(get_doc_worker_auth),
+ redis_task_channel: RedisTaskChannel = Depends(get_redis_task_channel),
+) -> None:
+ await redis_task_channel.put_result(task_id, result.json())
diff --git a/backend/transcribee_backend/util/redis_task_channel.py b/backend/transcribee_backend/util/redis_task_channel.py
new file mode 100644
index 00000000..9882165c
--- /dev/null
+++ b/backend/transcribee_backend/util/redis_task_channel.py
@@ -0,0 +1,21 @@
+from redis.asyncio import Redis
+
+
+class RedisTaskChannel:
+ redis: Redis
+ prefix: str
+
+ def __init__(self, redis, prefix="task-channel:"):
+ self.redis = redis
+ self.prefix = prefix
+
+ async def put_result(self, id: str, value: str):
+ # https://github.com/redis/redis-py/issues/2897
+ return await self.redis.rpush(self._redis_key(id), value) # type: ignore
+
+ async def wait_for_result(self, id) -> str:
+ # https://github.com/redis/redis-py/issues/2897
+ return (await self.redis.blpop(self._redis_key(id)))[1] # type: ignore
+
+ def _redis_key(self, id):
+ return self.prefix + id
diff --git a/frontend/src/openapi-schema.ts b/frontend/src/openapi-schema.ts
index 80c750fc..cb3cb59e 100644
--- a/frontend/src/openapi-schema.ts
+++ b/frontend/src/openapi-schema.ts
@@ -31,10 +31,18 @@ export interface paths {
/** Update Document */
patch: operations["update_document_api_v1_documents__document_id___patch"];
};
+ "/api/v1/documents/{document_id}/add_export_result/": {
+ /** Add Export Result */
+ post: operations["add_export_result_api_v1_documents__document_id__add_export_result__post"];
+ };
"/api/v1/documents/{document_id}/add_media_file/": {
/** Add Media File */
post: operations["add_media_file_api_v1_documents__document_id__add_media_file__post"];
};
+ "/api/v1/documents/{document_id}/export/": {
+ /** Export */
+ get: operations["export_api_v1_documents__document_id__export__get"];
+ };
"/api/v1/documents/{document_id}/set_duration/": {
/** Set Duration */
post: operations["set_duration_api_v1_documents__document_id__set_duration__post"];
@@ -331,6 +339,47 @@ export interface components {
/** Name */
name: string;
};
+ /** ExportError */
+ ExportError: {
+ /** Error */
+ error: string;
+ };
+ /**
+ * ExportFormat
+ * @description An enumeration.
+ * @enum {string}
+ */
+ ExportFormat: "VTT" | "SRT";
+ /** ExportResult */
+ ExportResult: {
+ /** Result */
+ result: string;
+ };
+ /** ExportTask */
+ ExportTask: {
+ /**
+ * Document Id
+ * Format: uuid
+ */
+ document_id: string;
+ task_parameters: components["schemas"]["ExportTaskParameters"];
+ /**
+ * Task Type
+ * @default EXPORT
+ * @enum {string}
+ */
+ task_type?: "EXPORT";
+ };
+ /** ExportTaskParameters */
+ ExportTaskParameters: {
+ format: components["schemas"]["ExportFormat"];
+ /** Include Speaker Names */
+ include_speaker_names: boolean;
+ /** Include Word Timing */
+ include_word_timing: boolean;
+ /** Max Line Length */
+ max_line_length?: number;
+ };
/** HTTPValidationError */
HTTPValidationError: {
/** Detail */
@@ -454,7 +503,7 @@ export interface components {
* @description An enumeration.
* @enum {string}
*/
- TaskType: "IDENTIFY_SPEAKERS" | "TRANSCRIBE" | "ALIGN" | "REENCODE";
+ TaskType: "IDENTIFY_SPEAKERS" | "TRANSCRIBE" | "ALIGN" | "REENCODE" | "EXPORT";
/** TranscribeTask */
TranscribeTask: {
/**
@@ -740,6 +789,40 @@ export interface operations {
};
};
};
+ /** Add Export Result */
+ add_export_result_api_v1_documents__document_id__add_export_result__post: {
+ parameters: {
+ query: {
+ task_id: string;
+ };
+ header?: {
+ authorization?: string;
+ "Share-Token"?: string;
+ };
+ path: {
+ document_id: string;
+ };
+ };
+ requestBody: {
+ content: {
+ "application/json": components["schemas"]["ExportResult"] | components["schemas"]["ExportError"];
+ };
+ };
+ responses: {
+ /** @description Successful Response */
+ 200: {
+ content: {
+ "application/json": unknown;
+ };
+ };
+ /** @description Validation Error */
+ 422: {
+ content: {
+ "application/json": components["schemas"]["HTTPValidationError"];
+ };
+ };
+ };
+ };
/** Add Media File */
add_media_file_api_v1_documents__document_id__add_media_file__post: {
parameters: {
@@ -770,6 +853,38 @@ export interface operations {
};
};
};
+ /** Export */
+ export_api_v1_documents__document_id__export__get: {
+ parameters: {
+ query: {
+ format: components["schemas"]["ExportFormat"];
+ include_speaker_names: boolean;
+ include_word_timing: boolean;
+ max_line_length?: number;
+ };
+ header?: {
+ authorization?: string;
+ "Share-Token"?: string;
+ };
+ path: {
+ document_id: string;
+ };
+ };
+ responses: {
+ /** @description Successful Response */
+ 200: {
+ content: {
+ "text/plain": string;
+ };
+ };
+ /** @description Validation Error */
+ 422: {
+ content: {
+ "application/json": components["schemas"]["HTTPValidationError"];
+ };
+ };
+ };
+ };
/** Set Duration */
set_duration_api_v1_documents__document_id__set_duration__post: {
parameters: {
@@ -980,7 +1095,7 @@ export interface operations {
};
requestBody: {
content: {
- "application/json": components["schemas"]["SpeakerIdentificationTask"] | components["schemas"]["TranscribeTask"] | components["schemas"]["AlignTask"] | components["schemas"]["UnknownTask"];
+ "application/json": components["schemas"]["SpeakerIdentificationTask"] | components["schemas"]["TranscribeTask"] | components["schemas"]["AlignTask"] | components["schemas"]["ExportTask"] | components["schemas"]["UnknownTask"];
};
};
responses: {
diff --git a/packaging/Procfile b/packaging/Procfile
index 1d7d5190..f1ef3817 100644
--- a/packaging/Procfile
+++ b/packaging/Procfile
@@ -2,3 +2,4 @@ backend: ./start_backend.sh
worker: pdm run -p ../worker/ start --coordinator http://127.0.0.1:8000 --token dev_worker --reload
frontend: pnpm --prefix ../frontend/ dev --clearScreen false
db: ./start_db.sh
+redis: redis-server
diff --git a/proto/transcribee_proto/api.py b/proto/transcribee_proto/api.py
index 75e6ec1d..515114bf 100644
--- a/proto/transcribee_proto/api.py
+++ b/proto/transcribee_proto/api.py
@@ -11,6 +11,7 @@ class TaskType(str, enum.Enum):
TRANSCRIBE = "TRANSCRIBE"
ALIGN = "ALIGN"
REENCODE = "REENCODE"
+ EXPORT = "EXPORT"
class DocumentMedia(BaseModel):
@@ -52,6 +53,18 @@ class TranscribeTaskParameters(BaseModel):
model: str
+class ExportFormat(str, enum.Enum):
+ VTT = "VTT"
+ SRT = "SRT"
+
+
+class ExportTaskParameters(BaseModel):
+ format: ExportFormat
+ include_speaker_names: bool
+ include_word_timing: bool
+ max_line_length: int | None
+
+
class TranscribeTask(TaskBase):
task_type: Literal[TaskType.TRANSCRIBE] = TaskType.TRANSCRIBE
task_parameters: TranscribeTaskParameters
@@ -67,7 +80,14 @@ class ReencodeTask(TaskBase):
task_parameters: Dict[str, Any]
-AssignedTask = SpeakerIdentificationTask | TranscribeTask | AlignTask | ReencodeTask
+class ExportTask(TaskBase):
+ task_type: Literal[TaskType.EXPORT] = TaskType.EXPORT
+ task_parameters: ExportTaskParameters
+
+
+AssignedTask = (
+ SpeakerIdentificationTask | TranscribeTask | AlignTask | ReencodeTask | ExportTask
+)
class LoginResponse(BaseModel):
diff --git a/shell.nix b/shell.nix
index fd11cefe..28b836c0 100644
--- a/shell.nix
+++ b/shell.nix
@@ -61,6 +61,10 @@ pkgs.mkShell {
# Our database
postgresql
+
+ # Our database2 ?
+ redis
+
openssl # needed for psycopg2
] ++
diff --git a/worker/transcribee_worker/webvtt/export_webvtt.py b/worker/transcribee_worker/webvtt/export_webvtt.py
new file mode 100644
index 00000000..c51f9a89
--- /dev/null
+++ b/worker/transcribee_worker/webvtt/export_webvtt.py
@@ -0,0 +1,140 @@
+import logging
+
+from transcribee_proto.document import Atom, Document, Paragraph
+
+from .webvtt_writer import VttCue, WebVtt, escape_vtt_string, formatted_time
+
+
+def get_speaker_name(
+ speaker: str | None,
+ speaker_names: dict[str, str],
+) -> str:
+ if speaker is None:
+ return "Unknown Speaker"
+ else:
+ try:
+ return speaker_names[speaker]
+ except KeyError:
+ return f"Unnamed Speaker {speaker}"
+
+
+def atom_to_string(item: Atom, include_word_timings: bool):
+ if include_word_timings and isinstance(item.start, float):
+ return (
+ f"<{formatted_time(item.start)}>{escape_vtt_string(str(item.text))}"
+ )
+ else:
+ return escape_vtt_string(str(item.text))
+
+
+def can_generate_vtt(paras: list[Paragraph] | None):
+ if paras is None:
+ return (False, "No document content")
+
+ for para in paras:
+ for atom in para.children:
+ if not isinstance(atom.end, float) or not isinstance(atom.start, float):
+ return (False, "Missing timings for at least one atom")
+
+ return (True, "")
+
+
+def paragraph_to_cues(
+ paragraph: Paragraph,
+ include_word_timings: bool,
+ include_speaker_names: bool,
+ max_line_length: int | None,
+ speaker_names,
+):
+ cues = []
+ cue_payload = ""
+ cue_length = 0
+ cue_start = None
+ cue_end = None
+
+ def push_payload(payload):
+ nonlocal cue_start, cue_end
+ if include_speaker_names and paragraph.speaker:
+ payload = (
+ f""
+ + payload
+ )
+
+ assert cue_start is not None
+ assert cue_end is not None
+
+ if cue_start >= cue_end:
+ logging.debug(
+ f"found {cue_start=} that is not before {cue_end=}"
+ ", fixing the end to be behind cue_start"
+ )
+ cue_end = cue_start + 0.02
+
+ cues.append(
+ VttCue(
+ start_time=cue_start,
+ end_time=cue_end,
+ payload=payload,
+ payload_escaped=True,
+ )
+ )
+
+ for atom in paragraph.children:
+ atom_text = str(atom.text)
+ if (
+ max_line_length is not None
+ and cue_start is not None
+ and cue_end is not None
+ and cue_length + len(atom_text) > max_line_length
+ ):
+ push_payload(cue_payload)
+
+ cue_payload = ""
+ cue_length = 0
+ cue_start = None
+ cue_end = None
+
+ if atom.start and (cue_start is None or atom.start < cue_start):
+ cue_start = atom.start
+
+ if atom.end and (cue_end is None or atom.end > cue_end):
+ cue_end = atom.end
+
+ cue_payload += atom_to_string(atom, include_word_timings)
+ cue_length += len(atom_text)
+
+ if len(cue_payload) > 0:
+ if cue_start is None or cue_end is None:
+ raise ValueError(
+ "Paragraph contains no timings, cannot generate cue(s)."
+ " Make sure to only call this function if `canGenerateVtt` returns true",
+ )
+ push_payload(cue_payload)
+
+ return cues
+
+
+def generate_web_vtt(
+ doc: Document,
+ include_speaker_names: bool,
+ include_word_timing: bool,
+ max_line_length: int | None,
+) -> WebVtt:
+ vtt = WebVtt(
+ "This file was generated using transcribee."
+ " Find out more at https://github.com/bugbakery/transcribee"
+ )
+ for par in doc.children:
+ if len(par.children) == 0:
+ continue
+
+ for cue in paragraph_to_cues(
+ par,
+ include_word_timing,
+ include_speaker_names,
+ max_line_length,
+ doc.speaker_names,
+ ):
+ vtt.add(cue)
+
+ return vtt
diff --git a/worker/transcribee_worker/webvtt/webvtt_writer.py b/worker/transcribee_worker/webvtt/webvtt_writer.py
new file mode 100644
index 00000000..fff33a6c
--- /dev/null
+++ b/worker/transcribee_worker/webvtt/webvtt_writer.py
@@ -0,0 +1,171 @@
+#!/usr/bin/env python3
+
+import dataclasses
+import enum
+import re
+from abc import ABC, abstractmethod
+
+
+class SubtitleFormat(str, enum.Enum):
+ VTT = "vtt"
+ SRT = "srt"
+
+
+class VttElement(ABC):
+ @abstractmethod
+ def to_string(self, format: SubtitleFormat) -> str:
+ ...
+
+
+class Vertical(str, enum.Enum):
+ RL = "rl"
+ LR = "lr"
+
+
+class Align(str, enum.Enum):
+ START = "start"
+ CENTER = "center"
+ END = "end"
+
+
+@dataclasses.dataclass
+class VttCueSettings(VttElement):
+ vertical: None | Vertical
+ line: None | int | str
+ position: None | str
+ size: None | str
+ align: None | Align
+
+ def to_string(self, format: SubtitleFormat):
+ if format == SubtitleFormat.SRT:
+ return ""
+
+ def format_elem(name, elem):
+ if elem is None:
+ return []
+ else:
+ return [f"{name}:{elem}"]
+
+ return " ".join(
+ format_elem("vertical", self.vertical)
+ + format_elem("line", self.line)
+ + format_elem("position", self.position)
+ + format_elem("size", self.size)
+ + format_elem("align", self.align)
+ )
+
+
+def formatted_time(secs: float):
+ subseconds = int((secs % 1.0) * 1000)
+ seconds = int(secs % 60)
+ minutes = int((secs / 60) % 60)
+ hours = int((secs / 60 / 60) % 60)
+
+ return f"{hours:02}:{minutes:02}:{seconds:02}.{subseconds:03}"
+
+
+class VttCue(VttElement):
+ identifier: str | None
+ start_time: float
+ end_time: float
+ settings: VttCueSettings | None
+ payload: str
+
+ def __init__(
+ self,
+ start_time: float,
+ end_time: float,
+ payload: str,
+ payload_escaped: bool | None = None,
+ identifier: str | None = None,
+ identifier_escaped: bool | None = None,
+ settings: VttCueSettings | None = None,
+ ):
+ if start_time >= end_time:
+ raise ValueError("Cue end time must be greater than cue start time")
+
+ self.start_time = start_time
+ self.end_time = end_time
+
+ if not payload_escaped:
+ payload = escape_vtt_string(payload)
+
+ self.payload = payload
+
+ if identifier and not identifier_escaped:
+ identifier = escape_vtt_string(identifier)
+
+ if identifier and "\n" in identifier:
+ raise ValueError("WebVTT cue identifiers MUST NOT contain a newline")
+
+ if identifier and "-->" in identifier:
+ raise ValueError("WebVTT cue identifiers MUST NOT contain -->")
+
+ self.identifier = identifier
+ self.settings = settings
+
+ def to_string(self, format: SubtitleFormat):
+ ident = f"{self.identifier}\n" if self.identifier else ""
+ time = f"{formatted_time(self.start_time)} --> {formatted_time(self.end_time)}"
+ settings = f"{self.settings.to_string(format)}".strip() if self.settings else ""
+
+ return ident + time + settings + "\n" + self.payload
+
+
+class VttComment(VttElement):
+ comment_text: str
+
+ def __init__(self, text: str, escaped: bool = False):
+ if not escaped:
+ text = escape_vtt_string(text)
+ if "-->" in text:
+ raise ValueError("WebVTT comments MUST NOT contain -->")
+
+ self.comment_text = text
+
+ def to_string(self, format: SubtitleFormat):
+ if format != SubtitleFormat.VTT:
+ return ""
+ return f"NOTE {self.comment_text}"
+
+
+def escape_vtt_string(text: str) -> str:
+ escape_map = {"&": "&", "<": "<", ">": ">"}
+ re.sub(r"[&<>]", lambda obj: escape_map[obj.group(0)], text)
+ return text
+
+
+class VttHeader(VttElement):
+ header_text: str
+
+ def __init__(self, text: str, escaped: bool = False):
+ if not escaped:
+ text = escape_vtt_string(text)
+
+ if "-->" in text:
+ raise ValueError("WebVTT text header MUST NOT contain -->")
+
+ if "\n" in text:
+ raise ValueError("WebVTT text header MUST NOT contain newlines")
+
+ self.header_text = text
+
+ def to_string(self, format: SubtitleFormat):
+ if format != SubtitleFormat.VTT:
+ return ""
+
+ return f"WEBVTT {self.header_text}"
+
+
+class WebVtt:
+ elements: list[VttElement]
+
+ def __init__(self, header=""):
+ self.elements = [VttHeader(header)]
+
+ def add(self, element: VttElement):
+ self.elements.append(element)
+
+ def to_string(self, format: SubtitleFormat = SubtitleFormat.VTT):
+ as_strings = [elem.to_string(format) for elem in self.elements]
+ return "\n\n".join([elem for elem in as_strings if len(elem) > 0]) + "\n"
diff --git a/worker/transcribee_worker/worker.py b/worker/transcribee_worker/worker.py
index 4572b658..66b144c5 100644
--- a/worker/transcribee_worker/worker.py
+++ b/worker/transcribee_worker/worker.py
@@ -15,6 +15,8 @@
from transcribee_proto.api import (
AlignTask,
AssignedTask,
+ ExportFormat,
+ ExportTask,
ReencodeTask,
SpeakerIdentificationTask,
TaskType,
@@ -29,6 +31,8 @@
from transcribee_worker.torchaudio_align import align
from transcribee_worker.types import ProgressCallbackType
from transcribee_worker.util import aenumerate, load_audio
+from transcribee_worker.webvtt.export_webvtt import generate_web_vtt
+from transcribee_worker.webvtt.webvtt_writer import SubtitleFormat
from transcribee_worker.whisper_transcribe import (
transcribe_clean_async,
)
@@ -116,6 +120,7 @@ def __init__(
TaskType.ALIGN,
TaskType.TRANSCRIBE,
TaskType.REENCODE,
+ TaskType.EXPORT,
]
def claim_task(self) -> Optional[AssignedTask]:
@@ -185,6 +190,8 @@ def progress_callback(*, progress, step: Optional[str] = "", extra_data=None):
await self.align(task, progress_callback)
elif task.task_type == TaskType.REENCODE:
await self.reencode(task, progress_callback)
+ elif task.task_type == TaskType.EXPORT:
+ await self.export(task, progress_callback)
else:
raise ValueError(f"Invalid task type: '{task.task_type}'")
@@ -308,6 +315,36 @@ async def reencode(
None, self.add_document_media_file, task, output_path, tags
)
+ async def export(self, task: ExportTask, progress_callback: ProgressCallbackType):
+ async with self.api_client.document(task.document.id) as doc:
+ params = task.task_parameters
+ res = None
+ try:
+ vtt = generate_web_vtt(
+ EditorDocument.parse_obj(automerge.dump(doc.doc)),
+ params.include_speaker_names,
+ params.include_word_timing,
+ params.max_line_length,
+ )
+
+ if params.format == ExportFormat.VTT:
+ res = vtt.to_string(SubtitleFormat.VTT)
+ elif params.format == ExportFormat.SRT:
+ res = vtt.to_string(SubtitleFormat.SRT)
+
+ res = {"result": res}
+ except ValueError as e:
+ res = {"error": str(e)}
+
+ if res is not None:
+ logging.info(
+ f"Uploading document export for {task.document.id=} {res=}"
+ )
+ self.api_client.post(
+ f"documents/{task.document.id}/add_export_result/?task_id={task.id}",
+ json=res,
+ )
+
def set_duration(self, task: AssignedTask, duration: float):
logging.debug(
f"Setting audio duration for document {task.document.id=} {duration=}"