Skip to content

Commit

Permalink
Manage user history
Browse files Browse the repository at this point in the history
  • Loading branch information
sudan45 committed Sep 27, 2024
1 parent c9abc28 commit dcec82c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 12 deletions.
15 changes: 4 additions & 11 deletions chatbotcore/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,14 @@ def create_chain(self, db_collection_name: str):
rag_chain = create_retrieval_chain(history_aware_retriever, chat_response_chain)
return rag_chain

<<<<<<< HEAD
def execute_chain(self, user_id: str, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
||||||| parent of efaacc1 (Integrate LLM service for user query)
def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
print("...........")
=======
def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
>>>>>>> efaacc1 (Integrate LLM service for user query)
"""
Executes the chain
"""
if not self.rag_chain:
self.rag_chain = self.create_chain(db_collection_name=db_collection_name)

if "user_id" not in self.user_memory_mapping:
if user_id not in self.user_memory_mapping:
self.user_memory_mapping[user_id] = ConversationBufferWindowMemory(
k=self.conversation_max_window, memory_key=self.mem_key, return_messages=True
)
Expand All @@ -155,13 +148,13 @@ def get_message_history(self, user_id: str):
"""
Returns the historical conversational data
"""
if "user_id" in self.user_memory_mapping:
if user_id in self.user_memory_mapping:
return self.user_memory_mapping[user_id].load_memory_variables({})
return {}
return {"chat_history": []}

def delete_message_history_by_user(self, user_id: str) -> bool:
"""Deletes the message history based on user id"""
if "user_id" in self.user_memory_mapping:
if user_id in self.user_memory_mapping:
del self.user_memory_mapping[user_id]
logger.info(f"Successfully delete the {user_id} conversational history.")
return True
Expand Down
1 change: 1 addition & 0 deletions content/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

class UserQuerySerializer(serializers.Serializer):
query = serializers.CharField(required=True, allow_null=False, allow_blank=False)
user_id = serializers.UUIDField(required=True)
2 changes: 1 addition & 1 deletion content/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ class UserQuery(GenericAPIView):
def post(self, request, *arg, **kwargs):
serializer = UserQuerySerializer(data=request.data)
if serializer.is_valid():
result = self.llm.execute_chain(request.data["query"])
result = self.llm.execute_chain(request.data['user_id'], request.data["query"])
return Response(result)
return Response(serializer.errors, 422)

0 comments on commit dcec82c

Please sign in to comment.