Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jun 25, 2024
1 parent 377ef00 commit 6b2e343
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions ragna/deploy/_database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import uuid
from typing import Any, Optional, cast
from typing import Any, Optional
from urllib.parse import urlsplit

from sqlalchemy import create_engine, select
Expand All @@ -28,8 +28,7 @@ def __init__(self, url: str) -> None:
self._to_orm = SchemaToOrmConverter()
self._to_schema = OrmToSchemaConverter()

# FIXME: make this get_user ??
def _get_user_id(self, session: Session, *, username: str) -> uuid.UUID:
def _get_user(self, session: Session, *, username: str) -> orm.User:
user: Optional[orm.User] = session.execute(
select(orm.User).where(orm.User.name == username)
).scalar_one_or_none()
Expand All @@ -41,7 +40,7 @@ def _get_user_id(self, session: Session, *, username: str) -> uuid.UUID:
session.add(user)
session.commit()

return cast(uuid.UUID, user.id)
return user

def add_document(
self,
Expand All @@ -54,7 +53,7 @@ def add_document(
session.add(
orm.Document(
id=document.id,
user_id=self._get_user_id(session, username=user),
user_id=self._get_user(session, username=user).id,
name=document.name,
metadata_=metadata,
)
Expand All @@ -74,11 +73,10 @@ def add_documents(
This function allows adding multiple documents at once by calling `add_all`. This is
important when there is non-negligible latency attached to each database operation.
"""
user_id = self._get_user_id(session, username=user)
documents = [
orm.Document(
id=document.id,
user_id=user_id,
user_id=self._get_user(session, username=user).id,
name=document.name,
metadata_=metadata,
)
Expand All @@ -92,7 +90,7 @@ def get_document(
) -> schemas.Document:
document = session.execute(
select(orm.Document).where(
(orm.Document.user_id == self._get_user_id(session, username=user))
(orm.Document.user_id == self._get_user(session, username=user).id)
& (orm.Document.id == id)
)
).scalar_one_or_none()
Expand All @@ -115,7 +113,7 @@ def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None:

orm_chat = self._to_orm.chat(
chat,
user_id=self._get_user_id(session, username=user),
user_id=self._get_user(session, username=user).id,
# We have to pass the documents here, because SQLAlchemy does not allow a
# second instance of orm.Document with the same primary key in the session.
documents=documents,
Expand All @@ -137,7 +135,7 @@ def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]:
self._to_schema.chat(chat)
for chat in session.execute(
self._select_chat(eager=True).where(
orm.Chat.user_id == self._get_user_id(session, username=user)
orm.Chat.user_id == self._get_user(session, username=user).id
)
)
.scalars()
Expand All @@ -152,7 +150,7 @@ def _get_orm_chat(
session.execute(
self._select_chat(eager=eager).where(
(orm.Chat.id == id)
& (orm.Chat.user_id == self._get_user_id(session, username=user))
& (orm.Chat.user_id == self._get_user(session, username=user).id)
)
)
.unique()
Expand All @@ -169,7 +167,7 @@ def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Cha

def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None:
orm_chat = self._to_orm.chat(
chat, user_id=self._get_user_id(session, username=user)
chat, user_id=self._get_user(session, username=user).id
)
session.merge(orm_chat)
session.commit()
Expand Down

0 comments on commit 6b2e343

Please sign in to comment.