From 2e18d4c4311e08ac11310d87378520dca36944ce Mon Sep 17 00:00:00 2001 From: Aisuko Date: Sun, 24 Mar 2024 23:15:56 +0000 Subject: [PATCH] Add device support Signed-off-by: GitHub --- Makefile | 7 ++++++- backend/src/repository/rag/chat.py | 10 +++------- backend/src/utilities/devices/__init__.py | 0 backend/src/utilities/devices/devices.py | 9 +++++++++ backend/tests/unit_tests/test_src.py | 6 ++++++ docker-compose.yaml | 2 +- 6 files changed, 25 insertions(+), 9 deletions(-) create mode 100644 backend/src/utilities/devices/__init__.py create mode 100644 backend/src/utilities/devices/devices.py diff --git a/Makefile b/Makefile index cc52c32..6296fed 100644 --- a/Makefile +++ b/Makefile @@ -8,4 +8,9 @@ build: env .PHONY: up up: env build - docker-compose up -d \ No newline at end of file + docker-compose up -d + + +.PHONY: stop +stop: + docker-compose stop \ No newline at end of file diff --git a/backend/src/repository/rag/chat.py b/backend/src/repository/rag/chat.py index 687c1b0..2e8326a 100644 --- a/backend/src/repository/rag/chat.py +++ b/backend/src/repository/rag/chat.py @@ -4,14 +4,13 @@ from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline +from utilities.devices.devices import get_device from src.config.settings.const import CHAT_COMTEXT, DEFAULT_MODEL, MAX_SQL_LENGTH, UPLOAD_FILE_PATH from src.repository.rag.base import BaseRAGRepository class RAGChatModelRepository(BaseRAGRepository): - # model = SentenceTransformer(DEFAULT_MODEL, "cuda") - # embeddings = model.encode([], convert_to_tensor=True).to("cuda") model_name = "deepset/roberta-base-squad2" nlp = pipeline("question-answering", model=model_name, tokenizer=model_name) @@ -19,7 +18,8 @@ class RAGChatModelRepository(BaseRAGRepository): async def load_model(self, session_id: int, model_name: str) -> bool: # Init model with input model_name try: - model = SentenceTransformer(model_name, "cuda") + # https://github.com/UKPLab/sentence-transformers/blob/85810ead37d02ef706da39e4a1757702d1b9f7c5/sentence_transformers/SentenceTransformer.py#L47 + model = SentenceTransformer(model_name, device=get_device()) model.max_seq_length = MAX_SQL_LENGTH except Exception as e: print(e) @@ -29,16 +29,12 @@ async def load_model(self, session_id: int, model_name: str) -> bool: async def get_response(self, session_id: int, input_msg: str) -> str: # TODO use RAG framework to generate the response message @Aisuko # query_embedding = self.model.encode(input_msg, convert_to_tensor=True).to("cuda") - # print(self.embeddings) - # print(query_embedding) # we use cosine-similarity and torch.topk to find the highest 5 scores # cos_scores = cos_sim(query_embedding, self.embeddings)[0] # top_results = torch.topk(cos_scores, k=1) # response_msg = self.data[top_results[1].item()] QA_input = {"question": input_msg, "context": CHAT_COMTEXT} res = self.nlp(QA_input) - print(res) - # response_msg = "Oh, really? It's amazing !" return res["answer"] async def load_csv_file(self, file_name: str, model_name: str) -> bool: diff --git a/backend/src/utilities/devices/__init__.py b/backend/src/utilities/devices/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/src/utilities/devices/devices.py b/backend/src/utilities/devices/devices.py new file mode 100644 index 0000000..982a9d5 --- /dev/null +++ b/backend/src/utilities/devices/devices.py @@ -0,0 +1,9 @@ +import platform +import torch + +def get_device(): + if platform.system() == 'Darwin': + return 'mps' + elif torch.cuda.is_available(): + return 'cuda' + return 'cpu' diff --git a/backend/tests/unit_tests/test_src.py b/backend/tests/unit_tests/test_src.py index af28379..5f2c667 100644 --- a/backend/tests/unit_tests/test_src.py +++ b/backend/tests/unit_tests/test_src.py @@ -14,3 +14,9 @@ def test_application_is_fastapi_instance() -> None: assert backend_app.docs_url == "/docs" assert backend_app.openapi_url == "/openapi.json" assert backend_app.redoc_url == "/redoc" + + +def test_get_device() -> None: + from src.utilities.devices.devices import get_device + + assert get_device() in ["mps", "cuda", "cpu"] \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index c4efb9d..f80c590 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -135,7 +135,7 @@ services: - "minio" frontend: - image: ghcr.io/skywardai/rebel:v0.0.5 + image: ghcr.io/skywardai/rebel:v0.1.1 container_name: frontend restart: always expose: