Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/dataset load opimize #373

Merged
merged 1 commit into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions backend/src/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,26 @@ async def chat(
)
match session.session_type:
case "rag":
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_uuid=session.session_uuid,
input_msg=chat_in_msg.message,
collection_name=session.dataset_name,
temperature=chat_in_msg.temperature,
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
)
# Verify dataset_name exist
if session.dataset_name is None:
stream_func: ContentStream = rag_chat_repo.inference(
session_uuid=session.session_uuid,
input_msg=chat_in_msg.message,
temperature=chat_in_msg.temperature,
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
)
else:
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_uuid=session.session_uuid,
input_msg=chat_in_msg.message,
collection_name=session.dataset_name,
temperature=chat_in_msg.temperature,
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
)
case _: # default is chat robot
stream_func: ContentStream = rag_chat_repo.inference(
session_uuid=session.session_uuid,
Expand Down
9 changes: 6 additions & 3 deletions backend/src/api/routes/rag_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ async def load_dataset(
current_user = account_repo.read_account_by_username(username=jwt_payload.username)
# Here we don't use async because load_dataset is a sync function in HF ds
# status: bool = True if DatasetEng.load_dataset(rag_ds_create.dataset_name).get("insert_count") > 0 else False
session = session_repo.read_create_sessions_by_uuid(
session_uuid=rag_ds_create.sessionUuid, account_id=current_user.id, name="new session"
)
try:
# Here we use async because we need to update the session db
DatasetEng.load_dataset(rag_ds_create.dataset_name)
Expand All @@ -138,7 +141,7 @@ async def load_dataset(
rag_ds_create.dataset_name
)
session_repo.append_ds_name_to_session(
session_uuid=rag_ds_create.sessionUuid,
session_uuid=session.session_uuid,
account_id=current_user.id,
ds_name=table_name, # ds_name should be same as collectioname in vector db
)
Expand All @@ -148,6 +151,6 @@ async def load_dataset(
des=""
))
case False:
return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, status=status)
return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, session_uuid=session.session_uuid, status=status)

return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, status=status)
return LoadRAGDSResponse(dataset_name=rag_ds_create.dataset_name, session_uuid=session.session_uuid, status=status)
1 change: 1 addition & 0 deletions backend/src/models/schemas/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class RagDatasetResponse(BaseSchemaModel):
class LoadRAGDSResponse(BaseSchemaModel):
dataset_name: str = Field(..., title="DataSet Name", description="DataSet Name")
status: bool = Field(default=False, title="Status", description="Status")
session_uuid: str = Field(..., title="Session UUID", description="Session UUID")
# created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time")
# updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time")
# ratio: Optional[float] = Field(..., title="Ratio", description="Ratio")
Expand Down
3 changes: 1 addition & 2 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ async def inference(

async def inference_with_rag(
self,
session_uuid: str,
input_msg: str,
collection_name: str,
temperature: float = 0.2,
Expand Down Expand Up @@ -146,7 +145,7 @@ async def get_context_by_question(input_msg: str):
context = f"Please answer the question based on answer {context}"
else:
context = InferenceHelper.instruction

loguru.logger.info(f"Context: {context}")
return context

current_context = await get_context_by_question(input_msg)
Expand Down
Loading