From 98a8513bdc0407d427532d27863c765f4029e675 Mon Sep 17 00:00:00 2001 From: micost Date: Sun, 25 Aug 2024 18:59:11 +0800 Subject: [PATCH] fix/dataset load opimize Signed-off-by: micost --- backend/src/api/routes/chat.py | 29 ++++++++++++++++++-------- backend/src/api/routes/rag_datasets.py | 9 +++++--- backend/src/models/schemas/dataset.py | 1 + backend/src/repository/rag/chat.py | 3 +-- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/backend/src/api/routes/chat.py b/backend/src/api/routes/chat.py index 5a9878c..e591289 100644 --- a/backend/src/api/routes/chat.py +++ b/backend/src/api/routes/chat.py @@ -246,15 +246,26 @@ async def chat( ) match session.session_type: case "rag": - stream_func: ContentStream = rag_chat_repo.inference_with_rag( - session_uuid=session.session_uuid, - input_msg=chat_in_msg.message, - collection_name=session.dataset_name, - temperature=chat_in_msg.temperature, - top_k=chat_in_msg.top_k, - top_p=chat_in_msg.top_p, - n_predict=chat_in_msg.n_predict, - ) + # Verify dataset_name exist + if session.dataset_name is None: + stream_func: ContentStream = rag_chat_repo.inference( + session_uuid=session.session_uuid, + input_msg=chat_in_msg.message, + temperature=chat_in_msg.temperature, + top_k=chat_in_msg.top_k, + top_p=chat_in_msg.top_p, + n_predict=chat_in_msg.n_predict, + ) + else: + stream_func: ContentStream = rag_chat_repo.inference_with_rag( + session_uuid=session.session_uuid, + input_msg=chat_in_msg.message, + collection_name=session.dataset_name, + temperature=chat_in_msg.temperature, + top_k=chat_in_msg.top_k, + top_p=chat_in_msg.top_p, + n_predict=chat_in_msg.n_predict, + ) case _: # default is chat robot stream_func: ContentStream = rag_chat_repo.inference( session_uuid=session.session_uuid, diff --git a/backend/src/api/routes/rag_datasets.py b/backend/src/api/routes/rag_datasets.py index 4dfd380..ce6087f 100644 --- a/backend/src/api/routes/rag_datasets.py +++ b/backend/src/api/routes/rag_datasets.py @@ -124,6 +124,9 @@ async def load_dataset( current_user = account_repo.read_account_by_username(username=jwt_payload.username) # Here we don't use async because load_dataset is a sync function in HF ds # status: bool = True if DatasetEng.load_dataset(rag_ds_create.dataset_name).get("insert_count") > 0 else False + session = session_repo.read_create_sessions_by_uuid( + session_uuid=rag_ds_create.sessionUuid, account_id=current_user.id, name="new session" + ) try: # Here we use async because we need to update the session db DatasetEng.load_dataset(rag_ds_create.dataset_name) @@ -138,7 +141,7 @@ async def load_dataset( rag_ds_create.dataset_name ) session_repo.append_ds_name_to_session( - session_uuid=rag_ds_create.sessionUuid, + session_uuid=session.session_uuid, account_id=current_user.id, ds_name=table_name, # ds_name should be same as collectioname in vector db ) @@ -148,6 +151,6 @@ async def load_dataset( des="" )) case False: - return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, status=status) + return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, session_uuid=session.session_uuid, status=status) - return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, status=status) + return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, session_uuid=session.session_uuid, status=status) diff --git a/backend/src/models/schemas/dataset.py b/backend/src/models/schemas/dataset.py index 8ce3f36..27c2e7e 100644 --- a/backend/src/models/schemas/dataset.py +++ b/backend/src/models/schemas/dataset.py @@ -51,6 +51,7 @@ class RagDatasetResponse(BaseSchemaModel): class LoadRAGDSResponse(BaseSchemaModel): dataset_name: str = Field(..., title="DataSet Name", description="DataSet Name") status: bool = Field(default=False, title="Status", description="Status") + session_uuid: str = Field(..., title="Session UUID", description="Session UUID") # created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time") # updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time") # ratio: Optional[float] = Field(..., title="Ratio", description="Ratio") diff --git a/backend/src/repository/rag/chat.py b/backend/src/repository/rag/chat.py index 9447e2f..29a8adc 100644 --- a/backend/src/repository/rag/chat.py +++ b/backend/src/repository/rag/chat.py @@ -107,7 +107,6 @@ async def inference( async def inference_with_rag( self, - session_uuid: str, input_msg: str, collection_name: str, temperature: float = 0.2, @@ -146,7 +145,7 @@ async def get_context_by_question(input_msg: str): context = f"Please answer the question based on answer {context}" else: context = InferenceHelper.instruction - + loguru.logger.info(f"Context: {context}") return context current_context = await get_context_by_question(input_msg)