Skip to content

Commit

Permalink
Fix the key word type was used as session type (#304)
Browse files Browse the repository at this point in the history
* Fix the key word type was used as session type

Signed-off-by: Aisuko <[email protected]>

* Add TODO and userid to dataset table

Signed-off-by: Aisuko <[email protected]>

---------

Signed-off-by: Aisuko <[email protected]>
  • Loading branch information
Aisuko authored Jul 22, 2024
1 parent a3fe0a5 commit 2347b05
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 18 deletions.
10 changes: 6 additions & 4 deletions backend/src/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def chat_uuid(
"",
name="chat:chatbot",
response_model=ChatInResponse,
status_code=fastapi.status.HTTP_201_CREATED,
status_code=fastapi.status.HTTP_200_OK,
)
async def chat(
chat_in_msg: ChatInMessage,
Expand Down Expand Up @@ -196,7 +196,7 @@ async def chat(
session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id, name=chat_in_msg.message[:20]
)

match session.type:
match session.session_type:
case "rag":
stream_func: ContentStream = rag_chat_repo.inference_with_rag(
session_id=session.id,
Expand Down Expand Up @@ -247,7 +247,7 @@ async def get_session(
res_session = Session(
sessionUuid=session.uuid,
name=session.name,
type=session.type,
session_type=session.session_type,
created_at=session.created_at,
)
sessions_list.append(res_session)
Expand Down Expand Up @@ -392,7 +392,9 @@ async def save_chats(
"""
current_user = await account_repo.read_account_by_username(username=jwt_payload.username)
if (
await session_repo.verify_session_by_account_id(session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id)
await session_repo.verify_session_by_account_id(
session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id
)
is False
):
raise http_404_exc_uuid_not_found_request(uuid=chat_in_msg.sessionUuid)
Expand Down
31 changes: 27 additions & 4 deletions backend/src/api/routes/rag_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,20 @@
# limitations under the License.

import fastapi
from fastapi.security import OAuth2PasswordBearer

from src.api.dependencies.repository import get_repository
from src.models.schemas.dataset import RagDatasetCreate, RagDatasetResponse
from src.repository.rag_datasets_eng import DatasetEng
from src.repository.crud.account import AccountCRUDRepository
from src.securities.authorizations.jwt import jwt_required
from src.repository.crud.chat import SessionCRUDRepository


router = fastapi.APIRouter(prefix="/ds", tags=["datasets"])

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/verify")


@router.get(
path="/list",
Expand All @@ -28,7 +37,7 @@
)
async def get_dataset_list() -> list[RagDatasetResponse]:
"""
Waiting for implementing
Get all the dataset list by using user's ID from pg
"""
pass
Expand All @@ -42,7 +51,7 @@ async def get_dataset_list() -> list[RagDatasetResponse]:
)
async def get_dataset_by_name(name: str) -> RagDatasetResponse:
"""
Waiting for implementing
Get the dataset by using the dataset name and user's ID from pg
"""
pass

Expand All @@ -51,14 +60,22 @@ async def get_dataset_by_name(name: str) -> RagDatasetResponse:
path="/load",
name="datasets:load-dataset",
response_model=RagDatasetResponse,
status_code=fastapi.status.HTTP_201_CREATED,
status_code=fastapi.status.HTTP_200_OK,
)
async def load_dataset(
rag_ds_create: RagDatasetCreate,
token: str = fastapi.Depends(oauth2_scheme),
session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)),
account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)),
jwt_payload: dict = fastapi.Depends(jwt_required),
) -> RagDatasetResponse:
"""
TODO: need to update
Loading the specific dataset into the vector db
Loading the specific dataset into the vector db. However here are some requirements:
* The dataset should be in the format of the RAG dataset. And we define the RAG dataset.
* Anonymous user can't load the dataset. The user should be authenticated.
* The dataset related to the specific user's specific session.
curl -X 'POST' \
'http://127.0.0.1:8000/api/ds/load' \
Expand All @@ -77,11 +94,17 @@ async def load_dataset(
}
"""

# TODO: we can't get session when loading dataset
res: dict = DatasetEng.load_dataset(rag_ds_create.name)

if res.get("insert_count") > 0:
status = "Success"
else:
status = "Failed"

# TODO: Save the ds to the db

# TODO: save dataset name to the session

# TODO If we bounding ds to specific user's session, we should upadte ds name to the session and return the session
return RagDatasetResponse(name=rag_ds_create.name, status=status)
15 changes: 15 additions & 0 deletions backend/src/config/events.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

import fastapi
Expand Down
5 changes: 0 additions & 5 deletions backend/src/config/settings/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
DEFAULT_DIM = 384

# DEFAULT MODELS
DEFAULT_ENCODER = "sentence-transformers/all-MiniLM-L6-v2"
CROSS_ENDOCDER = "cross-encoder/ms-marco-MiniLM-L-6-v2"
DEFAULT_MODEL = "microsoft/GODEL-v1_1-base-seq2seq"
DEFAUTL_SUMMERIZE_MODEL = "Falconsai/text_summarization"
DEFAULT_MODEL_PATH = "/models/"
# CONVERSATION
CONVERSATION_INACTIVE_SEC = 300
RAG_NUM = 5
Expand Down
18 changes: 17 additions & 1 deletion backend/src/models/db/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import sqlalchemy
import uuid
Expand All @@ -17,9 +32,10 @@ class Session(Base): # type: ignore
)
account_id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(nullable=True)
name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=64), nullable=True)
type: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(
session_type: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(
sqlalchemy.Enum("rag", "chat", name="session_type"), nullable=False, default="chat"
)
dataset_name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=256), nullable=True)
created_at: SQLAlchemyMapped[datetime.datetime] = sqlalchemy_mapped_column(
sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy_functions.now()
)
Expand Down
1 change: 1 addition & 0 deletions backend/src/models/db/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class DataSet(Base): # type: ignore

id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(primary_key=True, autoincrement="auto")
name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=64), nullable=False, unique=True)
account_id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(sqlalchemy.Integer)
created_at: SQLAlchemyMapped[datetime.datetime] = sqlalchemy_mapped_column(
sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy_functions.now()
)
Expand Down
7 changes: 5 additions & 2 deletions backend/src/models/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,16 @@ class SessionUpdate(BaseSchemaModel):

sessionUuid: str = Field(..., title="Session UUID", description="Session UUID")
name: Optional[str] = Field(default=None, title="Name", description="Name")
type: Optional[Literal["rag", "chat"]] = Field(default=None, title="Type", description="Type")
session_type: Optional[Literal["rag", "chat"]] = Field(
default=None, title="Session Type", description="Type of current session"
)


class Session(BaseSchemaModel):
sessionUuid: str = Field(..., title="Session UUID", description="Session UUID")
name: str | None = Field(..., title="Name", description="Name")
type: str | None = Field(..., title="Type", description="Type")
session_type: str | None = Field(..., title="Session Type", description="Type of current session")
dataset_name: str | None = Field(default=None, title="Dataset Name", description="Dataset Name")
created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time")


Expand Down
20 changes: 18 additions & 2 deletions backend/src/repository/crud/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# coding=utf-8

# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing
from typing import Optional
import loguru
Expand Down Expand Up @@ -44,11 +59,12 @@ async def update_sessions_by_uuid(self, session: SessionUpdate, account_id: int)
.where(Session.uuid == session.sessionUuid)
.values(updated_at=sqlalchemy_functions.now())
) # type: ignore

if session.name:
update_stmt = update_stmt.values(name=session.name)

if session.type:
update_stmt = update_stmt.values(type=session.type)
if session.session_type:
update_stmt = update_stmt.values(session_type=session.session_type)
await self.async_session.execute(statement=update_stmt)
await self.async_session.commit()
await self.async_session.refresh(instance=update_session)
Expand Down

0 comments on commit 2347b05

Please sign in to comment.