From b309304c0f99347b1939a09c0924e06f7d732d4e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 19 Jul 2024 10:43:07 -0400 Subject: [PATCH] fix: fix PyTorch model extension in simplify (#1596) ## Summary by CodeRabbit - **Bug Fixes** - Improved model file search logic to dynamically construct file patterns, ensuring better compatibility with different model suffixes. - Added an assertion to ensure at least one model file is found, providing clearer error messaging when no models are present. - **Tests** - Enhanced test setup by generating fake model files, improving robustness and reliability of tests. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> --- dpgen/simplify/simplify.py | 4 +++- tests/simplify/test_run_model_devi.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index 24205fda3..30b3472ac 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -221,7 +221,9 @@ def run_model_devi(iter_index, jdata, mdata): commands = [] run_tasks = ["."] # get models - models = glob.glob(os.path.join(work_path, "graph*pb")) + suffix = _get_model_suffix(jdata) + models = glob.glob(os.path.join(work_path, f"graph*{suffix}")) + assert len(models) > 0, "No model file found." model_names = [os.path.basename(ii) for ii in models] task_model_list = [] for ii in model_names: diff --git a/tests/simplify/test_run_model_devi.py b/tests/simplify/test_run_model_devi.py index e928afa8e..28d5732e5 100644 --- a/tests/simplify/test_run_model_devi.py +++ b/tests/simplify/test_run_model_devi.py @@ -17,6 +17,9 @@ class TestOneH5(unittest.TestCase): def setUp(self): work_path = Path("iter.000000") / "01.model_devi" work_path.mkdir(parents=True, exist_ok=True) + # fake models + for ii in range(4): + (work_path / f"graph.{ii:03d}.pb").touch() with tempfile.TemporaryDirectory() as tmpdir: with open(Path(tmpdir) / "test.xyz", "w") as f: f.write(