Skip to content

Commit

Permalink
[ENH] Add GET endpoints for documents (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
smokestacklightnin authored Feb 5, 2025
1 parent e5a602a commit e941366
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 30 deletions.
10 changes: 9 additions & 1 deletion ragna/core/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import io
import mimetypes
import uuid
from functools import cached_property
from pathlib import Path
Expand All @@ -25,11 +26,15 @@ def __init__(
name: str,
metadata: dict[str, Any],
handler: Optional[DocumentHandler] = None,
mime_type: str | None = None,
):
self.id = id or uuid.uuid4()
self.name = name
self.metadata = metadata
self.handler = handler or self.get_handler(name)
self.mime_type = (
mime_type or mimetypes.guess_type(name)[0] or "application/octet-stream"
)

@staticmethod
def supported_suffixes() -> set[str]:
Expand Down Expand Up @@ -76,8 +81,11 @@ def __init__(
name: str,
metadata: dict[str, Any],
handler: Optional[DocumentHandler] = None,
mime_type: str | None = None,
):
super().__init__(id=id, name=name, metadata=metadata, handler=handler)
super().__init__(
id=id, name=name, metadata=metadata, handler=handler, mime_type=mime_type
)
if "path" not in self.metadata:
metadata["path"] = str(ragna.local_root() / "documents" / str(self.id))

Expand Down
23 changes: 23 additions & 0 deletions ragna/deploy/_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import uuid
from typing import Annotated, Any, AsyncIterator

Expand Down Expand Up @@ -40,6 +41,28 @@ async def content_stream() -> AsyncIterator[bytes]:
],
)

@router.get("/documents")
async def get_documents(user: UserDependency) -> list[schemas.Document]:
return engine.get_documents(user=user.name)

@router.get("/documents/{id}")
async def get_document(user: UserDependency, id: uuid.UUID) -> schemas.Document:
return engine.get_document(user=user.name, id=id)

@router.get("/documents/{id}/content")
async def get_document_content(
user: UserDependency, id: uuid.UUID
) -> StreamingResponse:
schema_document = engine.get_document(user=user.name, id=id)
core_document = engine._to_core.document(schema_document)
headers = {"Content-Disposition": f"inline; filename={schema_document.name}"}

return StreamingResponse(
io.BytesIO(core_document.read()),
media_type=core_document.mime_type,
headers=headers,
)

@router.get("/components")
def get_components() -> schemas.Components:
return engine.get_components()
Expand Down
22 changes: 13 additions & 9 deletions ragna/deploy/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,24 @@ def add_documents(
session.commit()

def _get_orm_documents(
self, session: Session, *, user: str, ids: Collection[uuid.UUID]
self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[orm.Document]:
# FIXME also check if the user is allowed to access the documents
# FIXME: maybe just take the user id to avoid getting it twice in add_chat?
documents = (
session.execute(select(orm.Document).where(orm.Document.id.in_(ids)))
.scalars()
.all()
)
if len(documents) != len(ids):
expr = select(orm.Document)
if ids is not None:
expr = expr.where(orm.Document.id.in_(ids))
documents = session.execute(expr).scalars().all()

if (ids is not None) and (len(documents) != len(ids)):
raise RagnaException(
str(set(ids) - {document.id for document in documents})
)

return documents # type: ignore[no-any-return]

def get_documents(
self, session: Session, *, user: str, ids: Collection[uuid.UUID]
self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[schemas.Document]:
return [
self._to_schema.document(document)
Expand Down Expand Up @@ -288,6 +288,7 @@ def document(
user_id=user_id,
name=document.name,
metadata_=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: schemas.Source) -> orm.Source:
Expand Down Expand Up @@ -354,7 +355,10 @@ def api_key(self, api_key: orm.ApiKey) -> schemas.ApiKey:

def document(self, document: orm.Document) -> schemas.Document:
return schemas.Document(
id=document.id, name=document.name, metadata=document.metadata_
id=document.id,
name=document.name,
metadata=document.metadata_,
mime_type=document.mime_type,
)

def source(self, source: orm.Source) -> schemas.Source:
Expand Down
22 changes: 16 additions & 6 deletions ragna/deploy/_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import secrets
import uuid
from typing import Any, AsyncIterator, Optional, cast
from typing import Any, AsyncIterator, Collection, Optional, cast

from fastapi import status as http_status_code

Expand Down Expand Up @@ -156,7 +156,9 @@ def register_documents(
# We create core.Document's first, because they might update the metadata
core_documents = [
self._config.document(
name=registration.name, metadata=registration.metadata
name=registration.name,
metadata=registration.metadata,
mime_type=registration.mime_type,
)
for registration in document_registrations
]
Expand All @@ -182,17 +184,23 @@ async def store_documents(

streams = dict(ids_and_streams)

with self._database.get_session() as session:
documents = self._database.get_documents(
session, user=user, ids=streams.keys()
)
documents = self.get_documents(user=user, ids=streams.keys())

for document in documents:
core_document = cast(
ragna.core.LocalDocument, self._to_core.document(document)
)
await core_document._write(streams[document.id])

def get_documents(
self, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[schemas.Document]:
with self._database.get_session() as session:
return self._database.get_documents(session, user=user, ids=ids)

def get_document(self, *, user: str, id: uuid.UUID) -> schemas.Document:
return self.get_documents(user=user, ids=[id])[0]

def create_chat(
self, *, user: str, chat_creation: schemas.ChatCreation
) -> schemas.Chat:
Expand Down Expand Up @@ -280,6 +288,7 @@ def document(self, document: schemas.Document) -> core.Document:
id=document.id,
name=document.name,
metadata=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: schemas.Source) -> core.Source:
Expand Down Expand Up @@ -328,6 +337,7 @@ def document(self, document: core.Document) -> schemas.Document:
id=document.id,
name=document.name,
metadata=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: core.Source) -> schemas.Source:
Expand Down
1 change: 1 addition & 0 deletions ragna/deploy/_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Document(Base):
# Mind the trailing underscore here. Unfortunately, this is necessary, because
# metadata without the underscore is reserved by SQLAlchemy
metadata_ = Column(Json, nullable=False)
mime_type = Column(types.String, nullable=False)
chats = relationship(
"Chat",
secondary=document_chat_association_table,
Expand Down
2 changes: 2 additions & 0 deletions ragna/deploy/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ class Components(BaseModel):
class DocumentRegistration(BaseModel):
name: str
metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None


class Document(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
name: str
metadata: dict[str, Any]
mime_type: str


class Source(BaseModel):
Expand Down
20 changes: 6 additions & 14 deletions tests/deploy/api/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ragna import assistants
from ragna.core import RagnaException
from ragna.deploy import Config
from tests.deploy.api.utils import upload_documents
from tests.deploy.utils import make_api_app, make_api_client


Expand Down Expand Up @@ -56,17 +57,8 @@ def test_unknown_component(tmp_local_root):
with open(document_path, "w") as file:
file.write("!\n")

with make_api_client(
config=Config(), ignore_unavailable_components=False
) as client:
document = (
client.post("/api/documents", json=[{"name": document_path.name}])
.raise_for_status()
.json()[0]
)

with open(document_path, "rb") as file:
client.put("/api/documents", files={"documents": (document["id"], file)})
with make_api_client(config=config, ignore_unavailable_components=False) as client:
document = upload_documents(client=client, document_paths=[document_path])[0]

response = client.post(
"/api/chats",
Expand All @@ -80,7 +72,7 @@ def test_unknown_component(tmp_local_root):
},
)

assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

error = response.json()["error"]
assert "Unknown component" in error["message"]
error = response.json()["error"]
assert "Unknown component" in error["message"]
109 changes: 109 additions & 0 deletions tests/deploy/api/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import mimetypes

import pytest

from ragna.deploy import Config
from tests.deploy.api.utils import upload_documents
from tests.deploy.utils import make_api_client

_document_content_text = [
f"Needs more {needs_more_of}\n" for needs_more_of in ["reverb", "cowbell"]
]


mime_types = pytest.mark.parametrize(
("mime_type",),
[
(None,), # Let the mimetypes library decide
("text/markdown",),
("application/pdf",),
],
)


@mime_types
def test_get_documents(tmp_local_root, mime_type):
config = Config(local_root=tmp_local_root)

document_root = config.local_root / "documents"
document_root.mkdir()
document_paths = [
document_root / f"test{idx}.txt" for idx in range(len(_document_content_text))
]
for content, document_path in zip(_document_content_text, document_paths):
with open(document_path, "w") as file:
file.write(content)

with make_api_client(config=config, ignore_unavailable_components=False) as client:
documents = upload_documents(
client=client,
document_paths=document_paths,
mime_types=[mime_type for _ in document_paths],
)
response = client.get("/api/documents").raise_for_status()

# Sort the items in case they are retrieved in different orders
def sorting_key(d):
return d["id"]

assert sorted(documents, key=sorting_key) == sorted(
response.json(), key=sorting_key
)


@mime_types
def test_get_document(tmp_local_root, mime_type):
config = Config(local_root=tmp_local_root)

document_root = config.local_root / "documents"
document_root.mkdir()
document_path = document_root / "test.txt"
with open(document_path, "w") as file:
file.write(_document_content_text[0])

with make_api_client(config=config, ignore_unavailable_components=False) as client:
document = upload_documents(
client=client,
document_paths=[document_path],
mime_types=[mime_type],
)[0]
response = client.get(f"/api/documents/{document['id']}").raise_for_status()

assert document == response.json()


@mime_types
def test_get_document_content(tmp_local_root, mime_type):
config = Config(local_root=tmp_local_root)

document_root = config.local_root / "documents"
document_root.mkdir()
document_path = document_root / "test.txt"
document_content = _document_content_text[0]
with open(document_path, "w") as file:
file.write(document_content)

with make_api_client(config=config, ignore_unavailable_components=False) as client:
document = upload_documents(
client=client,
document_paths=[document_path],
mime_types=[mime_type],
)[0]

with client.stream(
"GET", f"/api/documents/{document['id']}/content"
) as response:
response_mime_type = response.headers["content-type"].split(";")[0]
received_lines = list(response.iter_lines())

assert received_lines == [document_content.replace("\n", "")]

assert (
document["mime_type"]
== response_mime_type
== (
mime_type
if mime_type is not None
else mimetypes.guess_type(document_path.name)[0]
)
)
37 changes: 37 additions & 0 deletions tests/deploy/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import contextlib


def upload_documents(*, client, document_paths, mime_types=None):
if mime_types is None:
mime_types = [None for _ in document_paths]
else:
assert len(mime_types) == len(document_paths)
documents = (
client.post(
"/api/documents",
json=[
{
"name": document_path.name,
"mime_type": mime_type,
}
for document_path, mime_type in zip(document_paths, mime_types)
],
)
.raise_for_status()
.json()
)

with contextlib.ExitStack() as stack:
files = [
stack.enter_context(open(document_path, "rb"))
for document_path in document_paths
]
client.put(
"/api/documents",
files=[
("documents", (document["id"], file))
for document, file in zip(documents, files)
],
)

return documents

0 comments on commit e941366

Please sign in to comment.