diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py index b31bfb00..3d96fcbe 100644 --- a/tests/e2e/vLLM/test_vllm.py +++ b/tests/e2e/vLLM/test_vllm.py @@ -1,4 +1,5 @@ import os +import re import shutil from pathlib import Path from typing import Callable @@ -23,6 +24,13 @@ HF_MODEL_HUB_NAME = "nm-testing" TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", "") +EXPECTED_SAVED_FILES = [ + "config.json", + r"^model(?:-\d{5}-of-\d{5})?\.safetensors$", + "recipe.yaml", + "tokenizer.json", +] + @pytest.fixture def record_config_file(record_testsuite_property: Callable[[str, object], None]): @@ -100,11 +108,17 @@ def test_vllm(self): quant_type=self.quant_type, ) + # check that session contains recipe + self._check_session_contains_recipe() + logger.info("================= SAVING TO DISK ======================") oneshot_model.save_pretrained(self.save_dir) tokenizer.save_pretrained(self.save_dir) recipe_path = os.path.join(self.save_dir, "recipe.yaml") + # check that expected files exist + self._check_save_dir_has_expected_files() + # Use the session to fetch the recipe; # Reset session for next test case session = active_session() @@ -146,3 +160,35 @@ def test_vllm(self): def tear_down(self): if self.save_dir is not None: shutil.rmtree(self.save_dir) + + def _check_session_contains_recipe(self) -> None: + session = active_session() + recipe_yaml_str = session.get_serialized_recipe() + assert recipe_yaml_str is not None + + def _check_save_dir_has_expected_files(self): + files = os.listdir(self.save_dir) + logger.debug("Saved files: ", files) + + matched_patterns = set() + + for expected in EXPECTED_SAVED_FILES: + # Find all files matching the expected pattern + matches = [ + file + for file in files + if ( + re.fullmatch(expected, file) + if expected.startswith("^") + else file == expected + ) + ] + if matches is not None: + matched_patterns.add(expected) + + assert len(matched_patterns) == len(EXPECTED_SAVED_FILES), ( + "expected: ", + EXPECTED_SAVED_FILES, + "\n saved: ", + list(matched_patterns), + )