-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor model server hardware config + add unit tests to load/reques…
…t to the server (#189) * remove mode/hardware * add test and pre commit hook * add pytest dependieces * fix format * fix lint * fix precommit * fix pre commit * fix pre commit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit
- Loading branch information
Showing
13 changed files
with
480 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: Run Model Server tests | ||
|
||
on: | ||
push: | ||
branches: | ||
- main # Run tests on pushes to the main branch | ||
pull_request: | ||
branches: | ||
- main # Run tests on pull requests to the main branch | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
# Step 1: Check out the code from your repository | ||
- name: Checkout code | ||
uses: actions/checkout@v3 | ||
|
||
# Step 2: Set up Python (specify the version) | ||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.10" # Adjust to your Python version | ||
|
||
# Step 3: Install dependencies (from requirements.txt or Pipfile) | ||
- name: Install dependencies | ||
run: | | ||
cd model_server | ||
pip install --upgrade pip | ||
pip install -r requirements.txt # Or use pipenv install | ||
pip install pytest | ||
# Step 4: Set PYTHONPATH and run tests | ||
- name: Run model server tests with pytest | ||
run: | | ||
cd model_server | ||
PYTHONPATH=. pytest --maxfail=5 --disable-warnings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,5 +2,3 @@ | |
|
||
|
||
DEVICE = utils.get_device() | ||
MODE = utils.get_serving_mode() | ||
HARDWARE = utils.get_hardware(MODE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import pytest | ||
import httpx | ||
from fastapi.testclient import TestClient | ||
from app.main import app # Assuming your FastAPI app is in main.py | ||
from unittest.mock import patch | ||
import app.commons.globals as glb | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
client = TestClient(app) | ||
|
||
logger.info(f"Model will be loaded on device: {glb.DEVICE}") | ||
|
||
|
||
# Unit tests for the health check endpoint | ||
@pytest.mark.asyncio | ||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' | ||
async def test_healthz(): | ||
response = client.get("/healthz") | ||
assert response.status_code == 200 | ||
assert response.json() == {"status": "ok"} | ||
|
||
|
||
# Unit test for the models endpoint | ||
@pytest.mark.asyncio | ||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' | ||
async def test_models(): | ||
response = client.get("/models") | ||
assert response.status_code == 200 | ||
assert response.json()["object"] == "list" | ||
assert len(response.json()["data"]) > 0 | ||
|
||
|
||
# Unit test for embeddings endpoint | ||
@pytest.mark.asyncio | ||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' | ||
async def test_embedding(): | ||
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"} | ||
response = client.post("/embeddings", json=request_data) | ||
if request_data["model"] == "katanemo/bge-large-en-v1.5": | ||
assert response.status_code == 200 | ||
assert response.json()["object"] == "list" | ||
assert "data" in response.json() | ||
else: | ||
assert response.status_code == 400 | ||
|
||
|
||
# Unit test for the guard endpoint | ||
@pytest.mark.asyncio | ||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' | ||
async def test_guard(): | ||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"} | ||
response = client.post("/guard", json=request_data) | ||
assert response.status_code == 200 | ||
assert "jailbreak_verdict" in response.json() | ||
|
||
|
||
# Unit test for the zero-shot endpoint | ||
@pytest.mark.asyncio | ||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' | ||
async def test_zeroshot(): | ||
request_data = { | ||
"input": "Test input", | ||
"labels": ["label1", "label2"], | ||
"model": "katanemo/bart-large-mnli", | ||
} | ||
response = client.post("/zeroshot", json=request_data) | ||
if request_data["model"] == "katanemo/bart-large-mnli": | ||
assert response.status_code == 200 | ||
assert "predicted_class" in response.json() | ||
else: | ||
assert response.status_code == 400 | ||
|
||
|
||
# Unit test for the hallucination endpoint | ||
@pytest.mark.asyncio | ||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' | ||
async def test_hallucination(): | ||
request_data = { | ||
"prompt": "Test hallucination", | ||
"parameters": {"param1": "value1"}, | ||
"model": "katanemo/bart-large-mnli", | ||
} | ||
response = client.post("/hallucination", json=request_data) | ||
if request_data["model"] == "katanemo/bart-large-mnli": | ||
assert response.status_code == 200 | ||
assert "params_scores" in response.json() | ||
else: | ||
assert response.status_code == 400 | ||
|
||
|
||
# Unit test for the chat completion endpoint | ||
@pytest.mark.asyncio | ||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' | ||
async def test_chat_completion(): | ||
async with httpx.AsyncClient(app=app, base_url="http://test") as client: | ||
request_data = { | ||
"messages": [{"role": "user", "content": "Hello!"}], | ||
"model": "Arch-Function-1.5B", | ||
"tools": [], # Assuming tools is part of the req as per the function | ||
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed | ||
} | ||
response = await client.post("/v1/chat/completions", json=request_data) | ||
assert response.status_code == 200 | ||
assert "choices" in response.json() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
import pytest | ||
from unittest.mock import patch, MagicMock | ||
import app.commons.globals as glb | ||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard | ||
|
||
# Mock constants | ||
glb.DEVICE = "cpu" # Adjust as needed for your test case | ||
arch_guard_model_type = { | ||
"cpu": "katanemo/Arch-Guard-cpu", | ||
"cuda": "katanemo/Arch-Guard", | ||
"mps": "katanemo/Arch-Guard", | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def mock_env(): | ||
# Mock environment variables | ||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5" | ||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli" | ||
|
||
|
||
# Test for get_embedding_model function | ||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained") | ||
@patch("app.loader.AutoModel.from_pretrained") | ||
@patch("app.loader.AutoTokenizer.from_pretrained") | ||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env): | ||
mock_automodel.return_value = MagicMock() | ||
mock_ort_model.return_value = MagicMock() | ||
mock_tokenizer.return_value = MagicMock() | ||
|
||
embedding_model = get_embedding_model() | ||
|
||
# Assertions | ||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5" | ||
assert mock_tokenizer.called_once_with( | ||
"katanemo/bge-large-en-v1.5", trust_remote_code=True | ||
) | ||
if glb.DEVICE != "cuda": | ||
assert mock_ort_model.called_once_with( | ||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx" | ||
) | ||
else: | ||
assert mock_automodel.called_once_with( | ||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE | ||
) | ||
|
||
|
||
# Test for get_zero_shot_model function | ||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained") | ||
@patch("app.loader.pipeline") | ||
@patch("app.loader.AutoTokenizer.from_pretrained") | ||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env): | ||
mock_pipeline.return_value = MagicMock() | ||
mock_ort_model.return_value = MagicMock() | ||
mock_tokenizer.return_value = MagicMock() | ||
|
||
zero_shot_model = get_zero_shot_model() | ||
|
||
# Assertions | ||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli" | ||
assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli") | ||
if glb.DEVICE != "cuda": | ||
assert mock_ort_model.called_once_with( | ||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx" | ||
) | ||
else: | ||
assert mock_pipeline.called_once() | ||
|
||
|
||
# Test for get_prompt_guard function | ||
@patch("app.loader.AutoTokenizer.from_pretrained") | ||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained") | ||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") | ||
def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer): | ||
# Mock model based on device | ||
if glb.DEVICE == "cpu": | ||
mock_ov_model.return_value = MagicMock() | ||
else: | ||
mock_auto_model.return_value = MagicMock() | ||
|
||
mock_tokenizer.return_value = MagicMock() | ||
|
||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) | ||
|
||
# Assertions | ||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE] | ||
assert mock_tokenizer.called_once_with( | ||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True | ||
) | ||
if glb.DEVICE == "cpu": | ||
assert mock_ov_model.called_once_with( | ||
arch_guard_model_type[glb.DEVICE], | ||
device_map=glb.DEVICE, | ||
low_cpu_mem_usage=True, | ||
) | ||
else: | ||
assert mock_auto_model.called_once_with( | ||
arch_guard_model_type[glb.DEVICE], | ||
device_map=glb.DEVICE, | ||
low_cpu_mem_usage=True, | ||
) |
Oops, something went wrong.