Skip to content

Commit

Permalink
add list_metadata functionality to web API (#501)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <[email protected]>
  • Loading branch information
nenb and pmeier authored Aug 27, 2024
1 parent 474736e commit 3cfabaa
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def list_corpuses(self) -> list[str]:

def list_metadata(
self, corpus_name: Optional[str] = None
) -> dict[str, dict[str, tuple[type, list[Any]]]]:
) -> dict[str, dict[str, tuple[str, list[Any]]]]:
"""List available metadata for corpuses.
Args:
Expand Down
29 changes: 21 additions & 8 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,9 @@ async def get_components(_: UserDependency) -> schemas.Components:
],
)

@app.get("/corpuses")
async def get_corpuses(
_: UserDependency,
source_storage: Optional[str] = None,
) -> dict[str, list[str]]:
def _get_source_storage_components(
source_storage: Optional[str],
) -> list[SourceStorage]:
if source_storage is not None:
component = components_map.get(source_storage)
if component is None or not isinstance(component, SourceStorage):
Expand All @@ -159,17 +157,32 @@ async def get_corpuses(
http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
http_detail=RagnaException.MESSAGE,
)
source_storages = [component]
return [component]
else:
source_storages = [
return [
source_storage
for source_storage in components_map.values()
if isinstance(source_storage, SourceStorage)
]

@app.get("/corpuses")
async def get_corpuses(
_: UserDependency, source_storage: Optional[str] = None
) -> dict[str, list[str]]:
return {
source_storage.display_name(): source_storage.list_corpuses()
for source_storage in source_storages
for source_storage in _get_source_storage_components(source_storage)
}

@app.get("/corpuses/metadata")
async def get_corpus_metadata(
_: UserDependency,
source_storage: Optional[str] = None,
corpus_name: Optional[str] = None,
) -> dict[str, dict[str, dict[str, tuple[str, list[Any]]]]]:
return {
source_storage.display_name(): source_storage.list_metadata(corpus_name)
for source_storage in _get_source_storage_components(source_storage)
}

make_session = database.get_sessionmaker(config.api.database_url)
Expand Down
11 changes: 3 additions & 8 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from typing import TYPE_CHECKING, Any, Optional, cast

import ragna
from ragna.core import (
Document,
MetadataFilter,
MetadataOperator,
Source,
)
from ragna.core import Document, MetadataFilter, MetadataOperator, Source

from ._utils import raise_no_corpuses_available, raise_non_existing_corpus
from ._vector_database import VectorDatabaseSourceStorage
Expand Down Expand Up @@ -69,7 +64,7 @@ def _get_collection(

def list_metadata(
self, corpus_name: Optional[str] = None
) -> dict[str, dict[str, tuple[type, list[Any]]]]:
) -> dict[str, dict[str, tuple[str, list[Any]]]]:
if corpus_name is None:
corpus_names = self.list_corpuses()
else:
Expand All @@ -91,7 +86,7 @@ def list_metadata(
corpus_metadata[key].add(value)

metadata[corpus_name] = {
key: ({type(value) for value in values}.pop(), sorted(values))
key: ({type(value).__name__ for value in values}.pop(), sorted(values))
for key, values in corpus_metadata.items()
}

Expand Down
4 changes: 2 additions & 2 deletions ragna/source_storages/_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _get_corpus(

def list_metadata(
self, corpus_name: Optional[str] = None
) -> dict[str, dict[str, tuple[type, list[Any]]]]:
) -> dict[str, dict[str, tuple[str, list[Any]]]]:
if corpus_name is None:
corpus_names = self.list_corpuses()
else:
Expand All @@ -73,7 +73,7 @@ def list_metadata(
corpus_metadata[key].add(value)

metadata[corpus_name] = {
key: ({type(value) for value in values}.pop(), sorted(values))
key: ({type(value).__name__ for value in values}.pop(), sorted(values))
for key, values in corpus_metadata.items()
}

Expand Down
4 changes: 2 additions & 2 deletions ragna/source_storages/_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _get_table(

def list_metadata(
self, corpus_name: Optional[str] = None
) -> dict[str, dict[str, tuple[type, list[Any]]]]:
) -> dict[str, dict[str, tuple[str, list[Any]]]]:
if corpus_name is None:
corpus_names = self.list_corpuses()
else:
Expand All @@ -118,7 +118,7 @@ def list_metadata(

metadata[corpus_name] = {
key: (
{type(value) for value in values}.pop(),
{type(value).__name__ for value in values}.pop(),
sorted(values),
)
for key, values in corpus_metadata.items()
Expand Down
30 changes: 30 additions & 0 deletions tests/deploy/api/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer, corpus_name)
corpuses = client.get("/corpuses").raise_for_status().json()
assert corpuses == {source_storage: []}

corpuses_metadata = client.get("/corpuses/metadata").raise_for_status().json()
assert corpuses_metadata == {source_storage: {}}

assert client.get("/chats").raise_for_status().json() == [chat]
assert client.get(f"/chats/{chat['id']}").raise_for_status().json() == chat

Expand Down Expand Up @@ -105,6 +108,33 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer, corpus_name)
"/corpuses", params={"source_storage": "unknown_source_storage"}
).raise_for_status()

corpuses_metadata = client.get("/corpuses/metadata").raise_for_status().json()
assert corpus_name in corpuses_metadata[source_storage]
metadata_keys = corpuses_metadata[source_storage][corpus_name].keys()
assert list(metadata_keys) == ["document_id", "document_name", "path"]
for key in metadata_keys:
assert corpuses_metadata[source_storage][corpus_name][key][0] == "str"

corpuses_metadata = (
client.get(
"/corpuses/metadata",
params={"source_storage": source_storage, corpus_name: corpus_name},
)
.raise_for_status()
.json()
)
assert corpus_name in corpuses_metadata[source_storage]
metadata_keys = corpuses_metadata[source_storage][corpus_name].keys()
assert list(metadata_keys) == ["document_id", "document_name", "path"]
for key in metadata_keys:
assert corpuses_metadata[source_storage][corpus_name][key][0] == "str"

with pytest.raises(httpx.HTTPStatusError, match="422 Unprocessable Entity"):
client.get(
"/corpuses/metadata",
params={"source_storage": "unknown_source_storage"},
).raise_for_status()

prompt = "?"
if stream_answer:
with client.stream(
Expand Down
2 changes: 1 addition & 1 deletion tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def test_list_metadata(tmp_local_root, cls):
corpus_metadata[key].add(value)

expected_metadata[corpus_name] = {
key: ({type(value) for value in values}.pop(), values)
key: ({type(value).__name__ for value in values}.pop(), values)
for key, values in corpus_metadata.items()
}

Expand Down

0 comments on commit 3cfabaa

Please sign in to comment.