Skip to content

Commit

Permalink
fix: improve output for empty databases
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber committed Dec 17, 2024
1 parent 2e9bfaf commit fea61a2
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_
Next, it is optional but recommended to install [an accelerated llama-cpp-python precompiled binary](https://github.com/abetlen/llama-cpp-python?tab=readme-ov-file#supported-backends) with:

```sh
# Configure which llama-cpp-python precompiled binary to install (⚠️ only v0.3.2 is supported right now):
# Configure which llama-cpp-python precompiled binary to install (⚠️ On macOS only v0.3.2 is supported right now):
LLAMA_CPP_PYTHON_VERSION=0.3.2
PYTHON_VERSION=310
ACCELERATOR=metal|cu121|cu122|cu123|cu124
Expand Down
5 changes: 3 additions & 2 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field
from io import StringIO
from pathlib import Path
from typing import Literal

from llama_cpp import llama_supports_gpu_offload
from platformdirs import user_data_dir
Expand Down Expand Up @@ -48,8 +49,8 @@ class RAGLiteConfig:
# Chunk config used to partition documents into chunks.
chunk_max_size: int = 1440 # Max number of characters per chunk.
# Vector search config.
vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine".
vector_search_query_adapter: bool = True
vector_search_index_metric: Literal["cosine", "dot", "l1", "l2"] = "cosine"
vector_search_query_adapter: bool = True # Only supported for "cosine" and "dot" metrics.
# Reranking config.
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = field(
default_factory=lambda: (
Expand Down
6 changes: 4 additions & 2 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
When responding, you MUST NOT reference the existence of the context, directly or indirectly.
Instead, you MUST treat the context as if its contents are entirely part of your working memory.
{context}
<context>{context}</context>
{user_prompt}
""".strip()
Expand Down Expand Up @@ -91,7 +91,9 @@ def _get_tools(
"""Get tools to search the knowledge base if no RAG context is provided in the messages."""
# Check if messages already contain RAG context or if the LLM supports tool use.
final_message = messages[-1].get("content", "")
messages_contain_rag_context = any(s in final_message for s in ("</document>", "from_chunk_id"))
messages_contain_rag_context = any(
s in final_message for s in ("<context>", "</document>", "from_chunk_id")
)
llm_supports_function_calling = supports_function_calling(config.llm)
if not messages_contain_rag_context and not llm_supports_function_calling:
error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling."
Expand Down
44 changes: 30 additions & 14 deletions src/raglite/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections import defaultdict
from collections.abc import Sequence
from itertools import groupby
from typing import cast

import numpy as np
from langdetect import LangDetectException, detect
Expand Down Expand Up @@ -66,23 +65,32 @@ def vector_search(
.order_by(distance)
.limit(oversample * num_results)
)
chunk_ids_, distance = zip(*results, strict=True)
chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance)
results = list(results) # type: ignore[assignment]
chunk_ids = np.asarray([result[0] for result in results])
similarity = 1.0 - np.asarray([result[1] for result in results])
elif db_backend == "sqlite":
# Load the NNDescent index.
index = index_metadata.get("index")
ids = np.asarray(index_metadata.get("chunk_ids"))
cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes")))
ids = np.asarray(index_metadata.get("chunk_ids", []))
cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes", [])))
# Find the neighbouring multi-vector indices.
from pynndescent import NNDescent

multi_vector_indices, distance = cast(NNDescent, index).query(
query_embedding[np.newaxis, :], k=oversample * num_results
)
similarity = 1 - distance[0, :]
# Transform the multi-vector indices into chunk indices, and then to chunk ids.
chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1
chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices])
if isinstance(index, NNDescent) and len(ids) and len(cumsum):
# Query the index.
multi_vector_indices, distance = index.query(
query_embedding[np.newaxis, :], k=oversample * num_results
)
similarity = 1 - distance[0, :]
# Transform the multi-vector indices into chunk indices, and then to chunk ids.
chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1
chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices])
else:
# Empty result set if there is no index or if no chunks are indexed.
chunk_ids, similarity = np.array([], dtype=np.intp), np.array([])
# Exit early if there are no search results.
if not len(chunk_ids):
return [], []
# Score each unique chunk id as the mean similarity of its multi-vector hits. Chunk ids with
# fewer hits are padded with the minimum similarity of the result set.
unique_chunk_ids, counts = np.unique(chunk_ids, return_counts=True)
Expand Down Expand Up @@ -157,6 +165,9 @@ def reciprocal_rank_fusion(
chunk_id_index = {chunk_id: i for i, chunk_id in enumerate(ranking)}
for chunk_id in chunk_ids:
chunk_id_score[chunk_id] += 1 / (k + chunk_id_index.get(chunk_id, len(chunk_id_index)))
# Exit early if there are no results to fuse.
if not chunk_id_score:
return [], []
# Rank RRF results according to descending RRF score.
rrf_chunk_ids, rrf_score = zip(
*sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True
Expand All @@ -181,6 +192,8 @@ def retrieve_chunks(
chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None
) -> list[Chunk]:
"""Retrieve chunks by their ids."""
if not chunk_ids:
return []
config = config or RAGLiteConfig()
engine = create_database_engine(config)
with Session(engine) as session:
Expand All @@ -207,8 +220,8 @@ def rerank_chunks(
if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids)
else chunk_ids
)
# Early exit if no reranker is configured.
if not config.reranker:
# Exit early if no reranker is configured or if the input is empty.
if not config.reranker or not chunks:
return chunks
# Select the reranker.
if isinstance(config.reranker, Sequence):
Expand Down Expand Up @@ -243,6 +256,9 @@ def retrieve_chunk_spans(
Chunk spans are ordered according to the aggregate relevance of their underlying chunks, as
determined by the order in which they are provided to this function.
"""
# Exit early if the input is empty.
if not chunk_ids:
return []
# Retrieve the chunks.
config = config or RAGLiteConfig()
chunks: list[Chunk] = (
Expand Down
17 changes: 17 additions & 0 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None:
assert [message["role"] for message in messages] == ["user", "assistant"]


def test_rag_manual_empty_database(llm: str, embedder: str) -> None:
"""Test Retrieval-Augmented Generation with manual retrieval."""
# Answer a question with manual RAG.
raglite_test_config = RAGLiteConfig(db_url="sqlite:///:memory:", llm=llm, embedder=embedder)
user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
chunk_spans = retrieve_rag_context(query=user_prompt, config=raglite_test_config)
messages = [create_rag_instruction(user_prompt, context=chunk_spans)]
stream = rag(messages, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
answer += update
assert "event" in answer.lower()
# Verify that no RAG context was retrieved through tool use.
assert [message["role"] for message in messages] == ["user", "assistant"]


def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
"""Test Retrieval-Augmented Generation with automatic retrieval."""
# Answer a question that requires RAG.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,15 @@ def test_search_no_results(raglite_test_config: RAGLiteConfig, search_method: Se
assert len(chunk_ids) == len(scores) == num_results_expected
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
assert all(isinstance(score, float) for score in scores)


def test_search_empty_database(llm: str, embedder: str, search_method: SearchMethod) -> None:
"""Test searching for a query with an empty database."""
raglite_test_config = RAGLiteConfig(db_url="sqlite:///:memory:", llm=llm, embedder=embedder)
query = "supercalifragilisticexpialidocious"
num_results = 5
chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
num_results_expected = 0
assert len(chunk_ids) == len(scores) == num_results_expected
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
assert all(isinstance(score, float) for score in scores)

0 comments on commit fea61a2

Please sign in to comment.