diff --git a/backend/src/api/routes/account.py b/backend/src/api/routes/account.py index 7578768..d65a791 100644 --- a/backend/src/api/routes/account.py +++ b/backend/src/api/routes/account.py @@ -22,12 +22,16 @@ from src.repository.crud.account import AccountCRUDRepository from src.securities.authorizations.jwt import jwt_generator, jwt_required from src.utilities.exceptions.database import EntityDoesNotExist -from src.utilities.exceptions.http.exc_404 import http_404_exc_id_not_found_request, http_404_exc_username_not_found_request +from src.utilities.exceptions.http.exc_404 import ( + http_404_exc_id_not_found_request, + http_404_exc_username_not_found_request, +) from src.utilities.exceptions.http.exc_401 import http_exc_401_cunauthorized_request router = fastapi.APIRouter(prefix="/accounts", tags=["accounts"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/verify") + @router.get( path="", name="accounts:read-accounts", @@ -37,7 +41,7 @@ async def get_accounts( token: str = fastapi.Depends(oauth2_scheme), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> list[AccountInResponse]: """ Get a list of accounts @@ -99,7 +103,7 @@ async def get_account( id: int, token: str = fastapi.Depends(oauth2_scheme), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> AccountInResponse: """ Get an account by id @@ -151,6 +155,7 @@ async def get_account( ), ) + @router.patch( path="", name="accounts:update-current-account", @@ -161,7 +166,7 @@ async def update_account( token: str = fastapi.Depends(oauth2_scheme), account_update: AccountInUpdate = fastapi.Body(...), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> AccountInResponse: """ update current account info @@ -226,6 +231,7 @@ async def update_account( ), ) + @router.patch( path="/{id}", name="accounts:update-account-by-id", @@ -237,7 +243,7 @@ async def update_account_by_admin( token: str = fastapi.Depends(oauth2_scheme), account_update: AccountInUpdate = fastapi.Body(...), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> AccountInResponse: """ update account info by account id @@ -301,9 +307,10 @@ async def update_account_by_admin( @router.delete(path="", name="accounts:delete-account-by-id", status_code=fastapi.status.HTTP_200_OK) async def delete_account( - id: int, account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), + id: int, + account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), token: str = fastapi.Depends(oauth2_scheme), - jwt_payload: dict = fastapi.Depends(jwt_required) + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> dict[str, str]: """ Delete an account by id diff --git a/backend/src/api/routes/ai_model.py b/backend/src/api/routes/ai_model.py index 13df16f..fa4c129 100644 --- a/backend/src/api/routes/ai_model.py +++ b/backend/src/api/routes/ai_model.py @@ -18,7 +18,7 @@ import fastapi from src.api.dependencies.repository import get_rag_repository, get_repository -from src.models.schemas.ai_model import AiModelCreate, AiModelChooseResponse, AiModelInResponse,AiModelCreateResponse +from src.models.schemas.ai_model import AiModelCreate, AiModelChooseResponse, AiModelInResponse, AiModelCreateResponse from src.repository.crud.ai_model import AiModelCRUDRepository from src.repository.rag.chat import RAGChatModelRepository @@ -137,8 +137,4 @@ async def create_ai_model( raise EntityDoesNotExist(f"AiModel with id `{ai_model.name}` alread exist!") ai_model = await aimodel_repo.create_aimodel(aimodel_create=req_model) - return AiModelCreateResponse( - id=ai_model.id, - name=ai_model.name, - des=ai_model.des - ) + return AiModelCreateResponse(id=ai_model.id, name=ai_model.name, des=ai_model.des) diff --git a/backend/src/api/routes/authentication.py b/backend/src/api/routes/authentication.py index aa7c8f7..e4faeba 100644 --- a/backend/src/api/routes/authentication.py +++ b/backend/src/api/routes/authentication.py @@ -17,7 +17,7 @@ from typing import Annotated from src.api.dependencies.repository import get_repository from src.models.schemas.account import AccountInCreate, AccountInLogin, AccountInResponse, AccountWithToken -from src.config.settings.const import ANONYMOUS_USER,ANONYMOUS_PASS +from src.config.settings.const import ANONYMOUS_USER, ANONYMOUS_PASS from src.repository.crud.account import AccountCRUDRepository from src.securities.authorizations.jwt import jwt_generator from src.utilities.exceptions.database import EntityAlreadyExists @@ -29,6 +29,7 @@ router = fastapi.APIRouter(prefix="/auth", tags=["authentication"]) + @router.post( "/signup", name="auth:signup", @@ -43,9 +44,9 @@ async def signup( Create a new account ```bash - curl -X 'POST' 'http://127.0.0.1:8000/api/auth/signup' - -H 'accept: application/json' - -H 'Content-Type: application/json' + curl -X 'POST' 'http://127.0.0.1:8000/api/auth/signup' + -H 'accept: application/json' + -H 'Content-Type: application/json' -d '{"username": "aisuko", "email": "aisuko@example.com", "password": "aisuko"}' ``` @@ -62,7 +63,6 @@ async def signup( - **updated_at**: The update time """ - try: await account_repo.is_username_taken(username=account_create.username) await account_repo.is_email_taken(email=account_create.email) @@ -123,7 +123,7 @@ async def signin( if account_login.username == ANONYMOUS_USER: raise await http_exc_400_credentials_bad_signin_request() - + try: db_account = await account_repo.read_user_by_password_authentication(account_login=account_login) @@ -146,6 +146,7 @@ async def signin( ), ) + @router.get( path="/token", name="authentication: token for anonymous user", @@ -153,7 +154,7 @@ async def signin( status_code=fastapi.status.HTTP_200_OK, ) async def get_token( - account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)) + account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), ) -> dict: """ Get chat history for an anonymous user @@ -171,6 +172,7 @@ async def get_token( return {"token": access_token} + @router.post("/verify") async def login_for_access_token( form_data: Annotated[fastapi.security.OAuth2PasswordRequestForm, fastapi.Depends()], @@ -191,8 +193,9 @@ async def login_for_access_token( - **token_type**: The token type """ try: - db_account= await account_repo.read_user_by_password_authentication( - account_login=AccountInLogin(username=form_data.username,password=form_data.password)) + db_account = await account_repo.read_user_by_password_authentication( + account_login=AccountInLogin(username=form_data.username, password=form_data.password) + ) except Exception: raise await http_exc_400_failed_validate_request() access_token = jwt_generator.generate_access_token(account=db_account) diff --git a/backend/src/api/routes/chat.py b/backend/src/api/routes/chat.py index 0264733..07bc79f 100644 --- a/backend/src/api/routes/chat.py +++ b/backend/src/api/routes/chat.py @@ -1,7 +1,23 @@ +# coding=utf-8 + +# Copyright [2024] [SkywardAI] +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import fastapi import loguru from fastapi.security import OAuth2PasswordBearer from fastapi.responses import StreamingResponse +from starlette.responses import ContentStream from src.api.dependencies.repository import get_rag_repository, get_repository from src.securities.authorizations.jwt import jwt_required from src.utilities.exceptions.database import EntityDoesNotExist @@ -9,17 +25,14 @@ from src.config.settings.const import ANONYMOUS_USER from src.models.schemas.chat import ( ChatsWithTime, - ChatInMessage, - ChatInResponse, + ChatInMessage, + ChatInResponse, SessionUpdate, Session, ChatUUIDResponse, - SaveChatHistory - ) -from src.repository.crud.chat import ( - ChatHistoryCRUDRepository, - SessionCRUDRepository - ) + SaveChatHistory, +) +from src.repository.crud.chat import ChatHistoryCRUDRepository, SessionCRUDRepository from src.repository.crud.account import AccountCRUDRepository from src.repository.rag.chat import RAGChatModelRepository @@ -39,7 +52,7 @@ async def update_session( session_info: SessionUpdate = fastapi.Body(...), session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> ChatUUIDResponse: """ update session info by session uuid @@ -73,24 +86,21 @@ async def update_session( except EntityDoesNotExist: raise await http_404_exc_uuid_not_found_request(uuid=session_info.sessionUuid) - return ChatUUIDResponse( - sessionUuid=sessions.uuid - ) - + return ChatUUIDResponse(sessionUuid=sessions.uuid) @router.get( - "/seesionuuid", - name="chat:session-uuid", - response_model=ChatUUIDResponse, - status_code=fastapi.status.HTTP_201_CREATED, + "/seesionuuid", + name="chat:session-uuid", + response_model=ChatUUIDResponse, + status_code=fastapi.status.HTTP_201_CREATED, ) async def chat_uuid( token: str = fastapi.Depends(oauth2_scheme), session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) -)->ChatUUIDResponse: + jwt_payload: dict = fastapi.Depends(jwt_required), +) -> ChatUUIDResponse: """ Create a new session for the current user. @@ -111,14 +121,11 @@ async def chat_uuid( # multiple await keyword will caused the error current_user = await account_repo.read_account_by_username(username=jwt_payload.username) - new_session = await session_repo.create_session( - account_id=current_user.id, name='new session' - ) + new_session = await session_repo.create_session(account_id=current_user.id, name="new session") session_uuid = new_session.uuid - return ChatUUIDResponse( - sessionUuid=session_uuid - ) + return ChatUUIDResponse(sessionUuid=session_uuid) + @router.post( "", @@ -130,16 +137,15 @@ async def chat( chat_in_msg: ChatInMessage, token: str = fastapi.Depends(oauth2_scheme), session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)), - chat_repo: ChatHistoryCRUDRepository = fastapi.Depends(get_repository(repo_type=ChatHistoryCRUDRepository)), rag_chat_repo: RAGChatModelRepository = fastapi.Depends(get_rag_repository(repo_type=RAGChatModelRepository)), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) -)-> StreamingResponse: + jwt_payload: dict = fastapi.Depends(jwt_required), +) -> StreamingResponse: """ Chat with the AI-powered chatbot. - + **Note:** - + You need to sing up and sign in before calling this API. If you are using the Swagger UI. You can get the token automatically by login in through `api/auth/verify` API. @@ -178,37 +184,34 @@ async def chat( """ - ############################################################################################################################## # Note: await keyword will cause issue. See https://github.com/sqlalchemy/sqlalchemy/discussions/9757 - # + # - # if not chat_in_msg.accountID: - # chat_in_msg.accountID = 0 current_user = await account_repo.read_account_by_username(username=jwt_payload.username) - # TODO need verify if sesson exist - # create_session = await session_repo.read_create_sessions_by_id(id=chat_in_msg.sessionId, account_id=chat_in_msg.accountID, name=chat_in_msg.message[:20]) - # response_msg = await rag_chat_repo.get_response(session_id=session_id, input_msg=chat_in_msg.message, chat_repo=chat_repo) + # TODO: Only read session here @Micost + session = await session_repo.read_create_sessions_by_uuid( + session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id, name=chat_in_msg.message[:20] + ) - # TODO: name=chat_in_msg.message[:20] use to create uuid in here, we use username to create session in /api/seesionuuid. Is that acceptable? @Micost - session = await session_repo.read_create_sessions_by_uuid(session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id, name=chat_in_msg.message[:20] ) + match session.type: + case "rag": + # TODO: Implement RAG + pass + case _: + stream_func: ContentStream = rag_chat_repo.inference( + session_id=session.id, + input_msg=chat_in_msg.message, + temperature=chat_in_msg.temperature, + top_k=chat_in_msg.top_k, + top_p=chat_in_msg.top_p, + n_predict=chat_in_msg.n_predict, + ) - # score = await rag_chat_repo.evaluate_response(request_msg = chat_in_msg.message, response_msg = response_msg) - # response_msg = response_msg + "score : {:.3f}".format(score) + # Buffering (the real problem) https://serverfault.com/questions/801628/for-server-sent-events-sse-what-nginx-proxy-configuration-is-appropriate/801629# return StreamingResponse( - rag_chat_repo.inference( - session_id=session.id, - input_msg=chat_in_msg.message, - chat_repo=chat_repo, - temperature=chat_in_msg.temperature, - top_k=chat_in_msg.top_k, - top_p=chat_in_msg.top_p, - n_predict=chat_in_msg.n_predict - ), - # Buffering (the real problem) https://serverfault.com/questions/801628/for-server-sent-events-sse-what-nginx-proxy-configuration-is-appropriate/801629# - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, - media_type='text/event-stream' + stream_func, headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, media_type="text/event-stream" ) @@ -222,7 +225,7 @@ async def get_session( token: str = fastapi.Depends(oauth2_scheme), session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> list[Session]: sessions_list: list = list() # Anonymous user won't related to any session @@ -245,6 +248,7 @@ async def get_session( return sessions_list + @router.get( path="/history/{uuid}", name="chat:get-chat-history-by-session-uuid", @@ -257,7 +261,7 @@ async def get_chathistory( chat_repo: ChatHistoryCRUDRepository = fastapi.Depends(get_repository(repo_type=ChatHistoryCRUDRepository)), session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) + jwt_payload: dict = fastapi.Depends(jwt_required), ) -> list[ChatsWithTime]: """ @@ -331,8 +335,8 @@ async def save_chats( session_repo: SessionCRUDRepository = fastapi.Depends(get_repository(repo_type=SessionCRUDRepository)), chat_repo: ChatHistoryCRUDRepository = fastapi.Depends(get_repository(repo_type=ChatHistoryCRUDRepository)), account_repo: AccountCRUDRepository = fastapi.Depends(get_repository(repo_type=AccountCRUDRepository)), - jwt_payload: dict = fastapi.Depends(jwt_required) -)->ChatUUIDResponse: + jwt_payload: dict = fastapi.Depends(jwt_required), +) -> ChatUUIDResponse: """ Save chat history to session by session uuid @@ -379,10 +383,11 @@ async def save_chats( """ current_user = await account_repo.read_account_by_username(username=jwt_payload.username) - if session_repo.verify_session_by_account_id(session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id) is False: + if ( + session_repo.verify_session_by_account_id(session_uuid=chat_in_msg.sessionUuid, account_id=current_user.id) + is False + ): raise http_404_exc_uuid_not_found_request(uuid=chat_in_msg.sessionUuid) session = await session_repo.read_sessions_by_uuid(session_uuid=chat_in_msg.sessionUuid) await chat_repo.load_create_chat_history(session_id=session.id, chats=chat_in_msg.chats) - return ChatUUIDResponse( - sessionUuid=chat_in_msg.sessionUuid - ) \ No newline at end of file + return ChatUUIDResponse(sessionUuid=chat_in_msg.sessionUuid) diff --git a/backend/src/api/routes/file.py b/backend/src/api/routes/file.py index d16135e..5ef72d4 100644 --- a/backend/src/api/routes/file.py +++ b/backend/src/api/routes/file.py @@ -12,7 +12,7 @@ from src.utilities.exceptions.database import EntityAlreadyExists from src.utilities.exceptions.http.exc_400 import ( - http_400_exc_bad_file_name_request, + http_400_exc_bad_file_name_request, ) router = fastapi.APIRouter(prefix="/file", tags=["file"]) @@ -62,6 +62,7 @@ async def upload_and_return_id( return FileInResponse(fileID=new_file.id) + @router.get( path="/dataset", name="dataset:get-dataset-list", @@ -86,18 +87,18 @@ async def get_dataset( - **updated_at**: The update date of the dataset """ db_datasets = await dataset_repo.get_dataset_list() - datasets_list: list = list() + datasets_list: list = list() for db_dataset in db_datasets: - # print(f"db_dataset:{type(db_dataset.id),type(db_dataset.name),type(db_dataset.created_at),type(db_dataset.updated_at)}") - dataset_res = DatasetResponse( + # print(f"db_dataset:{type(db_dataset.id),type(db_dataset.name),type(db_dataset.created_at),type(db_dataset.updated_at)}") + dataset_res = ( + DatasetResponse( id=db_dataset.id, dataset_name=db_dataset.name, created_at=db_dataset.created_at, updated_at=db_dataset.updated_at, ), + ) datasets_list.append(dataset_res) # datasets_list.append(db_dataset.name) return datasets_list - - diff --git a/backend/src/api/routes/health.py b/backend/src/api/routes/health.py index 6adb7e1..2e08dd8 100644 --- a/backend/src/api/routes/health.py +++ b/backend/src/api/routes/health.py @@ -21,7 +21,6 @@ @router.get("", name="health:health-check") - async def health_check() -> HealthCheckResponse: """ Check the health of the service @@ -33,4 +32,4 @@ async def health_check() -> HealthCheckResponse: Return: - **status**: The status of the service """ - return HealthCheckResponse(status="ok") \ No newline at end of file + return HealthCheckResponse(status="ok") diff --git a/backend/src/api/routes/train.py b/backend/src/api/routes/train.py index 8db829c..c0c1c91 100644 --- a/backend/src/api/routes/train.py +++ b/backend/src/api/routes/train.py @@ -45,19 +45,21 @@ async def save( # 2, validate fileID or dataset # 3, use file and or dataset perform the training logic (csv id done) - if train_in_msg.modelID is not None and train_in_msg.fileID is not None: + if train_in_msg.modelID is not None and train_in_msg.fileID is not None: # if contains file ID and modelID, then load file ai_model = await aimodel_repo.read_aimodel_by_id(id=train_in_msg.modelID) file_csv = await file_repo.read_uploadedfiles_by_id(id=train_in_msg.fileID) await rag_chat_repo.load_csv_file(file_name=file_csv.name, model_name=ai_model.name) else: # Else, load dataset - db_dataset=await dataset_repo.get_dataset_by_name(train_in_msg.dataSet) + db_dataset = await dataset_repo.get_dataset_by_name(train_in_msg.dataSet) if not db_dataset: - dataload_thread = threading.Thread(target=rag_chat_repo.load_data_set,args=(train_in_msg,) ) + dataload_thread = threading.Thread(target=rag_chat_repo.load_data_set, args=(train_in_msg,)) dataload_thread.daemon = True dataload_thread.start() - await dataset_repo.create_dataset(DatasetCreate(dataset_name=train_in_msg.dataSet,des=train_in_msg.dataSet)) + await dataset_repo.create_dataset( + DatasetCreate(dataset_name=train_in_msg.dataSet, des=train_in_msg.dataSet) + ) return TrainFileInResponse( msg="successful", diff --git a/backend/src/api/routes/version.py b/backend/src/api/routes/version.py index 360dc55..ef0a2f7 100644 --- a/backend/src/api/routes/version.py +++ b/backend/src/api/routes/version.py @@ -20,6 +20,7 @@ router = fastapi.APIRouter(prefix="/version", tags=["version"]) + @router.get( path="", name="version:get-version", @@ -29,19 +30,19 @@ async def get_version() -> ServiceVersionResponse: """ Get the version of the service - + ```bash curl http://localhost:8000/api/version -> {"llamacpp":"server--b1-a8d49d8","milvus":"v2.3.12","kirin":"v0.1.8"} ``` - - Return ServiceVersionResponse: + + Return ServiceVersionResponse: - **kirin**: The version of the API aggregator - **milvus**: The version of the vector database - - **inference_engine**: The version of the inference engine + - **inference_engine**: The version of the inference engine """ return ServiceVersionResponse( kirin=settings.BACKEND_SERVER_VERSION, milvus=settings.MILVUS_VERSION, - inference_engine=settings.INFERENCE_ENG_VERSION - ) + inference_engine=settings.INFERENCE_ENG_VERSION, + ) diff --git a/backend/src/config/events.py b/backend/src/config/events.py index c6a0441..8074e82 100644 --- a/backend/src/config/events.py +++ b/backend/src/config/events.py @@ -4,16 +4,18 @@ import loguru from src.repository.events import ( - dispose_db_connection, + dispose_db_connection, initialize_db_connection, - initialize_vectordb_collection - ) + initialize_vectordb_collection, + dispose_httpx_client, +) def execute_backend_server_event_handler(backend_app: fastapi.FastAPI) -> typing.Any: async def launch_backend_server_events() -> None: await initialize_db_connection(backend_app=backend_app) await initialize_vectordb_collection() + return launch_backend_server_events @@ -21,5 +23,6 @@ def terminate_backend_server_event_handler(backend_app: fastapi.FastAPI) -> typi @loguru.logger.catch async def stop_backend_server_events() -> None: await dispose_db_connection(backend_app=backend_app) + await dispose_httpx_client() return stop_backend_server_events diff --git a/backend/src/config/settings/base.py b/backend/src/config/settings/base.py index 21e2d88..5f8e47b 100644 --- a/backend/src/config/settings/base.py +++ b/backend/src/config/settings/base.py @@ -43,7 +43,6 @@ class BackendBaseSettings(BaseSettings): MILVUS_PORT: int = decouple.config("MILVUS_PORT", cast=int) # type: ignore MILVUS_VERSION: str = decouple.config("MILVUS_VERSION", cast=str) # type: ignore - POSTGRES_HOST: str = decouple.config("POSTGRES_HOST", cast=str) # type: ignore DB_MAX_POOL_CON: int = decouple.config("DB_MAX_POOL_CON", cast=int) # type: ignore POSTGRES_DB: str = decouple.config("POSTGRES_DB", cast=str) # type: ignore @@ -91,32 +90,32 @@ class BackendBaseSettings(BaseSettings): HASHING_SALT: str = decouple.config("HASHING_SALT", cast=str) # type: ignore JWT_ALGORITHM: str = decouple.config("JWT_ALGORITHM", cast=str) # type: ignore - INFERENCE_ENG: str = decouple.config("INFERENCE_ENG", cast=str) # type: ignore - INFERENCE_ENG_PORT: int=decouple.config("INFERENCE_ENG_PORT", cast=int) # type: ignore - INFERENCE_ENG_VERSION: str = decouple.config("INFERENCE_ENG_VERSION", cast=str) # type: ignore - + INFERENCE_ENG: str = decouple.config("INFERENCE_ENG", cast=str) # type: ignore + INFERENCE_ENG_PORT: int = decouple.config("INFERENCE_ENG_PORT", cast=int) # type: ignore + INFERENCE_ENG_VERSION: str = decouple.config("INFERENCE_ENG_VERSION", cast=str) # type: ignore + # Configurations for language model - LANGUAGE_MODEL_NAME: str = decouple.config("LANGUAGE_MODEL_NAME", cast=str) # type: ignore + LANGUAGE_MODEL_NAME: str = decouple.config("LANGUAGE_MODEL_NAME", cast=str) # type: ignore # Admin setting - ADMIN_USERNAME: str = decouple.config("ADMIN_USERNAME", cast=str) # type: ignore - ADMIN_EMAIL: str = decouple.config("ADMIN_EMAIL", cast=str) # type: ignore - ADMIN_PASS: str = decouple.config("ADMIN_PASS", cast=str) # type: ignore + ADMIN_USERNAME: str = decouple.config("ADMIN_USERNAME", cast=str) # type: ignore + ADMIN_EMAIL: str = decouple.config("ADMIN_EMAIL", cast=str) # type: ignore + ADMIN_PASS: str = decouple.config("ADMIN_PASS", cast=str) # type: ignore # Configurations for language model - INSTRUCTION: str = decouple.config("INSTRUCTION", cast=str) # type: ignore + INSTRUCTION: str = decouple.config("INSTRUCTION", cast=str) # type: ignore - ETCD_AUTO_COMPACTION_MODE: str = decouple.config("ETCD_AUTO_COMPACTION_MODE", cast=str) # type: ignore - ETCD_AUTO_COMPACTION_RETENTION: int = decouple.config("ETCD_AUTO_COMPACTION_RETENTION", cast=int) # type: ignore - ETCD_QUOTA_BACKEND_BYTES: int = decouple.config("ETCD_QUOTA_BACKEND_BYTES", cast=int) # type: ignore - NUM_CPU_CORES: float = decouple.config("NUM_CPU_CORES", cast=float) # type: ignore + ETCD_AUTO_COMPACTION_MODE: str = decouple.config("ETCD_AUTO_COMPACTION_MODE", cast=str) # type: ignore + ETCD_AUTO_COMPACTION_RETENTION: int = decouple.config("ETCD_AUTO_COMPACTION_RETENTION", cast=int) # type: ignore + ETCD_QUOTA_BACKEND_BYTES: int = decouple.config("ETCD_QUOTA_BACKEND_BYTES", cast=int) # type: ignore + NUM_CPU_CORES: float = decouple.config("NUM_CPU_CORES", cast=float) # type: ignore class Config(pydantic.ConfigDict): case_sensitive: bool = True env_file: str = f"{str(ROOT_DIR)}/.env" validate_assignment: bool = True # https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra - #TODO: We need to make sure pydanic is really useful + # TODO: We need to make sure pydanic is really useful # extra='allow' @property diff --git a/backend/src/config/settings/const.py b/backend/src/config/settings/const.py index 9a61a78..f78adbd 100644 --- a/backend/src/config/settings/const.py +++ b/backend/src/config/settings/const.py @@ -7,18 +7,18 @@ # DEFAULT MODELS DEFAULT_ENCODER = "sentence-transformers/all-MiniLM-L6-v2" -CROSS_ENDOCDER = 'cross-encoder/ms-marco-MiniLM-L-6-v2' +CROSS_ENDOCDER = "cross-encoder/ms-marco-MiniLM-L-6-v2" DEFAULT_MODEL = "microsoft/GODEL-v1_1-base-seq2seq" -DEFAUTL_SUMMERIZE_MODEL="Falconsai/text_summarization" +DEFAUTL_SUMMERIZE_MODEL = "Falconsai/text_summarization" DEFAULT_MODEL_PATH = "/models/" # CONVERSATION -CONVERSATION_INACTIVE_SEC= 300 +CONVERSATION_INACTIVE_SEC = 300 RAG_NUM = 5 -#DATASET LOADBATCH +# DATASET LOADBATCH LOAD_BATCH_SIZE = 100 -#ANONYMOUS USER +# ANONYMOUS USER ANONYMOUS_USER = "anonymous" ANONYMOUS_EMAIL = "anonymous@anony.com" -ANONYMOUS_PASS = "Marlboro@2211" \ No newline at end of file +ANONYMOUS_PASS = "Marlboro@2211" diff --git a/backend/src/main.py b/backend/src/main.py index 3d628a7..4114bab 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -20,10 +20,7 @@ from fastapi.middleware.cors import CORSMiddleware from src.config.manager import settings from src.api.endpoints import router as api_endpoint_router -from src.config.events import ( - execute_backend_server_event_handler, - terminate_backend_server_event_handler - ) +from src.config.events import execute_backend_server_event_handler, terminate_backend_server_event_handler def initialize_backend_application() -> fastapi.FastAPI: @@ -74,4 +71,4 @@ def initialize_backend_application() -> fastapi.FastAPI: reload=settings.DEBUG, workers=settings.BACKEND_SERVER_WORKERS, log_level=settings.LOGGING_LEVEL, - ) \ No newline at end of file + ) diff --git a/backend/src/models/db/chat.py b/backend/src/models/db/chat.py index c83cdbd..f9cf616 100644 --- a/backend/src/models/db/chat.py +++ b/backend/src/models/db/chat.py @@ -12,10 +12,14 @@ class Session(Base): # type: ignore __tablename__ = "session" id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(primary_key=True, autoincrement="auto") - uuid: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=36), nullable=False, default=lambda: str(uuid.uuid4())) + uuid: SQLAlchemyMapped[str] = sqlalchemy_mapped_column( + sqlalchemy.String(length=36), nullable=False, default=lambda: str(uuid.uuid4()) + ) account_id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(nullable=True) name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=64), nullable=True) - type: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.Enum("rag", "chat", name="session_type"), nullable=False, default="chat") + type: SQLAlchemyMapped[str] = sqlalchemy_mapped_column( + sqlalchemy.Enum("rag", "chat", name="session_type"), nullable=False, default="chat" + ) created_at: SQLAlchemyMapped[datetime.datetime] = sqlalchemy_mapped_column( sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy_functions.now() ) @@ -32,7 +36,9 @@ class ChatHistory(Base): # type: ignore id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(primary_key=True, autoincrement="auto") session_id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(nullable=False) - role: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.Enum("user", "assistant", name="role"), nullable=False, default="user") + role: SQLAlchemyMapped[str] = sqlalchemy_mapped_column( + sqlalchemy.Enum("user", "assistant", name="role"), nullable=False, default="user" + ) message: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=4096), nullable=False) created_at: SQLAlchemyMapped[datetime.datetime] = sqlalchemy_mapped_column( sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy_functions.now() diff --git a/backend/src/models/db/dataset.py b/backend/src/models/db/dataset.py index bee87d8..ec74a76 100644 --- a/backend/src/models/db/dataset.py +++ b/backend/src/models/db/dataset.py @@ -7,7 +7,6 @@ from src.repository.table import Base - class DataSet(Base): # type: ignore __tablename__ = "data_set" @@ -22,4 +21,4 @@ class DataSet(Base): # type: ignore server_onupdate=sqlalchemy.schema.FetchedValue(for_update=True), ) - __mapper_args__ = {"eager_defaults": True} \ No newline at end of file + __mapper_args__ = {"eager_defaults": True} diff --git a/backend/src/models/schemas/account.py b/backend/src/models/schemas/account.py index fe56e04..7912379 100644 --- a/backend/src/models/schemas/account.py +++ b/backend/src/models/schemas/account.py @@ -1,13 +1,13 @@ import datetime from typing import Optional -from pydantic import(Field,EmailStr) +from pydantic import Field, EmailStr from src.models.schemas.base import BaseSchemaModel class AccountInCreate(BaseSchemaModel): username: str = Field(..., title="username", description="username") - email: EmailStr = Field(..., title="email", description="email") + email: EmailStr = Field(..., title="email", description="email") password: str = Field(..., title="user password", description="Password length 6-20 characters") @@ -23,10 +23,10 @@ class AccountInLogin(BaseSchemaModel): class AccountWithToken(BaseSchemaModel): token: str = Field(..., title="token", description="Auth token") - username: str = Field(..., title="username", description="username") + username: str = Field(..., title="username", description="username") email: EmailStr = Field(..., title="email", description="email") is_verified: bool = Field(..., title="Verify", description="Verify true or false") - is_active: bool = Field(..., title="Active", description="Active true or false") + is_active: bool = Field(..., title="Active", description="Active true or false") is_logged_in: bool = Field(..., title="Logged", description="Logged true or false") created_at: datetime.datetime = Field(..., title="Creation time", description="Creation time") updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time") diff --git a/backend/src/models/schemas/ai_model.py b/backend/src/models/schemas/ai_model.py index 144a927..2c121ed 100644 --- a/backend/src/models/schemas/ai_model.py +++ b/backend/src/models/schemas/ai_model.py @@ -9,12 +9,14 @@ class AiModelCreate(BaseSchemaModel): name: str = Field(..., title="Model Name", description="Model Name") des: str = Field(..., title="Details", description="Details") + class AiModelCreateResponse(BaseSchemaModel): id: int = Field(..., title="id", description="id") name: str = Field(..., title="name", description="name") des: str = Field(..., title="Details", description="Details") msg: Optional[str] = None + class AiModelInResponse(BaseSchemaModel): id: int = Field(..., title="id", description="id") name: str = Field(..., title="name", description="name") diff --git a/backend/src/models/schemas/chat.py b/backend/src/models/schemas/chat.py index cddfdea..62aaba0 100644 --- a/backend/src/models/schemas/chat.py +++ b/backend/src/models/schemas/chat.py @@ -24,11 +24,11 @@ class ChatInMessage(BaseSchemaModel): Top_p parameter for inference(float) n_predict: int n_predict parameter for inference(int) - + """ sessionUuid: Optional[str] | None = Field(..., title="Session UUID", description="Session UUID") - message: str = Field(..., title="Message", description="Message") + message: str = Field(..., title="Message", description="Message") temperature: float = Field(..., title="Temperature", description="Temperature for inference(float)") top_k: int = Field(..., title="Top_k", description="Top_k parameter for inference(int)") top_p: float = Field(..., title="Top_p", description="Top_p parameter for inference(float)") @@ -39,6 +39,7 @@ class ChatInResponse(BaseSchemaModel): sessionUuid: str = Field(..., title="Session UUID", description="Session UUID") message: str = Field(..., title="Message", description="Message") + class ChatUUIDResponse(BaseSchemaModel): """ Object for the response body of the chat session endpoint. @@ -51,6 +52,7 @@ class ChatUUIDResponse(BaseSchemaModel): sessionUuid: str = Field(..., title="Session UUID", description="Session UUID") + class SessionUpdate(BaseSchemaModel): """ Object for the request body of update session. @@ -64,15 +66,17 @@ class SessionUpdate(BaseSchemaModel): type: str type of session: rag or chat """ + sessionUuid: str = Field(..., title="Session UUID", description="Session UUID") name: Optional[str] = Field(default=None, title="Name", description="Name") type: Optional[Literal["rag", "chat"]] = Field(default=None, title="Type", description="Type") + class Session(BaseSchemaModel): - sessionUuid: str = Field(..., title="Session UUID" ,description="Session UUID") - name: str | None = Field(..., title="Name", description="Name") - type: str | None = Field(..., title="Type", description="Type") - created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time") + sessionUuid: str = Field(..., title="Session UUID", description="Session UUID") + name: str | None = Field(..., title="Name", description="Name") + type: str | None = Field(..., title="Type", description="Type") + created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time") class Chats(BaseSchemaModel): @@ -86,9 +90,11 @@ class Chats(BaseSchemaModel): message: str Message """ - role: str = Field(..., title="Role", description="Role ") + + role: str = Field(..., title="Role", description="Role ") message: str = Field(..., title="Message", description="Message") + class ChatsWithTime(BaseSchemaModel): """ Object for the response body of the chat history endpoint. @@ -101,10 +107,12 @@ class ChatsWithTime(BaseSchemaModel): Message create_at: timestamp """ - role: str = Field(..., title="Role", description="Role ") + + role: str = Field(..., title="Role", description="Role ") message: str = Field(..., title="Message", description="Message") create_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time") + class SaveChatHistory(BaseSchemaModel): """ Object for the response body of the chat history endpoint. @@ -118,9 +126,11 @@ class SaveChatHistory(BaseSchemaModel): message: str Message """ - sessionUuid: str = Field(..., title="Session UUID" ,description="Session UUID") - chats: list[Chats] = Field(..., title="Chat history" ,description="Chat history") + + sessionUuid: str = Field(..., title="Session UUID", description="Session UUID") + chats: list[Chats] = Field(..., title="Chat history", description="Chat history") + class MessagesResponse(BaseSchemaModel): role: str = Field(..., title="Role", description="Role") - content: str = Field(..., title="Content", description="Content") \ No newline at end of file + content: str = Field(..., title="Content", description="Content") diff --git a/backend/src/models/schemas/dataset.py b/backend/src/models/schemas/dataset.py index 020c81c..fead10c 100644 --- a/backend/src/models/schemas/dataset.py +++ b/backend/src/models/schemas/dataset.py @@ -1,4 +1,3 @@ - import datetime from pydantic import Field @@ -7,12 +6,11 @@ class DatasetCreate(BaseSchemaModel): dataset_name: str = Field(..., title="DataSet Name", description="DataSet Name") - des: str | None = Field(..., title="Details", description="Details") - - + des: str | None = Field(..., title="Details", description="Details") + class DatasetResponse(BaseSchemaModel): - id: int = Field(..., title="id",description="id") - dataset_name: str = Field(..., title="DataSet Name", description="DataSet Name") - created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time") - updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time") \ No newline at end of file + id: int = Field(..., title="id", description="id") + dataset_name: str = Field(..., title="DataSet Name", description="DataSet Name") + created_at: datetime.datetime | None = Field(..., title="Creation time", description="Creation time") + updated_at: datetime.datetime | None = Field(..., title="Update time", description="Update time") diff --git a/backend/src/models/schemas/health.py b/backend/src/models/schemas/health.py index 56c43cc..941443c 100644 --- a/backend/src/models/schemas/health.py +++ b/backend/src/models/schemas/health.py @@ -16,10 +16,12 @@ from src.models.schemas.base import BaseSchemaModel from pydantic import Field + class HealthCheckResponse(BaseSchemaModel): - """ + """ The status of the service - **status**: The status of the service """ - status: str = Field(..., title="status", description="The status of the service", examples=['ok']) + + status: str = Field(..., title="status", description="The status of the service", examples=["ok"]) diff --git a/backend/src/models/schemas/jwt.py b/backend/src/models/schemas/jwt.py index 71d2f69..f0b11a9 100644 --- a/backend/src/models/schemas/jwt.py +++ b/backend/src/models/schemas/jwt.py @@ -2,12 +2,11 @@ import datetime - class JWToken(BaseModel): - exp: datetime.datetime = Field(..., title="exp", description="exp") + exp: datetime.datetime = Field(..., title="exp", description="exp") sub: str = Field(..., title="sub", description="sub") class JWTAccount(BaseModel): username: str = Field(..., title="username", description="username") - email: EmailStr = Field(..., title="email", description="email") + email: EmailStr = Field(..., title="email", description="email") diff --git a/backend/src/models/schemas/train.py b/backend/src/models/schemas/train.py index ff1809a..8f43818 100644 --- a/backend/src/models/schemas/train.py +++ b/backend/src/models/schemas/train.py @@ -6,11 +6,11 @@ class TrainFileIn(BaseSchemaModel): fileID: int | None = Field(..., title="File Id", description="File Id") dataSet: str | None = Field(..., title="DataSet", description="DataSet") modelID: int | None = Field(..., title="Model id", description="Model id") - embedField : str | None = Field(..., title="Embed Field", description="Embed Field") - resField : str | None = Field(..., title="Result Field", description="Result Field") + embedField: str | None = Field(..., title="Embed Field", description="Embed Field") + resField: str | None = Field(..., title="Result Field", description="Result Field") directLoad: bool = False class TrainFileInResponse(BaseSchemaModel): trainID: int | None = Field(..., title="TrainID", description="trainID") - msg : str =Field(..., title="Message", description="Message") + msg: str = Field(..., title="Message", description="Message") diff --git a/backend/src/models/schemas/version.py b/backend/src/models/schemas/version.py index 9268607..4963574 100644 --- a/backend/src/models/schemas/version.py +++ b/backend/src/models/schemas/version.py @@ -14,7 +14,8 @@ # limitations under the License. from src.models.schemas.base import BaseSchemaModel -from pydantic import(Field) +from pydantic import Field + class ServiceVersionResponse(BaseSchemaModel): """ @@ -24,6 +25,11 @@ class ServiceVersionResponse(BaseSchemaModel): - **milvus**: The version of the vector database - **inference_engine**: The version of the inference engine """ - inference_engine: str | None = Field(..., title="infernece engine version", description="infernece engine version", examples=['server--b1-a8d49d8']) - milvus : str | None = Field(..., title="milvus version", description="milvus version", examples=['v2.3.12']) - kirin :str | None = Field(..., title="backend service version", description="backend service version", examples=['v0.1.8']) + + inference_engine: str | None = Field( + ..., title="infernece engine version", description="infernece engine version", examples=["server--b1-a8d49d8"] + ) + milvus: str | None = Field(..., title="milvus version", description="milvus version", examples=["v2.3.12"]) + kirin: str | None = Field( + ..., title="backend service version", description="backend service version", examples=["v0.1.8"] + ) diff --git a/backend/src/repository/crud/account.py b/backend/src/repository/crud/account.py index 7cd9d10..c164734 100644 --- a/backend/src/repository/crud/account.py +++ b/backend/src/repository/crud/account.py @@ -59,18 +59,19 @@ async def read_account_by_email(self, email: str) -> Account: if account is None: raise EntityDoesNotExist("Account with email `{email}` does not exist!") - return account # type: ignore + return account # type: ignore async def read_user_by_password_authentication(self, account_login: AccountInLogin) -> Account: - stmt = sqlalchemy.select(Account).where( - Account.username == account_login.username) + stmt = sqlalchemy.select(Account).where(Account.username == account_login.username) query = await self.async_session.execute(statement=stmt) db_account = query.scalar() if not db_account: raise EntityDoesNotExist("Wrong username!") - if not pwd_generator.is_password_authenticated(hash_salt=db_account.hash_salt, password=account_login.password, hashed_password=db_account.hashed_password): # type: ignore + if not pwd_generator.is_password_authenticated( + hash_salt=db_account.hash_salt, password=account_login.password, hashed_password=db_account.hashed_password + ): # type: ignore raise PasswordDoesNotMatch("Password does not match!") return db_account # type: ignore @@ -85,14 +86,22 @@ async def update_account_by_id(self, id: int, account_update: AccountInUpdate) - if not update_account: raise EntityDoesNotExist(f"Account with id `{id}` does not exist!") # type: ignore - update_stmt = sqlalchemy.update(table=Account).where(Account.id == update_account.id).values(updated_at=sqlalchemy_functions.now()) # type: ignore + update_stmt = ( + sqlalchemy.update(table=Account) + .where(Account.id == update_account.id) + .values(updated_at=sqlalchemy_functions.now()) + ) # type: ignore if new_account_data["email"]: update_stmt = update_stmt.values(email=new_account_data["email"]) if new_account_data["password"]: update_account.set_hash_salt(hash_salt=pwd_generator.generate_salt) # type: ignore - update_account.set_hashed_password(hashed_password=pwd_generator.generate_hashed_password(hash_salt=update_account.hash_salt, new_password=new_account_data["password"])) # type: ignore + update_account.set_hashed_password( + hashed_password=pwd_generator.generate_hashed_password( + hash_salt=update_account.hash_salt, new_password=new_account_data["password"] + ) + ) # type: ignore await self.async_session.execute(statement=update_stmt) await self.async_session.commit() @@ -133,4 +142,4 @@ async def is_email_taken(self, email: str) -> bool: if not credential_verifier.is_email_available(email=db_email): raise EntityAlreadyExists(f"The email `{email}` is already registered!") # type: ignore - return True \ No newline at end of file + return True diff --git a/backend/src/repository/crud/ai_model.py b/backend/src/repository/crud/ai_model.py index 8cd4eee..b6eaa78 100644 --- a/backend/src/repository/crud/ai_model.py +++ b/backend/src/repository/crud/ai_model.py @@ -39,11 +39,11 @@ async def read_aimodel_by_name(self, name: str) -> AiModel: raise EntityDoesNotExist(f"AiModel with name `{name}` does not exist!") return ai_model - + async def get_aimodel_by_name(self, name: str) -> AiModel: stmt = sqlalchemy.select(AiModel).where(AiModel.name == name) query = await self.async_session.execute(statement=stmt) - return query.scalar() + return query.scalar() async def update_aimodel_by_id(self, id: int, aimodel_update: AiModelInUpdate) -> AiModel: new_aimodel_data = aimodel_update.dict() diff --git a/backend/src/repository/crud/base.py b/backend/src/repository/crud/base.py index 1919601..9505c1e 100644 --- a/backend/src/repository/crud/base.py +++ b/backend/src/repository/crud/base.py @@ -1,5 +1,6 @@ from sqlalchemy.ext.asyncio import AsyncSession as SQLAlchemyAsyncSession + class BaseCRUDRepository: def __init__(self, async_session: SQLAlchemyAsyncSession): - self.async_session = async_session \ No newline at end of file + self.async_session = async_session diff --git a/backend/src/repository/crud/chat.py b/backend/src/repository/crud/chat.py index 6942860..1481018 100644 --- a/backend/src/repository/crud/chat.py +++ b/backend/src/repository/crud/chat.py @@ -39,7 +39,11 @@ async def update_sessions_by_uuid(self, session: SessionUpdate, account_id: int) update_session = query.scalar() if update_session is None: raise EntityDoesNotExist(f"Session with uuid `{session.sessionUuid}` does not exist!") - update_stmt = sqlalchemy.update(table=Session).where(Session.uuid == session.sessionUuid).values(updated_at=sqlalchemy_functions.now()) # type: ignore + update_stmt = ( + sqlalchemy.update(table=Session) + .where(Session.uuid == session.sessionUuid) + .values(updated_at=sqlalchemy_functions.now()) + ) # type: ignore if session.name: update_stmt = update_stmt.values(name=session.name) @@ -71,21 +75,22 @@ async def read_sessions_by_account_id(self, id: int) -> typing.Sequence[Session] query = await self.async_session.execute(statement=stmt) return query.scalars().all() - async def verify_session_by_account_id(self, session_uuid: str, account_id: int ) -> bool: + async def verify_session_by_account_id(self, session_uuid: str, account_id: int) -> bool: # stmt = sqlalchemy.select(Session).where(Session.account_id == id) stmt = sqlalchemy.select(Session).where(Session.uuid == session_uuid, Session.account_id == account_id) query = await self.async_session.execute(statement=stmt) return bool(query) + class ChatHistoryCRUDRepository(BaseCRUDRepository): async def read_chat_history_by_id(self, id: int) -> ChatHistory: stmt = sqlalchemy.select(ChatHistory).where(ChatHistory.id == id) query = await self.async_session.execute(statement=stmt) - chat_history =query.scalar() + chat_history = query.scalar() if chat_history is None: raise EntityDoesNotExist("ChatHistory with id `{id}` does not exist!") - return chat_history # type: ignore + return chat_history # type: ignore async def read_chat_history_by_session_id(self, id: int, limit_num=50) -> typing.Sequence[ChatHistory]: # TODO limit num = 50 is a temp number @@ -103,7 +108,7 @@ async def load_create_chat_history(self, session_id: int, chats: list[Chats]): for chat in chats: new_chat_history = ChatHistory(session_id=session_id, role=chat.role, message=chat.message[:4096]) self.async_session.add(instance=new_chat_history) - await self.async_session.commit() + await self.async_session.commit() except Exception as e: - await self.async_session.rollback() + await self.async_session.rollback() loguru.logger.error(f"Error: {e}") diff --git a/backend/src/repository/crud/dataset_db.py b/backend/src/repository/crud/dataset_db.py index f4b48c6..6cbc12a 100644 --- a/backend/src/repository/crud/dataset_db.py +++ b/backend/src/repository/crud/dataset_db.py @@ -4,10 +4,11 @@ import sqlalchemy import typing + + class DataSetCRUDRepository(BaseCRUDRepository): - async def create_dataset(self, dataset_create: DatasetCreate) -> DataSet: - new_dataset=DataSet(name=dataset_create.dataset_name) + new_dataset = DataSet(name=dataset_create.dataset_name) self.async_session.add(instance=new_dataset) await self.async_session.commit() @@ -15,18 +16,12 @@ async def create_dataset(self, dataset_create: DatasetCreate) -> DataSet: return new_dataset + async def get_dataset_by_name(self, dataset_name: str) -> typing.Sequence[DataSet]: + stmt = sqlalchemy.select(DataSet).where(DataSet.name == dataset_name) + query = await self.async_session.execute(statement=stmt) + return query.scalars().all() - - async def get_dataset_by_name(self,dataset_name: str)->typing.Sequence[DataSet]: - stmt = sqlalchemy.select(DataSet).where(DataSet.name == dataset_name) - query = await self.async_session.execute(statement=stmt) - return query.scalars().all() - - - - async def get_dataset_list(self)->typing.Sequence[DataSet]: - stmt = sqlalchemy.select(DataSet).order_by(DataSet.updated_at.desc()) - query = await self.async_session.execute(statement=stmt) - return query.scalars().all() - - + async def get_dataset_list(self) -> typing.Sequence[DataSet]: + stmt = sqlalchemy.select(DataSet).order_by(DataSet.updated_at.desc()) + query = await self.async_session.execute(statement=stmt) + return query.scalars().all() diff --git a/backend/src/repository/crud/file.py b/backend/src/repository/crud/file.py index 81958f0..9aaa760 100644 --- a/backend/src/repository/crud/file.py +++ b/backend/src/repository/crud/file.py @@ -3,21 +3,21 @@ from src.models.db.file import UploadedFile from src.repository.crud.base import BaseCRUDRepository -from src.utilities.verifier.file import file_verifier # type: ignore +from src.utilities.verifier.file import file_verifier # type: ignore from src.utilities.exceptions.database import EntityAlreadyExists, EntityDoesNotExist class UploadedFileCRUDRepository(BaseCRUDRepository): async def create_uploadfile(self, file_name: str) -> UploadedFile: - - file_name_stmt = sqlalchemy.select(UploadedFile.name).select_from(UploadedFile).where(UploadedFile.name == file_name) + file_name_stmt = ( + sqlalchemy.select(UploadedFile.name).select_from(UploadedFile).where(UploadedFile.name == file_name) + ) file_name_query = await self.async_session.execute(file_name_stmt) db_file_name = file_name_query.scalar() - + if not file_verifier.is_file_available(name=db_file_name): - raise EntityAlreadyExists(f"The file_name `{file_name}` is already file_name!") - - + raise EntityAlreadyExists(f"The file_name `{file_name}` is already file_name!") + uploaded_file = UploadedFile(name=file_name) self.async_session.add(instance=uploaded_file) @@ -38,7 +38,7 @@ async def read_uploadedfiles_by_id(self, id: int) -> UploadedFile: if fileinfo is None: raise EntityDoesNotExist("File with id `{id}` does not exist!") - return fileinfo# type: ignore + return fileinfo # type: ignore async def delete_file_by_id(self, id: int) -> str: select_stmt = sqlalchemy.select(UploadedFile).where(UploadedFile.id == id) diff --git a/backend/src/repository/database.py b/backend/src/repository/database.py index eafbfb5..4cfd83d 100644 --- a/backend/src/repository/database.py +++ b/backend/src/repository/database.py @@ -23,7 +23,9 @@ def __init__(self): # max_overflow=settings.DB_POOL_OVERFLOW, # poolclass=SQLAlchemyQueuePool, ) - self.sync_engine =create_engine(f"{settings.POSTGRES_SCHEMA}://{settings.POSTGRES_USERNAME}:{settings.POSTGRES_PASSWORD}@{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/{settings.POSTGRES_DB}") + self.sync_engine = create_engine( + f"{settings.POSTGRES_SCHEMA}://{settings.POSTGRES_USERNAME}:{settings.POSTGRES_PASSWORD}@{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/{settings.POSTGRES_DB}" + ) self.async_session: SQLAlchemyAsyncSession = SQLAlchemyAsyncSession(bind=self.async_engine) self.pool: SQLAlchemyPool = self.async_engine.pool diff --git a/backend/src/repository/events.py b/backend/src/repository/events.py index 4a5c238..be67ce4 100644 --- a/backend/src/repository/events.py +++ b/backend/src/repository/events.py @@ -13,6 +13,8 @@ from src.repository.database import async_db from src.repository.table import Base from src.repository.vector_database import vector_db +from src.utilities.httpkit.httpx_kit import httpx_kit + @event.listens_for(target=async_db.async_engine.sync_engine, identifier="connect") def inspect_db_server_on_connection( @@ -38,6 +40,7 @@ async def initialize_db_tables(connection: AsyncConnection) -> None: loguru.logger.info("Database Table Creation --- Successfully Initialized!") + async def initialize_anonymous_user(async_session: AsyncSession) -> None: loguru.logger.info("Anonymous user --- Creating . . .") @@ -46,7 +49,7 @@ async def initialize_anonymous_user(async_session: AsyncSession) -> None: new_account.set_hash_salt(hash_salt=pwd_generator.generate_salt) new_account.set_hashed_password( hashed_password=pwd_generator.generate_hashed_password( - hash_salt=new_account.hash_salt, new_password=ANONYMOUS_PASS + hash_salt=new_account.hash_salt, new_password=ANONYMOUS_PASS ) ) @@ -56,6 +59,7 @@ async def initialize_anonymous_user(async_session: AsyncSession) -> None: loguru.logger.info("Anonymous user --- Successfully Created!") + async def initialize_admin_user(async_session: AsyncSession) -> None: loguru.logger.info("Admin user --- Creating . . .") @@ -64,7 +68,7 @@ async def initialize_admin_user(async_session: AsyncSession) -> None: new_account.set_hash_salt(hash_salt=pwd_generator.generate_salt) new_account.set_hashed_password( hashed_password=pwd_generator.generate_hashed_password( - hash_salt=new_account.hash_salt, new_password=settings.ADMIN_USERNAME + hash_salt=new_account.hash_salt, new_password=settings.ADMIN_USERNAME ) ) @@ -74,6 +78,7 @@ async def initialize_admin_user(async_session: AsyncSession) -> None: loguru.logger.info("Admin user --- Successfully Created!") + async def initialize_db_connection(backend_app: fastapi.FastAPI) -> None: loguru.logger.info("Database Connection --- Establishing . . .") @@ -89,7 +94,6 @@ async def initialize_db_connection(backend_app: fastapi.FastAPI) -> None: async def initialize_vectordb_collection() -> None: - loguru.logger.info("Vector Database Connection --- Establishing . . .") # RAG data can be loaded manually from the frontend # https://github.com/SkywardAI/chat-backend/issues/172 @@ -112,3 +116,19 @@ async def dispose_db_connection(backend_app: fastapi.FastAPI) -> None: await backend_app.state.db.async_engine.dispose() loguru.logger.info("Database Connection --- Successfully Disposed!") + + +async def dispose_httpx_client() -> None: + loguru.logger.info("Httpx Client --- Disposing . . .") + + close_async = await httpx_kit.teardown_async_client() + + loguru.logger.info( + "Httpx Async Client --- Successfully Disposed!" if close_async else "Httpx Async Client --- Failed to Dispose!" + ) + + close_sync = httpx_kit.teardown_sync_client() + + loguru.logger.info( + "Httpx Sync Client --- Successfully Disposed!" if close_sync else "Httpx Sync Client --- Failed to Dispose!" + ) diff --git a/backend/src/repository/inference_eng.py b/backend/src/repository/inference_eng.py index 63d10b2..d7111be 100644 --- a/backend/src/repository/inference_eng.py +++ b/backend/src/repository/inference_eng.py @@ -14,35 +14,35 @@ # limitations under the License. -# https://pypi.org/project/openai/1.35.5/ +import pydantic import openai from src.config.manager import settings + class InferenceHelper: - infer_eng_url=settings.INFERENCE_ENG - infer_eng_port=settings.INFERENCE_ENG_PORT - instruction=settings.INSTRUCTION - + infer_eng_url: pydantic.StrictStr = settings.INFERENCE_ENG + infer_eng_port: pydantic.PositiveInt = settings.INFERENCE_ENG_PORT + instruction: pydantic.StrictStr = settings.INSTRUCTION + def init(self) -> None: raise NotImplementedError("InferenceHelper is a singleton class. Use inference_helper instead.") - @classmethod def openai_client(cls) -> openai.OpenAI: """ Initialize OpenAI client - + Returns: openai.OpenAI: OpenAI client - + """ - url=f'http://{cls.infer_eng_url}:{cls.infer_eng_port}/v1' - api_key='sk-no-key-required' + url = f"http://{cls.infer_eng_url}:{cls.infer_eng_port}/v1" + api_key = "sk-no-key-required" return openai.OpenAI(base_url=url, api_key=api_key) @classmethod - def tokenizer_url(cls)->str: + def tokenizer_url(cls) -> str: """ Get the URL for the tokenization engine @@ -52,7 +52,7 @@ def tokenizer_url(cls)->str: return f"http://{cls.infer_eng_url}:{cls.infer_eng_port}/tokenize" @classmethod - def instruct_infer_url(cls)->str: + def instruct_infer_url(cls) -> str: """ Get the URL for the inference engine diff --git a/backend/src/repository/rag/chat.py b/backend/src/repository/rag/chat.py index af504e7..115fb49 100644 --- a/backend/src/repository/rag/chat.py +++ b/backend/src/repository/rag/chat.py @@ -21,10 +21,11 @@ from src.config.settings.const import UPLOAD_FILE_PATH, RAG_NUM from src.repository.rag.base import BaseRAGRepository from src.repository.inference_eng import InferenceHelper - +from src.utilities.httpkit.httpx_kit import httpx_kit from typing import Any from collections.abc import AsyncGenerator + class RAGChatModelRepository(BaseRAGRepository): async def load_model(self, session_id: int, model_name: str) -> bool: """ @@ -41,12 +42,12 @@ def search_context(self, query, n_results=RAG_NUM): """ Search the context in the vector database """ - #TODO: Implement the search context function + # TODO: Implement the search context function pass async def get_response(self, session_id: int, input_msg: str, chat_repo) -> str: # context = self.search_context(input_msg) - #TODO: Implement the inference function + # TODO: Implement the inference function pass async def load_csv_file(self, file_name: str, model_name: str) -> bool: @@ -65,28 +66,28 @@ async def load_csv_file(self, file_name: str, model_name: str) -> bool: # TODO: https://github.com/SkywardAI/chat-backend/issues/171 # embedding_list = ai_model.encode_string(data) - + # vector_db.insert_list(embedding_list, data) return True - def load_data_set(self, param: TrainFileIn)-> bool: + def load_data_set(self, param: TrainFileIn) -> bool: loguru.logger.info(f"load_data_set param {param}") if param.directLoad: self.load_data_set_directly(param=param) elif param.embedField is None or param.resField is None: - self.load_data_set_all_field(dataset_name=param.dataSet) + self.load_data_set_all_field(dataset_name=param.dataSet) else: self.load_data_set_by_field(param=param) return True - def load_data_set_directly(self, param: TrainFileIn)->bool: + def load_data_set_directly(self, param: TrainFileIn) -> bool: r""" - If the data set is already in the form of embeddings, + If the data set is already in the form of embeddings, this function can be used to load the data set directly into the vector database. - + @param param: the instance of TrainFileIn - + @return: boolean """ # reader_dataset=load_dataset(param.dataSet) @@ -106,7 +107,7 @@ def load_data_set_directly(self, param: TrainFileIn)->bool: # embed_field_list.append(embedField_val) # count += 1 # if count % LOAD_BATCH_SIZE == 0: - # vector_db.insert_list(embed_field_list, res_field_list, collection_name,start_idx = count) + # vector_db.insert_list(embed_field_list, res_field_list, collection_name,start_idx = count) # embed_field_list = [] # res_field_list = [] # loguru.logger.info(f"load_data_set_all_field count:{count}") @@ -116,9 +117,7 @@ def load_data_set_directly(self, param: TrainFileIn)->bool: # return True pass - - - def load_data_set_all_field(self, dataset_name: str)-> bool: + def load_data_set_all_field(self, dataset_name: str) -> bool: """ Load the data set into the vector database """ @@ -147,7 +146,7 @@ def load_data_set_all_field(self, dataset_name: str)-> bool: # loguru.logger.info("Dataset loaded successfully") return True - def load_data_set_by_field(self, param: TrainFileIn)->bool: + def load_data_set_by_field(self, param: TrainFileIn) -> bool: """ Load the data set into the vector database """ @@ -171,7 +170,7 @@ def load_data_set_by_field(self, param: TrainFileIn)->bool: # count += 1 # if count % LOAD_BATCH_SIZE == 0: # embedding_list = ai_model.encode_string(embed_field_list) - # vector_db.insert_list(embedding_list, res_field_list, collection_name,start_idx = count) + # vector_db.insert_list(embedding_list, res_field_list, collection_name,start_idx = count) # embed_field_list = [] # res_field_list = [] # loguru.logger.info(f"load_data_set_all_field count:{count}") @@ -185,14 +184,13 @@ async def evaluate_response(self, request_msg: str, response_msg: str) -> float: # evaluate_conbine=[request_msg, response_msg] # score = ai_model.cross_encoder.predict(evaluate_conbine) # return score - #TODO + # TODO pass def trim_collection_name(self, name: str) -> str: - return re.sub(r'\W+', '', name) - + return re.sub(r"\W+", "", name) - def format_prompt(self, prmpt: str, current_context:str = InferenceHelper.instruction) -> str: + def format_prompt(self, prmpt: str, current_context: str = InferenceHelper.instruction) -> str: """ Format the input questions, can be used for saving the conversation history @@ -204,18 +202,16 @@ def format_prompt(self, prmpt: str, current_context:str = InferenceHelper.instru str: formatted prompt """ return f"{current_context}\n" + f"\n### Human: {prmpt}\n### Assistant:" - async def inference( - self, - session_id: int, - input_msg: str, - chat_repo, + self, + session_id: int, + input_msg: str, temperature: float = 0.2, top_k: int = 40, top_p: float = 0.9, n_predict: int = 128, - )-> AsyncGenerator[Any, None]: + ) -> AsyncGenerator[Any, None]: """ **Inference using seperate service:(llamacpp)** @@ -231,54 +227,91 @@ async def inference( **Returns:** AsyncGenerator[Any, None]: response message """ - # if session_id not in conversations: - # conversations[session_id] = ConversationWithSession(session_id, chat_repo) - # await conversations[session_id].load() - # con = conversations[session_id] - # con.conversation.add_message({"role": "user", "content": input_msg}) - # context = self.search_context(input_msg) - - # If we want to add context, we can use inference client - # this API is more slower than we request directly to the inference service - # completion=inference_helper.client.chat.completions.create( - # model="", - # messages=[ - # {"role": "system", "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests."}, - # {"role": "user", "content": "Write a limerick about python exceptions"} - # ], - # ) - - #TODO: - # 1.Implement the further context function - - + data = { "prompt": self.format_prompt(input_msg), "temperature": temperature, "top_k": top_k, "top_p": top_p, - "n_keep": 0, # If the context window is full, we keep 0 tokens - "n_predict": n_predict, + "n_keep": 0, # If the context window is full, we keep 0 tokens + "n_predict": 128 if n_predict == 0 else n_predict, + "cache_prompt": True, + "slot_id": -1, # for cached prompt + "stop": ["\n### Human:"], + "stream": True, + } + + try: + async with httpx_kit.async_client.stream( + "POST", + InferenceHelper.instruct_infer_url(), + headers={"Content-Type": "application/json"}, + json=data, + # We disable all timeout and trying to fix streaming randomly cutting off + timeout=httpx.Timeout(timeout=None), + ) as response: + response.raise_for_status() + async for chunk in response.aiter_text(): + yield chunk + except httpx.ReadError as e: + loguru.logger.error(f"An error occurred while requesting {e.request.url!r}.") + except httpx.HTTPStatusError as e: + loguru.logger.error(f"Error response {e.response.status_code} while requesting {e.request.url!r}.") + + async def inference_with_rag( + self, + session_id: int, + input_msg: str, + temperature: float = 0.2, + top_k: int = 40, + top_p: float = 0.9, + n_predict: int = 128, + ) -> AsyncGenerator[Any, None]: + """ + Inference using RAG model + + Returns: + AsyncGenerator[Any, None]: response message + """ + + def get_context_by_question(input_msg: str): + """ + Get the context from v-db by the question + """ + + # tokenized_input + + # search the context in the vector database + # combine the context with the input message + context = "" + return context or InferenceHelper.instruction + + data_with_context = { + "prompt": self.format_prompt(input_msg, get_context_by_question(input_msg)), + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "n_keep": 0, # If the context window is full, we keep 0 tokens + "n_predict": 128 if n_predict == 0 else n_predict, "cache_prompt": False, - "slot_id": -1, # for cached prompt + "slot_id": -1, # for cached prompt "stop": ["\n### Human:"], "stream": True, } - async with httpx.AsyncClient() as client: - try: - async with client.stream( - "POST", - InferenceHelper.instruct_infer_url(), - headers={'Content-Type': 'application/json'}, - json=data, - # We disable all timeout and trying to fix streaming randomly cutting off - timeout=httpx.Timeout(timeout=None) - ) as response: - response.raise_for_status() - async for chunk in response.aiter_text(): - yield chunk - except httpx.ReadError as e: - loguru.logger.error(f"An error occurred while requesting {e.request.url!r}.") - except httpx.HTTPStatusError as e: - loguru.logger.error(f"Error response {e.response.status_code} while requesting {e.request.url!r}.") + try: + async with httpx_kit.async_client.stream( + "POST", + InferenceHelper.instruct_infer_url(), + headers={"Content-Type": "application/json"}, + json=data_with_context, + # We disable all timeout and trying to fix streaming randomly cutting off + timeout=httpx.Timeout(timeout=None), + ) as response: + response.raise_for_status() + async for chunk in response.aiter_text(): + yield chunk + except httpx.ReadError as e: + loguru.logger.error(f"An error occurred while requesting {e.request.url!r}.") + except httpx.HTTPStatusError as e: + loguru.logger.error(f"Error response {e.response.status_code} while requesting {e.request.url!r}.") diff --git a/backend/src/repository/vector_database.py b/backend/src/repository/vector_database.py index 613b3a6..013020f 100644 --- a/backend/src/repository/vector_database.py +++ b/backend/src/repository/vector_database.py @@ -16,7 +16,7 @@ def __init__(self): break except Exception as e: err = e - #loguru.logger.info(f"Exception --- {e}") + # loguru.logger.info(f"Exception --- {e}") # print(f"Failed to connect to Milvus: {e}") time.sleep(10) else: @@ -24,10 +24,10 @@ def __init__(self): async def load_dataset(self, *args, **kwargs): return - + async def load_csv(self, *args, **kwargs): - return - + return + async def save(self, *args, **kwargs): r""" Save data into vector database. @@ -36,7 +36,6 @@ async def save(self, *args, **kwargs): """ return - def create_collection(self, collection_name=DEFAULT_COLLECTION, dimension=DEFAULT_DIM, recreate=True): if recreate and self.client.has_collection(collection_name): @@ -49,12 +48,13 @@ def create_collection(self, collection_name=DEFAULT_COLLECTION, dimension=DEFAUL def insert_list(self, embedding, data, collection_name=DEFAULT_COLLECTION, start_idx=0): try: for i, item in enumerate(embedding): - self.client.insert(collection_name=collection_name, data={"id": i+start_idx, "vector": item, "doc": data[i]}) + self.client.insert( + collection_name=collection_name, data={"id": i + start_idx, "vector": item, "doc": data[i]} + ) except Exception as e: loguru.logger.info(f"Vector Databse --- Error: {e}") def search(self, data, n_results, collection_name=DEFAULT_COLLECTION): - search_params = {"metric_type": "COSINE", "params": {}} data_list = data.tolist() res = self.client.search( diff --git a/backend/src/securities/authorizations/jwt.py b/backend/src/securities/authorizations/jwt.py index ed3cd2e..f5be938 100644 --- a/backend/src/securities/authorizations/jwt.py +++ b/backend/src/securities/authorizations/jwt.py @@ -24,6 +24,7 @@ from src.models.db.account import Account from src.models.schemas.jwt import JWTAccount, JWToken + class JWTGenerator: def __init__(self): pass @@ -47,7 +48,6 @@ def _generate_jwt_token( return pyjwt.encode(to_encode, key=settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) def generate_access_token(self, account: Account) -> str: - return self._generate_jwt_token( jwt_data=JWTAccount(username=account.username, email=account.email).model_dump(), # type: ignore expires_delta=datetime.timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRATION_TIME), @@ -73,6 +73,7 @@ def get_jwt_generator() -> JWTGenerator: jwt_generator: JWTGenerator = get_jwt_generator() + async def jwt_required(request: Request): auth_scheme = HTTPBearer() credentials: HTTPAuthorizationCredentials = await auth_scheme(request) diff --git a/backend/src/utilities/exceptions/http/exc_400.py b/backend/src/utilities/exceptions/http/exc_400.py index 58bdb44..c9d1453 100644 --- a/backend/src/utilities/exceptions/http/exc_400.py +++ b/backend/src/utilities/exceptions/http/exc_400.py @@ -28,6 +28,7 @@ async def http_exc_400_credentials_bad_signin_request() -> Exception: detail=http_400_sigin_credentials_details(), ) + async def http_exc_400_failed_validate_request() -> Exception: return fastapi.HTTPException( status_code=fastapi.status.HTTP_400_BAD_REQUEST, @@ -35,6 +36,7 @@ async def http_exc_400_failed_validate_request() -> Exception: headers={"WWW-Authenticate": "Bearer"}, ) + async def http_400_exc_bad_username_request(username: str) -> Exception: return fastapi.HTTPException( status_code=fastapi.status.HTTP_400_BAD_REQUEST, @@ -53,4 +55,4 @@ async def http_400_exc_bad_file_name_request(file_name: str) -> Exception: return fastapi.HTTPException( status_code=fastapi.status.HTTP_400_BAD_REQUEST, detail=http_400_file_name_details(), - ) \ No newline at end of file + ) diff --git a/backend/src/utilities/exceptions/http/exc_404.py b/backend/src/utilities/exceptions/http/exc_404.py index 3c9d5f7..0244482 100644 --- a/backend/src/utilities/exceptions/http/exc_404.py +++ b/backend/src/utilities/exceptions/http/exc_404.py @@ -32,9 +32,9 @@ async def http_404_exc_username_not_found_request(username: str) -> Exception: detail=http_404_username_details(username=username), ) + async def http_404_exc_uuid_not_found_request(uuid: str) -> Exception: return fastapi.HTTPException( status_code=fastapi.status.HTTP_404_NOT_FOUND, detail=http_404_uuid_details(uuid=uuid), ) - diff --git a/backend/src/utilities/httpkit/httpx_kit.py b/backend/src/utilities/httpkit/httpx_kit.py new file mode 100644 index 0000000..1e5e35e --- /dev/null +++ b/backend/src/utilities/httpkit/httpx_kit.py @@ -0,0 +1,89 @@ +# coding=utf-8 + +# Copyright [2024] [SkywardAI] +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import httpx + + +class HttpxKit: + """ + A class to initialize an async and sync client using httpx + + We only create one client every time is efficient and easy to manage + + Note: We don't close client because we want to keep the connection alive, I don't know if it will cause any problem in more bigger scale + """ + + def __init__(self): + self.async_client = self.init_async_client() + self.sync_client = self.init_sync_client() + + def init_async_client(self) -> httpx.AsyncClient: + """ + Create async client by using Singleletton pattern + + Replace the code below: + + async with httpx.AsyncClient() as client: + try: + async with client.stream( + "POST", + InferenceHelper.instruct_infer_url(), + headers={"Content-Type": "application/json"}, + json=data_with_context, + # We disable all timeout and trying to fix streaming randomly cutting off + timeout=httpx.Timeout(timeout=None), + ) as response: + response.raise_for_status() + async for chunk in response.aiter_text(): + yield chunk + except httpx.ReadError as e: + loguru.logger.error(f"An error occurred while requesting {e.request.url!r}.") + except httpx.HTTPStatusError as e: + loguru.logger.error(f"Error response {e.response.status_code} while requesting {e.request.url!r}.") + + Returns: + httpx.AsyncClient: An async client + """ + return httpx.AsyncClient() + + def init_sync_client(self): + """ + Create sync client client by using Singleletton pattern + + Returns: + httpx.Client: A sync client + """ + return httpx.Client() + + async def teardown_async_client(self) -> bool: + """ + Close the async client + """ + await self.async_client.aclose() + return self.async_client.is_closed + + def teardown_sync_client(self) -> bool: + """ + Close the sync client + + Returns: + * + """ + self.sync_client.close() + return self.sync_client.is_closed + + +httpx_kit = HttpxKit() diff --git a/backend/src/utilities/httpkit/method_kit.py b/backend/src/utilities/httpkit/method_kit.py deleted file mode 100644 index ffc0985..0000000 --- a/backend/src/utilities/httpkit/method_kit.py +++ /dev/null @@ -1,59 +0,0 @@ -# coding=utf-8 - -# Copyright [2024] [SkywardAI] -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import httpx - -class MethodKit: - """ - A class that contains http methods for the HTTPKit class. - """ - - def __init__(self): - raise EnvironmentError( - "This MethodKit is not meant to be instantiated. Use the methods directly." - ) - - @classmethod - def http_post(cls, *args, **kwargs)-> httpx.Response: - """ - Post request with httpx client. - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - * url: URL to post to. - * json: JSON data to post. - * headers: Headers to send. - * timeout: Timeout for the request. - - Returns: - httpx.Response: Response from the server. - """ - url = kwargs.get("url") - jason_content = kwargs.get("json") - headers = kwargs.get("headers") - timeout = kwargs.get("timeout") - - with httpx.Client() as client: - res = client.post( - url, - headers=headers, - json=jason_content, - timeout=timeout - ) - - res.raise_for_status() - return res diff --git a/backend/src/utilities/messages/exceptions/http/exc_details.py b/backend/src/utilities/messages/exceptions/http/exc_details.py index ccba1d7..0d38bc8 100644 --- a/backend/src/utilities/messages/exceptions/http/exc_details.py +++ b/backend/src/utilities/messages/exceptions/http/exc_details.py @@ -21,9 +21,11 @@ def http_401_unauthorized_details() -> str: def http_403_forbidden_details() -> str: return "Refused access to the requested resource!" + def http_404_uuid_details(uuid: str) -> str: return f"Either the session with uuid `{uuid}` doesn't exist, has been deleted, or you are not authorized!" + def http_404_id_details(id: int) -> str: return f"Either the account with id `{id}` doesn't exist, has been deleted, or you are not authorized!" @@ -36,6 +38,7 @@ def http_404_email_details(email: str) -> str: return f"Either the account with email `{email}` doesn't exist, has been deleted, or you are not authorized!" - -def http_400_file_name_details()->str: - return "The file_name already exists. Please refrain from uploading it again, or try uploading with a different name!" +def http_400_file_name_details() -> str: + return ( + "The file_name already exists. Please refrain from uploading it again, or try uploading with a different name!" + ) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 57f24ac..d779cf8 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -28,4 +28,4 @@ # base_url="http://testserver", # headers={"Content-Type": "application/json"}, # ) as client: -# yield client \ No newline at end of file +# yield client diff --git a/backend/tests/unit_tests/test_api_authentication.py b/backend/tests/unit_tests/test_api_authentication.py index 70b5f5f..50f1d63 100644 --- a/backend/tests/unit_tests/test_api_authentication.py +++ b/backend/tests/unit_tests/test_api_authentication.py @@ -22,38 +22,43 @@ from src.repository.crud.account import AccountCRUDRepository from src.api.routes import authentication + class OverAccountInCreate: - username="aisuko" - email="aisuko@rmit.edu.au" + username = "aisuko" + email = "aisuko@rmit.edu.au" password: "123" + class OverAccountCRUDRepository(BaseCRUDRepository): def create_account(self, account_create: OverAccountInCreate): return Account( id=1, - username=account_create.username, + username=account_create.username, email=account_create.email, is_verified=True, is_active=True, is_logged_in=True, created_at="2024-01-01 00:00:00", - updated_at="2024-01-01 00:00:00" - ) + updated_at="2024-01-01 00:00:00", + ) - def is_username_taken(self, username: str)-> bool: + def is_username_taken(self, username: str) -> bool: return False - def is_email_taken(self, email: str)-> bool: + def is_email_taken(self, email: str) -> bool: return False - def generate_access_token(self, username: str)-> str: + def generate_access_token(self, username: str) -> str: return "access_token" + def over_get_repository(repo_type=AccountCRUDRepository): return OverAccountCRUDRepository -@unittest.skip("Skip the test, because the we cannot find dependencies of signup, see https://github.com/tiangolo/fastapi/discussions/8127#discussioncomment-5147586") +@unittest.skip( + "Skip the test, because the we cannot find dependencies of signup, see https://github.com/tiangolo/fastapi/discussions/8127#discussioncomment-5147586" +) class TestAPIAuthentication(unittest.TestCase): """ Test the FastAPI application attributes @@ -64,7 +69,7 @@ def setUpClass(cls): cls.app = fastapi.FastAPI() cls.app.include_router(authentication.router) cls.client = TestClient(cls.app) - + @classmethod def tearDownClass(cls): pass @@ -76,5 +81,7 @@ def test_api_authentication(self): """ self.app.dependency_overrides[AccountCRUDRepository] = over_get_repository - response = self.client.post("/auth/signup", json={"username": "aisuko", "email": "aisuko@rmit.edu.au", "password": "aisuko"}) - assert response.status_code == 400 \ No newline at end of file + response = self.client.post( + "/auth/signup", json={"username": "aisuko", "email": "aisuko@rmit.edu.au", "password": "aisuko"} + ) + assert response.status_code == 400 diff --git a/backend/tests/unit_tests/test_api_version.py b/backend/tests/unit_tests/test_api_version.py index ae95bd5..4fcc7fe 100644 --- a/backend/tests/unit_tests/test_api_version.py +++ b/backend/tests/unit_tests/test_api_version.py @@ -23,13 +23,13 @@ class TestAPIVersion(unittest.TestCase): """ Test the FastAPI application attributes """ - + @classmethod def setUpClass(cls): cls.app = fastapi.FastAPI() cls.app.include_router(version.router) cls.client = TestClient(cls.app) - + @classmethod def tearDownClass(cls): pass @@ -45,10 +45,3 @@ def test_api_version(self): "milvus": "v2.3.12", "kirin": "v0.1.15", } - - - - - - - diff --git a/backend/tests/unit_tests/test_httpx_kit.py b/backend/tests/unit_tests/test_httpx_kit.py new file mode 100644 index 0000000..5f1afee --- /dev/null +++ b/backend/tests/unit_tests/test_httpx_kit.py @@ -0,0 +1,34 @@ +# coding=utf-8 + +# Copyright [2024] [SkywardAI] +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +import pytest +from src.utilities.httpkit.httpx_kit import httpx_kit + + +class TestHttpxKit(unittest.TestCase): + @pytest.mark.asyncio + async def test_teardown_async_client(self): + """ + Test teardown of async client + """ + self.assertTrue(await httpx_kit.teardown_async_client()) + + def test_teardown_sync_client(self): + """ + Test teardown of sync client + """ + self.assertTrue(httpx_kit.teardown_sync_client()) diff --git a/backend/tests/unit_tests/test_jwt.py b/backend/tests/unit_tests/test_jwt.py index 285a4d5..e94f058 100644 --- a/backend/tests/unit_tests/test_jwt.py +++ b/backend/tests/unit_tests/test_jwt.py @@ -16,35 +16,34 @@ import unittest import jwt as pyjwt + @unittest.skip("Skip this test, it is a evidence of a security vulnerability") class TestJWTReplacedSolution(unittest.TestCase): + JWT_SECRET_KEY = "YOUR-KEY" + ALGORITHM = "HS256" + content = {"some": "payload"} - JWT_SECRET_KEY="YOUR-KEY" - ALGORITHM="HS256" - content={"some": "payload"} - @classmethod def setUpClass(cls) -> None: return super().setUpClass() - + @classmethod def tearDownClass(cls) -> None: return super().tearDownClass() - def test_jwt_jose(self): pass # jose_str=jose_jwt.encode(self.content, key=self.JWT_SECRET_KEY, algorithm=self.ALGORITHM) - + # pyjwt_str=pyjwt.encode(self.content, key=self.JWT_SECRET_KEY, algorithm=self.ALGORITHM) # # jose and pyjwt should produce the same token # assert jose_str == pyjwt_str - + def test_jwt_decode(self): pass # jose_str=jose_jwt.encode(self.content, key=self.JWT_SECRET_KEY, algorithm=self.ALGORITHM) - + # pyjwt_str=pyjwt.encode(self.content, key=self.JWT_SECRET_KEY, algorithm=self.ALGORITHM) # # jose and pyjwt should produce the same token @@ -55,4 +54,4 @@ def test_jwt_decode(self): # pyjwt_decoded=pyjwt.decode(pyjwt_str, key=self.JWT_SECRET_KEY, algorithms=[self.ALGORITHM]) # # jose and pyjwt should produce the same decoded token - # assert jose_decoded == pyjwt_decoded \ No newline at end of file + # assert jose_decoded == pyjwt_decoded diff --git a/backend/tests/unit_tests/test_method_kit.py b/backend/tests/unit_tests/test_method_kit.py deleted file mode 100644 index 4f70309..0000000 --- a/backend/tests/unit_tests/test_method_kit.py +++ /dev/null @@ -1,53 +0,0 @@ -# coding=utf-8 - -# Copyright [2024] [SkywardAI] -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from src.utilities.httpkit.method_kit import MethodKit - -@unittest.skip("Skip due to the fact that the server is not running") -class TestHTTPKits(unittest.TestCase): - url = "http://llamacpp:8080/tokenize" - jason_content = {"content": "Hello, World!"} - headers={'Content-Type': 'application/json'} - timeout = 10 - - - @classmethod - def setUpClass(cls) -> None: - return super().setUpClass() - - @classmethod - def tearDownClass(cls) -> None: - return super().tearDownClass() - - - def test_http_post(self)-> None: - """ - Test the http_post method. - - In this test case we calculate the the length of tokens of the content "Hello, World!". - - """ - - res = MethodKit.http_post( - url=self.url, - jason_content=self.jason_content, - headers=self.headers, - timeout=self.timeout - ) - - assert res.status_code == 200 - # length of token is a integer - assert isinstance(res.json().get('tokens'), int)