Skip to content

Commit

Permalink
Fix #399 : eager loading of Chats docs, msgs, srcs (#401)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <[email protected]>
Co-authored-by: Nick Byrne <[email protected]>
  • Loading branch information
3 people authored May 6, 2024
1 parent 04168ab commit e397acb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
35 changes: 27 additions & 8 deletions ragna/deploy/_api/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from urllib.parse import urlsplit

from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import sessionmaker as _sessionmaker

from ragna.core import RagnaException
Expand Down Expand Up @@ -136,30 +136,49 @@ def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat:
)


def _select_chat(*, eager: bool = False) -> Any:
selector = select(orm.Chat)
if eager:
selector = selector.options( # type: ignore[attr-defined]
joinedload(orm.Chat.messages).joinedload(orm.Message.sources),
joinedload(orm.Chat.documents),
)
return selector


def get_chats(session: Session, *, user: str) -> list[schemas.Chat]:
return [
_orm_to_schema_chat(chat)
for chat in session.execute(
select(orm.Chat).where(orm.Chat.user_id == _get_user_id(session, user))
_select_chat(eager=True).where(
orm.Chat.user_id == _get_user_id(session, user)
)
)
.scalars()
.unique()
.all()
]


def _get_orm_chat(session: Session, *, user: str, id: uuid.UUID) -> orm.Chat:
chat: Optional[orm.Chat] = session.execute(
select(orm.Chat).where(
(orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user))
def _get_orm_chat(
session: Session, *, user: str, id: uuid.UUID, eager: bool = False
) -> orm.Chat:
chat: Optional[orm.Chat] = (
session.execute(
_select_chat(eager=eager).where(
(orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user))
)
)
).scalar_one_or_none()
.unique()
.scalar_one_or_none()
)
if chat is None:
raise RagnaException()
return chat


def get_chat(session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat:
return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id))
return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id, eager=True))


def _schema_to_orm_source(session: Session, source: schemas.Source) -> orm.Source:
Expand Down
4 changes: 3 additions & 1 deletion ragna/deploy/_api/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class Chat(Base):
source_storage = Column(types.String)
assistant = Column(types.String)
params = Column(Json)
messages = relationship("Message", cascade="all, delete")
messages = relationship(
"Message", cascade="all, delete", order_by="Message.timestamp"
)
prepared = Column(types.Boolean)


Expand Down
4 changes: 1 addition & 3 deletions ragna/deploy/_api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ class Message(BaseModel):
content: str
role: ragna.core.MessageRole
sources: list[Source] = Field(default_factory=list)
timestamp: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.utcnow()
)
timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)

@classmethod
def from_core(cls, message: ragna.core.Message) -> Message:
Expand Down
5 changes: 5 additions & 0 deletions tests/deploy/api/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time

import pytest
from fastapi.testclient import TestClient
Expand All @@ -12,6 +13,10 @@

class TestAssistant(RagnaDemoAssistant):
def answer(self, prompt, sources, *, multiple_answer_chunks: bool):
# Simulate a "real" assistant through a small delay. See
# https://github.com/Quansight/ragna/pull/401#issuecomment-2095851440
# for why this is needed.
time.sleep(1e-3)
content = next(super().answer(prompt, sources))

if multiple_answer_chunks:
Expand Down

0 comments on commit e397acb

Please sign in to comment.