Skip to content

Commit

Permalink
remove ml libraries
Browse files Browse the repository at this point in the history
Signed-off-by: Aisuko <[email protected]>
  • Loading branch information
Aisuko authored Mar 25, 2024
1 parent 8a0e32c commit 0663ca2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 22 deletions.
4 changes: 1 addition & 3 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ python-dotenv==1.0.1
python-jose==3.3.0
python-multipart==0.0.9
python-slugify==8.0.4
sentence-transformers==2.3.1
SQLAlchemy==2.0.0b3
transformers==4.38.2
trio==0.24.0
uvicorn==0.28.0
uvicorn==0.28.0
24 changes: 5 additions & 19 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
import csv

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline


from src.config.settings.const import CHAT_COMTEXT, DEFAULT_MODEL, MAX_SQL_LENGTH, UPLOAD_FILE_PATH
from src.repository.rag.base import BaseRAGRepository
from src.utilities.devices.devices import get_device


class RAGChatModelRepository(BaseRAGRepository):
model_name = "deepset/roberta-base-squad2"

nlp = pipeline("question-answering", model=model_name, tokenizer=model_name)
class RAGChatModelRepository(BaseRAGRepository):

async def load_model(self, session_id: int, model_name: str) -> bool:
# Init model with input model_name
try:
# https://github.com/UKPLab/sentence-transformers/blob/85810ead37d02ef706da39e4a1757702d1b9f7c5/sentence_transformers/SentenceTransformer.py#L47
model = SentenceTransformer(model_name, device=get_device())
model.max_seq_length = MAX_SQL_LENGTH
pass
except Exception as e:
print(e)
return False
Expand All @@ -34,9 +23,7 @@ async def get_response(self, session_id: int, input_msg: str) -> str:
# cos_scores = cos_sim(query_embedding, self.embeddings)[0]
# top_results = torch.topk(cos_scores, k=1)
# response_msg = self.data[top_results[1].item()]
QA_input = {"question": input_msg, "context": CHAT_COMTEXT}
res = self.nlp(QA_input)
return res["answer"]
return "response message"

async def load_csv_file(self, file_name: str, model_name: str) -> bool:
# read file named file_name and convert the content into a list of strings @Aisuko
Expand All @@ -54,8 +41,7 @@ async def load_csv_file(self, file_name: str, model_name: str) -> bool:
# Add the row to the list
self.data.extend(row)
print(self.data)
self.model = SentenceTransformer(model_name, "cuda")
row_embedding = self.model.encode(self.data, convert_to_tensor=True).to("cuda")
# row_embedding = self.model.encode(self.data, convert_to_tensor=True).to("cuda")
# TODO
self.embeddings.append(row_embedding)
# self.embeddings.append(row_embedding)
return True

0 comments on commit 0663ca2

Please sign in to comment.