diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 75be6fc..9876d6f 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,9 +1,9 @@ [project] -name = "DAPSQL QWIK-Stack Application" -version = "0.0.1" -description = "A comprehensive template to start a real-world quality project." +name = "ChatBot Application" +version = "0.1" +description = "A chat bot and model training backend project" authors = [ - {name = "Aeternalis-Ingenium", email="aeternalisingenium@proton.me"}, + {name = "Rob-Zhang", email="chuan.z@hotmail.com"}, ] classifiers = [ "Topic :: Software Development" @@ -18,9 +18,9 @@ requires-python = ">=3.11" dependencies = {file = ["requirements.txt"]} [project.urls] -homepage = "https://github.com/Aeternalis-Ingenium/DAPSQL-FAQ-Stack-Template" -documentation = "https://github.com/Aeternalis-Ingenium/DAPSQL-FAQ-Stack-Template" -repository = "https://github.com/Aeternalis-Ingenium/DAPSQL-FAQ-Stack-Template" +homepage = "https://github.com/SkywardAI/chat-backend" +documentation = "https://github.com/SkywardAI/chat-backend" +repository = "https://github.com/SkywardAI/chat-backend" [tool.black] color=true diff --git a/backend/requirements.txt b/backend/requirements.txt index 527d73c..e5aa12e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -30,5 +30,6 @@ python-multipart==0.0.9 python-slugify==8.0.4 sentence-transformers==2.3.1 SQLAlchemy==2.0.0b3 +transformers==4.38.2 trio==0.24.0 uvicorn==0.28.0 diff --git a/backend/sample_files/sample.csv b/backend/sample_files/sample.csv new file mode 100644 index 0000000..8b01793 --- /dev/null +++ b/backend/sample_files/sample.csv @@ -0,0 +1,10 @@ +"Melbourne is the capital city of the Australian state of Victoria." +"It is known for its diverse and vibrant cultural scene." +"The city is famous for its coffee culture, with numerous cafes scattered throughout." +"Melbourne is home to iconic landmarks like the Royal Exhibition Building and Flinders Street Station." +"The Yarra River runs through the heart of the city, providing a picturesque setting." +"The Melbourne Cricket Ground (MCG) is a historic sports venue and a key part of the city's identity." +"The city hosts various events, including the Australian Open, Melbourne Fashion Week, and the Melbourne International Comedy Festival." +"Melbourne's street art scene is renowned, with vibrant murals adorning many laneways." +"The Queen Victoria Market is a popular spot for fresh produce, local crafts, and diverse international cuisines." +"Melbourne is often considered one of the most livable cities globally, offering a high quality of life." diff --git a/backend/src/api/routes/ai_model.py b/backend/src/api/routes/ai_model.py index 1d8ceb1..10152a7 100644 --- a/backend/src/api/routes/ai_model.py +++ b/backend/src/api/routes/ai_model.py @@ -1,8 +1,10 @@ import fastapi -from src.api.dependencies.repository import get_repository +from src.api.dependencies.repository import get_rag_repository, get_repository +from src.config.settings.const import UPLOAD_FILE_PATH from src.models.schemas.ai_model import AiModel, AiModelChooseResponse, AiModelInResponse from src.repository.crud.ai_model import AiModelCRUDRepository +from src.repository.rag.chat import RAGChatModelRepository router = fastapi.APIRouter(prefix="/models", tags=["model"]) @@ -39,9 +41,17 @@ async def get_aimodels( async def choose_aimodels( id: int, aimodel_repo: AiModelCRUDRepository = fastapi.Depends(get_repository(repo_type=AiModelCRUDRepository)), + rag_chat_repo: RAGChatModelRepository = fastapi.Depends(get_rag_repository(repo_type=RAGChatModelRepository)), ) -> AiModelChooseResponse: ai_model = await aimodel_repo.read_aimodel_by_id(id=id) - return AiModelChooseResponse( - name=ai_model.name, - msg="Model has been selected", - ) + result = await rag_chat_repo.load_model(session_id=id, model_name=ai_model.name) + if result: + return AiModelChooseResponse( + name=ai_model.name, + msg="Model has been selected", + ) + else: + return AiModelChooseResponse( + name=ai_model.name, + msg="Sorry Model init failed! Please try again!", + ) diff --git a/backend/src/api/routes/chat.py b/backend/src/api/routes/chat.py index 32e112c..fc0ae0b 100644 --- a/backend/src/api/routes/chat.py +++ b/backend/src/api/routes/chat.py @@ -26,12 +26,12 @@ async def chat( # chat_in_msg.accountID = 0 if not hasattr(chat_in_msg, "sessionId") or not chat_in_msg.sessionId: new_session = await session_repo.create_session( - account_id=chat_in_msg.accountID, name=chat_in_msg.message[:40] + account_id=chat_in_msg.accountID, name=chat_in_msg.message[:20] ) session_id = new_session.id else: # TODO need verify if sesson exist - # create_session = await session_repo.read_create_sessions_by_id(id=chat_in_msg.sessionId, account_id=chat_in_msg.accountID, name=chat_in_msg.message[:40]) + # create_session = await session_repo.read_create_sessions_by_id(id=chat_in_msg.sessionId, account_id=chat_in_msg.accountID, name=chat_in_msg.message[:20]) session_id = chat_in_msg.sessionId await chat_repo.create_chat_history(session_id=session_id, is_bot_msg=False, message=chat_in_msg.message) response_msg = await rag_chat_repo.get_response(session_id=session_id, input_msg=chat_in_msg.message) @@ -60,6 +60,7 @@ async def get_session( res_session = Session( id=session.id, name=session.name, + created_at=session.created_at, ) sessions_list.append(res_session) except Exception as e: @@ -84,6 +85,7 @@ async def get_all_sessions( res_session = Session( id=session.id, name=session.name, + created_at=session.created_at, ) sessions_list.append(res_session) @@ -105,7 +107,7 @@ async def get_chathistory( for chat in chats: res_session = ChatHistory( id=chat.id, - type="out" if chat.is_bot_msg else "out", + type="out" if chat.is_bot_msg else "in", message=chat.message, ) chats_list.append(res_session) diff --git a/backend/src/api/routes/file.py b/backend/src/api/routes/file.py index fefa8fd..a47a7b3 100644 --- a/backend/src/api/routes/file.py +++ b/backend/src/api/routes/file.py @@ -2,19 +2,21 @@ import random import fastapi -from src.config.settings.const import UPLOAD_FILE_PATH +from fastapi import BackgroundTasks + from src.api.dependencies.repository import get_repository +from src.config.settings.const import UPLOAD_FILE_PATH from src.models.schemas.file import FileInResponse, FileStatusInResponse from src.repository.crud.file import UploadedFileCRUDRepository -from fastapi import BackgroundTasks router = fastapi.APIRouter(prefix="/file", tags=["file"]) -async def save_upload_file(file: fastapi.UploadFile, save_file: str): + +async def save_upload_file(contents: bytes, save_file: str): with open(save_file, "wb") as f: - contents = await file.read() f.write(contents) + @router.post( "", name="file:upload file", @@ -33,11 +35,12 @@ async def upload_and_return_id( if not os.path.exists(save_path): os.mkdir(save_path) save_file = os.path.join(save_path, file.filename) - - background_tasks.add_task(save_upload_file, file, save_file) + contents = await file.read() + background_tasks.add_task(save_upload_file, contents, save_file) return FileInResponse(fileID=new_file.id) + @router.get( path="/{id}", name="file:check upload status", diff --git a/backend/src/api/routes/train.py b/backend/src/api/routes/train.py index 865be8c..b47c5d4 100644 --- a/backend/src/api/routes/train.py +++ b/backend/src/api/routes/train.py @@ -2,8 +2,11 @@ import fastapi -from src.api.dependencies.repository import get_repository +from src.api.dependencies.repository import get_rag_repository, get_repository from src.models.schemas.train import TrainFileIn, TrainFileInResponse, TrainStatusInResponse +from src.repository.crud.ai_model import AiModelCRUDRepository +from src.repository.crud.file import UploadedFileCRUDRepository +from src.repository.rag.chat import RAGChatModelRepository router = fastapi.APIRouter(prefix="/train", tags=["train"]) @@ -16,8 +19,16 @@ ) async def train( train_in_msg: TrainFileIn, + aimodel_repo: AiModelCRUDRepository = fastapi.Depends(get_repository(repo_type=AiModelCRUDRepository)), + file_repo: UploadedFileCRUDRepository = fastapi.Depends(get_repository(repo_type=UploadedFileCRUDRepository)), + rag_chat_repo: RAGChatModelRepository = fastapi.Depends(get_rag_repository(repo_type=RAGChatModelRepository)), ) -> TrainFileInResponse: - # TODO start process the file with model + + ai_model = await aimodel_repo.read_aimodel_by_id(id=train_in_msg.modelID) + file_csv = await file_repo.read_uploadedfiles_by_id(id=train_in_msg.fileID) + + await rag_chat_repo.load_csv_file(file_name=file_csv.name, model_name=ai_model.name) + return TrainFileInResponse( trainID=train_in_msg.fileID + train_in_msg.modelID, ) diff --git a/backend/src/config/settings/const.py b/backend/src/config/settings/const.py index f92c09a..aca18fa 100644 --- a/backend/src/config/settings/const.py +++ b/backend/src/config/settings/const.py @@ -2,3 +2,4 @@ UPLOAD_FILE_PATH = "./uploaded_files/" MAX_SQL_LENGTH = 200 DEFAULT_MODEL = "all-MiniLM-L6-v2" +CHAT_COMTEXT = "Melbourne is the capital city of the Australian state of Victoria.It is known for its diverse and vibrant cultural scene.The city is famous for its coffee culture, with numerous cafes scattered throughout.Melbourne is home to iconic landmarks like the Royal Exhibition Building and Flinders Street Station.The Yarra River runs through the heart of the city, providing a picturesque setting.The Melbourne Cricket Ground (MCG) is a historic sports venue and a key part of the city's identity.The city hosts various events, including the Australian Open, Melbourne Fashion Week, and the Melbourne International Comedy Festival.Melbourne's street art scene is renowned, with vibrant murals adorning many laneways.The Queen Victoria Market is a popular spot for fresh produce, local crafts, and diverse international cuisines.Melbourne is often considered one of the most livable cities globally, offering a high quality of life." diff --git a/backend/src/models/db/chat.py b/backend/src/models/db/chat.py index 01c8db7..dd965ae 100644 --- a/backend/src/models/db/chat.py +++ b/backend/src/models/db/chat.py @@ -13,6 +13,9 @@ class Session(Base): # type: ignore id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(primary_key=True, autoincrement="auto") account_id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(nullable=True) name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=64), nullable=True) + created_at: SQLAlchemyMapped[datetime.datetime] = sqlalchemy_mapped_column( + sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy_functions.now() + ) __mapper_args__ = {"eager_defaults": True} diff --git a/backend/src/models/schemas/chat.py b/backend/src/models/schemas/chat.py index 967f015..152f580 100644 --- a/backend/src/models/schemas/chat.py +++ b/backend/src/models/schemas/chat.py @@ -1,3 +1,4 @@ +import datetime from typing import Optional from pydantic import Field @@ -19,6 +20,7 @@ class ChatInResponse(BaseSchemaModel): class Session(BaseSchemaModel): id: int name: str | None + created_at: datetime.datetime class ChatHistory(BaseSchemaModel): diff --git a/backend/src/repository/crud/chat.py b/backend/src/repository/crud/chat.py index 22efa38..1bcd5af 100644 --- a/backend/src/repository/crud/chat.py +++ b/backend/src/repository/crud/chat.py @@ -21,12 +21,12 @@ async def create_session(self, account_id: Optional[int], name: str) -> Session: return new_session async def read_sessions(self) -> typing.Sequence[Session]: - stmt = sqlalchemy.select(Session) + stmt = sqlalchemy.select(Session).order_by(Session.created_at.desc()) query = await self.async_session.execute(statement=stmt) return query.scalars().all() async def read_sessions_by_id(self, id: int) -> Session: - stmt = sqlalchemy.select(Session).where(Session.id == id) + stmt = sqlalchemy.select(Session).where(Session.id == id).order_by(Session.created_at.desc()) query = await self.async_session.execute(statement=stmt) if not query: diff --git a/backend/src/repository/rag/chat.py b/backend/src/repository/rag/chat.py index 0655caf..687c1b0 100644 --- a/backend/src/repository/rag/chat.py +++ b/backend/src/repository/rag/chat.py @@ -1,11 +1,20 @@ +import csv + +import torch from sentence_transformers import SentenceTransformer +from sentence_transformers.util import cos_sim +from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline -from src.config.settings.const import DEFAULT_MODEL, MAX_SQL_LENGTH +from src.config.settings.const import CHAT_COMTEXT, DEFAULT_MODEL, MAX_SQL_LENGTH, UPLOAD_FILE_PATH from src.repository.rag.base import BaseRAGRepository class RAGChatModelRepository(BaseRAGRepository): - model = SentenceTransformer(DEFAULT_MODEL, "cuda") + # model = SentenceTransformer(DEFAULT_MODEL, "cuda") + # embeddings = model.encode([], convert_to_tensor=True).to("cuda") + model_name = "deepset/roberta-base-squad2" + + nlp = pipeline("question-answering", model=model_name, tokenizer=model_name) async def load_model(self, session_id: int, model_name: str) -> bool: # Init model with input model_name @@ -18,6 +27,38 @@ async def load_model(self, session_id: int, model_name: str) -> bool: return True async def get_response(self, session_id: int, input_msg: str) -> str: - # TODO use RAG framework to generate the response message - response_msg = "Oh, really? It's amazing !" - return response_msg + # TODO use RAG framework to generate the response message @Aisuko + # query_embedding = self.model.encode(input_msg, convert_to_tensor=True).to("cuda") + # print(self.embeddings) + # print(query_embedding) + # we use cosine-similarity and torch.topk to find the highest 5 scores + # cos_scores = cos_sim(query_embedding, self.embeddings)[0] + # top_results = torch.topk(cos_scores, k=1) + # response_msg = self.data[top_results[1].item()] + QA_input = {"question": input_msg, "context": CHAT_COMTEXT} + res = self.nlp(QA_input) + print(res) + # response_msg = "Oh, really? It's amazing !" + return res["answer"] + + async def load_csv_file(self, file_name: str, model_name: str) -> bool: + # read file named file_name and convert the content into a list of strings @Aisuko + print(file_name) + print(model_name) + self.data = [] + self.embeddings = [] + # Open the CSV file + with open(UPLOAD_FILE_PATH + file_name, "r") as file: + # Create a CSV reader + reader = csv.reader(file) + + # Iterate over each row in the CSV + for row in reader: + # Add the row to the list + self.data.extend(row) + print(self.data) + self.model = SentenceTransformer(model_name, "cuda") + row_embedding = self.model.encode(self.data, convert_to_tensor=True).to("cuda") + # TODO + self.embeddings.append(row_embedding) + return True diff --git a/docker-compose.yaml b/docker-compose.yaml index cdfef72..c4efb9d 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -35,6 +35,13 @@ services: backend_app: container_name: backend_app restart: always + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] build: dockerfile: Dockerfile context: ./backend/ @@ -78,6 +85,7 @@ services: - 8001:8000 depends_on: - db + # command: ["--gpus", "all"] etcd: container_name: milvus-etcd @@ -127,7 +135,7 @@ services: - "minio" frontend: - image: ghcr.io/skywardai/rebel:v0.0.4 + image: ghcr.io/skywardai/rebel:v0.0.5 container_name: frontend restart: always expose: