Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Increase performance of list_documents by eager loading #392

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions backend/openapi-schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1251,8 +1251,7 @@ paths:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/AssignedTaskResponse'
schema: {}
description: Successful Response
'422':
content:
Expand Down Expand Up @@ -1289,8 +1288,7 @@ paths:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/AssignedTaskResponse'
schema: {}
description: Successful Response
'422':
content:
Expand Down Expand Up @@ -1327,8 +1325,7 @@ paths:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/AssignedTaskResponse'
schema: {}
description: Successful Response
'422':
content:
Expand Down
5 changes: 4 additions & 1 deletion backend/transcribee_backend/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional, Tuple

from fastapi import Depends, Header, HTTPException
from sqlalchemy.orm import joinedload
from sqlmodel import Session, col, or_, select

from transcribee_backend.db import get_session
Expand Down Expand Up @@ -157,7 +158,9 @@ def get_authorized_task(
session: Session = Depends(get_session),
authorized_worker: Worker = Depends(get_authorized_worker),
):
statement = select(Task).where(Task.id == task_id)
statement = (
select(Task).where(Task.id == task_id).options(joinedload(Task.current_attempt))
)
task = session.exec(statement).one_or_none()
if task is None:
raise HTTPException(status_code=404)
Expand Down
35 changes: 29 additions & 6 deletions backend/transcribee_backend/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ class Task(TaskBase, table=True):
"secondaryjoin": "Task.id==TaskDependency.dependant_on_id",
},
)
dependency_links: List[TaskDependency] = Relationship(
sa_relationship_kwargs={
"primaryjoin": "Task.id==TaskDependency.dependent_task_id",
"viewonly": True,
},
)
dependants: List["Task"] = Relationship(
back_populates="dependencies",
link_model=TaskDependency,
Expand Down Expand Up @@ -155,12 +161,29 @@ class TaskResponse(TaskBase):

@classmethod
def from_orm(cls, task: Task, update={}) -> Self:
return super().from_orm(
task,
update={
"dependencies": [x.id for x in task.dependencies],
**update,
},
# The following code is equivalent to this:
# return super().from_orm(
# task,
# update={
# "dependencies": [x.dependant_on_id for x in task.dependency_links],
# **update,
# },
# )
# But much faster, because from_orm destructures the `obj` to mix it
# with the `update` dict, which causes an access to all attributes,
# including `dependencies`/`dependents` which are then all seperately
# selected from the database, causing many query
# Even with a small number of document this cuts the loading time of
# the `/api/v1/documents/` endpoint roughly in half on my test machine
return cls(
id=task.id,
state=task.state,
dependencies=[x.dependant_on_id for x in task.dependency_links],
current_attempt=None,
document_id=task.document_id,
task_type=task.task_type,
task_parameters=task.task_parameters,
**update,
)


Expand Down
13 changes: 12 additions & 1 deletion backend/transcribee_backend/routers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import desc
from sqlmodel import Session, col, select
from transcribee_proto.api import Document as ApiDocument
Expand Down Expand Up @@ -409,6 +410,12 @@ def list_documents(
select(Document)
.where(Document.user == token.user)
.order_by(desc(Document.changed_at), Document.id)
.options(
selectinload("tasks"),
selectinload("media_files"),
selectinload("media_files.tags"),
selectinload("tasks.dependency_links"),
)
)
results = session.exec(statement)
return [doc.as_api_document() for doc in results]
Expand Down Expand Up @@ -455,7 +462,11 @@ def get_document_tasks(
auth: AuthInfo = Depends(get_doc_min_readonly_auth),
session: Session = Depends(get_session),
) -> List[TaskResponse]:
statement = select(Task).where(Task.document_id == auth.document.id)
statement = (
select(Task)
.where(Task.document_id == auth.document.id)
.options(selectinload(Task.dependency_links))
)
return [TaskResponse.from_orm(x) for x in session.exec(statement)]


Expand Down
19 changes: 10 additions & 9 deletions backend/transcribee_backend/routers/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fastapi import APIRouter, Body, Depends, Query
from fastapi.exceptions import HTTPException
from sqlalchemy.orm import aliased
from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.sql.operators import is_
from sqlmodel import Session, col, select
from transcribee_proto.api import KeepaliveBody
Expand Down Expand Up @@ -104,7 +104,7 @@ def keepalive(
keepalive_data: KeepaliveBody = Body(),
session: Session = Depends(get_session),
task: Task = Depends(get_authorized_task),
) -> Optional[AssignedTaskResponse]:
):
# mostly to please the type checker, get_authorized_task already ensures
# that the task has a current attempt
if task.current_attempt is None:
Expand All @@ -115,7 +115,6 @@ def keepalive(
session.add(task.current_attempt)
session.add(task)
session.commit()
return AssignedTaskResponse.from_orm(task)


@task_router.post("/{task_id}/mark_completed/")
Expand All @@ -124,11 +123,10 @@ def mark_completed(
session: Session = Depends(get_session),
task: Task = Depends(get_authorized_task),
now: datetime.datetime = Depends(now_tz_aware),
) -> Optional[AssignedTaskResponse]:
):
finish_current_attempt(
session=session, task=task, now=now, extra_data=extra_data, successful=True
)
return AssignedTaskResponse.from_orm(task)


@task_router.post("/{task_id}/mark_failed/")
Expand All @@ -137,21 +135,24 @@ def mark_failed(
session: Session = Depends(get_session),
task: Task = Depends(get_authorized_task),
now: datetime.datetime = Depends(now_tz_aware),
) -> Optional[AssignedTaskResponse]:
):
now = now_tz_aware()

finish_current_attempt(
session=session, task=task, now=now, extra_data=extra_data, successful=False
)

return AssignedTaskResponse.from_orm(task)


@task_router.get("/")
def list_tasks(
session: Session = Depends(get_session),
token: UserToken = Depends(get_user_token),
) -> List[TaskResponse]:
statement = select(Task).join(Document).where(Document.user == token.user)
statement = (
select(Task)
.join(Document)
.where(Document.user == token.user)
.options(selectinload(Task.dependency_links))
)
results = session.exec(statement)
return [TaskResponse.from_orm(x) for x in results]
9 changes: 3 additions & 6 deletions backend/transcribee_backend/routers/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, delete, select
from sqlmodel import Session, delete
from transcribee_proto.api import LoginResponse

from transcribee_backend.auth import (
Expand All @@ -12,7 +12,7 @@
)
from transcribee_backend.db import get_session
from transcribee_backend.exceptions import UserAlreadyExists
from transcribee_backend.models import CreateUser, User, UserBase, UserToken
from transcribee_backend.models import CreateUser, UserBase, UserToken
from transcribee_backend.models.user import ChangePasswordRequest

user_router = APIRouter()
Expand Down Expand Up @@ -57,11 +57,8 @@ def logout(
@user_router.get("/me/")
def read_user(
token: UserToken = Depends(get_user_token),
session: Session = Depends(get_session),
) -> UserBase:
statement = select(User).where(User.id == token.user_id)
user = session.exec(statement).one()
return UserBase(username=user.username)
return UserBase(username=token.user.username)


@user_router.post("/change_password/")
Expand Down
6 changes: 3 additions & 3 deletions frontend/src/openapi-schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ export interface operations {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["AssignedTaskResponse"];
"application/json": unknown;
};
};
/** @description Validation Error */
Expand Down Expand Up @@ -1044,7 +1044,7 @@ export interface operations {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["AssignedTaskResponse"];
"application/json": unknown;
};
};
/** @description Validation Error */
Expand Down Expand Up @@ -1074,7 +1074,7 @@ export interface operations {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["AssignedTaskResponse"];
"application/json": unknown;
};
};
/** @description Validation Error */
Expand Down