Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embeddings Model and Chunking Engine (Preliminary PR for evaluation purposes only) #354

Closed
wants to merge 7 commits into from
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ __pycache__/
*.py[cod]
*$py.class

# Ignore my entrypoint
test.py
doctest/
Comment on lines +9 to +11
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flag to revert them before merge.


# C extensions
*.so

Expand Down
56 changes: 55 additions & 1 deletion ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@
import enum
import functools
import inspect
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union
import warnings
from typing import (
AsyncIterable,
AsyncIterator,
Iterator,
Optional,
Type,
Union,
get_args,
get_origin,
get_type_hints,
)

import pydantic
import pydantic.utils
Expand Down Expand Up @@ -76,6 +87,12 @@ def _protocol_model(cls) -> Type[pydantic.BaseModel]:
return merge_models(cls.display_name(), *cls._protocol_models().values())


# Just for demo purposes. We need to move the actual class here.
# See https://github.com/Quansight/ragna/pull/354#discussion_r1526235318
class Embedding:
pass


class Source(pydantic.BaseModel):
"""Data class for sources stored inside a source storage.

Expand All @@ -98,6 +115,43 @@ class Source(pydantic.BaseModel):

class SourceStorage(Component, abc.ABC):
__ragna_protocol_methods__ = ["store", "retrieve"]
__ragna_input_type__: Union[Document, Embedding]

def __init_subclass__(cls):
if inspect.isabstract(cls):
return

valid_input_types = get_args(get_type_hints(cls)["__ragna_input_type__"])

input_parameter_name = list(inspect.signature(cls.store).parameters.keys())[1]
input_parameter_annotation = get_type_hints(cls.store).get(input_parameter_name)

if input_parameter_annotation is None:
input_type = None
else:

def extract_input_type():
origin = get_origin(input_parameter_annotation)
if origin is None:
return None

args = get_args(input_parameter_annotation)
if len(args) != 1:
return None

input_type = args[0]
if not issubclass(input_type, valid_input_types):
return None

return input_type

input_type = extract_input_type()

if input_type is None:
warnings.warn("ADDME")
input_type = Document

cls.__ragna_input_type__ = input_type

@abc.abstractmethod
def store(self, documents: list[Document]) -> None:
Expand Down
28 changes: 26 additions & 2 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from ._document import Document, LocalDocument
from ._utils import RagnaException, default_user, merge_models

from ragna.source_storages._embedding_model import GenericEmbeddingModel
from ragna.source_storages._chunking_model import GenericChunkingModel

T = TypeVar("T")
C = TypeVar("C", bound=Component)

Expand Down Expand Up @@ -80,6 +83,8 @@ def chat(
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
embedding_model: Union[Type[GenericEmbeddingModel], GenericEmbeddingModel],
chunking_model: Union[Type[GenericChunkingModel], GenericChunkingModel],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].
Expand All @@ -89,13 +94,17 @@ def chat(
[ragna.core.LocalDocument.from_path][] is invoked on it.
source_storage: Source storage to use.
assistant: Assistant to use.
embedding_model: Embedding Model to use
chunking_model: Chunking Model to use
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
source_storage=source_storage,
assistant=assistant,
embedding_model=embedding_model,
chunking_model=chunking_model,
**params,
)

Expand Down Expand Up @@ -148,10 +157,15 @@ def __init__(
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
embedding_model: Union[Type[GenericEmbeddingModel], GenericEmbeddingModel],
chunking_model: Union[Type[GenericChunkingModel], GenericChunkingModel],
**params: Any,
) -> None:
self._rag = rag

self.embedding_model = cast(GenericEmbeddingModel, self._rag._load_component(embedding_model))
self.chunking_model = cast(GenericChunkingModel, self._rag._load_component(chunking_model))

self.documents = self._parse_documents(documents)
self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
Expand Down Expand Up @@ -188,7 +202,14 @@ async def prepare(self) -> Message:
detail=RagnaException.EVENT,
)

await self._run(self.source_storage.store, self.documents)
if list[Document] in inspect.signature(self.source_storage.store).parameters.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should have this kind of logic on as class attribute on the SourceStorage itself. Otherwise, how are we going to communicate this to the REST API / web UI? This needs to be known, because it makes no sense to force the user to select an embedding model when it is unused by the backend.
  2. This check needs to be more strict. We should only check the first argument rather than the whole signature.

await self._run(self.source_storage.store, self.documents)
else:
# Here we need to generate the list of embeddings
chunks = self.chunking_model.chunk_documents(self.documents)
embeddings = self.embedding_model.embed_chunks(chunks)
Comment on lines +208 to +210
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source storage should also be able to just request chunks. Meaning, we have three distinct cases and cannot group these two. However, if we split this PR as suggested above, this distinction will only come in the follow-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So something like?

        if type(self.source_storage).__ragna_input_type__ == Document:
            await self._run(self.source_storage.store, self.documents)
        else:
            chunks = self.chunking_model.chunk_documents(documents=self.documents)
            if type(self.source_storage).__ragna_input_type__ == Chunk:
                await self._run(self.source_storage.store, chunks)
            else:
                embeddings = self.embedding_model.embed_chunks(chunks)
                await self._run(self.source_storage.store, embeddings)

await self._run(self.source_storage.store, embeddings)

self._prepared = True

welcome = Message(
Expand Down Expand Up @@ -218,7 +239,10 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:

self._messages.append(Message(content=prompt, role=MessageRole.USER))

sources = await self._run(self.source_storage.retrieve, self.documents, prompt)
if list[Document] in inspect.signature(self.source_storage.store).parameters.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hits a point that I didn't consider before: we are currently passing the documents again to the retrieve function. See the part about BC in #256 (comment) for a reason why. This will likely change when we implement #256. However, in the mean time we need to decide if we want the same "input switching" here as for store. I think this is ok, but want to hear your thoughts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh, I only in hindsight understand the logic here. Of course we need to be able to embed the prompt. So this correct.

sources = await self._run(self.source_storage.retrieve, self.documents, prompt)
else:
sources = await self._run(self.source_storage.retrieve, self.documents, self.embedding_model.embed_text(prompt))

answer = Message(
content=self._run_gen(self.assistant.answer, prompt, sources),
Expand Down
6 changes: 6 additions & 0 deletions ragna/source_storages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
"Chroma",
"LanceDB",
"RagnaDemoSourceStorage",
"MiniLML6v2",
"NLTKChunkingModel",
"SpacyChunkingModel",
"TokenChunkingModel"
]

from ._chroma import Chroma
from ._demo import RagnaDemoSourceStorage
from ._lancedb import LanceDB
from ._embedding_model import MiniLML6v2
from ._chunking_model import NLTKChunkingModel, SpacyChunkingModel, TokenChunkingModel

# isort: split

Expand Down
55 changes: 31 additions & 24 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from ._vector_database import VectorDatabaseSourceStorage

from ._embedding_model import MiniLML6v2, Embedding
from ._chunking_model import NLTKChunkingModel


class Chroma(VectorDatabaseSourceStorage):
"""[Chroma vector database](https://www.trychroma.com/)
Expand All @@ -25,6 +28,9 @@ def __init__(self) -> None:

import chromadb

self._embedding_model = MiniLML6v2()
self._chunking_model = NLTKChunkingModel()
Comment on lines +31 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would the source storage need any of these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't, and these have been removed.


self._client = chromadb.Client(
chromadb.config.Settings(
is_persistent=True,
Expand All @@ -33,58 +39,59 @@ def __init__(self) -> None:
)
)

self._tokens = 0
self._embeddings = 0
Comment on lines +42 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These will change with every .store call. Why are they instance attributes rather than local variables?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variables represent the average length of a chunk within the source storage, It is aggregated across calls to store. I'm not sure what scope you're referring to but I don't think they can be local for them to work.


def store(
self,
documents: list[Document],
documents: list[Embedding],
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
collection = self._client.create_collection(
str(chat_id), embedding_function=self._embedding_function
str(chat_id)
)

ids = []
texts = []
embeddings = []
metadatas = []
for document in documents:
for chunk in self._chunk_pages(
document.extract_pages(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
):
ids.append(str(uuid.uuid4()))
texts.append(chunk.text)
metadatas.append(
{
"document_id": str(document.id),
"page_numbers": self._page_numbers_to_str(chunk.page_numbers),
"num_tokens": chunk.num_tokens,
}
)
for embedding in documents:
self._tokens += embedding.chunk.num_tokens
self._embeddings += 1

ids.append(str(uuid.uuid4()))
texts.append(embedding.chunk.text)
metadatas.append(
{
"document_id": str(embedding.chunk.document_id),
"page_numbers": self._page_numbers_to_str(embedding.chunk.page_numbers),
"num_tokens": embedding.chunk.num_tokens,
}
)
embeddings.append(embedding.embedding)

collection.add(
ids=ids,
embeddings=embeddings,
documents=texts,
metadatas=metadatas, # type: ignore[arg-type]
)

def retrieve(
self,
documents: list[Document],
prompt: str,
prompt: list[float],
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
collection = self._client.get_collection(
str(chat_id), embedding_function=self._embedding_function
str(chat_id)
)

result = collection.query(
query_texts=prompt,
query_embeddings=prompt,
n_results=min(
# We cannot retrieve source by a maximum number of tokens. Thus, we
# estimate how many sources we have to query. We overestimate by a
Expand All @@ -97,7 +104,7 @@ def retrieve(
# Instead of just querying more documents here, we should use the
# appropriate index parameters when creating the collection. However,
# they are undocumented for now.
max(int(num_tokens * 2 / chunk_size), 100),
max(int(num_tokens * 2 / self._tokens * self._embeddings), 100),
collection.count(),
),
include=["distances", "metadatas", "documents"],
Expand Down
Loading
Loading