Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support prompt caching and apply Anthropic's long-context prompt format #52

Merged
merged 12 commits into from
Dec 3, 2024
160 changes: 129 additions & 31 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,19 @@

import datetime
import json
from collections.abc import Callable
from dataclasses import dataclass
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import Any
from xml.sax.saxutils import escape

import numpy as np
from markdown_it import MarkdownIt
from pydantic import ConfigDict
from sqlalchemy.engine import Engine, make_url
from sqlmodel import (
JSON,
Column,
Field,
Relationship,
Session,
SQLModel,
create_engine,
text,
)
from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text

from raglite._config import RAGLiteConfig
from raglite._litellm import get_embedding_dim
Expand Down Expand Up @@ -83,11 +77,7 @@ class Chunk(SQLModel, table=True):

@staticmethod
def from_body(
document_id: str,
index: int,
body: str,
headings: str = "",
**kwargs: Any,
document_id: str, index: int, body: str, headings: str = "", **kwargs: Any
) -> "Chunk":
"""Create a chunk from Markdown."""
return Chunk(
Expand Down Expand Up @@ -221,10 +211,7 @@ class Eval(SQLModel, table=True):

@staticmethod
def from_chunks(
question: str,
contexts: list[Chunk],
ground_truth: str,
**kwargs: Any,
question: str, contexts: list[Chunk], ground_truth: str, **kwargs: Any
) -> "Eval":
"""Create a chunk from Markdown."""
document_id = contexts[0].document_id
Expand Down Expand Up @@ -284,18 +271,22 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
with Session(engine) as session:
metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"}
session.execute(
text("""
text(
"""
CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body));
""")
"""
)
)
session.execute(
text(f"""
text(
f"""
CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding
USING hnsw (
(embedding::halfvec({embedding_dim}))
halfvec_{metrics[config.vector_search_index_metric]}_ops
);
""")
"""
)
)
session.commit()
elif db_backend == "sqlite":
Expand All @@ -304,31 +295,138 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
# [1] https://www.sqlite.org/fts5.html#external_content_tables
with Session(engine) as session:
session.execute(
text("""
text(
"""
CREATE VIRTUAL TABLE IF NOT EXISTS keyword_search_chunk_index USING fts5(body, content='chunk', content_rowid='rowid');
""")
"""
)
)
session.execute(
text("""
text(
"""
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
END;
""")
"""
)
)
session.execute(
text("""
text(
"""
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
END;
""")
"""
)
)
session.execute(
text("""
text(
"""
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
END;
""")
"""
)
)
session.commit()
return engine


@dataclass
class ContextSegment:
"""A class representing a segment of context from a document.

This class holds information about a specific segment of a document,
including its document ID and associated chunks of text with their IDs and scores.

Attributes
----------
document_id (str): The unique identifier for the document.
chunks (list[Chunk]): List of chunks for this segment.
chunk_scores (list[float]): List of scores for each chunk.

Raises
------
ValueError: If document_id is empty or if chunks is empty.
"""

document_id: str
chunks: list[Chunk]
chunk_scores: list[float]

def __post_init__(self) -> None:
"""Validate the segment data after initialization."""
if not isinstance(self.document_id, str) or not self.document_id.strip():
msg = "document_id must be a non-empty string"
raise ValueError(msg)
if not self.chunks:
msg = "chunks cannot be empty"
raise ValueError(msg)
if not all(isinstance(chunk, Chunk) for chunk in self.chunks):
msg = "all elements in chunks must be Chunk instances"
raise ValueError(msg)
undo76 marked this conversation as resolved.
Show resolved Hide resolved

def to_xml(self, indent: int = 4) -> str:
undo76 marked this conversation as resolved.
Show resolved Hide resolved
"""Convert the segment to an XML string representation.

Args:
indent (int): Number of spaces to use for indentation.

Returns
-------
str: XML representation of the segment.
"""
chunks_content = "\n".join(str(chunk) for chunk in self.chunks)
undo76 marked this conversation as resolved.
Show resolved Hide resolved

# Create the final XML
chunk_ids = ",".join(self.chunk_ids)
xml = f"""<document id="{escape(self.document_id)}" chunk_ids="{escape(chunk_ids)}">\n{escape(str(chunks_content))}\n</document>"""

return xml
undo76 marked this conversation as resolved.
Show resolved Hide resolved

def score(self, scoring_function: Callable[[list[float]], float] = sum) -> float:
"""Return an aggregated score of the segment, given a scoring function."""
return scoring_function(self.chunk_scores)

@property
def chunk_ids(self) -> list[str]:
"""Return a list of chunk IDs."""
return [chunk.id for chunk in self.chunks]

def __str__(self) -> str:
"""Return a string representation reconstructing the document with headings.

Treats headings as a stack, showing headers only when they differ from
the current stack path.

For example:
- "# A ## B" shows both headers
- "# A ## B" shows nothing (already seen)
- "# A ## C" shows only "## C" (new branch)
- "# D ## B" shows both (new path)
"""
undo76 marked this conversation as resolved.
Show resolved Hide resolved
if not self.chunks:
return ""

result: list[str] = []
stack: list[str] = []

for chunk in self.chunks:
headers = [h.strip() for h in chunk.headings.split("\n") if h.strip()]

# Find first differing header
i = 0
while i < len(headers) and i < len(stack) and headers[i] == stack[i]:
i += 1

# Update stack and show new headers
stack[i:] = headers[i:]
if headers[i:]:
result.extend(headers[i:])
result.append("")

if chunk.body.strip():
result.append(chunk.body.strip())

return "\n".join(result).strip()
13 changes: 6 additions & 7 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def validate_question(cls, value: str) -> str:
num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311
config=config,
)
related_chunks = retrieve_segments(related_chunk_ids, config=config)
related_chunks = [
str(segment) for segment in retrieve_segments(related_chunk_ids, config=config)
]
# Extract a question from the seed chunk's related chunks.
try:
question_response = extract_with_llm(
Expand Down Expand Up @@ -157,9 +159,7 @@ class AnswerResponse(BaseModel):
answer = answer_response.answer
# Store the eval in the database.
eval_ = Eval.from_chunks(
question=question,
contexts=relevant_chunks,
ground_truth=answer,
question=question, contexts=relevant_chunks, ground_truth=answer
)
session.add(eval_)
session.commit()
Expand All @@ -185,7 +185,7 @@ def answer_evals(
answer = "".join(response)
answers.append(answer)
chunk_ids, _ = search(eval_.question, config=config)
contexts.append(retrieve_segments(chunk_ids))
contexts.append([str(segment) for segment in retrieve_segments(chunk_ids)])
# Collect the answered evals.
answered_evals: dict[str, list[str] | list[list[str]]] = {
"question": [eval_.question for eval_ in evals],
Expand All @@ -199,8 +199,7 @@ def answer_evals(


def evaluate(
answered_evals: pd.DataFrame | int = 100,
config: RAGLiteConfig | None = None,
answered_evals: pd.DataFrame | int = 100, config: RAGLiteConfig | None = None
) -> pd.DataFrame:
"""Evaluate the performance of a set of answered evals with Ragas."""
try:
Expand Down
Loading