Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mem0ai dependency to pyproject.toml #3244

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/api/quivr_api/modules/rag_service/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ async def generate_answer(
# Format the history, sanitize the input
chat_history = self._build_chat_history(history)

parsed_response = rag_pipeline.answer(question, chat_history, list_files)
parsed_response = rag_pipeline.answer(question, chat_history, list_files, str(self.brain.brain_id))

# Save the answer to db
new_chat_entry = self.save_answer(question, parsed_response)
Expand Down Expand Up @@ -212,7 +212,7 @@ async def generate_answer_stream(
)
# Initialize the rag pipline
rag_pipeline = QuivrQARAGLangGraph(
rag_config=rag_config, llm=llm, vector_store=vector_store
rag_config=rag_config, llm=llm, vector_store=vector_store, memory_id=str(self.brain.brain_id)
)

full_answer = ""
Expand Down
1 change: 1 addition & 0 deletions backend/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"aiofiles>=23.1.0",
"langchain-community>=0.2.12",
"langchain-anthropic>=0.1.23",
"mem0ai>=0.1.15",
]
readme = "README.md"
requires-python = ">= 3.11"
Expand Down
1 change: 1 addition & 0 deletions backend/core/quivr_core/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

If not None, User instruction to follow to answer: {custom_instructions}
Don't cite the source id in the answer objects, but you can use the source to answer the question.
{memories}
"""


Expand Down
52 changes: 50 additions & 2 deletions backend/core/quivr_core/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import logging
from typing import Annotated, AsyncGenerator, Optional, Sequence, TypedDict
import os
from urllib.parse import urlparse
from mem0 import Memory


# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
from langchain.retrievers import ContextualCompressionRetriever
Expand All @@ -23,7 +27,6 @@
)
from quivr_core.prompts import ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
from quivr_core.utils import (
combine_documents,
format_file_list,
get_chunk_metadata,
parse_chunk_response,
Expand All @@ -32,6 +35,26 @@

logger = logging.getLogger("quivr_core")

# Read the database URL from the environment
database_url = os.getenv("PG_DATABASE_ASYNC_URL")
parsed_url = urlparse(database_url)

config = {
"vector_store": {
"provider": "pgvector",
"config": {
"user": parsed_url.username,
"password": parsed_url.password,
"host": parsed_url.hostname,
"port": parsed_url.port,
"dbname": parsed_url.path[1:], # Remove leading '/'
}
}
}

m = Memory.from_config(config)



class AgentState(TypedDict):
# The add_messages function defines how an update should be processed
Expand All @@ -43,6 +66,7 @@ class AgentState(TypedDict):
transformed_question: BaseMessage
files: str
final_response: dict
mem0_user_id: str


class IdempotentCompressor(BaseDocumentCompressor):
Expand All @@ -69,6 +93,7 @@ def __init__(
llm: LLMEndpoint,
vector_store: VectorStore,
reranker: BaseDocumentCompressor | None = None,
memory_id: str | None = None,
):
"""
Construct a QuivrQARAGLangGraph object.
Expand All @@ -87,6 +112,8 @@ def __init__(
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=self.reranker, base_retriever=self.retriever
)
self.memory = Memory.from_config(config)
self.memory_id = memory_id

@property
def retriever(self):
Expand Down Expand Up @@ -181,15 +208,26 @@ def generate(self, state):
files = state["files"]

docs = state["docs"]

memories = self.memory.search(question, user_id=self.memory_id)
print(memories)

context = "Memory and relevant information from previous conversations:\n"
for memory in memories:
context += f"- {memory['memory']}\n"
context += "End of memory"

print(context)

# Prompt
prompt = self.rag_config.prompt

final_inputs = {
"context": combine_documents(docs),
"context": context,
"question": question,
"custom_instructions": prompt,
"files": files,
"memories": context,
}

# LLM
Expand All @@ -209,6 +247,11 @@ def generate(self, state):
"answer": response, # Assuming the last message contains the final answer
"docs": docs,
}

print("Adding to memory")
result = self.memory.add(f"User: {question}\nAssistant: {response}", user_id=self.memory_id)
print("Added to memory")
print(result)
return {"messages": [response], "final_response": formatted_response}

def build_langgraph_chain(self):
Expand Down Expand Up @@ -266,6 +309,7 @@ def answer(
question: str,
history: ChatHistory,
list_files: list[QuivrKnowledge],
mem0_user_id: str | None = None,
metadata: dict[str, str] = {},
) -> ParsedRAGResponse:
"""
Expand All @@ -288,6 +332,7 @@ def answer(
],
"chat_history": history,
"files": concat_list_files,
"mem0_user_id": mem0_user_id,
}
raw_llm_response = conversational_qa_chain.invoke(
inputs,
Expand All @@ -303,6 +348,7 @@ async def answer_astream(
question: str,
history: ChatHistory,
list_files: list[QuivrKnowledge],
mem0_user_id: str | None = None,
metadata: dict[str, str] = {},
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
"""
Expand All @@ -324,6 +370,7 @@ async def answer_astream(
sources = []
prev_answer = ""
chunk_id = 0


async for event in conversational_qa_chain.astream_events(
{
Expand All @@ -332,6 +379,7 @@ async def answer_astream(
],
"chat_history": history,
"files": concat_list_files,
"mem0_user_id": mem0_user_id,
},
version="v1",
config={"metadata": metadata},
Expand Down
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dev-dependencies = [
]

[tool.rye.workspace]
members = [".", "core", "worker", "api", "docs", "core/examples/chatbot"]
members = [".", "core", "worker", "api", "docs", "core/examples/chatbot", "diff-assistant"]

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
Loading
Loading