Skip to content

Commit

Permalink
mypy for sqlalchemy (#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored May 9, 2024
1 parent d7e4783 commit 3cef0f7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ xfail_strict = true

files = "ragna"

plugins = [
"sqlmypy",
]

show_error_codes = true
pretty = true

Expand Down
2 changes: 1 addition & 1 deletion ragna/deploy/_api/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def update_chat(session: Session, user: str, chat: schemas.Chat) -> None:
orm_chat = _get_orm_chat(session, user=user, id=chat.id)

orm_chat.prepared = chat.prepared
orm_chat.messages = [
orm_chat.messages = [ # type: ignore[assignment]
_schema_to_orm_message(session, chat_id=chat.id, message=message)
for message in chat.messages
]
Expand Down
28 changes: 14 additions & 14 deletions ragna/deploy/_api/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class User(Base):
__tablename__ = "users"

id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined]
name = Column(types.String)
name = Column(types.String, nullable=False)


document_chat_association_table = Table(
Expand All @@ -57,10 +57,10 @@ class Document(Base):

id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined]
user_id = Column(ForeignKey("users.id"))
name = Column(types.String)
name = Column(types.String, nullable=False)
# Mind the trailing underscore here. Unfortunately, this is necessary, because
# metadata without the underscore is reserved by SQLAlchemy
metadata_ = Column(Json)
metadata_ = Column(Json, nullable=False)
chats = relationship(
"Chat",
secondary=document_chat_association_table,
Expand All @@ -77,19 +77,19 @@ class Chat(Base):

id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined]
user_id = Column(ForeignKey("users.id"))
name = Column(types.String)
name = Column(types.String, nullable=False)
documents = relationship(
"Document",
secondary=document_chat_association_table,
back_populates="chats",
)
source_storage = Column(types.String)
assistant = Column(types.String)
params = Column(Json)
source_storage = Column(types.String, nullable=False)
assistant = Column(types.String, nullable=False)
params = Column(Json, nullable=False)
messages = relationship(
"Message", cascade="all, delete", order_by="Message.timestamp"
)
prepared = Column(types.Boolean)
prepared = Column(types.Boolean, nullable=False)


source_message_association_table = Table(
Expand All @@ -111,9 +111,9 @@ class Source(Base):
document_id = Column(ForeignKey("documents.id"))
document = relationship("Document", back_populates="sources")

location = Column(types.String)
content = Column(types.String)
num_tokens = Column(types.Integer)
location = Column(types.String, nullable=False)
content = Column(types.String, nullable=False)
num_tokens = Column(types.Integer, nullable=False)

messages = relationship(
"Message",
Expand All @@ -127,11 +127,11 @@ class Message(Base):

id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined]
chat_id = Column(ForeignKey("chats.id"))
content = Column(types.String)
role = Column(types.Enum(MessageRole))
content = Column(types.String, nullable=False)
role = Column(types.Enum(MessageRole), nullable=False)
sources = relationship(
"Source",
secondary=source_message_association_table,
back_populates="messages",
)
timestamp = Column(types.DateTime)
timestamp = Column(types.DateTime, nullable=False)

0 comments on commit 3cef0f7

Please sign in to comment.