Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: simplify RAG API
Browse files Browse the repository at this point in the history
lsorber committed Dec 3, 2024

Verified

This commit was signed with the committer’s verified signature.
lsorber Laurent Sorber
1 parent 851311c commit a27ab99
Showing 11 changed files with 356 additions and 335 deletions.
75 changes: 61 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -157,38 +157,85 @@ insert_document(Path("Special Relativity.pdf"), config=my_config)

### 3. Searching and Retrieval-Augmented Generation (RAG)

Now, you can search for chunks with vector search, keyword search, or a hybrid of the two. You can also rerank the search results with the configured reranker. And you can use any search method of your choice (`hybrid_search` is the default) together with reranking to answer questions with RAG:
#### 3.1 Simple RAG pipeline

Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:

```python
from raglite import create_rag_instruction, rag, retrieve_rag_context

# Retrieve relevant chunk spans with hybrid search and reranking:
user_prompt = "How is intelligence measured?"
chunk_spans = retrieve_rag_context(query=user_prompt, num_chunks=5, config=my_config)

# Append a RAG instruction based on the user prompt and context to the message history:
messages = [] # Or start with an existing message history.
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))

# Stream the RAG response:
stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# Access the documents cited in the RAG response:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

#### 3.2 Advanced RAG pipeline

> [!TIP]
> 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks.
In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps:

1. Searching for relevant chunks with keyword, vector, or hybrid search
2. Retrieving the chunks from the database
3. Reranking the chunks and truncating the results to the top 5
4. Extending the chunks with their neighbors and grouping them into chunk spans
5. Converting the user prompt to a RAG instruction and appending it to the message history
6. Streaming an LLM response to the message history
7. Accessing the cited documents from the chunk spans

```python
# Search for chunks:
from raglite import hybrid_search, keyword_search, vector_search

prompt = "How is intelligence measured?"
chunk_ids_vector, _ = vector_search(prompt, num_results=20, config=my_config)
chunk_ids_keyword, _ = keyword_search(prompt, num_results=20, config=my_config)
chunk_ids_hybrid, _ = hybrid_search(prompt, num_results=20, config=my_config)
user_prompt = "How is intelligence measured?"
chunk_ids_vector, _ = vector_search(user_prompt, num_results=20, config=my_config)
chunk_ids_keyword, _ = keyword_search(user_prompt, num_results=20, config=my_config)
chunk_ids_hybrid, _ = hybrid_search(user_prompt, num_results=20, config=my_config)

# Retrieve chunks:
from raglite import retrieve_chunks

chunks_hybrid = retrieve_chunks(chunk_ids_hybrid, config=my_config)

# Rerank chunks:
# Rerank chunks and keep the top 5 (optional, but recommended):
from raglite import rerank_chunks

chunks_reranked = rerank_chunks(prompt, chunks_hybrid, config=my_config)
chunks_reranked = rerank_chunks(user_prompt, chunks_hybrid, config=my_config)
chunks_reranked = chunks_reranked[:5]

# Extend chunks with their neighbors and group them into chunk spans:
from raglite import retrieve_chunk_spans

chunk_spans = retrieve_chunk_spans(chunks_reranked, config=my_config)

# Append a RAG instruction based on the user prompt and context to the message history:
from raglite import create_rag_instruction

messages = [] # Or start with an existing message history.
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))

# Answer questions with RAG:
# Stream the RAG response:
from raglite import rag

prompt = "What does it mean for two events to be simultaneous?"
stream = rag(prompt, config=my_config)
stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# You can also pass a search method or search results directly:
stream = rag(prompt, search=hybrid_search, config=my_config)
stream = rag(prompt, search=chunks_reranked, config=my_config)
# Access the documents cited in the RAG response:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

### 4. Computing and using an optimal query adapter
@@ -200,7 +247,7 @@ RAGLite can compute and apply an [optimal closed-form query adapter](src/raglite
from raglite import insert_evals, update_query_adapter

insert_evals(num_evals=100, config=my_config)
update_query_adapter(config=my_config) # From here, simply call vector_search to use the query adapter.
update_query_adapter(config=my_config) # From here, every vector search will use the query adapter.
```

### 5. Evaluation of retrieval and generation
35 changes: 18 additions & 17 deletions src/raglite/__init__.py
Original file line number Diff line number Diff line change
@@ -5,38 +5,39 @@
from raglite._eval import answer_evals, evaluate, insert_evals
from raglite._insert import insert_document
from raglite._query_adapter import update_query_adapter
from raglite._rag import async_generate, generate, get_context_segments
from raglite._rag import async_rag, create_rag_instruction, rag, retrieve_rag_context
from raglite._search import (
hybrid_search,
keyword_search,
rerank_chunks,
retrieve_chunk_spans,
retrieve_chunks,
retrieve_segments,
vector_search,
)

__all__ = [
# Config
"RAGLiteConfig",
"answer_evals",
"async_generate",
# CLI
"cli",
"evaluate",
# RAG
"generate",
"get_context_segments",
# Search
"hybrid_search",
# Insert
"insert_document",
# Evaluate
"insert_evals",
# Search
"hybrid_search",
"keyword_search",
"rerank_chunks",
"vector_search",
"retrieve_chunks",
"retrieve_segments",
"retrieve_chunk_spans",
"rerank_chunks",
# RAG
"retrieve_rag_context",
"create_rag_instruction",
"async_rag",
"rag",
# Query adapter
"update_query_adapter",
"vector_search",
# Evaluate
"insert_evals",
"answer_evals",
"evaluate",
# CLI
"cli",
]
38 changes: 20 additions & 18 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
@@ -8,18 +8,20 @@

from raglite import (
RAGLiteConfig,
async_generate,
get_context_segments,
async_rag,
create_rag_instruction,
hybrid_search,
insert_document,
rerank_chunks,
retrieve_chunk_spans,
retrieve_chunks,
)
from raglite._markdown import document_to_markdown

async_insert_document = cl.make_async(insert_document)
async_hybrid_search = cl.make_async(hybrid_search)
async_retrieve_chunks = cl.make_async(retrieve_chunks)
async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans)
async_rerank_chunks = cl.make_async(rerank_chunks)


@@ -85,35 +87,35 @@ async def handle_message(user_message: cl.Message) -> None:
step.input = Path(file.path).name
await async_insert_document(Path(file.path), config=config)
# Append any inline attachments to the user prompt.
user_prompt = f"{user_message.content}\n\n" + "\n\n".join(
f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>'
for i, attachment in enumerate(inline_attachments)
user_prompt = (
"\n\n".join(
f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>'
for i, attachment in enumerate(inline_attachments)
)
+ f"\n\n{user_message.content}"
)
# Search for relevant contexts for RAG.
async with cl.Step(name="search", type="retrieval") as step:
step.input = user_message.content
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
step.output = chunks
step.elements = [ # Show the top 3 chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
step.elements = [ # Show the top chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5]
]
# Rerank the chunks.
# Rerank the chunks and group them into chunk spans.
async with cl.Step(name="rerank", type="rerank") as step:
step.input = chunks
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
step.output = chunks
step.elements = [ # Show the top 3 chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config)
step.output = chunk_spans
step.elements = [ # Show the top chunk spans inline.
cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans
]
# Stream the LLM response.
assistant_message = cl.Message(content="")
context_segments = get_context_segments(user_prompt, config=config)
async for token in async_generate(
prompt=user_prompt,
messages=cl.chat_context.to_openai()[-5:-1], # type: ignore[no-untyped-call]
context_segments=context_segments,
config=config,
):
messages: list[dict[str, str]] = cl.chat_context.to_openai() # type: ignore[no-untyped-call]
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
async for token in async_rag(messages, config=config):
await assistant_message.stream_token(token)
await assistant_message.update() # type: ignore[no-untyped-call]
140 changes: 65 additions & 75 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@

import datetime
import json
from collections.abc import Callable
from dataclasses import dataclass
from functools import lru_cache
from hashlib import sha256
@@ -18,7 +17,16 @@

from raglite._config import RAGLiteConfig
from raglite._litellm import get_embedding_dim
from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject
from raglite._typing import (
ChunkId,
DocumentId,
Embedding,
EvalId,
FloatMatrix,
FloatVector,
IndexId,
PickledObject,
)


def hash_bytes(data: bytes, max_len: int = 16) -> str:
@@ -33,7 +41,7 @@ class Document(SQLModel, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]

# Table columns.
id: str = Field(..., primary_key=True)
id: DocumentId = Field(..., primary_key=True)
filename: str
url: str | None = Field(default=None)
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
@@ -64,8 +72,8 @@ class Chunk(SQLModel, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]

# Table columns.
id: str = Field(..., primary_key=True)
document_id: str = Field(..., foreign_key="document.id", index=True)
id: ChunkId = Field(..., primary_key=True)
document_id: DocumentId = Field(..., foreign_key="document.id", index=True)
index: int = Field(..., index=True)
headings: str
body: str
@@ -129,10 +137,55 @@ def __repr__(self) -> str:
indent=4,
)

def __str__(self) -> str:
"""Context representation of this chunk."""
@property
def content(self) -> str:
"""Return this chunk's contextual heading and body."""
return f"{self.headings.strip()}\n\n{self.body.strip()}".strip()

def __str__(self) -> str:
"""Return this chunk's content."""
return self.content


@dataclass
class ChunkSpan:
"""A consecutive sequence of chunks from a single document."""

chunks: list[Chunk]
document: Document

def to_xml(self, index: int | None = None) -> str:
"""Convert this chunk span to an XML representation.
The XML representation follows Anthropic's best practices [1].
[1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
"""
if not self.chunks:
return ""
index_attribute = f' index="{index}"' if index is not None else ""
xml = "\n".join(
[
f'<document{index_attribute} id="{self.document.id}" from_chunk_id="{self.chunks[0].id}" to_chunk_id="{self.chunks[-1].id}">',
f"<source>{self.document.url if self.document.url else self.document.filename}</source>"
f"<span_heading>{escape(self.chunks[0].headings.strip())}</span_heading>"
f"<span_content>\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n</span_content>",
"</document>",
]
)
return xml

@property
def content(self) -> str:
"""Return this chunk span's contextual heading and chunk bodies."""
heading = self.chunks[0].headings.strip() if self.chunks else ""
bodies = "".join(chunk.body for chunk in self.chunks)
return f"{heading}\n\n{bodies}".strip()

def __str__(self) -> str:
"""Return this chunk span's content."""
return self.content


class ChunkEmbedding(SQLModel, table=True):
"""A (sub-)chunk embedding."""
@@ -144,7 +197,7 @@ class ChunkEmbedding(SQLModel, table=True):

# Table columns.
id: int = Field(..., primary_key=True)
chunk_id: str = Field(..., foreign_key="chunk.id", index=True)
chunk_id: ChunkId = Field(..., foreign_key="chunk.id", index=True)
embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1)))

# Add relationship so we can access embedding.chunk.
@@ -165,7 +218,7 @@ class IndexMetadata(SQLModel, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]

# Table columns.
id: str = Field(..., primary_key=True)
id: IndexId = Field(..., primary_key=True)
version: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
)
@@ -198,9 +251,9 @@ class Eval(SQLModel, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]

# Table columns.
id: str = Field(..., primary_key=True)
document_id: str = Field(..., foreign_key="document.id", index=True)
chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON))
id: EvalId = Field(..., primary_key=True)
document_id: DocumentId = Field(..., foreign_key="document.id", index=True)
chunk_ids: list[ChunkId] = Field(default_factory=list, sa_column=Column(JSON))
question: str
contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON))
ground_truth: str
@@ -331,66 +384,3 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
)
session.commit()
return engine


@dataclass
class ContextSegment:
"""A class representing a segment of context from a document.
This class holds information about a specific segment of a document,
including its document ID and associated chunks of text with their IDs and scores.
Attributes
----------
document_id (str): The unique identifier for the document.
chunks (list[Chunk]): List of chunks for this segment.
chunk_scores (list[float]): List of scores for each chunk.
Raises
------
ValueError: If document_id is empty or if chunks is empty.
"""

document_id: str
chunks: list[Chunk]
chunk_scores: list[float]

def __str__(self) -> str:
"""Return a string representation of the segment."""
return self.as_xml

@property
def as_xml(self) -> str:
"""Returns the segment as an XML string representation.
Returns
-------
str: XML representation of the segment.
"""
chunk_ids = ",".join(self.chunk_ids)
xml = "\n".join(
[
f'<document id="{escape(self.document_id)}" chunk_ids="{escape(chunk_ids)}">',
escape(self.reconstructed_str),
"</document>",
]
)

return xml

def score(self, scoring_function: Callable[[list[float]], float] = sum) -> float:
"""Return an aggregated score of the segment, given a scoring function."""
return scoring_function(self.chunk_scores)

@property
def chunk_ids(self) -> list[str]:
"""Return a list of chunk IDs."""
return [chunk.id for chunk in self.chunks]

@property
def reconstructed_str(self) -> str:
"""Return a string representation reconstructing the document with headings."""
heading = self.chunks[0].headings if self.chunks else ""
bodies = "\n".join(chunk.body for chunk in self.chunks)

return f"{heading}\n\n{bodies}".strip()
22 changes: 13 additions & 9 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
@@ -12,8 +12,8 @@
from raglite._config import RAGLiteConfig
from raglite._database import Chunk, Document, Eval, create_database_engine
from raglite._extract import extract_with_llm
from raglite._rag import generate, get_context_segments
from raglite._search import hybrid_search, retrieve_segments, vector_search
from raglite._rag import create_rag_instruction, rag, retrieve_rag_context
from raglite._search import hybrid_search, retrieve_chunk_spans, vector_search
from raglite._typing import SearchMethod


@@ -74,12 +74,13 @@ def validate_question(cls, value: str) -> str:
continue
# Expand the seed chunk into a set of related chunks.
related_chunk_ids, _ = vector_search(
np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
query=np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311
config=config,
)
related_chunks = [
str(segment) for segment in retrieve_segments(related_chunk_ids, config=config)
str(chunk_spans)
for chunk_spans in retrieve_chunk_spans(related_chunk_ids, config=config)
]
# Extract a question from the seed chunk's related chunks.
try:
@@ -92,7 +93,7 @@ def validate_question(cls, value: str) -> str:
question = question_response.question
# Search for candidate chunks to answer the generated question.
candidate_chunk_ids, _ = hybrid_search(
question, num_results=max_contexts_per_eval, config=config
query=question, num_results=max_contexts_per_eval, config=config
)
candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]

@@ -181,12 +182,15 @@ def answer_evals(
answers: list[str] = []
contexts: list[list[str]] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
segments = get_context_segments(eval_.question, search=search, config=config)
response = generate(eval_.question, context_segments=segments, config=config)
chunk_spans = retrieve_rag_context(query=eval_.question, search=search, config=config)
messages = [create_rag_instruction(user_prompt=eval_.question, context=chunk_spans)]
response = rag(messages, config=config)
answer = "".join(response)
answers.append(answer)
chunk_ids, _ = search(eval_.question, config=config)
contexts.append([str(segment) for segment in retrieve_segments(chunk_ids)])
chunk_ids, _ = search(query=eval_.question, config=config)
contexts.append(
[str(chunk_span) for chunk_span in retrieve_chunk_spans(chunk_ids, config=config)]
)
# Collect the answered evals.
answered_evals: dict[str, list[str] | list[list[str]]] = {
"question": [eval_.question for eval_ in evals],
2 changes: 1 addition & 1 deletion src/raglite/_extract.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ class MyNameResponse(BaseModel):
# Concatenate the user prompt if it is a list of strings.
if isinstance(user_prompt, list):
user_prompt = "\n\n".join(
f'<context index="{i}">\n{chunk.strip()}\n</context>'
f'<context index="{i + 1}">\n{chunk.strip()}\n</context>'
for i, chunk in enumerate(user_prompt)
)
# Enable JSON schema validation.
183 changes: 65 additions & 118 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
@@ -2,148 +2,95 @@

from collections.abc import AsyncIterator, Iterator

import numpy as np
from litellm import acompletion, completion

from raglite._config import RAGLiteConfig
from raglite._database import Chunk, ContextSegment
from raglite._database import ChunkSpan
from raglite._litellm import get_context_size
from raglite._search import hybrid_search, rerank_chunks, retrieve_segments
from raglite._search import hybrid_search, rerank_chunks, retrieve_chunk_spans
from raglite._typing import SearchMethod

RAG_SYSTEM_PROMPT = """
# The default RAG instruction template follows Anthropic's best practices [1].
# [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
RAG_INSTRUCTION_TEMPLATE = """
You are a friendly and knowledgeable assistant that provides complete and insightful answers.
Answer the user's question using only the context below.
When responding, you MUST NOT reference the existence of the context, directly or indirectly.
Instead, you MUST treat the context as if its contents are entirely part of your working memory.
You MUST observe the following rules:
1. Whenever possible, use only the provided context below to answer the question at the end.
2. Cite your sources with inline numerical citations of the form "[n]", where n is the document index.
Use commas to separate citations as "[a], [b], [c]" when citing multiple sources consecutively.
3. Do not print a list of sources at the end.
{context}
{user_prompt}
""".strip()


def _max_contexts(
prompt: str,
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
messages: list[dict[str, str]] | None = None,
config: RAGLiteConfig | None = None,
) -> int:
"""Determine the maximum number of contexts for RAG."""
# Get the model's context size.
config = config or RAGLiteConfig()
max_tokens = get_context_size(config)
# Reduce the maximum number of contexts to take into account the LLM's context size.
max_context_tokens = (
max_tokens
- sum(len(message["content"]) // 3 for message in messages or []) # Previous messages.
- len(RAG_SYSTEM_PROMPT) // 3 # System prompt.
- len(prompt) // 3 # User prompt.
)
max_tokens_per_context = config.chunk_max_size // 3
max_tokens_per_context *= 1 + len(context_neighbors or [])
max_contexts = min(max_contexts, max_context_tokens // max_tokens_per_context)
if max_contexts <= 0:
error_message = "Not enough context tokens available for RAG."
raise ValueError(error_message)
return max_contexts


def get_context_segments( # noqa: PLR0913
prompt: str,
def retrieve_rag_context(
query: str,
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
messages: list[dict[str, str]] | None = None,
num_chunks: int = 5,
chunk_neighbors: tuple[int, ...] | None = (-1, 1),
search: SearchMethod = hybrid_search,
config: RAGLiteConfig | None = None,
) -> list[ContextSegment]:
"""Retrieve contexts for RAG."""
# Determine the maximum number of contexts.
max_contexts = _max_contexts(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
messages=messages,
config=config,
)
# Retrieve the top chunks.
) -> list[ChunkSpan]:
"""Retrieve context for RAG."""
# If the user has configured a reranker, we retrieve extra contexts to rerank.
config = config or RAGLiteConfig()
chunks: list[str] | list[Chunk]
if callable(search):
# If the user has configured a reranker, we retrieve extra contexts to rerank.
extra_contexts = 3 * max_contexts if config.reranker else 0
# Retrieve relevant contexts.
chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config)
# Rerank the relevant contexts.
chunks = rerank_chunks(query=prompt, chunk_ids=chunk_ids, config=config)
else:
# The user has passed a list of chunk_ids or chunks directly.
chunks = search
extra_chunks = 3 * num_chunks if config.reranker else 0
# Search for relevant chunks.
chunk_ids, _ = search(query, num_results=num_chunks + extra_chunks, config=config)
# Rerank the chunks from most to least relevant.
chunks = rerank_chunks(query, chunk_ids=chunk_ids, config=config)
# Extend the top contexts with their neighbors and group chunks into contiguous segments.
segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config)
return segments
context = retrieve_chunk_spans(chunks[:num_chunks], neighbors=chunk_neighbors, config=config)
return context


def generate(
prompt: str,
def create_rag_instruction(
user_prompt: str,
context: list[ChunkSpan],
*,
system_prompt: str = RAG_SYSTEM_PROMPT,
messages: list[dict[str, str]] | None = None,
context_segments: list[ContextSegment],
config: RAGLiteConfig,
) -> Iterator[str]:
messages = _compose_messages(
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments
)
# Stream the LLM response.
stream = completion(
model=config.llm,
messages=_compose_messages(
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments
rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE,
) -> dict[str, str]:
"""Convert a user prompt to a RAG instruction.
The RAG instruction's format follows Anthropic's best practices [1].
[1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
"""
message = {
"role": "user",
"content": rag_instruction_template.format(
user_prompt=user_prompt,
context="\n".join(
chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context)
),
),
stream=True,
)
}
return message


def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[str]:
# Truncate the oldest messages so we don't hit the context limit.
max_tokens = get_context_size(config)
cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1])
messages = messages[-np.searchsorted(cum_tokens, max_tokens) :]
# Stream the LLM response.
stream = completion(model=config.llm, messages=messages, stream=True)
for output in stream:
token: str = output["choices"][0]["delta"].get("content") or ""
yield token


async def async_generate(
prompt: str,
*,
system_prompt: str = RAG_SYSTEM_PROMPT,
messages: list[dict[str, str]] | None = None,
context_segments: list[ContextSegment],
config: RAGLiteConfig,
) -> AsyncIterator[str]:
messages = _compose_messages(
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments
)
# Stream the LLM response.
async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> AsyncIterator[str]:
# Truncate the oldest messages so we don't hit the context limit.
max_tokens = get_context_size(config)
cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1])
messages = messages[-np.searchsorted(cum_tokens, max_tokens) :]
# Asynchronously stream the LLM response.
async_stream = await acompletion(model=config.llm, messages=messages, stream=True)
async for output in async_stream:
token: str = output["choices"][0]["delta"].get("content") or ""
yield token


def _compose_messages(
prompt: str,
system_prompt: str,
messages: list[dict[str, str]] | None,
segments: list[ContextSegment] | None,
) -> list[dict[str, str]]:
"""Compose the messages for the LLM, placing the context in the user position."""
# Using the format recommended by Anthropic for documents in RAG
# (https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips#essential-tips-for-long-context-prompts
if not segments:
return [
{"role": "system", "content": system_prompt},
*(messages or []),
{"role": "user", "content": prompt},
]

context_content = "<documents>\n" + "\n".join(str(seg) for seg in segments) + "\n</documents>"

return [
{"role": "system", "content": system_prompt},
*(messages or []),
{"role": "user", "content": prompt + "\n\n" + context_content},
]
139 changes: 84 additions & 55 deletions src/raglite/_search.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
"""Query documents."""
"""Search and retrieve chunks."""

import re
import string
from collections import defaultdict
from collections.abc import Sequence
from itertools import groupby
from operator import attrgetter, methodcaller
from typing import cast

import numpy as np
from langdetect import detect
from sqlalchemy.engine import make_url
from sqlalchemy.orm import joinedload
from sqlmodel import Session, and_, col, or_, select, text

from raglite._config import RAGLiteConfig
from raglite._database import (
Chunk,
ChunkEmbedding,
ContextSegment,
ChunkSpan,
IndexMetadata,
create_database_engine,
)
from raglite._embed import embed_sentences
from raglite._typing import FloatMatrix
from raglite._typing import ChunkId, FloatMatrix


def vector_search(
query: str | FloatMatrix, *, num_results: int = 3, config: RAGLiteConfig | None = None
) -> tuple[list[str], list[float]]:
) -> tuple[list[ChunkId], list[float]]:
"""Search chunks using ANN vector search."""
# Read the config.
config = config or RAGLiteConfig()
@@ -94,7 +94,7 @@ def vector_search(

def keyword_search(
query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
) -> tuple[list[str], list[float]]:
) -> tuple[list[ChunkId], list[float]]:
"""Search chunks using BM25 keyword search."""
# Read the config.
config = config or RAGLiteConfig()
@@ -144,8 +144,8 @@ def keyword_search(


def reciprocal_rank_fusion(
rankings: list[list[str]], *, k: int = 60
) -> tuple[list[str], list[float]]:
rankings: list[list[ChunkId]], *, k: int = 60
) -> tuple[list[ChunkId], list[float]]:
"""Reciprocal Rank Fusion."""
# Compute the RRF score.
chunk_ids = {chunk_id for ranking in rankings for chunk_id in ranking}
@@ -163,7 +163,7 @@ def reciprocal_rank_fusion(

def hybrid_search(
query: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None
) -> tuple[list[str], list[float]]:
) -> tuple[list[ChunkId], list[float]]:
"""Search chunks by combining ANN vector search with BM25 keyword search."""
# Run both searches.
vs_chunk_ids, _ = vector_search(query, num_results=num_rerank, config=config)
@@ -174,67 +174,34 @@ def hybrid_search(
return chunk_ids, hybrid_score


def retrieve_chunks(chunk_ids: list[str], *, config: RAGLiteConfig | None = None) -> list[Chunk]:
def retrieve_chunks(
chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None
) -> list[Chunk]:
"""Retrieve chunks by their ids."""
config = config or RAGLiteConfig()
engine = create_database_engine(config)
with Session(engine) as session:
chunks = list(session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all())
chunks = list(
session.exec(
select(Chunk)
.where(col(Chunk.id).in_(chunk_ids))
# Eagerly load chunk.document.
.options(joinedload(Chunk.document)) # type: ignore[arg-type]
).all()
)
chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id))
return chunks


def retrieve_segments(
chunk_ids: list[str] | list[Chunk],
*,
neighbors: tuple[int, ...] | None = (-1, 1),
config: RAGLiteConfig | None = None,
) -> list[ContextSegment]:
"""Group chunks into contiguous segments and retrieve them."""
# Retrieve the chunks.
config = config or RAGLiteConfig()
chunks: list[Chunk] = (
retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
if all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
else chunk_ids
)
# Assign a reciprocal ranking score to each chunk based on its position in the original list.
chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)}
# Extend the chunks with their neighbouring chunks.
if neighbors:
engine = create_database_engine(config)
with Session(engine) as session:
neighbor_conditions = [
and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset)
for chunk in chunks
for offset in neighbors
]
chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all())
# Deduplicate and sort the chunks by document_id and index (needed for groupby).
unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index))
# Group the chunks into contiguous segments.
context_segments: list[ContextSegment] = [
ContextSegment(
document_id=doc_id,
chunks=(doc_chunks := list(group)),
chunk_scores=[chunk_id_to_score.get(chunk.id, 0.0) for chunk in doc_chunks],
)
for doc_id, group in groupby(unique_chunks, key=attrgetter("document_id"))
]
# Rank segments according to the aggregate relevance of their chunks.
context_segments.sort(key=methodcaller("score", scoring_function=sum), reverse=True)
return context_segments


def rerank_chunks(
query: str, chunk_ids: list[str] | list[Chunk], *, config: RAGLiteConfig | None = None
query: str, chunk_ids: list[ChunkId] | list[Chunk], *, config: RAGLiteConfig | None = None
) -> list[Chunk]:
"""Rerank chunks according to their relevance to a given query."""
# Retrieve the chunks.
config = config or RAGLiteConfig()
chunks: list[Chunk] = (
retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
if all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids)
else chunk_ids
)
# Early exit if no reranker is configured.
@@ -259,3 +226,65 @@ def rerank_chunks(
results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks])
chunks = [chunks[result.doc_id] for result in results.results]
return chunks


def retrieve_chunk_spans(
chunk_ids: list[ChunkId] | list[Chunk],
*,
neighbors: tuple[int, ...] | None = (-1, 1),
config: RAGLiteConfig | None = None,
) -> list[ChunkSpan]:
"""Group chunks into contiguous chunk spans and retrieve them.
Chunk spans are ordered according to the aggregate relevance of their underlying chunks, as
determined by the order in which they are provided to this function.
"""
# Retrieve the chunks.
config = config or RAGLiteConfig()
chunks: list[Chunk] = (
retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids)
else chunk_ids
)
# Assign a reciprocal ranking score to each chunk based on its position in the original list.
chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)}
# Extend the chunks with their neighbouring chunks.
engine = create_database_engine(config)
with Session(engine) as session:
if neighbors:
neighbor_conditions = [
and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset)
for chunk in chunks
for offset in neighbors
]
chunks += list(
session.exec(
select(Chunk)
.where(or_(*neighbor_conditions))
# Eagerly load chunk.document.
.options(joinedload(Chunk.document)) # type: ignore[arg-type]
).all()
)
# Deduplicate and sort the chunks by document_id and index (needed for groupby).
unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index))
# Group the chunks into contiguous segments.
chunk_spans: list[ChunkSpan] = []
for _, group in groupby(unique_chunks, key=lambda chunk: chunk.document_id):
chunk_sequence: list[Chunk] = []
for chunk in group:
if not chunk_sequence or chunk.index == chunk_sequence[-1].index + 1:
chunk_sequence.append(chunk)
else:
chunk_spans.append(
ChunkSpan(chunks=chunk_sequence, document=chunk_sequence[0].document)
)
chunk_sequence = [chunk]
chunk_spans.append(ChunkSpan(chunks=chunk_sequence, document=chunk_sequence[0].document))
# Rank segments according to the aggregate relevance of their chunks.
chunk_spans.sort(
key=lambda chunk_span: sum(
chunk_id_to_score.get(chunk.id, 0.0) for chunk in chunk_span.chunks
),
reverse=True,
)
return chunk_spans
5 changes: 5 additions & 0 deletions src/raglite/_typing.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,11 @@

from raglite._config import RAGLiteConfig

ChunkId = str
DocumentId = str
EvalId = str
IndexId = str

FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]
39 changes: 15 additions & 24 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Test RAGLite's RAG functionality."""

import os
from typing import TYPE_CHECKING

import pytest
from llama_cpp import llama_supports_gpu_offload

from raglite import RAGLiteConfig, hybrid_search, retrieve_chunks
from raglite._rag import generate, get_context_segments

if TYPE_CHECKING:
from raglite._database import Chunk
from raglite._typing import SearchMethod
from raglite import (
RAGLiteConfig,
create_rag_instruction,
retrieve_rag_context,
)
from raglite._rag import rag


def is_accelerator_available() -> bool:
@@ -22,21 +21,13 @@ def is_accelerator_available() -> bool:
@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
def test_rag(raglite_test_config: RAGLiteConfig) -> None:
"""Test Retrieval-Augmented Generation."""
# Assemble different types of search inputs for RAG.
prompt = "What does it mean for two events to be simultaneous?"
search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [
hybrid_search, # A search method as input.
hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input.
retrieve_chunks( # Chunks as input.
hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config
),
]
# Answer a question with RAG.
for search_input in search_inputs:
segments = get_context_segments(prompt, search=search_input, config=raglite_test_config)
stream = generate(prompt, context_segments=segments, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
answer += update
assert "simultaneous" in answer.lower()
user_prompt = "What does it mean for two events to be simultaneous?"
chunk_spans = retrieve_rag_context(query=user_prompt, config=raglite_test_config)
messages = [create_rag_instruction(user_prompt, context=chunk_spans)]
stream = rag(messages, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
answer += update
assert "simultaneous" in answer.lower()
13 changes: 9 additions & 4 deletions tests/test_search.py
Original file line number Diff line number Diff line change
@@ -6,11 +6,11 @@
RAGLiteConfig,
hybrid_search,
keyword_search,
retrieve_chunk_spans,
retrieve_chunks,
retrieve_segments,
vector_search,
)
from raglite._database import Chunk, ContextSegment
from raglite._database import Chunk, ChunkSpan, Document
from raglite._typing import SearchMethod


@@ -43,9 +43,14 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
assert all(isinstance(chunk.document, Document) for chunk in chunks)
# Extend the chunks with their neighbours and group them into contiguous segments.
segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
assert all(isinstance(segment, ContextSegment) for segment in segments)
chunk_spans = retrieve_chunk_spans(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans)
chunk_spans = retrieve_chunk_spans(chunks, neighbors=(-1, 1), config=raglite_test_config)
assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans)


def test_search_no_results(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None:

0 comments on commit a27ab99

Please sign in to comment.