From dcec82cb8a8cbf069e188cb405b5d44a84484eff Mon Sep 17 00:00:00 2001 From: sudan45 Date: Fri, 27 Sep 2024 10:42:56 +0545 Subject: [PATCH] Manage user history --- chatbotcore/llm.py | 15 ++++----------- content/serializers.py | 1 + content/views.py | 2 +- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/chatbotcore/llm.py b/chatbotcore/llm.py index 1b47530..3eec209 100644 --- a/chatbotcore/llm.py +++ b/chatbotcore/llm.py @@ -121,21 +121,14 @@ def create_chain(self, db_collection_name: str): rag_chain = create_retrieval_chain(history_aware_retriever, chat_response_chain) return rag_chain -<<<<<<< HEAD def execute_chain(self, user_id: str, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME): -||||||| parent of efaacc1 (Integrate LLM service for user query) - def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME): - print("...........") -======= - def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME): ->>>>>>> efaacc1 (Integrate LLM service for user query) """ Executes the chain """ if not self.rag_chain: self.rag_chain = self.create_chain(db_collection_name=db_collection_name) - if "user_id" not in self.user_memory_mapping: + if user_id not in self.user_memory_mapping: self.user_memory_mapping[user_id] = ConversationBufferWindowMemory( k=self.conversation_max_window, memory_key=self.mem_key, return_messages=True ) @@ -155,13 +148,13 @@ def get_message_history(self, user_id: str): """ Returns the historical conversational data """ - if "user_id" in self.user_memory_mapping: + if user_id in self.user_memory_mapping: return self.user_memory_mapping[user_id].load_memory_variables({}) - return {} + return {"chat_history": []} def delete_message_history_by_user(self, user_id: str) -> bool: """Deletes the message history based on user id""" - if "user_id" in self.user_memory_mapping: + if user_id in self.user_memory_mapping: del self.user_memory_mapping[user_id] logger.info(f"Successfully delete the {user_id} conversational history.") return True diff --git a/content/serializers.py b/content/serializers.py index 292fe10..f4bbee6 100644 --- a/content/serializers.py +++ b/content/serializers.py @@ -3,3 +3,4 @@ class UserQuerySerializer(serializers.Serializer): query = serializers.CharField(required=True, allow_null=False, allow_blank=False) + user_id = serializers.UUIDField(required=True) diff --git a/content/views.py b/content/views.py index cdeed8e..7596e58 100644 --- a/content/views.py +++ b/content/views.py @@ -23,6 +23,6 @@ class UserQuery(GenericAPIView): def post(self, request, *arg, **kwargs): serializer = UserQuerySerializer(data=request.data) if serializer.is_valid(): - result = self.llm.execute_chain(request.data["query"]) + result = self.llm.execute_chain(request.data['user_id'], request.data["query"]) return Response(result) return Response(serializer.errors, 422)