Skip to content

Commit

Permalink
Add optional collection name field for chating (#283)
Browse files Browse the repository at this point in the history
Signed-off-by: Aisuko <[email protected]>
  • Loading branch information
Aisuko authored Jul 21, 2024
1 parent 0ca48f7 commit d5f44e8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
3 changes: 2 additions & 1 deletion backend/src/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ async def chat(
top_k=chat_in_msg.top_k,
top_p=chat_in_msg.top_p,
n_predict=chat_in_msg.n_predict,
collection_name=chat_in_msg.collection_name,
)
case _: # default is chat robot
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
stream_func: ContentStream = rag_chat_repo.inference(
session_id=session.id,
input_msg=chat_in_msg.message,
temperature=chat_in_msg.temperature,
Expand Down
17 changes: 17 additions & 0 deletions backend/src/models/schemas/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import datetime
from typing import Optional
from typing import Literal
Expand Down Expand Up @@ -33,6 +49,7 @@ class ChatInMessage(BaseSchemaModel):
top_k: int = Field(..., title="Top_k", description="Top_k parameter for inference(int)")
top_p: float = Field(..., title="Top_p", description="Top_p parameter for inference(float)")
n_predict: int = Field(..., title="n_predict", description="n_predict parameter for inference(int)")
collection_name: Optional[str] = Field(..., title="Collection Name", description="Collection Name")


class ChatInResponse(BaseSchemaModel):
Expand Down
4 changes: 3 additions & 1 deletion backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from src.repository.inference_eng import InferenceHelper
from src.utilities.httpkit.httpx_kit import httpx_kit
from src.repository.vector_database import vector_db
from src.config.settings.const import DEFAULT_COLLECTION


class RAGChatModelRepository(BaseRAGRepository):
Expand Down Expand Up @@ -112,6 +113,7 @@ async def inference_with_rag(
top_k: int = 40,
top_p: float = 0.9,
n_predict: int = 128,
collection_name: str = DEFAULT_COLLECTION,
) -> AsyncGenerator[Any, None]:
"""
Inference using RAG
Expand All @@ -137,7 +139,7 @@ async def get_context_by_question(input_msg: str):
except Exception as e:
loguru.logger.error(e)
# collection name for testing
context = vector_db.search(list(embedd_input), 1, collection_name="aisuko_squad01")
context = vector_db.search(list(embedd_input), 1, collection_name=collection_name)
return context or InferenceHelper.instruction

current_context = await get_context_by_question(input_msg)
Expand Down

0 comments on commit d5f44e8

Please sign in to comment.