diff --git a/ragna/deploy/_api/database.py b/ragna/deploy/_api/database.py index ae0bccb0..cbf388e2 100644 --- a/ragna/deploy/_api/database.py +++ b/ragna/deploy/_api/database.py @@ -6,7 +6,7 @@ from urllib.parse import urlsplit from sqlalchemy import create_engine, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import sessionmaker as _sessionmaker from ragna.core import RagnaException @@ -136,30 +136,49 @@ def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat: ) +def _select_chat(*, eager: bool = False) -> Any: + selector = select(orm.Chat) + if eager: + selector = selector.options( # type: ignore[attr-defined] + joinedload(orm.Chat.messages).joinedload(orm.Message.sources), + joinedload(orm.Chat.documents), + ) + return selector + + def get_chats(session: Session, *, user: str) -> list[schemas.Chat]: return [ _orm_to_schema_chat(chat) for chat in session.execute( - select(orm.Chat).where(orm.Chat.user_id == _get_user_id(session, user)) + _select_chat(eager=True).where( + orm.Chat.user_id == _get_user_id(session, user) + ) ) .scalars() + .unique() .all() ] -def _get_orm_chat(session: Session, *, user: str, id: uuid.UUID) -> orm.Chat: - chat: Optional[orm.Chat] = session.execute( - select(orm.Chat).where( - (orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user)) +def _get_orm_chat( + session: Session, *, user: str, id: uuid.UUID, eager: bool = False +) -> orm.Chat: + chat: Optional[orm.Chat] = ( + session.execute( + _select_chat(eager=eager).where( + (orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user)) + ) ) - ).scalar_one_or_none() + .unique() + .scalar_one_or_none() + ) if chat is None: raise RagnaException() return chat def get_chat(session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat: - return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id)) + return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id, eager=True)) def _schema_to_orm_source(session: Session, source: schemas.Source) -> orm.Source: diff --git a/ragna/deploy/_api/orm.py b/ragna/deploy/_api/orm.py index 7e3f8c8e..033d4b4e 100644 --- a/ragna/deploy/_api/orm.py +++ b/ragna/deploy/_api/orm.py @@ -86,7 +86,9 @@ class Chat(Base): source_storage = Column(types.String) assistant = Column(types.String) params = Column(Json) - messages = relationship("Message", cascade="all, delete") + messages = relationship( + "Message", cascade="all, delete", order_by="Message.timestamp" + ) prepared = Column(types.Boolean) diff --git a/ragna/deploy/_api/schemas.py b/ragna/deploy/_api/schemas.py index 7439f0d5..53957a74 100644 --- a/ragna/deploy/_api/schemas.py +++ b/ragna/deploy/_api/schemas.py @@ -56,9 +56,7 @@ class Message(BaseModel): content: str role: ragna.core.MessageRole sources: list[Source] = Field(default_factory=list) - timestamp: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.utcnow() - ) + timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) @classmethod def from_core(cls, message: ragna.core.Message) -> Message: diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index 0efa0014..41b154db 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -1,4 +1,5 @@ import json +import time import pytest from fastapi.testclient import TestClient @@ -12,6 +13,10 @@ class TestAssistant(RagnaDemoAssistant): def answer(self, prompt, sources, *, multiple_answer_chunks: bool): + # Simulate a "real" assistant through a small delay. See + # https://github.com/Quansight/ragna/pull/401#issuecomment-2095851440 + # for why this is needed. + time.sleep(1e-3) content = next(super().answer(prompt, sources)) if multiple_answer_chunks: