From 4e364ec959f5ff88ee14fa86fa7dba52b9cb35f9 Mon Sep 17 00:00:00 2001 From: zhenyu wang Date: Fri, 31 May 2024 20:18:39 +0800 Subject: [PATCH] fix: expl_config as InputParameter --- dpgen2/op/prep_caly_model_devi.py | 8 ++++---- tests/op/test_prep_caly_model_devi.py | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dpgen2/op/prep_caly_model_devi.py b/dpgen2/op/prep_caly_model_devi.py index da479402..134fd848 100644 --- a/dpgen2/op/prep_caly_model_devi.py +++ b/dpgen2/op/prep_caly_model_devi.py @@ -48,7 +48,7 @@ def get_input_sign(cls): return OPIOSign( { "task_name": Parameter(str), - "model_devi_group_size": Parameter(int), + "expl_config": BigParameter(dict), "traj_results": Artifact(List[Path]), } ) @@ -74,7 +74,7 @@ def execute( ip : dict Input dict with components: - `task_name` : (`str`) - - `model_devi_group_size` : (`int`) + - `expl_config` : (`BigParameter(dict)`) - `traj_results` : (`Path`) Returns @@ -93,8 +93,8 @@ def execute( for traj_dir in traj_results_dir for traj in Path(traj_dir).rglob("*.traj") ] - model_devi_group_size = ip["model_devi_group_size"] - group_size = model_devi_group_size if model_devi_group_size != 0 else len(trajs) + expl_config = ip["expl_config"] + group_size = expl_config.get("model_devi_group_size", len(trajs)) with set_directory(work_dir): grouped_trajs_list = [ diff --git a/tests/op/test_prep_caly_model_devi.py b/tests/op/test_prep_caly_model_devi.py index 179db057..2a3587ce 100644 --- a/tests/op/test_prep_caly_model_devi.py +++ b/tests/op/test_prep_caly_model_devi.py @@ -59,7 +59,6 @@ def setUp(self): self.group_size = 5 self.ngroup = ntrajs_dir * ntrajs_per_dir / self.group_size - self.model_devi_group_size_1 = self.group_size self.model_devi_group_size_2 = 0 def tearDown(self): @@ -67,12 +66,13 @@ def tearDown(self): shutil.rmtree(self.run_dir_name) def test_00_success(self): + explore_config = {"model_devi_group_size": self.group_size} op = PrepCalyModelDevi() out = op.execute( OPIO( { "task_name": self.run_dir_name, - "model_devi_group_size": self.model_devi_group_size_1, + "expl_config": explore_config, "traj_results": self.ref_traj_results, } ) @@ -94,12 +94,13 @@ def test_00_success(self): # self.assertTrue(Path("run_dir/trajs_part_0/0.0.0.traj") in traj_list) def test_01_success(self): + explore_config = {} op = PrepCalyModelDevi() out = op.execute( OPIO( { "task_name": self.run_dir_name, - "model_devi_group_size": self.model_devi_group_size_2, + "expl_config": explore_config, "traj_results": self.ref_traj_results, } )