Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recipe check vllm e2e #929

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
51 changes: 50 additions & 1 deletion tests/e2e/vLLM/test_vllm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import shutil
from pathlib import Path
from typing import Callable
Expand All @@ -23,6 +24,13 @@
HF_MODEL_HUB_NAME = "nm-testing"
TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", "")

EXPECTED_SAVED_FILES = [
"config.json",
r"^model(?:-\d{5}-of-\d{5})?\.safetensors$",
"recipe.yaml",
"tokenizer.json",
horheynm marked this conversation as resolved.
Show resolved Hide resolved
]


@pytest.fixture
def record_config_file(record_testsuite_property: Callable[[str, object], None]):
Expand All @@ -33,7 +41,7 @@ def record_config_file(record_testsuite_property: Callable[[str, object], None])
# Will run each test case in its own process through run_tests.sh
# emulating vLLM CI testing
@requires_gpu_count(1)
@pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test")
# @pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test")
horheynm marked this conversation as resolved.
Show resolved Hide resolved
class TestvLLM:
"""
The following test quantizes a model using a preset scheme or recipe,
Expand Down Expand Up @@ -100,11 +108,20 @@ def test_vllm(self):
quant_type=self.quant_type,
)

# check that session contains recipe
self._check_session_contains_recipe()

logger.info("================= SAVING TO DISK ======================")
oneshot_model.save_pretrained(self.save_dir)
tokenizer.save_pretrained(self.save_dir)
recipe_path = os.path.join(self.save_dir, "recipe.yaml")

# check that expected files exist
self._check_save_dir_has_expected_files()

# Reset after session info is extracted on save -- recipe
self.session.reset()

# Use the session to fetch the recipe;
# Reset session for next test case
session = active_session()
Expand Down Expand Up @@ -146,3 +163,35 @@ def test_vllm(self):
def tear_down(self):
if self.save_dir is not None:
shutil.rmtree(self.save_dir)

def _check_session_contains_recipe(self) -> None:
horheynm marked this conversation as resolved.
Show resolved Hide resolved
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
assert recipe_yaml_str is not None

def _check_save_dir_has_expected_files(self):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
files = os.listdir(self.save_dir)
logger.debug("Saved files: ", files)

matched_patterns = set()

for expected in EXPECTED_SAVED_FILES:
# Find all files matching the expected pattern
matches = [
file
for file in files
if (
re.fullmatch(expected, file)
if expected.startswith("^")
else file == expected
)
]
if matches is not None:
matched_patterns.add(expected)

assert len(matched_patterns) == len(EXPECTED_SAVED_FILES), (
"expected: ",
EXPECTED_SAVED_FILES,
"\n saved: ",
list(matched_patterns),
)
Loading