Skip to content

Commit

Permalink
Merge pull request #40 from SkywardAI/fix/device
Browse files Browse the repository at this point in the history
Add device support
  • Loading branch information
Aisuko authored Mar 25, 2024
2 parents f73fff4 + c0e7c9e commit 905f3bf
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 13 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci-backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ on:
- 'v*'
pull_request:
branches:
- 'feature/**'
- 'feat/**'
- 'fix/**'
- 'main'

jobs:
build:
Expand Down
31 changes: 29 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,37 @@
env:
@cp .env.example .env

############################################################################################################
# For development, require Nvidia GPU
.PHONY: build
build: env
docker-compose build
docker-compose -f docker-compose.yaml build


.PHONY: up
up: env build
docker-compose up -d
docker-compose -f docker-compose.yaml up -d


.PHONY: stop
stop:
docker-compose -f docker-compose.yaml stop

############################################################################################################
# For demo, without GPU augumentation, but slow for inference. Might include some bugs.
.PHONY: demo
demo: env
docker-compose -f docker-compose.demo.yaml build
docker-compose -f docker-compose.demo.yaml up -d

.PHONY: demo-stop
demo-stop:
docker-compose -f docker-compose.demo.yaml stop

.PHONY: demo-logs
demo-logs:
docker-compose -f docker-compose.demo.yaml logs -f

.PHONY: demo-remove
demo-remove:
docker-compose -f docker-compose.demo.yaml down
11 changes: 4 additions & 7 deletions backend/src/repository/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@
from sentence_transformers.util import cos_sim
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline


from src.config.settings.const import CHAT_COMTEXT, DEFAULT_MODEL, MAX_SQL_LENGTH, UPLOAD_FILE_PATH
from src.repository.rag.base import BaseRAGRepository
from src.utilities.devices.devices import get_device


class RAGChatModelRepository(BaseRAGRepository):
# model = SentenceTransformer(DEFAULT_MODEL, "cuda")
# embeddings = model.encode([], convert_to_tensor=True).to("cuda")
model_name = "deepset/roberta-base-squad2"

nlp = pipeline("question-answering", model=model_name, tokenizer=model_name)

async def load_model(self, session_id: int, model_name: str) -> bool:
# Init model with input model_name
try:
model = SentenceTransformer(model_name, "cuda")
# https://github.com/UKPLab/sentence-transformers/blob/85810ead37d02ef706da39e4a1757702d1b9f7c5/sentence_transformers/SentenceTransformer.py#L47
model = SentenceTransformer(model_name, device=get_device())
model.max_seq_length = MAX_SQL_LENGTH
except Exception as e:
print(e)
Expand All @@ -29,16 +30,12 @@ async def load_model(self, session_id: int, model_name: str) -> bool:
async def get_response(self, session_id: int, input_msg: str) -> str:
# TODO use RAG framework to generate the response message @Aisuko
# query_embedding = self.model.encode(input_msg, convert_to_tensor=True).to("cuda")
# print(self.embeddings)
# print(query_embedding)
# we use cosine-similarity and torch.topk to find the highest 5 scores
# cos_scores = cos_sim(query_embedding, self.embeddings)[0]
# top_results = torch.topk(cos_scores, k=1)
# response_msg = self.data[top_results[1].item()]
QA_input = {"question": input_msg, "context": CHAT_COMTEXT}
res = self.nlp(QA_input)
print(res)
# response_msg = "Oh, really? It's amazing !"
return res["answer"]

async def load_csv_file(self, file_name: str, model_name: str) -> bool:
Expand Down
Empty file.
9 changes: 9 additions & 0 deletions backend/src/utilities/devices/devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import platform
import torch

def get_device():
if platform.system() == 'Darwin':
return 'mps'
elif torch.cuda.is_available():
return 'cuda'
return 'cpu'
6 changes: 6 additions & 0 deletions backend/tests/unit_tests/test_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ def test_application_is_fastapi_instance() -> None:
assert backend_app.docs_url == "/docs"
assert backend_app.openapi_url == "/openapi.json"
assert backend_app.redoc_url == "/redoc"


def test_get_device() -> None:
from src.utilities.devices.devices import get_device

assert get_device() in ["mps", "cuda", "cpu"]
141 changes: 141 additions & 0 deletions docker-compose.demo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
services:
db:
image: postgres:latest
container_name: db
restart: always
environment:
- POSTGRES_USER=${POSTGRES_USERNAME}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_DB=${POSTGRES_DB}
- PGDATA=/var/lib/postgresql/data/
volumes:
- postgresql_db_data:/var/lib/postgresql/data/
expose:
- 5432
ports:
- 5433:5432

db_editor:
image: adminer
container_name: db_editor
restart: always
environment:
- POSTGRES_USER=${POSTGRES_USERNAME}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_HOST=${POSTGRES_HOST}
- POSTGRES_PORT=${POSTGRES_PORT}
- POSTGRES_DB=${POSTGRES_DB}
expose:
- 8080
ports:
- 8081:8080
depends_on:
- db

backend_app:
container_name: backend_app
restart: always

build:
dockerfile: Dockerfile
context: ./backend/
environment:
- ENVIRONMENT=${ENVIRONMENT}
- DEBUG=${DEBUG}
- POSTGRES_DB=${POSTGRES_DB}
- POSTGRES_HOST=${POSTGRES_HOST}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_PORT=${POSTGRES_PORT}
- POSTGRES_SCHEMA=${POSTGRES_SCHEMA}
- POSTGRES_USERNAME=${POSTGRES_USERNAME}
- BACKEND_SERVER_HOST=${BACKEND_SERVER_HOST}
- BACKEND_SERVER_PORT=${BACKEND_SERVER_PORT}
- BACKEND_SERVER_WORKERS=${BACKEND_SERVER_WORKERS}
- DB_TIMEOUT=${DB_TIMEOUT}
- DB_POOL_SIZE=${DB_POOL_SIZE}
- DB_MAX_POOL_CON=${DB_MAX_POOL_CON}
- DB_POOL_OVERFLOW=${DB_POOL_OVERFLOW}
- IS_DB_ECHO_LOG=${IS_DB_ECHO_LOG}
- IS_DB_EXPIRE_ON_COMMIT=${IS_DB_EXPIRE_ON_COMMIT}
- IS_DB_FORCE_ROLLBACK=${IS_DB_FORCE_ROLLBACK}
- IS_ALLOWED_CREDENTIALS=${IS_ALLOWED_CREDENTIALS}
- API_TOKEN=${API_TOKEN}
- AUTH_TOKEN=${AUTH_TOKEN}
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
- JWT_SUBJECT=${JWT_SUBJECT}
- JWT_TOKEN_PREFIX=${JWT_TOKEN_PREFIX}
- JWT_ALGORITHM=${JWT_ALGORITHM}
- JWT_MIN=${JWT_MIN}
- JWT_HOUR=${JWT_HOUR}
- JWT_DAY=${JWT_DAY}
- HASHING_ALGORITHM_LAYER_1=${HASHING_ALGORITHM_LAYER_1}
- HASHING_ALGORITHM_LAYER_2=${HASHING_ALGORITHM_LAYER_2}
- HASHING_SALT=${HASHING_SALT}
volumes:
- ./backend/:/usr/backend/
expose:
- 8000
ports:
- 8001:8000
depends_on:
- db
# command: ["--gpus", "all"]

etcd:
container_name: milvus-etcd
image: quay.io/coreos/etcd:v3.5.0
environment:
- ETCD_AUTO_COMPACTION_MODE=${ETCD_AUTO_COMPACTION_MODE}
- ETCD_AUTO_COMPACTION_RETENTION=${ETCD_AUTO_COMPACTION_RETENTION}
- ETCD_QUOTA_BACKEND_BYTES=${ETCD_QUOTA_BACKEND_BYTES}
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd

minio:
container_name: milvus-minio
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
environment:
MINIO_ACCESS_KEY: minioadmin
MINIO_SECRET_KEY: minioadmin
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
command: minio server /minio_data
healthcheck:
test:
[
"CMD",
"curl",
"-f",
"http://localhost:9000/minio/health/live"
]
interval: 30s
timeout: 20s
retries: 3

standalone:
container_name: milvus-standalone
image: milvusdb/milvus:v2.0.2
command: [ "milvus", "run", "standalone" ]
environment:
ETCD_ENDPOINTS: etcd:2379
MINIO_ADDRESS: minio:9000
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
ports:
- "19530:19530"
depends_on:
- "etcd"
- "minio"

frontend:
image: ghcr.io/skywardai/rebel:v0.1.1
container_name: frontend
restart: always
expose:
- 80
ports:
- 80:80

volumes:
postgresql_db_data:
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ services:
- "minio"

frontend:
image: ghcr.io/skywardai/rebel:v0.0.5
image: ghcr.io/skywardai/rebel:v0.1.1
container_name: frontend
restart: always
expose:
Expand Down

0 comments on commit 905f3bf

Please sign in to comment.