Skip to content

Commit

Permalink
Merge pull request #45 from SkywardAI/feat/debug
Browse files Browse the repository at this point in the history
[draft]:remove ml libraries
  • Loading branch information
Micost authored Mar 26, 2024
2 parents 8a0e32c + 3e34132 commit e98ca25
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 20 deletions.
1 change: 0 additions & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ python-dotenv==1.0.1
python-jose==3.3.0
python-multipart==0.0.9
python-slugify==8.0.4
sentence-transformers==2.3.1
SQLAlchemy==2.0.0b3
transformers==4.38.2
trio==0.24.0
Expand Down
18 changes: 2 additions & 16 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
import csv

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline


from src.config.settings.const import CHAT_COMTEXT, DEFAULT_MODEL, MAX_SQL_LENGTH, UPLOAD_FILE_PATH
from src.repository.rag.base import BaseRAGRepository
from src.utilities.devices.devices import get_device


class RAGChatModelRepository(BaseRAGRepository):
model_name = "deepset/roberta-base-squad2"

nlp = pipeline("question-answering", model=model_name, tokenizer=model_name)

async def load_model(self, session_id: int, model_name: str) -> bool:
# Init model with input model_name
try:
# https://github.com/UKPLab/sentence-transformers/blob/85810ead37d02ef706da39e4a1757702d1b9f7c5/sentence_transformers/SentenceTransformer.py#L47
model = SentenceTransformer(model_name, device=get_device())
model.max_seq_length = MAX_SQL_LENGTH
pass
except Exception as e:
print(e)
return False
Expand All @@ -34,9 +23,7 @@ async def get_response(self, session_id: int, input_msg: str) -> str:
# cos_scores = cos_sim(query_embedding, self.embeddings)[0]
# top_results = torch.topk(cos_scores, k=1)
# response_msg = self.data[top_results[1].item()]
QA_input = {"question": input_msg, "context": CHAT_COMTEXT}
res = self.nlp(QA_input)
return res["answer"]
return "response message"

async def load_csv_file(self, file_name: str, model_name: str) -> bool:
# read file named file_name and convert the content into a list of strings @Aisuko
Expand All @@ -54,7 +41,6 @@ async def load_csv_file(self, file_name: str, model_name: str) -> bool:
# Add the row to the list
self.data.extend(row)
print(self.data)
self.model = SentenceTransformer(model_name, "cuda")
row_embedding = self.model.encode(self.data, convert_to_tensor=True).to("cuda")
# TODO
self.embeddings.append(row_embedding)
Expand Down
3 changes: 0 additions & 3 deletions backend/src/utilities/devices/devices.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import platform
import torch

def get_device():
if platform.system() == 'Darwin':
return 'mps'
elif torch.cuda.is_available():
return 'cuda'
return 'cpu'

0 comments on commit e98ca25

Please sign in to comment.