From 3cef0f7da1f2ed90e5d0618bcad82f824d00dc5a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 May 2024 08:49:32 +0200 Subject: [PATCH] mypy for sqlalchemy (#402) --- pyproject.toml | 4 ++++ ragna/deploy/_api/database.py | 2 +- ragna/deploy/_api/orm.py | 28 ++++++++++++++-------------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 81581d48..bd7a7ca7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,10 @@ xfail_strict = true files = "ragna" +plugins = [ + "sqlmypy", +] + show_error_codes = true pretty = true diff --git a/ragna/deploy/_api/database.py b/ragna/deploy/_api/database.py index 30d62ef9..2a61b048 100644 --- a/ragna/deploy/_api/database.py +++ b/ragna/deploy/_api/database.py @@ -256,7 +256,7 @@ def update_chat(session: Session, user: str, chat: schemas.Chat) -> None: orm_chat = _get_orm_chat(session, user=user, id=chat.id) orm_chat.prepared = chat.prepared - orm_chat.messages = [ + orm_chat.messages = [ # type: ignore[assignment] _schema_to_orm_message(session, chat_id=chat.id, message=message) for message in chat.messages ] diff --git a/ragna/deploy/_api/orm.py b/ragna/deploy/_api/orm.py index 033d4b4e..04a3583e 100644 --- a/ragna/deploy/_api/orm.py +++ b/ragna/deploy/_api/orm.py @@ -41,7 +41,7 @@ class User(Base): __tablename__ = "users" id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] - name = Column(types.String) + name = Column(types.String, nullable=False) document_chat_association_table = Table( @@ -57,10 +57,10 @@ class Document(Base): id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] user_id = Column(ForeignKey("users.id")) - name = Column(types.String) + name = Column(types.String, nullable=False) # Mind the trailing underscore here. Unfortunately, this is necessary, because # metadata without the underscore is reserved by SQLAlchemy - metadata_ = Column(Json) + metadata_ = Column(Json, nullable=False) chats = relationship( "Chat", secondary=document_chat_association_table, @@ -77,19 +77,19 @@ class Chat(Base): id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] user_id = Column(ForeignKey("users.id")) - name = Column(types.String) + name = Column(types.String, nullable=False) documents = relationship( "Document", secondary=document_chat_association_table, back_populates="chats", ) - source_storage = Column(types.String) - assistant = Column(types.String) - params = Column(Json) + source_storage = Column(types.String, nullable=False) + assistant = Column(types.String, nullable=False) + params = Column(Json, nullable=False) messages = relationship( "Message", cascade="all, delete", order_by="Message.timestamp" ) - prepared = Column(types.Boolean) + prepared = Column(types.Boolean, nullable=False) source_message_association_table = Table( @@ -111,9 +111,9 @@ class Source(Base): document_id = Column(ForeignKey("documents.id")) document = relationship("Document", back_populates="sources") - location = Column(types.String) - content = Column(types.String) - num_tokens = Column(types.Integer) + location = Column(types.String, nullable=False) + content = Column(types.String, nullable=False) + num_tokens = Column(types.Integer, nullable=False) messages = relationship( "Message", @@ -127,11 +127,11 @@ class Message(Base): id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] chat_id = Column(ForeignKey("chats.id")) - content = Column(types.String) - role = Column(types.Enum(MessageRole)) + content = Column(types.String, nullable=False) + role = Column(types.Enum(MessageRole), nullable=False) sources = relationship( "Source", secondary=source_message_association_table, back_populates="messages", ) - timestamp = Column(types.DateTime) + timestamp = Column(types.DateTime, nullable=False)