Skip to content

Commit

Permalink
487 corpus name as protocol (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
nenb authored Aug 14, 2024
1 parent aea6b67 commit bd2962c
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 21 deletions.
10 changes: 7 additions & 3 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,24 @@ class SourceStorage(Component, abc.ABC):
__ragna_protocol_methods__ = ["store", "retrieve"]

@abc.abstractmethod
def store(self, documents: list[Document]) -> None:
def store(self, corpus_name: Optional[str], documents: list[Document]) -> None:
"""Store content of documents.
Args:
corpus_name: Name of the corpus to store the documents in.
documents: Documents to store.
"""
...

@abc.abstractmethod
def retrieve(self, metadata_filter: MetadataFilter, prompt: str) -> list[Source]:
def retrieve(
self, corpus_name: Optional[str], metadata_filter: MetadataFilter, prompt: str
) -> list[Source]:
"""Retrieve sources for a given prompt.
Args:
documents: Documents to retrieve sources from.
corpus_name: Name of the corpus to retrieve sources from.
metadata_filter: Filter to select available sources.
prompt: Prompt to retrieve sources for.
Returns:
Expand Down
11 changes: 9 additions & 2 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def chat(
*,
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
corpus_name: Optional[str],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].
Expand All @@ -95,13 +96,15 @@ def chat(
[ragna.core.LocalDocument.from_path][] is invoked on it.
source_storage: Source storage to use.
assistant: Assistant to use.
corpus_name: Corpus name to use for the source storage.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
input=input,
source_storage=source_storage,
assistant=assistant,
corpus_name=corpus_name,
**params,
)

Expand Down Expand Up @@ -144,6 +147,7 @@ class Chat:
[ragna.core.LocalDocument.from_path][] is invoked on it.
source_storage: Source storage to use.
assistant: Assistant to use.
corpus_name: Corpus name to use for the source storage.
**params: Additional parameters passed to the source storage and assistant.
"""

Expand All @@ -154,6 +158,7 @@ def __init__(
input: Union[MetadataFilter, None, Iterable[Union[Document, str, Path]]],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
corpus_name: Optional[str],
**params: Any,
) -> None:
self._rag = rag
Expand All @@ -165,6 +170,8 @@ def __init__(
)
self.assistant = cast(Assistant, self._rag._load_component(assistant))

self.corpus_name = corpus_name

special_params = SpecialChatParams().model_dump()
special_params.update(params)
params = special_params
Expand All @@ -190,7 +197,7 @@ async def prepare(self) -> Message:
if self._prepared:
return welcome

await self._run(self.source_storage.store, self.documents)
await self._run(self.source_storage.store, self.corpus_name, self.documents)
self._prepared = True

self._messages.append(welcome)
Expand All @@ -215,7 +222,7 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:
)

sources = await self._run(
self.source_storage.retrieve, self.metadata_filter, prompt
self.source_storage.retrieve, self.corpus_name, self.metadata_filter, prompt
)

question = Message(content=prompt, role=MessageRole.USER, sources=sources)
Expand Down
4 changes: 4 additions & 0 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def schema_to_core_chat(
input=input,
source_storage=get_component(chat.metadata.source_storage), # type: ignore[arg-type]
assistant=get_component(chat.metadata.assistant), # type: ignore[arg-type]
corpus_name=chat.metadata.corpus_name,
user=user,
chat_id=chat.id,
chat_name=chat.metadata.name,
Expand Down Expand Up @@ -289,6 +290,9 @@ async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message:

welcome = schemas.Message.from_core(await core_chat.prepare())

if chat.prepared:
return welcome

chat.prepared = True
chat.messages.append(welcome)
database.update_chat(session, user=user, chat=chat)
Expand Down
2 changes: 2 additions & 0 deletions ragna/deploy/_api/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def add_chat(session: Session, *, user: str, chat: schemas.Chat) -> None:
documents=documents,
source_storage=chat.metadata.source_storage,
assistant=chat.metadata.assistant,
corpus_name=chat.metadata.corpus_name,
params=chat.metadata.params,
prepared=chat.prepared,
)
Expand Down Expand Up @@ -180,6 +181,7 @@ def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat:
input=input,
source_storage=chat.source_storage,
assistant=chat.assistant,
corpus_name=chat.corpus_name,
params=chat.params,
),
messages=messages,
Expand Down
1 change: 1 addition & 0 deletions ragna/deploy/_api/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Chat(Base):
)
source_storage = Column(types.String, nullable=False)
assistant = Column(types.String, nullable=False)
corpus_name = Column(types.String, nullable=True)
params = Column(Json, nullable=False)
messages = relationship(
"Message", cascade="all, delete", order_by="Message.timestamp"
Expand Down
3 changes: 2 additions & 1 deletion ragna/deploy/_api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
import uuid
from typing import Any, Union
from typing import Any, Optional, Union

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -102,6 +102,7 @@ class ChatMetadata(BaseModel):
assistant: str
params: dict
input: Union[None, ragna.core.MetadataFilter, list[Document]]
corpus_name: Optional[str] = None


class Chat(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def _get_collection(self, corpus_name: Optional[str]) -> chromadb.Collection:

def store(
self,
corpus_name: Optional[str],
documents: list[Document],
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
Expand Down Expand Up @@ -123,10 +123,10 @@ def _translate_metadata_filter(

def retrieve(
self,
corpus_name: Optional[str],
metadata_filter: MetadataFilter,
prompt: str,
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
Expand Down
14 changes: 8 additions & 6 deletions ragna/source_storages/_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def display_name(cls) -> str:
return "Ragna/DemoSourceStorage"

def __init__(self) -> None:
self._storage: list[dict[str, Any]] = []
self._storage: dict[Optional[str], list[dict[str, Any]]] = {None: []}

def store(self, documents: list[Document]) -> None:
self._storage.extend(
def store(self, corpus_name: Optional[str], documents: list[Document]) -> None:
self._storage[None].extend(
[
dict(
document_id=str(document.id),
Expand Down Expand Up @@ -66,7 +66,7 @@ def _apply_filter(
self, metadata_filter: Optional[MetadataFilter]
) -> list[tuple[int, dict[str, Any]]]:
if metadata_filter is None:
return list(enumerate(self._storage))
return list(enumerate(self._storage[None]))
elif metadata_filter.operator is MetadataOperator.RAW:
raise RagnaException
elif metadata_filter.operator in {MetadataOperator.AND, MetadataOperator.OR}:
Expand All @@ -93,7 +93,7 @@ def _apply_filter(
return [(idx, rows_map[idx]) for idx in sorted(idcs)]
else:
rows_with_idx = []
for idx, row in enumerate(self._storage):
for idx, row in enumerate(self._storage[None]):
value = row.get(metadata_filter.key)
if value is None:
continue
Expand All @@ -105,7 +105,9 @@ def _apply_filter(

return rows_with_idx

def retrieve(self, metadata_filter: MetadataFilter, prompt: str) -> list[Source]:
def retrieve(
self, corpus_name: Optional[str], metadata_filter: MetadataFilter, prompt: str
) -> list[Source]:
return [
Source(
id=row["__id__"],
Expand Down
4 changes: 2 additions & 2 deletions ragna/source_storages/_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def _get_table(self, corpus_name: Optional[str] = None) -> lancedb.table.Table:

def store(
self,
corpus_name: Optional[str],
documents: list[Document],
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
Expand Down Expand Up @@ -192,10 +192,10 @@ def _translate_metadata_filter(self, metadata_filter: MetadataFilter) -> str:

def retrieve(
self,
corpus_name: Optional[str],
metadata_filter: Optional[MetadataFilter],
prompt: str,
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ async def main(*, input, source_storage, assistant):
input=input,
source_storage=source_storage,
assistant=assistant,
corpus_name=None,
) as chat:
return await chat.answer("?")

Expand All @@ -30,7 +31,7 @@ async def main(*, input, source_storage, assistant):
if input_type == "documents":
input = [document]
else:
source_storage.store([document])
source_storage.store(None, [document])

if input_type == "corpus":
input = None
Expand Down
1 change: 1 addition & 0 deletions tests/core/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def chat(
input=documents,
source_storage=source_storage,
assistant=assistant,
corpus_name="test-corpus",
**params,
)

Expand Down
1 change: 1 addition & 0 deletions tests/deploy/api/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_unknown_component(tmp_local_root):
"name": "test-chat",
"source_storage": "unknown_source_storage",
"assistant": "unknown_assistant",
"corpus_name": "test-corpus",
"params": {},
"input": [document],
},
Expand Down
4 changes: 3 additions & 1 deletion tests/deploy/api/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
@skip_on_windows
@pytest.mark.parametrize("multiple_answer_chunks", [True, False])
@pytest.mark.parametrize("stream_answer", [True, False])
def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):
@pytest.mark.parametrize("corpus_name", ["test-corpus", None])
def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer, corpus_name):
config = Config(local_root=tmp_local_root, assistants=[TestAssistant])

document_root = config.local_root / "documents"
Expand Down Expand Up @@ -64,6 +65,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):
"name": "test-chat",
"source_storage": source_storage,
"assistant": assistant,
"corpus_name": corpus_name,
"params": {"multiple_answer_chunks": multiple_answer_chunks},
"input": [document],
}
Expand Down
57 changes: 54 additions & 3 deletions tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,67 @@ def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idc
)

source_storage = source_storage_cls()
source_storage.store(documents)
source_storage.store("test-corpus", documents)

prompt = "What is the secret number?"
num_tokens = 4096
sources = source_storage.retrieve(
metadata_filter=metadata_filter, prompt=prompt, num_tokens=num_tokens
corpus_name="test-corpus",
metadata_filter=metadata_filter,
prompt=prompt,
num_tokens=num_tokens,
)

actual_idcs = sorted(map(int, (source.document_name for source in sources)))
assert actual_idcs == expected_idcs

# Should be able to call .store() multiple times
source_storage.store(documents)
source_storage.store("test-corpus", documents)


@pytest.mark.parametrize("source_storage_cls", [Chroma, LanceDB])
def test_corpus_names(tmp_local_root, source_storage_cls):
document_root = tmp_local_root / "documents"
document_root.mkdir()

secret_path = document_root / "secret_doc"
with open(secret_path, "w") as file:
file.write("The secret number is 42!\n")
secret_document = LocalDocument.from_path(
secret_path,
handler=PlainTextDocumentHandler(),
)

dummy_path = document_root / "dummy_doc"
with open(dummy_path, "w") as file:
file.write("Dummy Doc!\n")
dummy_document = LocalDocument.from_path(
dummy_path,
handler=PlainTextDocumentHandler(),
)

source_storage = source_storage_cls()
source_storage.store("test-corpus-secret", [secret_document])

source_storage = source_storage_cls()
source_storage.store("test-corpus-dummy", [dummy_document])

prompt = "What is the secret number?"
num_tokens = 4096
secret_sources = source_storage.retrieve(
corpus_name="test-corpus-secret",
prompt=prompt,
metadata_filter=None,
num_tokens=num_tokens,
)
assert "The secret number is 42" in secret_sources[0].content

prompt = "What is the secret number?"
num_tokens = 4096
secret_sources = source_storage.retrieve(
corpus_name="test-corpus-dummy",
prompt=prompt,
metadata_filter=None,
num_tokens=num_tokens,
)
assert "The secret number is 42" not in secret_sources[0].content

0 comments on commit bd2962c

Please sign in to comment.