Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embeddings Model and Chunking Engine (Preliminary PR for evaluation purposes only) #354

Closed
wants to merge 7 commits into from
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ __pycache__/
*.py[cod]
*$py.class

# Ignore my entrypoint
test.py
doctest/
Comment on lines +9 to +11
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flag to revert them before merge.


# C extensions
*.so

Expand Down
28 changes: 26 additions & 2 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from ._document import Document, LocalDocument
from ._utils import RagnaException, default_user, merge_models

from ragna.source_storages._embedding_model import GenericEmbeddingModel
from ragna.source_storages._chunking_model import GenericChunkingModel

T = TypeVar("T")
C = TypeVar("C", bound=Component)

Expand Down Expand Up @@ -80,6 +83,8 @@ def chat(
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
embedding_model: Union[Type[GenericEmbeddingModel], GenericEmbeddingModel],
chunking_model: Union[Type[GenericChunkingModel], GenericChunkingModel],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].
Expand All @@ -89,13 +94,17 @@ def chat(
[ragna.core.LocalDocument.from_path][] is invoked on it.
source_storage: Source storage to use.
assistant: Assistant to use.
embedding_model: Embedding Model to use
chunking_model: Chunking Model to use
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
source_storage=source_storage,
assistant=assistant,
embedding_model=embedding_model,
chunking_model=chunking_model,
**params,
)

Expand Down Expand Up @@ -148,10 +157,15 @@ def __init__(
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
embedding_model: Union[Type[GenericEmbeddingModel], GenericEmbeddingModel],
chunking_model: Union[Type[GenericChunkingModel], GenericChunkingModel],
**params: Any,
) -> None:
self._rag = rag

self.embedding_model = cast(GenericEmbeddingModel, self._rag._load_component(embedding_model))
self.chunking_model = cast(GenericChunkingModel, self._rag._load_component(chunking_model))

self.documents = self._parse_documents(documents)
self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
Expand Down Expand Up @@ -188,7 +202,14 @@ async def prepare(self) -> Message:
detail=RagnaException.EVENT,
)

await self._run(self.source_storage.store, self.documents)
if list[Document] in inspect.signature(self.source_storage.store).parameters.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should have this kind of logic on as class attribute on the SourceStorage itself. Otherwise, how are we going to communicate this to the REST API / web UI? This needs to be known, because it makes no sense to force the user to select an embedding model when it is unused by the backend.
  2. This check needs to be more strict. We should only check the first argument rather than the whole signature.

await self._run(self.source_storage.store, self.documents)
else:
# Here we need to generate the list of embeddings
chunks = self.chunking_model.chunk_documents(self.documents)
embeddings = self.embedding_model.embed_chunks(chunks)
Comment on lines +208 to +210
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source storage should also be able to just request chunks. Meaning, we have three distinct cases and cannot group these two. However, if we split this PR as suggested above, this distinction will only come in the follow-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So something like?

        if type(self.source_storage).__ragna_input_type__ == Document:
            await self._run(self.source_storage.store, self.documents)
        else:
            chunks = self.chunking_model.chunk_documents(documents=self.documents)
            if type(self.source_storage).__ragna_input_type__ == Chunk:
                await self._run(self.source_storage.store, chunks)
            else:
                embeddings = self.embedding_model.embed_chunks(chunks)
                await self._run(self.source_storage.store, embeddings)

await self._run(self.source_storage.store, embeddings)

self._prepared = True

welcome = Message(
Expand Down Expand Up @@ -218,7 +239,10 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:

self._messages.append(Message(content=prompt, role=MessageRole.USER))

sources = await self._run(self.source_storage.retrieve, self.documents, prompt)
if list[Document] in inspect.signature(self.source_storage.store).parameters.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hits a point that I didn't consider before: we are currently passing the documents again to the retrieve function. See the part about BC in #256 (comment) for a reason why. This will likely change when we implement #256. However, in the mean time we need to decide if we want the same "input switching" here as for store. I think this is ok, but want to hear your thoughts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh, I only in hindsight understand the logic here. Of course we need to be able to embed the prompt. So this correct.

sources = await self._run(self.source_storage.retrieve, self.documents, prompt)
else:
sources = await self._run(self.source_storage.retrieve, self.documents, self.embedding_model.embed_text(prompt))

answer = Message(
content=self._run_gen(self.assistant.answer, prompt, sources),
Expand Down
6 changes: 6 additions & 0 deletions ragna/source_storages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
"Chroma",
"LanceDB",
"RagnaDemoSourceStorage",
"MiniLML6v2",
"NLTKChunkingModel",
"SpacyChunkingModel",
"TokenChunkingModel"
]

from ._chroma import Chroma
from ._demo import RagnaDemoSourceStorage
from ._lancedb import LanceDB
from ._embedding_model import MiniLML6v2
from ._chunking_model import NLTKChunkingModel, SpacyChunkingModel, TokenChunkingModel

# isort: split

Expand Down
55 changes: 31 additions & 24 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from ._vector_database import VectorDatabaseSourceStorage

from ._embedding_model import MiniLML6v2, Embedding
from ._chunking_model import NLTKChunkingModel


class Chroma(VectorDatabaseSourceStorage):
"""[Chroma vector database](https://www.trychroma.com/)
Expand All @@ -25,6 +28,9 @@ def __init__(self) -> None:

import chromadb

self._embedding_model = MiniLML6v2()
self._chunking_model = NLTKChunkingModel()
Comment on lines +31 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would the source storage need any of these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't, and these have been removed.


self._client = chromadb.Client(
chromadb.config.Settings(
is_persistent=True,
Expand All @@ -33,58 +39,59 @@ def __init__(self) -> None:
)
)

self._tokens = 0
self._embeddings = 0
Comment on lines +42 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These will change with every .store call. Why are they instance attributes rather than local variables?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variables represent the average length of a chunk within the source storage, It is aggregated across calls to store. I'm not sure what scope you're referring to but I don't think they can be local for them to work.


def store(
self,
documents: list[Document],
documents: list[Embedding],
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
collection = self._client.create_collection(
str(chat_id), embedding_function=self._embedding_function
str(chat_id)
)

ids = []
texts = []
embeddings = []
metadatas = []
for document in documents:
for chunk in self._chunk_pages(
document.extract_pages(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
):
ids.append(str(uuid.uuid4()))
texts.append(chunk.text)
metadatas.append(
{
"document_id": str(document.id),
"page_numbers": self._page_numbers_to_str(chunk.page_numbers),
"num_tokens": chunk.num_tokens,
}
)
for embedding in documents:
self._tokens += embedding.chunk.num_tokens
self._embeddings += 1

ids.append(str(uuid.uuid4()))
texts.append(embedding.chunk.text)
metadatas.append(
{
"document_id": str(embedding.chunk.document_id),
"page_numbers": self._page_numbers_to_str(embedding.chunk.page_numbers),
"num_tokens": embedding.chunk.num_tokens,
}
)
embeddings.append(embedding.embedding)

collection.add(
ids=ids,
embeddings=embeddings,
documents=texts,
metadatas=metadatas, # type: ignore[arg-type]
)

def retrieve(
self,
documents: list[Document],
prompt: str,
prompt: list[float],
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
collection = self._client.get_collection(
str(chat_id), embedding_function=self._embedding_function
str(chat_id)
)

result = collection.query(
query_texts=prompt,
query_embeddings=prompt,
n_results=min(
# We cannot retrieve source by a maximum number of tokens. Thus, we
# estimate how many sources we have to query. We overestimate by a
Expand All @@ -97,7 +104,7 @@ def retrieve(
# Instead of just querying more documents here, we should use the
# appropriate index parameters when creating the collection. However,
# they are undocumented for now.
max(int(num_tokens * 2 / chunk_size), 100),
max(int(num_tokens * 2 / self._tokens * self._embeddings), 100),
collection.count(),
),
include=["distances", "metadatas", "documents"],
Expand Down
119 changes: 119 additions & 0 deletions ragna/source_storages/_chunking_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from abc import ABC, abstractmethod

import tiktoken

from ragna.core import (
Document,
)
from ragna.source_storages._vector_database import Chunk

from ragna.core import Page, Document, Component

from collections import deque

from typing import Iterable, Iterator, Deque, TypeVar


class GenericChunkingModel(Component, ABC):
def __init__(self):
# we need a way of estimating tokens that is common to all chunking models
self._tokenizer = tiktoken.get_encoding("cl100k_base")
@abstractmethod
def chunk_documents(self, documents: list[Document]) -> list[Chunk]:
raise NotImplementedError

def generate_chunks_from_pages(self, chunks: list[str], pages: Iterable[Page], document_id: int) -> list[Chunk]:

return [Chunk(page_numbers=[1], text=chunks[i], document_id=document_id,
num_tokens=len(self._tokenizer.encode(chunks[i]))) for i in range(len(chunks))]


class NLTKChunkingModel(GenericChunkingModel):
def __init__(self):
super().__init__()

# our text splitter goes here
from langchain.text_splitter import NLTKTextSplitter
self.text_splitter = NLTKTextSplitter()

def chunk_documents(self, documents: list[Document]) -> list[Chunk]:
# This is not perfect, but it's the only way I could get this to somewhat work
chunks = []
for document in documents:
pages = list(document.extract_pages())
text = "".join([page.text for page in pages])

chunks += self.generate_chunks_from_pages(self.text_splitter.split_text(text), pages, document.id)

return chunks


class SpacyChunkingModel(GenericChunkingModel):
def __init__(self):
super().__init__()

from langchain_text_splitters import SpacyTextSplitter
self.text_splitter = SpacyTextSplitter()

# TODO: This needs to keep track of the pages
def chunk_documents(self, documents: list[Document]) -> list[Chunk]:
# Problem: chunk need to preserve its page number
chunks = []
for document in documents:
pages = list(document.extract_pages())
text = "".join([page.text for page in pages])

chunks += self.generate_chunks_from_pages(self.text_splitter.split_text(text), pages, document.id)

return chunks


T = TypeVar("T")


class TokenChunkingModel(GenericChunkingModel):
def chunk_documents(self, documents: list[Document], chunk_size: int = 512, chunk_overlap: int = 128) -> list[Chunk]:
chunks = []
for document in documents:
chunks += self._chunk_pages(document.id, document.extract_pages(), chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return chunks

def _chunk_pages(
self, document_id: int, pages: Iterable[Page], *, chunk_size: int, chunk_overlap: int
) -> Iterator[Chunk]:
for window in TokenChunkingModel._windowed_ragged(
(
(tokens, page.number)
for page in pages
for tokens in self._tokenizer.encode(page.text)
),
n=chunk_size,
step=chunk_size - chunk_overlap,
):
tokens, page_numbers = zip(*window)
yield Chunk(
document_id=document_id,
text=self._tokenizer.decode(tokens), # type: ignore[arg-type]
page_numbers=list(filter(lambda n: n is not None, page_numbers))
or None,
num_tokens=len(tokens),
)

# The function is adapted from more_itertools.windowed to allow a ragged last window
# https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed
@classmethod
def _windowed_ragged(
cls, iterable: Iterable[T], *, n: int, step: int
) -> Iterator[tuple[T, ...]]:
window: Deque[T] = deque(maxlen=n)
i = n
for _ in map(window.append, iterable):
i -= 1
if not i:
i = step
yield tuple(window)

if len(window) < n:
yield tuple(window)
elif 0 < i < min(step, n):
yield tuple(window)[i:]
47 changes: 47 additions & 0 deletions ragna/source_storages/_embedding_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from abc import ABC, abstractmethod

from sentence_transformers import SentenceTransformer
import torch
Comment on lines +3 to +4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch is a massive dependency that we cannot pull in by default. This has to be optional.


from ragna.core._components import Component

from ragna.source_storages._vector_database import Chunk

device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Embedding:
embedding: list[float]
chunk: Chunk

def __init__(self, embedding: list[float], chunk: Chunk):
super().__init__()
self.embedding = embedding
self.chunk = chunk


class GenericEmbeddingModel(Component, ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class as well as the Embedding class above should be moved into ragna.core._components given that we want to separate them from the source storages.

_EMBEDDING_DIMENSIONS: int

@abstractmethod
def embed_chunks(self, chunks: list[Chunk]) -> list[Embedding]:
raise NotImplementedError

def embed_text(self, text: str) -> list[float]:
raise NotImplementedError
Comment on lines +30 to +31
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a use case to do that? Aren't we always going to embed chunks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because when you prompt the database you must convert your prompt into an embedding when you query. Either we turn this prompt into a chunk which I don't like, or just separate the logic like this.


def get_embedding_dimensions(self):
return self._EMBEDDING_DIMENSIONS


class MiniLML6v2(GenericEmbeddingModel):
_EMBEDDING_DIMENSIONS = 384

def __init__(self):
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device=device)

def embed_chunks(self, chunks: list[Chunk]) -> list[Embedding]:
return [Embedding(self.embed_text(chunk.text), chunk) for chunk in chunks]

def embed_text(self, text: str) -> list[float]:
return self.model.encode(text).tolist()
Loading