-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Embeddings Model and Chunking Engine (Preliminary PR for evaluation purposes only) #354
Changes from all commits
ff64425
ac2adb0
0da2472
75f27b8
579906d
1f591aa
34fd3ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,10 @@ __pycache__/ | |
*.py[cod] | ||
*$py.class | ||
|
||
# Ignore my entrypoint | ||
test.py | ||
doctest/ | ||
|
||
# C extensions | ||
*.so | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,9 @@ | |
from ._document import Document, LocalDocument | ||
from ._utils import RagnaException, default_user, merge_models | ||
|
||
from ragna.source_storages._embedding_model import GenericEmbeddingModel | ||
from ragna.source_storages._chunking_model import GenericChunkingModel | ||
|
||
T = TypeVar("T") | ||
C = TypeVar("C", bound=Component) | ||
|
||
|
@@ -80,6 +83,8 @@ def chat( | |
documents: Iterable[Any], | ||
source_storage: Union[Type[SourceStorage], SourceStorage], | ||
assistant: Union[Type[Assistant], Assistant], | ||
embedding_model: Union[Type[GenericEmbeddingModel], GenericEmbeddingModel], | ||
chunking_model: Union[Type[GenericChunkingModel], GenericChunkingModel], | ||
**params: Any, | ||
) -> Chat: | ||
"""Create a new [ragna.core.Chat][]. | ||
|
@@ -89,13 +94,17 @@ def chat( | |
[ragna.core.LocalDocument.from_path][] is invoked on it. | ||
source_storage: Source storage to use. | ||
assistant: Assistant to use. | ||
embedding_model: Embedding Model to use | ||
chunking_model: Chunking Model to use | ||
**params: Additional parameters passed to the source storage and assistant. | ||
""" | ||
return Chat( | ||
self, | ||
documents=documents, | ||
source_storage=source_storage, | ||
assistant=assistant, | ||
embedding_model=embedding_model, | ||
chunking_model=chunking_model, | ||
**params, | ||
) | ||
|
||
|
@@ -148,10 +157,15 @@ def __init__( | |
documents: Iterable[Any], | ||
source_storage: Union[Type[SourceStorage], SourceStorage], | ||
assistant: Union[Type[Assistant], Assistant], | ||
embedding_model: Union[Type[GenericEmbeddingModel], GenericEmbeddingModel], | ||
chunking_model: Union[Type[GenericChunkingModel], GenericChunkingModel], | ||
**params: Any, | ||
) -> None: | ||
self._rag = rag | ||
|
||
self.embedding_model = cast(GenericEmbeddingModel, self._rag._load_component(embedding_model)) | ||
self.chunking_model = cast(GenericChunkingModel, self._rag._load_component(chunking_model)) | ||
|
||
self.documents = self._parse_documents(documents) | ||
self.source_storage = cast( | ||
SourceStorage, self._rag._load_component(source_storage) | ||
|
@@ -188,7 +202,14 @@ async def prepare(self) -> Message: | |
detail=RagnaException.EVENT, | ||
) | ||
|
||
await self._run(self.source_storage.store, self.documents) | ||
if list[Document] in inspect.signature(self.source_storage.store).parameters.values(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
await self._run(self.source_storage.store, self.documents) | ||
else: | ||
# Here we need to generate the list of embeddings | ||
chunks = self.chunking_model.chunk_documents(self.documents) | ||
embeddings = self.embedding_model.embed_chunks(chunks) | ||
Comment on lines
+208
to
+210
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The source storage should also be able to just request chunks. Meaning, we have three distinct cases and cannot group these two. However, if we split this PR as suggested above, this distinction will only come in the follow-up PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So something like? if type(self.source_storage).__ragna_input_type__ == Document:
await self._run(self.source_storage.store, self.documents)
else:
chunks = self.chunking_model.chunk_documents(documents=self.documents)
if type(self.source_storage).__ragna_input_type__ == Chunk:
await self._run(self.source_storage.store, chunks)
else:
embeddings = self.embedding_model.embed_chunks(chunks)
await self._run(self.source_storage.store, embeddings) |
||
await self._run(self.source_storage.store, embeddings) | ||
|
||
self._prepared = True | ||
|
||
welcome = Message( | ||
|
@@ -218,7 +239,10 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: | |
|
||
self._messages.append(Message(content=prompt, role=MessageRole.USER)) | ||
|
||
sources = await self._run(self.source_storage.retrieve, self.documents, prompt) | ||
if list[Document] in inspect.signature(self.source_storage.store).parameters.values(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This hits a point that I didn't consider before: we are currently passing the documents again to the retrieve function. See the part about BC in #256 (comment) for a reason why. This will likely change when we implement #256. However, in the mean time we need to decide if we want the same "input switching" here as for store. I think this is ok, but want to hear your thoughts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooh, I only in hindsight understand the logic here. Of course we need to be able to embed the prompt. So this correct. |
||
sources = await self._run(self.source_storage.retrieve, self.documents, prompt) | ||
else: | ||
sources = await self._run(self.source_storage.retrieve, self.documents, self.embedding_model.embed_text(prompt)) | ||
|
||
answer = Message( | ||
content=self._run_gen(self.assistant.answer, prompt, sources), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,9 @@ | |
|
||
from ._vector_database import VectorDatabaseSourceStorage | ||
|
||
from ._embedding_model import MiniLML6v2, Embedding | ||
from ._chunking_model import NLTKChunkingModel | ||
|
||
|
||
class Chroma(VectorDatabaseSourceStorage): | ||
"""[Chroma vector database](https://www.trychroma.com/) | ||
|
@@ -25,6 +28,9 @@ def __init__(self) -> None: | |
|
||
import chromadb | ||
|
||
self._embedding_model = MiniLML6v2() | ||
self._chunking_model = NLTKChunkingModel() | ||
Comment on lines
+31
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would the source storage need any of these? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't, and these have been removed. |
||
|
||
self._client = chromadb.Client( | ||
chromadb.config.Settings( | ||
is_persistent=True, | ||
|
@@ -33,58 +39,59 @@ def __init__(self) -> None: | |
) | ||
) | ||
|
||
self._tokens = 0 | ||
self._embeddings = 0 | ||
Comment on lines
+42
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These will change with every There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These variables represent the average length of a chunk within the source storage, It is aggregated across calls to store. I'm not sure what scope you're referring to but I don't think they can be local for them to work. |
||
|
||
def store( | ||
self, | ||
documents: list[Document], | ||
documents: list[Embedding], | ||
*, | ||
chat_id: uuid.UUID, | ||
chunk_size: int = 500, | ||
chunk_overlap: int = 250, | ||
) -> None: | ||
collection = self._client.create_collection( | ||
str(chat_id), embedding_function=self._embedding_function | ||
str(chat_id) | ||
) | ||
|
||
ids = [] | ||
texts = [] | ||
embeddings = [] | ||
metadatas = [] | ||
for document in documents: | ||
for chunk in self._chunk_pages( | ||
document.extract_pages(), | ||
chunk_size=chunk_size, | ||
chunk_overlap=chunk_overlap, | ||
): | ||
ids.append(str(uuid.uuid4())) | ||
texts.append(chunk.text) | ||
metadatas.append( | ||
{ | ||
"document_id": str(document.id), | ||
"page_numbers": self._page_numbers_to_str(chunk.page_numbers), | ||
"num_tokens": chunk.num_tokens, | ||
} | ||
) | ||
for embedding in documents: | ||
self._tokens += embedding.chunk.num_tokens | ||
self._embeddings += 1 | ||
|
||
ids.append(str(uuid.uuid4())) | ||
texts.append(embedding.chunk.text) | ||
metadatas.append( | ||
{ | ||
"document_id": str(embedding.chunk.document_id), | ||
"page_numbers": self._page_numbers_to_str(embedding.chunk.page_numbers), | ||
"num_tokens": embedding.chunk.num_tokens, | ||
} | ||
) | ||
embeddings.append(embedding.embedding) | ||
|
||
collection.add( | ||
ids=ids, | ||
embeddings=embeddings, | ||
documents=texts, | ||
metadatas=metadatas, # type: ignore[arg-type] | ||
) | ||
|
||
def retrieve( | ||
self, | ||
documents: list[Document], | ||
prompt: str, | ||
prompt: list[float], | ||
*, | ||
chat_id: uuid.UUID, | ||
chunk_size: int = 500, | ||
num_tokens: int = 1024, | ||
) -> list[Source]: | ||
collection = self._client.get_collection( | ||
str(chat_id), embedding_function=self._embedding_function | ||
str(chat_id) | ||
) | ||
|
||
result = collection.query( | ||
query_texts=prompt, | ||
query_embeddings=prompt, | ||
n_results=min( | ||
# We cannot retrieve source by a maximum number of tokens. Thus, we | ||
# estimate how many sources we have to query. We overestimate by a | ||
|
@@ -97,7 +104,7 @@ def retrieve( | |
# Instead of just querying more documents here, we should use the | ||
# appropriate index parameters when creating the collection. However, | ||
# they are undocumented for now. | ||
max(int(num_tokens * 2 / chunk_size), 100), | ||
max(int(num_tokens * 2 / self._tokens * self._embeddings), 100), | ||
collection.count(), | ||
), | ||
include=["distances", "metadatas", "documents"], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flag to revert them before merge.