From 2347b0554d0570eaf82aec05a3de7f7412e37104 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Mon, 22 Jul 2024 21:32:06 +1000 Subject: [PATCH] Fix the key word type was used as session type (#304) * Fix the key word type was used as session type Signed-off-by: Aisuko * Add TODO and userid to dataset table Signed-off-by: Aisuko --------- Signed-off-by: Aisuko --- backend/src/api/routes/chat.py | 10 +++++---- backend/src/api/routes/rag_datasets.py | 31 ++++++++++++++++++++++---- backend/src/config/events.py | 15 +++++++++++++ backend/src/config/settings/const.py | 5 ----- backend/src/models/db/chat.py | 18 ++++++++++++++- backend/src/models/db/dataset.py | 1 + backend/src/models/schemas/chat.py | 7 ++++-- backend/src/repository/crud/chat.py | 20 +++++++++++++++-- 8 files changed, 89 insertions(+), 18 deletions(-) diff --git a/backend/src/api/routes/chat.py b/backend/src/api/routes/chat.py index 0533532..0ff619e 100644 --- a/backend/src/api/routes/chat.py +++ b/backend/src/api/routes/chat.py @@ -131,7 +131,7 @@ async def chat_uuid( "", name="chat:chatbot", response_model=ChatInResponse, - status_code=fastapi.status.HTTP_201_CREATED, + status_code=fastapi.status.HTTP_200_OK, ) async def chat( chat_in_msg: ChatInMessage, @@ -196,7 +196,7 @@ async def chat( session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id, name=chat_in_msg.message[:20] ) - match session.type: + match session.session_type: case "rag": stream_func: ContentStream = rag_chat_repo.inference_with_rag( session_id=session.id, @@ -247,7 +247,7 @@ async def get_session( res_session = Session( sessionUuid=session.uuid, name=session.name, - type=session.type, + session_type=session.session_type, created_at=session.created_at, ) sessions_list.append(res_session) @@ -392,7 +392,9 @@ async def save_chats( """ current_user = await account_repo.read_account_by_username(username=jwt_payload.username) if ( - await session_repo.verify_session_by_account_id(session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id) + await session_repo.verify_session_by_account_id( + session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id + ) is False ): raise http_404_exc_uuid_not_found_request(uuid=chat_in_msg.sessionUuid) diff --git a/backend/src/api/routes/rag_datasets.py b/backend/src/api/routes/rag_datasets.py index e611194..72617b6 100644 --- a/backend/src/api/routes/rag_datasets.py +++ b/backend/src/api/routes/rag_datasets.py @@ -14,11 +14,20 @@ # limitations under the License. import fastapi +from fastapi.security import OAuth2PasswordBearer + +from src.api.dependencies.repository import get_repository from src.models.schemas.dataset import RagDatasetCreate, RagDatasetResponse from src.repository.rag_datasets_eng import DatasetEng +from src.repository.crud.account import AccountCRUDRepository +from src.securities.authorizations.jwt import jwt_required +from src.repository.crud.chat import SessionCRUDRepository + router = fastapi.APIRouter(prefix="/ds", tags=["datasets"]) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/verify") + @router.get( path="/list", @@ -28,7 +37,7 @@ ) async def get_dataset_list() -> list[RagDatasetResponse]: """ - Waiting for implementing + Get all the dataset list by using user's ID from pg """ pass @@ -42,7 +51,7 @@ async def get_dataset_list() -> list[RagDatasetResponse]: ) async def get_dataset_by_name(name: str) -> RagDatasetResponse: """ - Waiting for implementing + Get the dataset by using the dataset name and user's ID from pg """ pass @@ -51,14 +60,22 @@ async def get_dataset_by_name(name: str) -> RagDatasetResponse: path="/load", name="datasets:load-dataset", response_model=RagDatasetResponse, - status_code=fastapi.status.HTTP_201_CREATED, + status_code=fastapi.status.HTTP_200_OK, ) async def load_dataset( rag_ds_create: RagDatasetCreate, + token: str = fastapi.Depends(oauth2_scheme), + session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)), + account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> RagDatasetResponse: """ + TODO: need to update - Loading the specific dataset into the vector db + Loading the specific dataset into the vector db. However here are some requirements: + * The dataset should be in the format of the RAG dataset. And we define the RAG dataset. + * Anonymous user can't load the dataset. The user should be authenticated. + * The dataset related to the specific user's specific session. curl -X 'POST' \ 'http://127.0.0.1:8000/api/ds/load' \ @@ -77,6 +94,7 @@ async def load_dataset( } """ + # TODO: we can't get session when loading dataset res: dict = DatasetEng.load_dataset(rag_ds_create.name) if res.get("insert_count") > 0: @@ -84,4 +102,9 @@ async def load_dataset( else: status = "Failed" + # TODO: Save the ds to the db + + # TODO: save dataset name to the session + + # TODO If we bounding ds to specific user's session, we should upadte ds name to the session and return the session return RagDatasetResponse(name=rag_ds_create.name, status=status) diff --git a/backend/src/config/events.py b/backend/src/config/events.py index 8074e82..849b694 100644 --- a/backend/src/config/events.py +++ b/backend/src/config/events.py @@ -1,3 +1,18 @@ +# coding=utf-8 + +# Copyright [2024] [SkywardAI] +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import typing import fastapi diff --git a/backend/src/config/settings/const.py b/backend/src/config/settings/const.py index 93e4a1c..cd70056 100644 --- a/backend/src/config/settings/const.py +++ b/backend/src/config/settings/const.py @@ -7,11 +7,6 @@ DEFAULT_DIM = 384 # DEFAULT MODELS -DEFAULT_ENCODER = "sentence-transformers/all-MiniLM-L6-v2" -CROSS_ENDOCDER = "cross-encoder/ms-marco-MiniLM-L-6-v2" -DEFAULT_MODEL = "microsoft/GODEL-v1_1-base-seq2seq" -DEFAUTL_SUMMERIZE_MODEL = "Falconsai/text_summarization" -DEFAULT_MODEL_PATH = "/models/" # CONVERSATION CONVERSATION_INACTIVE_SEC = 300 RAG_NUM = 5 diff --git a/backend/src/models/db/chat.py b/backend/src/models/db/chat.py index f9cf616..3290d76 100644 --- a/backend/src/models/db/chat.py +++ b/backend/src/models/db/chat.py @@ -1,3 +1,18 @@ +# coding=utf-8 + +# Copyright [2024] [SkywardAI] +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import datetime import sqlalchemy import uuid @@ -17,9 +32,10 @@ class Session(Base): # type: ignore ) account_id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(nullable=True) name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=64), nullable=True) - type: SQLAlchemyMapped[str] = sqlalchemy_mapped_column( + session_type: SQLAlchemyMapped[str] = sqlalchemy_mapped_column( sqlalchemy.Enum("rag", "chat", name="session_type"), nullable=False, default="chat" ) + dataset_name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=256), nullable=True) created_at: SQLAlchemyMapped[datetime.datetime] = sqlalchemy_mapped_column( sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy_functions.now() ) diff --git a/backend/src/models/db/dataset.py b/backend/src/models/db/dataset.py index ec74a76..a5f2369 100644 --- a/backend/src/models/db/dataset.py +++ b/backend/src/models/db/dataset.py @@ -12,6 +12,7 @@ class DataSet(Base): # type: ignore id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(primary_key=True, autoincrement="auto") name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=64), nullable=False, unique=True) + account_id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(sqlalchemy.Integer) created_at: SQLAlchemyMapped[datetime.datetime] = sqlalchemy_mapped_column( sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy_functions.now() ) diff --git a/backend/src/models/schemas/chat.py b/backend/src/models/schemas/chat.py index 68dacb3..afb21ac 100644 --- a/backend/src/models/schemas/chat.py +++ b/backend/src/models/schemas/chat.py @@ -87,13 +87,16 @@ class SessionUpdate(BaseSchemaModel): sessionUuid: str = Field(..., title="Session UUID", description="Session UUID") name: Optional[str] = Field(default=None, title="Name", description="Name") - type: Optional[Literal["rag", "chat"]] = Field(default=None, title="Type", description="Type") + session_type: Optional[Literal["rag", "chat"]] = Field( + default=None, title="Session Type", description="Type of current session" + ) class Session(BaseSchemaModel): sessionUuid: str = Field(..., title="Session UUID", description="Session UUID") name: str | None = Field(..., title="Name", description="Name") - type: str | None = Field(..., title="Type", description="Type") + session_type: str | None = Field(..., title="Session Type", description="Type of current session") + dataset_name: str | None = Field(default=None, title="Dataset Name", description="Dataset Name") created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time") diff --git a/backend/src/repository/crud/chat.py b/backend/src/repository/crud/chat.py index 1481018..67571e9 100644 --- a/backend/src/repository/crud/chat.py +++ b/backend/src/repository/crud/chat.py @@ -1,3 +1,18 @@ +# coding=utf-8 + +# Copyright [2024] [SkywardAI] +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import typing from typing import Optional import loguru @@ -44,11 +59,12 @@ async def update_sessions_by_uuid(self, session: SessionUpdate, account_id: int) .where(Session.uuid == session.sessionUuid) .values(updated_at=sqlalchemy_functions.now()) ) # type: ignore + if session.name: update_stmt = update_stmt.values(name=session.name) - if session.type: - update_stmt = update_stmt.values(type=session.type) + if session.session_type: + update_stmt = update_stmt.values(session_type=session.session_type) await self.async_session.execute(statement=update_stmt) await self.async_session.commit() await self.async_session.refresh(instance=update_session)