Skip to content

Commit

Permalink
feat: decouple generation from context retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
undo76 committed Nov 27, 2024
1 parent 4e32914 commit c978e0e
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 104 deletions.
29 changes: 15 additions & 14 deletions src/raglite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from raglite._eval import answer_evals, evaluate, insert_evals
from raglite._insert import insert_document
from raglite._query_adapter import update_query_adapter
from raglite._rag import async_rag, rag
from raglite._rag import async_generate, generate, get_context_segments
from raglite._search import (
hybrid_search,
keyword_search,
Expand All @@ -18,24 +18,25 @@
__all__ = [
# Config
"RAGLiteConfig",
# Insert
"insert_document",
"answer_evals",
"async_generate",
# CLI
"cli",
"evaluate",
# RAG
"generate",
"get_context_segments",
# Search
"hybrid_search",
# Insert
"insert_document",
# Evaluate
"insert_evals",
"keyword_search",
"vector_search",
"rerank_chunks",
"retrieve_chunks",
"retrieve_segments",
"rerank_chunks",
# RAG
"async_rag",
"rag",
# Query adapter
"update_query_adapter",
# Evaluate
"insert_evals",
"answer_evals",
"evaluate",
# CLI
"cli",
"vector_search",
]
10 changes: 6 additions & 4 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from raglite import (
RAGLiteConfig,
async_rag,
async_generate,
get_context_segments,
hybrid_search,
insert_document,
rerank_chunks,
Expand Down Expand Up @@ -107,10 +108,11 @@ async def handle_message(user_message: cl.Message) -> None:
]
# Stream the LLM response.
assistant_message = cl.Message(content="")
async for token in async_rag(
context_segments = get_context_segments(user_prompt, config=config)
async for token in async_generate(
prompt=user_prompt,
search=chunks,
messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call]
messages=cl.chat_context.to_openai()[-5:-1], # type: ignore[no-untyped-call]
context_segments=context_segments,
config=config,
):
await assistant_message.stream_token(token)
Expand Down
42 changes: 6 additions & 36 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def as_xml(self) -> str:
xml = "\n".join(
[
f'<document id="{escape(self.document_id)}" chunk_ids="{escape(chunk_ids)}">',
escape(self.as_str),
escape(self.reconstructed_str),
"</document>",
]
)
Expand All @@ -388,39 +388,9 @@ def chunk_ids(self) -> list[str]:
return [chunk.id for chunk in self.chunks]

@property
def as_str(self) -> str:
"""Return a string representation reconstructing the document with headings.
def reconstructed_str(self) -> str:
"""Return a string representation reconstructing the document with headings."""
heading = self.chunks[0].headings if self.chunks else ""
bodies = "\n".join(chunk.body for chunk in self.chunks)

Treats headings as a stack, showing headers only when they differ from
the current stack path.
For example:
- "# A ## B" shows both headers
- "# A ## B" shows nothing (already seen)
- "# A ## C" shows only "## C" (new branch)
- "# D ## B" shows both (new path)
"""
if not self.chunks:
return ""

result: list[str] = []
stack: list[str] = []

for chunk in self.chunks:
headers = [h.strip() for h in chunk.headings.split("\n") if h.strip()]

# Find first differing header
i = 0
while i < len(headers) and i < len(stack) and headers[i] == stack[i]:
i += 1

# Update stack and show new headers
stack[i:] = headers[i:]
if headers[i:]:
result.extend(headers[i:])
result.append("")

if chunk.body.strip():
result.append(chunk.body.strip())

return "\n".join(result).strip()
return f"{heading}\n\n{bodies}".strip()
9 changes: 5 additions & 4 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from raglite._config import RAGLiteConfig
from raglite._database import Chunk, Document, Eval, create_database_engine
from raglite._extract import extract_with_llm
from raglite._rag import rag
from raglite._rag import generate, get_context_segments
from raglite._search import hybrid_search, retrieve_segments, vector_search
from raglite._typing import SearchMethod

Expand Down Expand Up @@ -181,7 +181,8 @@ def answer_evals(
answers: list[str] = []
contexts: list[list[str]] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
response = rag(eval_.question, search=search, config=config)
segments = get_context_segments(eval_.question, search=search, config=config)
response = generate(eval_.question, context_segments=segments, config=config)
answer = "".join(response)
answers.append(answer)
chunk_ids, _ = search(eval_.question, config=config)
Expand Down Expand Up @@ -233,13 +234,13 @@ def evaluate(
verbose=llm.verbose,
)
else:
lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg]
lc_llm = ChatLiteLLM(model=config.llm)
# Load the embedder.
if not config.embedder.startswith("llama-cpp-python"):
error_message = "Currently, only `llama-cpp-python` embedders are supported."
raise NotImplementedError(error_message)
embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True)
lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg]
lc_embedder = LlamaCppEmbeddings(
model_path=embedder.model_path,
n_batch=embedder.n_batch,
n_ctx=embedder.n_ctx(),
Expand Down
58 changes: 14 additions & 44 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Retrieval-augmented generation."""

from collections.abc import AsyncIterator, Iterator
from typing import cast

from litellm import acompletion, completion

Expand Down Expand Up @@ -47,7 +46,7 @@ def _max_contexts(
return max_contexts


def context_segments( # noqa: PLR0913
def get_context_segments( # noqa: PLR0913
prompt: str,
*,
max_contexts: int = 5,
Expand Down Expand Up @@ -83,35 +82,22 @@ def context_segments( # noqa: PLR0913
return segments


def rag( # noqa: PLR0913
def generate(
prompt: str,
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search,
messages: list[dict[str, str]] | None = None,
system_prompt: str = RAG_SYSTEM_PROMPT,
config: RAGLiteConfig | None = None,
messages: list[dict[str, str]] | None = None,
context_segments: list[ContextSegment],
config: RAGLiteConfig,
) -> Iterator[str]:
"""Retrieval-augmented generation."""
# Get the contexts for RAG as contiguous segments of chunks.
config = config or RAGLiteConfig()
segments: list[ContextSegment]
if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search):
segments = cast(list[ContextSegment], search)
else:
segments = context_segments(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
search=search, # type: ignore[arg-type]
config=config,
)
messages = _compose_messages(
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments
)
# Stream the LLM response.
stream = completion(
model=config.llm,
messages=_compose_messages(
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments
),
stream=True,
)
Expand All @@ -120,32 +106,16 @@ def rag( # noqa: PLR0913
yield token


async def async_rag( # noqa: PLR0913
async def async_generate(
prompt: str,
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search,
messages: list[dict[str, str]] | None = None,
system_prompt: str = RAG_SYSTEM_PROMPT,
config: RAGLiteConfig | None = None,
messages: list[dict[str, str]] | None = None,
context_segments: list[ContextSegment],
config: RAGLiteConfig,
) -> AsyncIterator[str]:
"""Retrieval-augmented generation."""
# Get the contexts for RAG as contiguous segments of chunks.
config = config or RAGLiteConfig()
segments: list[ContextSegment]
if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search):
segments = cast(list[ContextSegment], search)
else:
segments = context_segments(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
search=search, # type: ignore[arg-type]
config=config,
)
messages = _compose_messages(
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments
)
# Stream the LLM response.
async_stream = await acompletion(model=config.llm, messages=messages, stream=True)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest
from llama_cpp import llama_supports_gpu_offload

from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks
from raglite import RAGLiteConfig, hybrid_search, retrieve_chunks
from raglite._rag import generate, get_context_segments

if TYPE_CHECKING:
from raglite._database import Chunk
Expand All @@ -32,7 +33,8 @@ def test_rag(raglite_test_config: RAGLiteConfig) -> None:
]
# Answer a question with RAG.
for search_input in search_inputs:
stream = rag(prompt, search=search_input, config=raglite_test_config)
segments = get_context_segments(prompt, search=search_input, config=raglite_test_config)
stream = generate(prompt, context_segments=segments, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
Expand Down

0 comments on commit c978e0e

Please sign in to comment.