Skip to content

Commit

Permalink
Merge pull request #31 from SkywardAI/dev
Browse files Browse the repository at this point in the history
Develop for v0.1
  • Loading branch information
Micost authored Mar 23, 2024
2 parents 042960d + c261e49 commit 9dc1ad2
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 31 deletions.
14 changes: 7 additions & 7 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[project]
name = "DAPSQL QWIK-Stack Application"
version = "0.0.1"
description = "A comprehensive template to start a real-world quality project."
name = "ChatBot Application"
version = "0.1"
description = "A chat bot and model training backend project"
authors = [
{name = "Aeternalis-Ingenium", email="[email protected]"},
{name = "Rob-Zhang", email="[email protected]"},
]
classifiers = [
"Topic :: Software Development"
Expand All @@ -18,9 +18,9 @@ requires-python = ">=3.11"
dependencies = {file = ["requirements.txt"]}

[project.urls]
homepage = "https://github.com/Aeternalis-Ingenium/DAPSQL-FAQ-Stack-Template"
documentation = "https://github.com/Aeternalis-Ingenium/DAPSQL-FAQ-Stack-Template"
repository = "https://github.com/Aeternalis-Ingenium/DAPSQL-FAQ-Stack-Template"
homepage = "https://github.com/SkywardAI/chat-backend"
documentation = "https://github.com/SkywardAI/chat-backend"
repository = "https://github.com/SkywardAI/chat-backend"

[tool.black]
color=true
Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ python-multipart==0.0.9
python-slugify==8.0.4
sentence-transformers==2.3.1
SQLAlchemy==2.0.0b3
transformers==4.38.2
trio==0.24.0
uvicorn==0.28.0
10 changes: 10 additions & 0 deletions backend/sample_files/sample.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"Melbourne is the capital city of the Australian state of Victoria."
"It is known for its diverse and vibrant cultural scene."
"The city is famous for its coffee culture, with numerous cafes scattered throughout."
"Melbourne is home to iconic landmarks like the Royal Exhibition Building and Flinders Street Station."
"The Yarra River runs through the heart of the city, providing a picturesque setting."
"The Melbourne Cricket Ground (MCG) is a historic sports venue and a key part of the city's identity."
"The city hosts various events, including the Australian Open, Melbourne Fashion Week, and the Melbourne International Comedy Festival."
"Melbourne's street art scene is renowned, with vibrant murals adorning many laneways."
"The Queen Victoria Market is a popular spot for fresh produce, local crafts, and diverse international cuisines."
"Melbourne is often considered one of the most livable cities globally, offering a high quality of life."
20 changes: 15 additions & 5 deletions backend/src/api/routes/ai_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import fastapi

from src.api.dependencies.repository import get_repository
from src.api.dependencies.repository import get_rag_repository, get_repository
from src.config.settings.const import UPLOAD_FILE_PATH
from src.models.schemas.ai_model import AiModel, AiModelChooseResponse, AiModelInResponse
from src.repository.crud.ai_model import AiModelCRUDRepository
from src.repository.rag.chat import RAGChatModelRepository

router = fastapi.APIRouter(prefix="/models", tags=["model"])

Expand Down Expand Up @@ -39,9 +41,17 @@ async def get_aimodels(
async def choose_aimodels(
id: int,
aimodel_repo: AiModelCRUDRepository = fastapi.Depends(get_repository(repo_type=AiModelCRUDRepository)),
rag_chat_repo: RAGChatModelRepository = fastapi.Depends(get_rag_repository(repo_type=RAGChatModelRepository)),
) -> AiModelChooseResponse:
ai_model = await aimodel_repo.read_aimodel_by_id(id=id)
return AiModelChooseResponse(
name=ai_model.name,
msg="Model has been selected",
)
result = await rag_chat_repo.load_model(session_id=id, model_name=ai_model.name)
if result:
return AiModelChooseResponse(
name=ai_model.name,
msg="Model has been selected",
)
else:
return AiModelChooseResponse(
name=ai_model.name,
msg="Sorry Model init failed! Please try again!",
)
8 changes: 5 additions & 3 deletions backend/src/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ async def chat(
# chat_in_msg.accountID = 0
if not hasattr(chat_in_msg, "sessionId") or not chat_in_msg.sessionId:
new_session = await session_repo.create_session(
account_id=chat_in_msg.accountID, name=chat_in_msg.message[:40]
account_id=chat_in_msg.accountID, name=chat_in_msg.message[:20]
)
session_id = new_session.id
else:
# TODO need verify if sesson exist
# create_session = await session_repo.read_create_sessions_by_id(id=chat_in_msg.sessionId, account_id=chat_in_msg.accountID, name=chat_in_msg.message[:40])
# create_session = await session_repo.read_create_sessions_by_id(id=chat_in_msg.sessionId, account_id=chat_in_msg.accountID, name=chat_in_msg.message[:20])
session_id = chat_in_msg.sessionId
await chat_repo.create_chat_history(session_id=session_id, is_bot_msg=False, message=chat_in_msg.message)
response_msg = await rag_chat_repo.get_response(session_id=session_id, input_msg=chat_in_msg.message)
Expand Down Expand Up @@ -60,6 +60,7 @@ async def get_session(
res_session = Session(
id=session.id,
name=session.name,
created_at=session.created_at,
)
sessions_list.append(res_session)
except Exception as e:
Expand All @@ -84,6 +85,7 @@ async def get_all_sessions(
res_session = Session(
id=session.id,
name=session.name,
created_at=session.created_at,
)
sessions_list.append(res_session)

Expand All @@ -105,7 +107,7 @@ async def get_chathistory(
for chat in chats:
res_session = ChatHistory(
id=chat.id,
type="out" if chat.is_bot_msg else "out",
type="out" if chat.is_bot_msg else "in",
message=chat.message,
)
chats_list.append(res_session)
Expand Down
15 changes: 9 additions & 6 deletions backend/src/api/routes/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import random

import fastapi
from src.config.settings.const import UPLOAD_FILE_PATH
from fastapi import BackgroundTasks

from src.api.dependencies.repository import get_repository
from src.config.settings.const import UPLOAD_FILE_PATH
from src.models.schemas.file import FileInResponse, FileStatusInResponse
from src.repository.crud.file import UploadedFileCRUDRepository
from fastapi import BackgroundTasks

router = fastapi.APIRouter(prefix="/file", tags=["file"])

async def save_upload_file(file: fastapi.UploadFile, save_file: str):

async def save_upload_file(contents: bytes, save_file: str):
with open(save_file, "wb") as f:
contents = await file.read()
f.write(contents)


@router.post(
"",
name="file:upload file",
Expand All @@ -33,11 +35,12 @@ async def upload_and_return_id(
if not os.path.exists(save_path):
os.mkdir(save_path)
save_file = os.path.join(save_path, file.filename)

background_tasks.add_task(save_upload_file, file, save_file)
contents = await file.read()
background_tasks.add_task(save_upload_file, contents, save_file)

return FileInResponse(fileID=new_file.id)


@router.get(
path="/{id}",
name="file:check upload status",
Expand Down
15 changes: 13 additions & 2 deletions backend/src/api/routes/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import fastapi

from src.api.dependencies.repository import get_repository
from src.api.dependencies.repository import get_rag_repository, get_repository
from src.models.schemas.train import TrainFileIn, TrainFileInResponse, TrainStatusInResponse
from src.repository.crud.ai_model import AiModelCRUDRepository
from src.repository.crud.file import UploadedFileCRUDRepository
from src.repository.rag.chat import RAGChatModelRepository

router = fastapi.APIRouter(prefix="/train", tags=["train"])

Expand All @@ -16,8 +19,16 @@
)
async def train(
train_in_msg: TrainFileIn,
aimodel_repo: AiModelCRUDRepository = fastapi.Depends(get_repository(repo_type=AiModelCRUDRepository)),
file_repo: UploadedFileCRUDRepository = fastapi.Depends(get_repository(repo_type=UploadedFileCRUDRepository)),
rag_chat_repo: RAGChatModelRepository = fastapi.Depends(get_rag_repository(repo_type=RAGChatModelRepository)),
) -> TrainFileInResponse:
# TODO start process the file with model

ai_model = await aimodel_repo.read_aimodel_by_id(id=train_in_msg.modelID)
file_csv = await file_repo.read_uploadedfiles_by_id(id=train_in_msg.fileID)

await rag_chat_repo.load_csv_file(file_name=file_csv.name, model_name=ai_model.name)

return TrainFileInResponse(
trainID=train_in_msg.fileID + train_in_msg.modelID,
)
Expand Down
1 change: 1 addition & 0 deletions backend/src/config/settings/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
UPLOAD_FILE_PATH = "./uploaded_files/"
MAX_SQL_LENGTH = 200
DEFAULT_MODEL = "all-MiniLM-L6-v2"
CHAT_COMTEXT = "Melbourne is the capital city of the Australian state of Victoria.It is known for its diverse and vibrant cultural scene.The city is famous for its coffee culture, with numerous cafes scattered throughout.Melbourne is home to iconic landmarks like the Royal Exhibition Building and Flinders Street Station.The Yarra River runs through the heart of the city, providing a picturesque setting.The Melbourne Cricket Ground (MCG) is a historic sports venue and a key part of the city's identity.The city hosts various events, including the Australian Open, Melbourne Fashion Week, and the Melbourne International Comedy Festival.Melbourne's street art scene is renowned, with vibrant murals adorning many laneways.The Queen Victoria Market is a popular spot for fresh produce, local crafts, and diverse international cuisines.Melbourne is often considered one of the most livable cities globally, offering a high quality of life."
3 changes: 3 additions & 0 deletions backend/src/models/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class Session(Base): # type: ignore
id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(primary_key=True, autoincrement="auto")
account_id: SQLAlchemyMapped[int] = sqlalchemy_mapped_column(nullable=True)
name: SQLAlchemyMapped[str] = sqlalchemy_mapped_column(sqlalchemy.String(length=64), nullable=True)
created_at: SQLAlchemyMapped[datetime.datetime] = sqlalchemy_mapped_column(
sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy_functions.now()
)

__mapper_args__ = {"eager_defaults": True}

Expand Down
2 changes: 2 additions & 0 deletions backend/src/models/schemas/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Optional

from pydantic import Field
Expand All @@ -19,6 +20,7 @@ class ChatInResponse(BaseSchemaModel):
class Session(BaseSchemaModel):
id: int
name: str | None
created_at: datetime.datetime


class ChatHistory(BaseSchemaModel):
Expand Down
4 changes: 2 additions & 2 deletions backend/src/repository/crud/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ async def create_session(self, account_id: Optional[int], name: str) -> Session:
return new_session

async def read_sessions(self) -> typing.Sequence[Session]:
stmt = sqlalchemy.select(Session)
stmt = sqlalchemy.select(Session).order_by(Session.created_at.desc())
query = await self.async_session.execute(statement=stmt)
return query.scalars().all()

async def read_sessions_by_id(self, id: int) -> Session:
stmt = sqlalchemy.select(Session).where(Session.id == id)
stmt = sqlalchemy.select(Session).where(Session.id == id).order_by(Session.created_at.desc())
query = await self.async_session.execute(statement=stmt)

if not query:
Expand Down
51 changes: 46 additions & 5 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import csv

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline

from src.config.settings.const import DEFAULT_MODEL, MAX_SQL_LENGTH
from src.config.settings.const import CHAT_COMTEXT, DEFAULT_MODEL, MAX_SQL_LENGTH, UPLOAD_FILE_PATH
from src.repository.rag.base import BaseRAGRepository


class RAGChatModelRepository(BaseRAGRepository):
model = SentenceTransformer(DEFAULT_MODEL, "cuda")
# model = SentenceTransformer(DEFAULT_MODEL, "cuda")
# embeddings = model.encode([], convert_to_tensor=True).to("cuda")
model_name = "deepset/roberta-base-squad2"

nlp = pipeline("question-answering", model=model_name, tokenizer=model_name)

async def load_model(self, session_id: int, model_name: str) -> bool:
# Init model with input model_name
Expand All @@ -18,6 +27,38 @@ async def load_model(self, session_id: int, model_name: str) -> bool:
return True

async def get_response(self, session_id: int, input_msg: str) -> str:
# TODO use RAG framework to generate the response message
response_msg = "Oh, really? It's amazing !"
return response_msg
# TODO use RAG framework to generate the response message @Aisuko
# query_embedding = self.model.encode(input_msg, convert_to_tensor=True).to("cuda")
# print(self.embeddings)
# print(query_embedding)
# we use cosine-similarity and torch.topk to find the highest 5 scores
# cos_scores = cos_sim(query_embedding, self.embeddings)[0]
# top_results = torch.topk(cos_scores, k=1)
# response_msg = self.data[top_results[1].item()]
QA_input = {"question": input_msg, "context": CHAT_COMTEXT}
res = self.nlp(QA_input)
print(res)
# response_msg = "Oh, really? It's amazing !"
return res["answer"]

async def load_csv_file(self, file_name: str, model_name: str) -> bool:
# read file named file_name and convert the content into a list of strings @Aisuko
print(file_name)
print(model_name)
self.data = []
self.embeddings = []
# Open the CSV file
with open(UPLOAD_FILE_PATH + file_name, "r") as file:
# Create a CSV reader
reader = csv.reader(file)

# Iterate over each row in the CSV
for row in reader:
# Add the row to the list
self.data.extend(row)
print(self.data)
self.model = SentenceTransformer(model_name, "cuda")
row_embedding = self.model.encode(self.data, convert_to_tensor=True).to("cuda")
# TODO
self.embeddings.append(row_embedding)
return True
10 changes: 9 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ services:
backend_app:
container_name: backend_app
restart: always
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
build:
dockerfile: Dockerfile
context: ./backend/
Expand Down Expand Up @@ -78,6 +85,7 @@ services:
- 8001:8000
depends_on:
- db
# command: ["--gpus", "all"]

etcd:
container_name: milvus-etcd
Expand Down Expand Up @@ -127,7 +135,7 @@ services:
- "minio"

frontend:
image: ghcr.io/skywardai/rebel:v0.0.4
image: ghcr.io/skywardai/rebel:v0.0.5
container_name: frontend
restart: always
expose:
Expand Down

0 comments on commit 9dc1ad2

Please sign in to comment.