Skip to content

Commit

Permalink
fix/dataset load opimize
Browse files Browse the repository at this point in the history
Signed-off-by: micost <[email protected]>
  • Loading branch information
Micost committed Aug 25, 2024
1 parent 213c27c commit 98a8513
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
29 changes: 20 additions & 9 deletions backend/src/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,26 @@ async def chat(
)
match session.session_type:
case "rag":
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_uuid=session.session_uuid,
input_msg=chat_in_msg.message,
collection_name=session.dataset_name,
temperature=chat_in_msg.temperature,
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
)
# Verify dataset_name exist
if session.dataset_name is None:
stream_func: ContentStream = rag_chat_repo.inference(
session_uuid=session.session_uuid,
input_msg=chat_in_msg.message,
temperature=chat_in_msg.temperature,
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
)
else:
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_uuid=session.session_uuid,
input_msg=chat_in_msg.message,
collection_name=session.dataset_name,
temperature=chat_in_msg.temperature,
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
)
case _: # default is chat robot
stream_func: ContentStream = rag_chat_repo.inference(
session_uuid=session.session_uuid,
Expand Down
9 changes: 6 additions & 3 deletions backend/src/api/routes/rag_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ async def load_dataset(
current_user = account_repo.read_account_by_username(username=jwt_payload.username)
# Here we don't use async because load_dataset is a sync function in HF ds
# status: bool = True if DatasetEng.load_dataset(rag_ds_create.dataset_name).get("insert_count") > 0 else False
session = session_repo.read_create_sessions_by_uuid(
session_uuid=rag_ds_create.sessionUuid, account_id=current_user.id, name="new session"
)
try:
# Here we use async because we need to update the session db
DatasetEng.load_dataset(rag_ds_create.dataset_name)
Expand All @@ -138,7 +141,7 @@ async def load_dataset(
rag_ds_create.dataset_name
)
session_repo.append_ds_name_to_session(
session_uuid=rag_ds_create.sessionUuid,
session_uuid=session.session_uuid,
account_id=current_user.id,
ds_name=table_name, # ds_name should be same as collectioname in vector db
)
Expand All @@ -148,6 +151,6 @@ async def load_dataset(
des=""
))
case False:
return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, status=status)
return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, session_uuid=session.session_uuid, status=status)

return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, status=status)
return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, session_uuid=session.session_uuid, status=status)
1 change: 1 addition & 0 deletions backend/src/models/schemas/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class RagDatasetResponse(BaseSchemaModel):
class LoadRAGDSResponse(BaseSchemaModel):
dataset_name: str = Field(..., title="DataSet Name", description="DataSet Name")
status: bool = Field(default=False, title="Status", description="Status")
session_uuid: str = Field(..., title="Session UUID", description="Session UUID")
# created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time")
# updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time")
# ratio: Optional[float] = Field(..., title="Ratio", description="Ratio")
Expand Down
3 changes: 1 addition & 2 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ async def inference(

async def inference_with_rag(
self,
session_uuid: str,
input_msg: str,
collection_name: str,
temperature: float = 0.2,
Expand Down Expand Up @@ -146,7 +145,7 @@ async def get_context_by_question(input_msg: str):
context = f"Please answer the question based on answer {context}"
else:
context = InferenceHelper.instruction

loguru.logger.info(f"Context: {context}")
return context

current_context = await get_context_by_question(input_msg)
Expand Down

0 comments on commit 98a8513

Please sign in to comment.