Skip to content

Commit

Permalink
feat: Align configuration for inference and evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
undo76 committed Dec 9, 2024
1 parent 0c5b7b5 commit 0d21062
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 92 deletions.
9 changes: 6 additions & 3 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def handle_message(user_message: cl.Message) -> None:
# Search for relevant contexts for RAG.
async with cl.Step(name="search", type="retrieval") as step:
step.input = user_message.content
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
chunk_ids, _ = await async_hybrid_search(query=user_prompt, config=config)
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
step.output = chunks
step.elements = [ # Show the top chunks inline.
Expand All @@ -116,8 +116,11 @@ async def handle_message(user_message: cl.Message) -> None:
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
# Stream the LLM response.
assistant_message = cl.Message(content="")
messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
messages: list[dict[str, str]] = [
*([{"role": "system", "content": config.system_prompt}] if config.system_prompt else []),
*(cl.chat_context.to_openai()[:-1]), # type: ignore[no-untyped-call]
create_rag_instruction(user_prompt=user_prompt, context=chunk_spans, config=config),
]
async for token in async_rag(messages, config=config):
await assistant_message.stream_token(token)
await assistant_message.update() # type: ignore[no-untyped-call]
20 changes: 20 additions & 0 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,32 @@
import contextlib
import os
from dataclasses import dataclass, field
from functools import partial
from io import StringIO
from typing import TYPE_CHECKING

from llama_cpp import llama_supports_gpu_offload
from sqlalchemy.engine import URL

from raglite._prompts import RAG_INSTRUCTION_TEMPLATE

if TYPE_CHECKING:
from raglite._typing import SearchMethod

# Suppress rerankers output on import until [1] is fixed.
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
with contextlib.redirect_stdout(StringIO()):
from rerankers.models.flashrank_ranker import FlashRankRanker
from rerankers.models.ranker import BaseRanker


def _default_search_method() -> "SearchMethod":
"""Get the default search method."""
from raglite._search import hybrid_search

return partial(hybrid_search, oversample=4)


@dataclass(frozen=True)
class RAGLiteConfig:
"""Configuration for RAGLite."""
Expand Down Expand Up @@ -53,6 +67,12 @@ class RAGLiteConfig:
),
compare=False, # Exclude the reranker from comparison to avoid lru_cache misses.
)
search_method: "SearchMethod" = field(default_factory=_default_search_method, compare=False)
system_prompt: str | None = None
rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE
num_chunks: int = 5
chunk_neighbors: tuple[int, ...] | None = (-1, 1) # Neighbors to include in the context.
reranker_oversample: int = 4 # How many extra chunks to retrieve for reranking (multiplied).

def __post_init__(self) -> None:
# Late chunking with llama-cpp-python does not apply sentence windowing.
Expand Down
32 changes: 18 additions & 14 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Generation and evaluation of evals."""

from dataclasses import replace
from random import randint
from typing import ClassVar

Expand All @@ -13,8 +14,7 @@
from raglite._database import Chunk, Document, Eval, create_database_engine
from raglite._extract import extract_with_llm
from raglite._rag import create_rag_instruction, rag, retrieve_rag_context
from raglite._search import hybrid_search, retrieve_chunk_spans, vector_search
from raglite._typing import SearchMethod
from raglite._search import retrieve_chunk_spans, search, vector_search


def insert_evals( # noqa: C901
Expand Down Expand Up @@ -76,8 +76,10 @@ def validate_question(cls, value: str) -> str:
# Expand the seed chunk into a set of related chunks.
related_chunk_ids, _ = vector_search(
query=np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311
config=config,
config=replace(
config,
num_chunks=randint(2, max_contexts_per_eval // 2), # noqa: S311
),
)
related_chunks = [
str(chunk_spans)
Expand All @@ -93,8 +95,8 @@ def validate_question(cls, value: str) -> str:
else:
question = question_response.question
# Search for candidate chunks to answer the generated question.
candidate_chunk_ids, _ = hybrid_search(
query=question, num_results=max_contexts_per_eval, config=config
candidate_chunk_ids, _ = search(
query=question, config=replace(config, num_chunks=max_contexts_per_eval)
)
candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]

Expand Down Expand Up @@ -173,12 +175,7 @@ class AnswerResponse(BaseModel):
session.commit()


def answer_evals(
num_evals: int = 100,
search: SearchMethod = hybrid_search,
*,
config: RAGLiteConfig | None = None,
) -> pd.DataFrame:
def answer_evals(num_evals: int = 100, *, config: RAGLiteConfig | None = None) -> pd.DataFrame:
"""Read evals from the database and answer them with RAG."""
# Read evals from the database.
config = config or RAGLiteConfig()
Expand All @@ -189,8 +186,15 @@ def answer_evals(
answers: list[str] = []
contexts: list[list[str]] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
chunk_spans = retrieve_rag_context(query=eval_.question, search=search, config=config)
messages = [create_rag_instruction(user_prompt=eval_.question, context=chunk_spans)]
chunk_spans = retrieve_rag_context(query=eval_.question, config=config)
messages = [
*(
[{"role": "system", "content": config.system_prompt}]
if config.system_prompt
else []
),
create_rag_instruction(user_prompt=eval_.question, context=chunk_spans),
]
response = rag(messages, config=config)
answer = "".join(response)
answers.append(answer)
Expand Down
12 changes: 12 additions & 0 deletions src/raglite/_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# The default RAG instruction template follows Anthropic's best practices [1].
# [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
RAG_INSTRUCTION_TEMPLATE = """
You are a friendly and knowledgeable assistant that provides complete and insightful answers.
Whenever possible, use only the provided context to respond to the question at the end.
When responding, you MUST NOT reference the existence of the context, directly or indirectly.
Instead, you MUST treat the context as if its contents are entirely part of your working memory.
{context}
{user_prompt}
""".strip()
16 changes: 10 additions & 6 deletions src/raglite/_query_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Compute and update an optimal query adapter."""

from dataclasses import replace

import numpy as np
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, col, select
Expand Down Expand Up @@ -64,8 +66,8 @@ def update_query_adapter( # noqa: PLR0915, C901
C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
"""
config = config or RAGLiteConfig()
config_no_query_adapter = RAGLiteConfig(
**{**config.__dict__, "vector_search_query_adapter": False}
config_no_query_adapter = replace(
config, vector_search_query_adapter=False, num_chunks=optimize_top_k
)
engine = create_database_engine(config)
with Session(engine) as session:
Expand All @@ -90,9 +92,7 @@ def update_query_adapter( # noqa: PLR0915, C901
# Embed the question.
question_embedding = embed_sentences([eval_.question], config=config)
# Retrieve chunks that would be used to answer the question.
chunk_ids, _ = vector_search(
question_embedding, num_results=optimize_top_k, config=config_no_query_adapter
)
chunk_ids, _ = vector_search(question_embedding, config=config_no_query_adapter)
retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
# Extract (q, p, n) triplets by comparing the retrieved chunks with the eval.
num_triplets = 0
Expand Down Expand Up @@ -131,7 +131,11 @@ def update_query_adapter( # noqa: PLR0915, C901
break
# Check if we have sufficient triplets to compute the query adapter.
if Q.shape[0] > max_triplets:
Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806
Q, P, N = ( # noqa: N806
Q[:max_triplets, :],
P[:max_triplets, :],
N[:max_triplets, :],
)
break
# Normalise the rows of Q, P, N.
Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806
Expand Down
42 changes: 11 additions & 31 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,46 @@
"""Retrieval-augmented generation."""

from collections.abc import AsyncIterator, Iterator
from dataclasses import replace

import numpy as np
from litellm import acompletion, completion

from raglite._config import RAGLiteConfig
from raglite._database import ChunkSpan
from raglite._litellm import get_context_size
from raglite._search import hybrid_search, rerank_chunks, retrieve_chunk_spans
from raglite._typing import SearchMethod
from raglite._search import rerank_chunks, retrieve_chunk_spans, search

# The default RAG instruction template follows Anthropic's best practices [1].
# [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
RAG_INSTRUCTION_TEMPLATE = """
You are a friendly and knowledgeable assistant that provides complete and insightful answers.
Whenever possible, use only the provided context to respond to the question at the end.
When responding, you MUST NOT reference the existence of the context, directly or indirectly.
Instead, you MUST treat the context as if its contents are entirely part of your working memory.

{context}
{user_prompt}
""".strip()


def retrieve_rag_context(
query: str,
*,
num_chunks: int = 5,
chunk_neighbors: tuple[int, ...] | None = (-1, 1),
search: SearchMethod = hybrid_search,
config: RAGLiteConfig | None = None,
) -> list[ChunkSpan]:
def retrieve_rag_context(query: str, *, config: RAGLiteConfig | None = None) -> list[ChunkSpan]:
"""Retrieve context for RAG."""
# If the user has configured a reranker, we retrieve extra contexts to rerank.
config = config or RAGLiteConfig()
extra_chunks = 3 * num_chunks if config.reranker else 0
oversampled_num_chunks = (
config.reranker_oversample * config.num_chunks if config.reranker else config.num_chunks
)
# Search for relevant chunks.
chunk_ids, _ = search(query, num_results=num_chunks + extra_chunks, config=config)
chunk_ids, _ = search(query, config=replace(config, num_chunks=oversampled_num_chunks))
# Rerank the chunks from most to least relevant.
chunks = rerank_chunks(query, chunk_ids=chunk_ids, config=config)
# Extend the top contexts with their neighbors and group chunks into contiguous segments.
context = retrieve_chunk_spans(chunks[:num_chunks], neighbors=chunk_neighbors, config=config)
context = retrieve_chunk_spans(chunks[: config.num_chunks], config=config)
return context


def create_rag_instruction(
user_prompt: str,
context: list[ChunkSpan],
*,
rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE,
user_prompt: str, context: list[ChunkSpan], *, config: RAGLiteConfig | None = None
) -> dict[str, str]:
"""Convert a user prompt to a RAG instruction.
The RAG instruction's format follows Anthropic's best practices [1].
[1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
"""
config = config or RAGLiteConfig()
message = {
"role": "user",
"content": rag_instruction_template.format(
"content": config.rag_instruction_template.format(
user_prompt=user_prompt.strip(),
context="\n".join(
chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context)
Expand Down
Loading

0 comments on commit 0d21062

Please sign in to comment.