-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Embeddings Model and Chunking Engine (Preliminary PR for evaluation purposes only) #354
Changes from 6 commits
ff64425
ac2adb0
0da2472
75f27b8
579906d
1f591aa
34fd3ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,10 @@ __pycache__/ | |
*.py[cod] | ||
*$py.class | ||
|
||
# Ignore my entrypoint | ||
test.py | ||
doctest/ | ||
|
||
# C extensions | ||
*.so | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,9 @@ | |
from ._document import Document, LocalDocument | ||
from ._utils import RagnaException, default_user, merge_models | ||
|
||
from ragna.source_storages._embedding_model import GenericEmbeddingModel | ||
from ragna.source_storages._chunking_model import GenericChunkingModel | ||
|
||
T = TypeVar("T") | ||
C = TypeVar("C", bound=Component) | ||
|
||
|
@@ -80,6 +83,8 @@ def chat( | |
documents: Iterable[Any], | ||
source_storage: Union[Type[SourceStorage], SourceStorage], | ||
assistant: Union[Type[Assistant], Assistant], | ||
embedding_model: Union[Type[GenericEmbeddingModel], GenericEmbeddingModel], | ||
chunking_model: Union[Type[GenericChunkingModel], GenericChunkingModel], | ||
**params: Any, | ||
) -> Chat: | ||
"""Create a new [ragna.core.Chat][]. | ||
|
@@ -89,13 +94,17 @@ def chat( | |
[ragna.core.LocalDocument.from_path][] is invoked on it. | ||
source_storage: Source storage to use. | ||
assistant: Assistant to use. | ||
embedding_model: Embedding Model to use | ||
chunking_model: Chunking Model to use | ||
**params: Additional parameters passed to the source storage and assistant. | ||
""" | ||
return Chat( | ||
self, | ||
documents=documents, | ||
source_storage=source_storage, | ||
assistant=assistant, | ||
embedding_model=embedding_model, | ||
chunking_model=chunking_model, | ||
**params, | ||
) | ||
|
||
|
@@ -148,10 +157,15 @@ def __init__( | |
documents: Iterable[Any], | ||
source_storage: Union[Type[SourceStorage], SourceStorage], | ||
assistant: Union[Type[Assistant], Assistant], | ||
embedding_model: Union[Type[GenericEmbeddingModel], GenericEmbeddingModel], | ||
chunking_model: Union[Type[GenericChunkingModel], GenericChunkingModel], | ||
**params: Any, | ||
) -> None: | ||
self._rag = rag | ||
|
||
self.embedding_model = cast(GenericEmbeddingModel, self._rag._load_component(embedding_model)) | ||
self.chunking_model = cast(GenericChunkingModel, self._rag._load_component(chunking_model)) | ||
|
||
self.documents = self._parse_documents(documents) | ||
self.source_storage = cast( | ||
SourceStorage, self._rag._load_component(source_storage) | ||
|
@@ -188,7 +202,14 @@ async def prepare(self) -> Message: | |
detail=RagnaException.EVENT, | ||
) | ||
|
||
await self._run(self.source_storage.store, self.documents) | ||
if list[Document] in inspect.signature(self.source_storage.store).parameters.values(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
await self._run(self.source_storage.store, self.documents) | ||
else: | ||
# Here we need to generate the list of embeddings | ||
chunks = self.chunking_model.chunk_documents(self.documents) | ||
embeddings = self.embedding_model.embed_chunks(chunks) | ||
Comment on lines
+208
to
+210
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The source storage should also be able to just request chunks. Meaning, we have three distinct cases and cannot group these two. However, if we split this PR as suggested above, this distinction will only come in the follow-up PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So something like? if type(self.source_storage).__ragna_input_type__ == Document:
await self._run(self.source_storage.store, self.documents)
else:
chunks = self.chunking_model.chunk_documents(documents=self.documents)
if type(self.source_storage).__ragna_input_type__ == Chunk:
await self._run(self.source_storage.store, chunks)
else:
embeddings = self.embedding_model.embed_chunks(chunks)
await self._run(self.source_storage.store, embeddings) |
||
await self._run(self.source_storage.store, embeddings) | ||
|
||
self._prepared = True | ||
|
||
welcome = Message( | ||
|
@@ -218,7 +239,10 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: | |
|
||
self._messages.append(Message(content=prompt, role=MessageRole.USER)) | ||
|
||
sources = await self._run(self.source_storage.retrieve, self.documents, prompt) | ||
if list[Document] in inspect.signature(self.source_storage.store).parameters.values(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This hits a point that I didn't consider before: we are currently passing the documents again to the retrieve function. See the part about BC in #256 (comment) for a reason why. This will likely change when we implement #256. However, in the mean time we need to decide if we want the same "input switching" here as for store. I think this is ok, but want to hear your thoughts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooh, I only in hindsight understand the logic here. Of course we need to be able to embed the prompt. So this correct. |
||
sources = await self._run(self.source_storage.retrieve, self.documents, prompt) | ||
else: | ||
sources = await self._run(self.source_storage.retrieve, self.documents, self.embedding_model.embed_text(prompt)) | ||
|
||
answer = Message( | ||
content=self._run_gen(self.assistant.answer, prompt, sources), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,9 @@ | |
|
||
from ._vector_database import VectorDatabaseSourceStorage | ||
|
||
from ._embedding_model import MiniLML6v2, Embedding | ||
from ._chunking_model import NLTKChunkingModel | ||
|
||
|
||
class Chroma(VectorDatabaseSourceStorage): | ||
"""[Chroma vector database](https://www.trychroma.com/) | ||
|
@@ -25,6 +28,9 @@ def __init__(self) -> None: | |
|
||
import chromadb | ||
|
||
self._embedding_model = MiniLML6v2() | ||
self._chunking_model = NLTKChunkingModel() | ||
Comment on lines
+31
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would the source storage need any of these? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't, and these have been removed. |
||
|
||
self._client = chromadb.Client( | ||
chromadb.config.Settings( | ||
is_persistent=True, | ||
|
@@ -33,58 +39,59 @@ def __init__(self) -> None: | |
) | ||
) | ||
|
||
self._tokens = 0 | ||
self._embeddings = 0 | ||
Comment on lines
+42
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These will change with every There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These variables represent the average length of a chunk within the source storage, It is aggregated across calls to store. I'm not sure what scope you're referring to but I don't think they can be local for them to work. |
||
|
||
def store( | ||
self, | ||
documents: list[Document], | ||
documents: list[Embedding], | ||
*, | ||
chat_id: uuid.UUID, | ||
chunk_size: int = 500, | ||
chunk_overlap: int = 250, | ||
) -> None: | ||
collection = self._client.create_collection( | ||
str(chat_id), embedding_function=self._embedding_function | ||
str(chat_id) | ||
) | ||
|
||
ids = [] | ||
texts = [] | ||
embeddings = [] | ||
metadatas = [] | ||
for document in documents: | ||
for chunk in self._chunk_pages( | ||
document.extract_pages(), | ||
chunk_size=chunk_size, | ||
chunk_overlap=chunk_overlap, | ||
): | ||
ids.append(str(uuid.uuid4())) | ||
texts.append(chunk.text) | ||
metadatas.append( | ||
{ | ||
"document_id": str(document.id), | ||
"page_numbers": self._page_numbers_to_str(chunk.page_numbers), | ||
"num_tokens": chunk.num_tokens, | ||
} | ||
) | ||
for embedding in documents: | ||
self._tokens += embedding.chunk.num_tokens | ||
self._embeddings += 1 | ||
|
||
ids.append(str(uuid.uuid4())) | ||
texts.append(embedding.chunk.text) | ||
metadatas.append( | ||
{ | ||
"document_id": str(embedding.chunk.document_id), | ||
"page_numbers": self._page_numbers_to_str(embedding.chunk.page_numbers), | ||
"num_tokens": embedding.chunk.num_tokens, | ||
} | ||
) | ||
embeddings.append(embedding.embedding) | ||
|
||
collection.add( | ||
ids=ids, | ||
embeddings=embeddings, | ||
documents=texts, | ||
metadatas=metadatas, # type: ignore[arg-type] | ||
) | ||
|
||
def retrieve( | ||
self, | ||
documents: list[Document], | ||
prompt: str, | ||
prompt: list[float], | ||
*, | ||
chat_id: uuid.UUID, | ||
chunk_size: int = 500, | ||
num_tokens: int = 1024, | ||
) -> list[Source]: | ||
collection = self._client.get_collection( | ||
str(chat_id), embedding_function=self._embedding_function | ||
str(chat_id) | ||
) | ||
|
||
result = collection.query( | ||
query_texts=prompt, | ||
query_embeddings=prompt, | ||
n_results=min( | ||
# We cannot retrieve source by a maximum number of tokens. Thus, we | ||
# estimate how many sources we have to query. We overestimate by a | ||
|
@@ -97,7 +104,7 @@ def retrieve( | |
# Instead of just querying more documents here, we should use the | ||
# appropriate index parameters when creating the collection. However, | ||
# they are undocumented for now. | ||
max(int(num_tokens * 2 / chunk_size), 100), | ||
max(int(num_tokens * 2 / self._tokens * self._embeddings), 100), | ||
collection.count(), | ||
), | ||
include=["distances", "metadatas", "documents"], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import tiktoken | ||
|
||
from ragna.core import ( | ||
Document, | ||
) | ||
from ragna.source_storages._vector_database import Chunk | ||
|
||
from ragna.core import Page, Document, Component | ||
|
||
from collections import deque | ||
|
||
from typing import Iterable, Iterator, Deque, TypeVar | ||
|
||
|
||
class GenericChunkingModel(Component, ABC): | ||
def __init__(self): | ||
# we need a way of estimating tokens that is common to all chunking models | ||
self._tokenizer = tiktoken.get_encoding("cl100k_base") | ||
@abstractmethod | ||
def chunk_documents(self, documents: list[Document]) -> list[Chunk]: | ||
raise NotImplementedError | ||
|
||
def generate_chunks_from_pages(self, chunks: list[str], pages: Iterable[Page], document_id: int) -> list[Chunk]: | ||
|
||
return [Chunk(page_numbers=[1], text=chunks[i], document_id=document_id, | ||
num_tokens=len(self._tokenizer.encode(chunks[i]))) for i in range(len(chunks))] | ||
|
||
|
||
class NLTKChunkingModel(GenericChunkingModel): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
# our text splitter goes here | ||
from langchain.text_splitter import NLTKTextSplitter | ||
self.text_splitter = NLTKTextSplitter() | ||
|
||
def chunk_documents(self, documents: list[Document]) -> list[Chunk]: | ||
# This is not perfect, but it's the only way I could get this to somewhat work | ||
chunks = [] | ||
for document in documents: | ||
pages = list(document.extract_pages()) | ||
text = "".join([page.text for page in pages]) | ||
|
||
chunks += self.generate_chunks_from_pages(self.text_splitter.split_text(text), pages, document.id) | ||
|
||
return chunks | ||
|
||
|
||
class SpacyChunkingModel(GenericChunkingModel): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
from langchain_text_splitters import SpacyTextSplitter | ||
self.text_splitter = SpacyTextSplitter() | ||
|
||
# TODO: This needs to keep track of the pages | ||
def chunk_documents(self, documents: list[Document]) -> list[Chunk]: | ||
# Problem: chunk need to preserve its page number | ||
chunks = [] | ||
for document in documents: | ||
pages = list(document.extract_pages()) | ||
text = "".join([page.text for page in pages]) | ||
|
||
chunks += self.generate_chunks_from_pages(self.text_splitter.split_text(text), pages, document.id) | ||
|
||
return chunks | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
|
||
class TokenChunkingModel(GenericChunkingModel): | ||
def chunk_documents(self, documents: list[Document], chunk_size: int = 512, chunk_overlap: int = 128) -> list[Chunk]: | ||
chunks = [] | ||
for document in documents: | ||
chunks += self._chunk_pages(document.id, document.extract_pages(), chunk_size=chunk_size, chunk_overlap=chunk_overlap) | ||
return chunks | ||
|
||
def _chunk_pages( | ||
self, document_id: int, pages: Iterable[Page], *, chunk_size: int, chunk_overlap: int | ||
) -> Iterator[Chunk]: | ||
for window in TokenChunkingModel._windowed_ragged( | ||
( | ||
(tokens, page.number) | ||
for page in pages | ||
for tokens in self._tokenizer.encode(page.text) | ||
), | ||
n=chunk_size, | ||
step=chunk_size - chunk_overlap, | ||
): | ||
tokens, page_numbers = zip(*window) | ||
yield Chunk( | ||
document_id=document_id, | ||
text=self._tokenizer.decode(tokens), # type: ignore[arg-type] | ||
page_numbers=list(filter(lambda n: n is not None, page_numbers)) | ||
or None, | ||
num_tokens=len(tokens), | ||
) | ||
|
||
# The function is adapted from more_itertools.windowed to allow a ragged last window | ||
# https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed | ||
@classmethod | ||
def _windowed_ragged( | ||
cls, iterable: Iterable[T], *, n: int, step: int | ||
) -> Iterator[tuple[T, ...]]: | ||
window: Deque[T] = deque(maxlen=n) | ||
i = n | ||
for _ in map(window.append, iterable): | ||
i -= 1 | ||
if not i: | ||
i = step | ||
yield tuple(window) | ||
|
||
if len(window) < n: | ||
yield tuple(window) | ||
elif 0 < i < min(step, n): | ||
yield tuple(window)[i:] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from sentence_transformers import SentenceTransformer | ||
import torch | ||
Comment on lines
+3
to
+4
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PyTorch is a massive dependency that we cannot pull in by default. This has to be optional. |
||
|
||
from ragna.core._components import Component | ||
|
||
from ragna.source_storages._vector_database import Chunk | ||
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
Tengal-Teemo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class Embedding: | ||
embedding: list[float] | ||
chunk: Chunk | ||
|
||
def __init__(self, embedding: list[float], chunk: Chunk): | ||
super().__init__() | ||
self.embedding = embedding | ||
self.chunk = chunk | ||
|
||
|
||
class GenericEmbeddingModel(Component, ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class as well as the |
||
_EMBEDDING_DIMENSIONS: int | ||
Tengal-Teemo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@abstractmethod | ||
def embed_chunks(self, chunks: list[Chunk]) -> list[Embedding]: | ||
raise NotImplementedError | ||
Tengal-Teemo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def embed_text(self, text: str) -> list[float]: | ||
raise NotImplementedError | ||
Comment on lines
+30
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there ever a use case to do that? Aren't we always going to embed chunks? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, because when you prompt the database you must convert your prompt into an embedding when you query. Either we turn this prompt into a chunk which I don't like, or just separate the logic like this. |
||
|
||
def get_embedding_dimensions(self): | ||
return self._EMBEDDING_DIMENSIONS | ||
|
||
|
||
class MiniLML6v2(GenericEmbeddingModel): | ||
_EMBEDDING_DIMENSIONS = 384 | ||
|
||
def __init__(self): | ||
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device=device) | ||
|
||
def embed_chunks(self, chunks: list[Chunk]) -> list[Embedding]: | ||
return [Embedding(self.embed_text(chunk.text), chunk) for chunk in chunks] | ||
|
||
def embed_text(self, text: str) -> list[float]: | ||
return self.model.encode(text).tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flag to revert them before merge.