Skip to content

Commit

Permalink
Add device support
Browse files Browse the repository at this point in the history
Signed-off-by: GitHub <[email protected]>
  • Loading branch information
Aisuko authored Mar 24, 2024
1 parent f73fff4 commit b5cbbd4
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 12 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci-backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ on:
- 'v*'
pull_request:
branches:
- 'feature/**'
- 'feat/**'
- 'fix/**'
- 'main'

jobs:
build:
Expand Down
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,9 @@ build: env

.PHONY: up
up: env build
docker-compose up -d
docker-compose up -d


.PHONY: stop
stop:
docker-compose stop
10 changes: 3 additions & 7 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from utilities.devices.devices import get_device

from src.config.settings.const import CHAT_COMTEXT, DEFAULT_MODEL, MAX_SQL_LENGTH, UPLOAD_FILE_PATH
from src.repository.rag.base import BaseRAGRepository


class RAGChatModelRepository(BaseRAGRepository):
# model = SentenceTransformer(DEFAULT_MODEL, "cuda")
# embeddings = model.encode([], convert_to_tensor=True).to("cuda")
model_name = "deepset/roberta-base-squad2"

nlp = pipeline("question-answering", model=model_name, tokenizer=model_name)

async def load_model(self, session_id: int, model_name: str) -> bool:
# Init model with input model_name
try:
model = SentenceTransformer(model_name, "cuda")
# https://github.com/UKPLab/sentence-transformers/blob/85810ead37d02ef706da39e4a1757702d1b9f7c5/sentence_transformers/SentenceTransformer.py#L47
model = SentenceTransformer(model_name, device=get_device())
model.max_seq_length = MAX_SQL_LENGTH
except Exception as e:
print(e)
Expand All @@ -29,16 +29,12 @@ async def load_model(self, session_id: int, model_name: str) -> bool:
async def get_response(self, session_id: int, input_msg: str) -> str:
# TODO use RAG framework to generate the response message @Aisuko
# query_embedding = self.model.encode(input_msg, convert_to_tensor=True).to("cuda")
# print(self.embeddings)
# print(query_embedding)
# we use cosine-similarity and torch.topk to find the highest 5 scores
# cos_scores = cos_sim(query_embedding, self.embeddings)[0]
# top_results = torch.topk(cos_scores, k=1)
# response_msg = self.data[top_results[1].item()]
QA_input = {"question": input_msg, "context": CHAT_COMTEXT}
res = self.nlp(QA_input)
print(res)
# response_msg = "Oh, really? It's amazing !"
return res["answer"]

async def load_csv_file(self, file_name: str, model_name: str) -> bool:
Expand Down
Empty file.
9 changes: 9 additions & 0 deletions backend/src/utilities/devices/devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import platform
import torch

def get_device():
if platform.system() == 'Darwin':
return 'mps'
elif torch.cuda.is_available():
return 'cuda'
return 'cpu'
6 changes: 6 additions & 0 deletions backend/tests/unit_tests/test_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ def test_application_is_fastapi_instance() -> None:
assert backend_app.docs_url == "/docs"
assert backend_app.openapi_url == "/openapi.json"
assert backend_app.redoc_url == "/redoc"


def test_get_device() -> None:
from src.utilities.devices.devices import get_device

assert get_device() in ["mps", "cuda", "cpu"]
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ services:
- "minio"

frontend:
image: ghcr.io/skywardai/rebel:v0.0.5
image: ghcr.io/skywardai/rebel:v0.1.1
container_name: frontend
restart: always
expose:
Expand Down

0 comments on commit b5cbbd4

Please sign in to comment.