Skip to content

Commit

Permalink
remove ml libraries
Browse files Browse the repository at this point in the history
Signed-off-by: Aisuko <[email protected]>
  • Loading branch information
Aisuko committed Mar 26, 2024
1 parent 8a0e32c commit b77b120
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 20 deletions.
20 changes: 3 additions & 17 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
import csv

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline


from src.config.settings.const import CHAT_COMTEXT, DEFAULT_MODEL, MAX_SQL_LENGTH, UPLOAD_FILE_PATH
from src.repository.rag.base import BaseRAGRepository
from src.utilities.devices.devices import get_device


class RAGChatModelRepository(BaseRAGRepository):
model_name = "deepset/roberta-base-squad2"

nlp = pipeline("question-answering", model=model_name, tokenizer=model_name)
class RAGChatModelRepository(BaseRAGRepository):

async def load_model(self, session_id: int, model_name: str) -> bool:
# Init model with input model_name
try:
# https://github.com/UKPLab/sentence-transformers/blob/85810ead37d02ef706da39e4a1757702d1b9f7c5/sentence_transformers/SentenceTransformer.py#L47
model = SentenceTransformer(model_name, device=get_device())
model.max_seq_length = MAX_SQL_LENGTH
pass
except Exception as e:
print(e)
return False
Expand All @@ -34,9 +23,7 @@ async def get_response(self, session_id: int, input_msg: str) -> str:
# cos_scores = cos_sim(query_embedding, self.embeddings)[0]
# top_results = torch.topk(cos_scores, k=1)
# response_msg = self.data[top_results[1].item()]
QA_input = {"question": input_msg, "context": CHAT_COMTEXT}
res = self.nlp(QA_input)
return res["answer"]
return "response message"

async def load_csv_file(self, file_name: str, model_name: str) -> bool:
# read file named file_name and convert the content into a list of strings @Aisuko
Expand All @@ -54,7 +41,6 @@ async def load_csv_file(self, file_name: str, model_name: str) -> bool:
# Add the row to the list
self.data.extend(row)
print(self.data)
self.model = SentenceTransformer(model_name, "cuda")
row_embedding = self.model.encode(self.data, convert_to_tensor=True).to("cuda")
# TODO
self.embeddings.append(row_embedding)
Expand Down
3 changes: 0 additions & 3 deletions backend/src/utilities/devices/devices.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import platform
import torch

def get_device():
if platform.system() == 'Darwin':
return 'mps'
elif torch.cuda.is_available():
return 'cuda'
return 'cpu'

0 comments on commit b77b120

Please sign in to comment.