Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: much improved RAG, added LLM post-processing of results #435

Merged
merged 10 commits into from
Feb 5, 2025
53 changes: 50 additions & 3 deletions gptme/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,56 @@ def dict(self) -> dict:
}


default_post_process_prompt = """
You are an intelligent knowledge retrieval assistant designed to analyze context chunks and extract relevant information based on user queries. Your primary goal is to provide accurate and helpful information while adhering to specific guidelines.

You will be provided with a user query inside <user_query> tags and a list of potentially relevant context chunks inside <chunks> tags.

When a user submits a query, follow these steps:

1. Analyze the user's query carefully to identify key concepts and requirements.

2. Search through the provided context chunks for relevant information.

3. If you find relevant information:
a. Extract the most pertinent parts.
b. Summarize the relevant context inside <context_summary> tags.
c. Output the exact relevant context chunks, including the complete <chunks path="...">...</chunks> tags.

4. If you cannot find any relevant information, respond with exactly: "No relevant context found".

Important guidelines:
- Do not make assumptions beyond the available data.
- Maintain objectivity in source selection.
- When returning context chunks, include the entire content of the <chunks> tag. Do not modify or truncate it in any way.
- Ensure that you're providing complete information from the chunks, not partial or summarized versions within the tags.
- When no relevant context is found, do not return anything other than exactly "No relevant context found".
- Do not output anything else than the <context_summary> and <chunks> tags.

Please provide your response, starting with the summary and followed by the relevant chunks (if any).
"""


@dataclass
class RagConfig:
enabled: bool = False
max_tokens: int | None = None
min_relevance: float | None = None
post_process: bool = True
post_process_model: str | None = None
post_process_prompt: str = default_post_process_prompt
workspace_only: bool = True
paths: list[str] = field(default_factory=list)


@dataclass
class ProjectConfig:
"""Project-level configuration, such as which files to include in the context by default."""

base_prompt: str | None = None
prompt: str | None = None
files: list[str] = field(default_factory=list)
rag: dict = field(default_factory=dict)
rag: RagConfig = field(default_factory=RagConfig)


ABOUT_ACTIVITYWATCH = """ActivityWatch is a free and open-source automated time-tracker that helps you track how you spend your time on your devices."""
Expand Down Expand Up @@ -146,8 +188,13 @@ def get_project_config(workspace: Path | None) -> ProjectConfig | None:
)
# load project config
with open(project_config_path) as f:
project_config = tomlkit.load(f)
return ProjectConfig(**project_config) # type: ignore
config_data = dict(tomlkit.load(f))

# Handle RAG config conversion before creating ProjectConfig
if "rag" in config_data:
config_data["rag"] = RagConfig(**config_data["rag"]) # type: ignore

return ProjectConfig(**config_data) # type: ignore
return None


Expand Down
112 changes: 83 additions & 29 deletions gptme/tools/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

[rag]
enabled = true
post_process = false # Whether to post-process the context with an LLM to extract the most relevant information
post_process_model = "openai/gpt-4o-mini" # Which model to use for post-processing
post_process_prompt = "" # Optional prompt to use for post-processing (overrides default prompt)
workspace_only = true # Whether to only search in the workspace directory, or the whole RAG index
paths = [] # List of paths to include in the RAG index. Has no effect if workspace_only is true.

.. rubric:: Features

Expand All @@ -36,9 +41,10 @@
from functools import lru_cache
from pathlib import Path

from ..config import get_project_config
from ..config import RagConfig, get_project_config
from ..message import Message
from ..util import get_project_dir
from ..llm import _chat_complete
from .base import ToolSpec, ToolUse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -106,7 +112,7 @@ def rag_search(query: str, return_full: bool = False) -> str:
cmd = ["gptme-rag", "search", query]
if return_full:
# shows full context of the search results
cmd.append("--show-context")
cmd.extend(["--format", "full", "--score"])

result = _run_rag_cmd(cmd)
return result.stdout.strip()
Expand All @@ -129,7 +135,7 @@ def init() -> ToolSpec:
# Check project configuration
project_dir = get_project_dir()
if project_dir and (config := get_project_config(project_dir)):
enabled = config.rag.get("enabled", False)
enabled = config.rag.enabled
if not enabled:
logger.debug("RAG not enabled in the project configuration")
return replace(tool, available=False)
Expand All @@ -140,41 +146,89 @@ def init() -> ToolSpec:
return tool


def rag_enhance_messages(messages: list[Message]) -> list[Message]:
def get_rag_context(
query: str,
rag_config: RagConfig,
workspace: Path | None = None,
) -> Message:
"""Get relevant context chunks from RAG for the user query."""

should_post_process = (
rag_config.post_process and rag_config.post_process_model is not None
)

cmd = [
"gptme-rag",
"search",
query,
]
if workspace and rag_config.workspace_only:
cmd.append(workspace.as_posix())
elif rag_config.paths:
cmd.extend(rag_config.paths)
if not should_post_process:
cmd.append("--score")
cmd.extend(["--format", "full"])

if rag_config.max_tokens:
cmd.extend(["--max-tokens", str(rag_config.max_tokens)])
if rag_config.min_relevance:
cmd.extend(["--min-relevance", str(rag_config.min_relevance)])
rag_result = _run_rag_cmd(cmd).stdout

# Post-process the context with an LLM (if enabled)
if should_post_process:
post_process_msgs = [
Message(role="system", content=rag_config.post_process_prompt),
Message(role="system", content=rag_result),
Message(
role="user",
content=f"<user_query>\n{query}\n</user_query>",
),
]
start = time.monotonic()
rag_result = _chat_complete(
messages=post_process_msgs,
model=rag_config.post_process_model, # type: ignore
tools=[],
)
logger.info(f"Ran RAG post-process in {time.monotonic() - start:.2f}s")

# Create the context message
msg = Message(
role="system",
content=f"Relevant context retrieved using `gptme-rag search`:\n\n{rag_result}",
hide=True,
)
return msg


def rag_enhance_messages(
messages: list[Message], workspace: Path | None = None
) -> list[Message]:
"""Enhance messages with context from RAG."""
if not _has_gptme_rag():
return messages

# Load config
config = get_project_config(Path.cwd())
rag_config = config.rag if config and config.rag else {}
rag_config = config.rag if config and config.rag else RagConfig()

if not rag_config.get("enabled", False):
if not rag_config.enabled:
return messages

enhanced_messages = []
for msg in messages:
if msg.role == "user":
try:
# Get context using gptme-rag CLI
cmd = ["gptme-rag", "search", msg.content, "--show-context"]
if max_tokens := rag_config.get("max_tokens"):
cmd.extend(["--max-tokens", str(max_tokens)])
if min_relevance := rag_config.get("min_relevance"):
cmd.extend(["--min-relevance", str(min_relevance)])
enhanced_messages.append(
Message(
role="system",
content=f"Relevant context retrieved using `gptme-rag search`:\n\n{_run_rag_cmd(cmd).stdout}",
hide=True,
)
)
except Exception as e:
logger.warning(f"Error getting context: {e}")

enhanced_messages.append(msg)

return enhanced_messages
last_msg = messages[-1] if messages else None
if last_msg and last_msg.role == "user":
try:
# Get context using gptme-rag CLI
msg = get_rag_context(last_msg.content, rag_config, workspace)

# Append context message right before the last user message
messages.insert(-1, msg)
except Exception as e:
logger.warning(f"Error getting context: {e}")

return messages


tool = ToolSpec(
Expand Down
2 changes: 1 addition & 1 deletion gptme/util/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def enrich_messages_with_context(
msgs = copy(msgs)

# First enhance messages with context, if gptme-rag is available
msgs = rag_enhance_messages(msgs)
msgs = rag_enhance_messages(msgs, workspace)

msgs = [
append_file_content(msg, workspace, check_modified=use_fresh_context())
Expand Down
Loading