Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maintain the historical chat conversation per user; #21

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions chatbot-core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,28 @@
class LLMBase:
"""LLM Base containing common methods"""

mem_key: str = field(init=False)
conversation_max_window: int = field(init=False)
qdrant_client: QdrantClient = field(init=False)
llm_model: Any = field(init=False)
user_memory_mapping: dict = field(init=False)
memory: Any = field(init=False)
embedding_model: CustomEmbeddingsWrapper = field(init=False)
rag_chain: Optional[Any] = None

def __post_init__(self, mem_key: str = "chat_history", conversation_max_window: int = 3):
self.llm_model = None
self.qdrant_client = None
self.memory = None

self.mem_key = mem_key
self.conversation_max_window = conversation_max_window

try:
self.qdrant_client = QdrantClient(host=settings.QDRANT_DB_HOST, port=settings.QDRANT_DB_PORT)
except Exception as e:
raise Exception(f"Qdrant client is not properly setup. {str(e)}")
self.memory = ConversationBufferWindowMemory(k=conversation_max_window, memory_key=mem_key, return_messages=True)

self.user_memory_mapping = {}

self.embedding_model = CustomEmbeddingsWrapper(
url=settings.EMBEDDING_MODEL_URL,
Expand Down Expand Up @@ -113,23 +119,44 @@ def create_chain(self, db_collection_name: str):
rag_chain = create_retrieval_chain(history_aware_retriever, chat_response_chain)
return rag_chain

def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
def execute_chain(self, user_id: str, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
"""
Executes the chain
"""
if not self.rag_chain:
self.rag_chain = self.create_chain(db_collection_name=db_collection_name)

response = self.rag_chain.invoke({"input": query, "chat_history": self.get_message_history()["chat_history"]})
self.memory.chat_memory.add_message(HumanMessage(content=query))
self.memory.chat_memory.add_message(AIMessage(content=response["answer"]))
if "user_id" not in self.user_memory_mapping:
self.user_memory_mapping[user_id] = ConversationBufferWindowMemory(
k=self.conversation_max_window, memory_key=self.mem_key, return_messages=True
)

memory = self.user_memory_mapping[user_id]

response = self.rag_chain.invoke(
{"input": query, "chat_history": self.get_message_history(user_id=user_id)["chat_history"]}
)
memory.chat_memory.add_message(HumanMessage(content=query))
memory.chat_memory.add_message(AIMessage(content=response["answer"]))
self.user_memory_mapping[user_id] = memory

return response["answer"] if "answer" in response else ""

def get_message_history(self):
def get_message_history(self, user_id: str):
"""
Returns the historical conversational data
"""
return self.memory.load_memory_variables({})
if "user_id" in self.user_memory_mapping:
return self.user_memory_mapping[user_id].load_memory_variables({})
return {}

def delete_message_history_by_user(self, user_id: str) -> bool:
"""Deletes the message history based on user id"""
if "user_id" in self.user_memory_mapping:
del self.user_memory_mapping[user_id]
logger.info(f"Successfully delete the {user_id} conversational history.")
return True
return False


@dataclass
Expand Down
Loading