From 8fb287efe98b9ff867da9634aa24c19e7925ff5c Mon Sep 17 00:00:00 2001 From: Xinzijian Liu Date: Tue, 3 Sep 2024 20:56:43 +0800 Subject: [PATCH] Support valid data for multitask training (#257) ## Summary by CodeRabbit - **New Features** - Introduced support for multitask validation data, allowing users to specify multiple validation datasets through new arguments. - Enhanced flexibility in handling validation data, accommodating both single and multitask configurations. - Added dynamic model freezing capability based on configuration parameters. - Improved configurability by allowing external configuration data to be integrated into function execution. - **Bug Fixes** - Improved error handling for validation data inputs, ensuring robust processing of various data structures. - **Documentation** - Updated documentation to clarify the usage of new arguments related to multitask validation data. --------- Signed-off-by: zjgemi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dpgen2/entrypoint/args.py | 19 +++++++++++ dpgen2/entrypoint/submit.py | 31 +++++++++++++----- dpgen2/op/run_dp_train.py | 22 ++++++++++--- dpgen2/op/run_lmp.py | 52 +++++++++++++++++------------- dpgen2/op/run_relax.py | 37 +++++++++++++++++++++ dpgen2/superop/prep_run_diffcsp.py | 1 + tests/op/test_run_relax.py | 1 + 7 files changed, 127 insertions(+), 36 deletions(-) diff --git a/dpgen2/entrypoint/args.py b/dpgen2/entrypoint/args.py index 3445e42e..645e8c6b 100644 --- a/dpgen2/entrypoint/args.py +++ b/dpgen2/entrypoint/args.py @@ -529,6 +529,11 @@ def input_args(): doc_valid_data_prefix = "The prefix of validation data systems" doc_valid_sys = "The validation data systems" doc_valid_data_uri = "The URI of validation data" + doc_multi_valid_data = ( + "The validation data for multitask, it should be a dict, whose keys are task names and each value is a dict" + "containing fields `prefix` and `sys` for initial data of each task" + ) + doc_multi_valid_data_uri = "The URI of validation data for multitask" return [ Argument("type_map", List[str], optional=False, doc=doc_type_map), @@ -607,6 +612,20 @@ def input_args(): default=None, doc=doc_valid_data_uri, ), + Argument( + "multi_valid_data", + dict, + optional=True, + default=None, + doc=doc_multi_valid_data, + ), + Argument( + "multi_valid_data_uri", + str, + optional=True, + default=None, + doc=doc_multi_valid_data_uri, + ), ] diff --git a/dpgen2/entrypoint/submit.py b/dpgen2/entrypoint/submit.py index 3f1b2825..ce414cfb 100644 --- a/dpgen2/entrypoint/submit.py +++ b/dpgen2/entrypoint/submit.py @@ -513,13 +513,28 @@ def workflow_concurrent_learning( ] upload_python_packages = _upload_python_packages - valid_data = config["inputs"]["valid_data_sys"] - if config["inputs"]["valid_data_uri"] is not None: - valid_data = get_artifact_from_uri(config["inputs"]["valid_data_uri"]) - elif valid_data is not None: - valid_data_prefix = config["inputs"]["valid_data_prefix"] - valid_data = get_systems_from_data(valid_data, valid_data_prefix) - valid_data = upload_artifact_and_print_uri(valid_data, "valid_data") + multitask = config["inputs"]["multitask"] + valid_data = None + if multitask: + if config["inputs"]["multi_valid_data_uri"] is not None: + valid_data = get_artifact_from_uri(config["inputs"]["multi_valid_data_uri"]) + elif config["inputs"]["multi_valid_data"] is not None: + multi_valid_data = config["inputs"]["multi_valid_data"] + valid_data = {} + for k, v in multi_valid_data.items(): + sys = v["sys"] + sys = get_systems_from_data(sys, v.get("prefix", None)) + valid_data[k] = sys + valid_data = upload_artifact_and_print_uri(valid_data, "multi_valid_data") + else: + if config["inputs"]["valid_data_uri"] is not None: + valid_data = get_artifact_from_uri(config["inputs"]["valid_data_uri"]) + elif config["inputs"]["valid_data_prefix"] is not None: + valid_data_prefix = config["inputs"]["valid_data_prefix"] + valid_data = config["inputs"]["valid_data_sys"] + valid_data = get_systems_from_data(valid_data, valid_data_prefix) + valid_data = upload_artifact_and_print_uri(valid_data, "valid_data") + concurrent_learning_op = make_concurrent_learning_op( train_style, explore_style, @@ -591,7 +606,7 @@ def workflow_concurrent_learning( init_data = upload_artifact_and_print_uri(init_data, "multi_init_data") train_config["multitask"] = True train_config["head"] = head - explore_config["head"] = head + explore_config["model_frozen_head"] = head else: if config["inputs"]["init_data_uri"] is not None: init_data = get_artifact_from_uri(config["inputs"]["init_data_uri"]) diff --git a/dpgen2/op/run_dp_train.py b/dpgen2/op/run_dp_train.py index 872f60b8..dccbc518 100644 --- a/dpgen2/op/run_dp_train.py +++ b/dpgen2/op/run_dp_train.py @@ -129,7 +129,7 @@ def get_input_sign(cls): "init_model": Artifact(Path, optional=True), "init_data": Artifact(NestedDict[Path]), "iter_data": Artifact(List[Path]), - "valid_data": Artifact(List[Path], optional=True), + "valid_data": Artifact(NestedDict[Path], optional=True), "optional_files": Artifact(List[Path], optional=True), } ) @@ -182,11 +182,10 @@ def execute( finetune_mode = ip["optional_parameter"]["finetune_mode"] config = ip["config"] if ip["config"] is not None else {} impl = ip["config"].get("impl", "tensorflow") + dp_command = ip["config"].get("command", "dp").split() assert impl in ["tensorflow", "pytorch"] if impl == "pytorch": - dp_command = ["dp", "--pt"] - else: - dp_command = ["dp"] + dp_command.append("--pt") finetune_args = config.get("finetune_args", "") train_args = config.get("train_args", "") config = RunDPTrain.normalize_config(config) @@ -356,7 +355,7 @@ def write_data_to_input_script( iter_data: List[Path], auto_prob_str: str = "prob_sys_size", major_version: str = "1", - valid_data: Optional[List[Path]] = None, + valid_data: Optional[Union[List[Path], Dict[str, List[Path]]]] = None, ): odict = idict.copy() if config["multitask"]: @@ -368,6 +367,11 @@ def write_data_to_input_script( if k == head: v["training_data"]["systems"] += [str(ii) for ii in iter_data] v["training_data"]["auto_prob"] = auto_prob_str + if valid_data is None: + v.pop("validation_data", None) + else: + v["validation_data"] = v.get("validation_data", {"batch_size": 1}) + v["validation_data"]["systems"] = [str(ii) for ii in valid_data[k]] return odict data_list = [str(ii) for ii in init_data] + [str(ii) for ii in iter_data] if major_version == "1": @@ -490,6 +494,7 @@ def decide_init_model( @staticmethod def training_args(): + doc_command = "The command for DP, 'dp' for default" doc_impl = "The implementation/backend of DP. It can be 'tensorflow' or 'pytorch'. 'tensorflow' for default." doc_init_model_policy = "The policy of init-model training. It can be\n\n\ - 'no': No init-model training. Traing from scratch.\n\n\ @@ -513,6 +518,13 @@ def training_args(): doc_init_model_with_finetune = "Use finetune for init model" doc_train_args = "Extra arguments for dp train" return [ + Argument( + "command", + str, + optional=True, + default="dp", + doc=doc_command, + ), Argument( "impl", str, diff --git a/dpgen2/op/run_lmp.py b/dpgen2/op/run_lmp.py index 997407d8..2f60631d 100644 --- a/dpgen2/op/run_lmp.py +++ b/dpgen2/op/run_lmp.py @@ -150,28 +150,7 @@ def execute( elif ext == ".pt": # freeze model mname = pytorch_model_name_pattern % (idx) - freeze_args = "-o %s" % mname - if config.get("head") is not None: - freeze_args += " --head %s" % config["head"] - freeze_cmd = "dp --pt freeze -c %s %s" % (mm, freeze_args) - ret, out, err = run_command(freeze_cmd, shell=True) - if ret != 0: - logging.error( - "".join( - ( - "freeze failed\n", - "command was", - freeze_cmd, - "out msg", - out, - "\n", - "err msg", - err, - "\n", - ) - ) - ) - raise TransientError("freeze failed") + freeze_model(mm, mname, config.get("model_frozen_head")) else: raise RuntimeError( "Model file with extension '%s' is not supported" % ext @@ -240,7 +219,9 @@ def lmp_args(): default=False, doc=doc_shuffle_models, ), - Argument("head", str, optional=True, default=None, doc=doc_head), + Argument( + "model_frozen_head", str, optional=True, default=None, doc=doc_head + ), ] @staticmethod @@ -310,3 +291,28 @@ def find_only_one_key(lmp_lines, key, raise_not_found=True): else: return None return found[0] + + +def freeze_model(input_model, frozen_model, head=None): + freeze_args = "-o %s" % frozen_model + if head is not None: + freeze_args += " --head %s" % head + freeze_cmd = "dp --pt freeze -c %s %s" % (input_model, freeze_args) + ret, out, err = run_command(freeze_cmd, shell=True) + if ret != 0: + logging.error( + "".join( + ( + "freeze failed\n", + "command was", + freeze_cmd, + "out msg", + out, + "\n", + "err msg", + err, + "\n", + ) + ) + ) + raise TransientError("freeze failed") diff --git a/dpgen2/op/run_relax.py b/dpgen2/op/run_relax.py index 110ad4d8..8876eb14 100644 --- a/dpgen2/op/run_relax.py +++ b/dpgen2/op/run_relax.py @@ -1,3 +1,4 @@ +import logging import os from pathlib import ( Path, @@ -6,6 +7,9 @@ List, ) +from dargs import ( + Argument, +) from dflow.python import ( OP, OPIO, @@ -14,6 +18,9 @@ OPIOSign, ) +from dpgen2.constants import ( + pytorch_model_name_pattern, +) from dpgen2.exploration.task import ( DiffCSPTaskGroup, ) @@ -21,6 +28,9 @@ from .run_caly_model_devi import ( atoms2lmpdump, ) +from .run_lmp import ( + freeze_model, +) class RunRelax(OP): @@ -29,6 +39,7 @@ def get_input_sign(cls): return OPIOSign( { "diffcsp_task_grp": BigParameter(DiffCSPTaskGroup), + "expl_config": dict, "task_path": Artifact(Path), "models": Artifact(List[Path]), } @@ -73,6 +84,15 @@ def execute( task_group = ip["diffcsp_task_grp"] task = next(iter(task_group)) # Only support single task models = ip["models"] + config = ip["expl_config"] + config = RunRelax.normalize_config(config) + if config["model_frozen_head"] is not None: + frozen_models = [] + for idx in range(len(models)): + mname = pytorch_model_name_pattern % (idx) + freeze_model(models[idx], mname, config["model_frozen_head"]) + frozen_models.append(Path(mname)) + models = frozen_models relaxer = Relaxer(models[0]) type_map = relaxer.calculator.dp.get_type_map() fmax = task.fmax @@ -178,3 +198,20 @@ def execute( "model_devis": model_devis, } ) + + @staticmethod + def relax_args(): + doc_head = "Select a head from multitask" + return [ + Argument( + "model_frozen_head", str, optional=True, default=None, doc=doc_head + ), + ] + + @staticmethod + def normalize_config(data={}): + ta = RunRelax.relax_args() + base = Argument("base", dict, ta) + data = base.normalize_value(data, trim_pattern="_*") + base.check_value(data, strict=False) + return data diff --git a/dpgen2/superop/prep_run_diffcsp.py b/dpgen2/superop/prep_run_diffcsp.py index c4d3bd87..c44851fe 100644 --- a/dpgen2/superop/prep_run_diffcsp.py +++ b/dpgen2/superop/prep_run_diffcsp.py @@ -196,6 +196,7 @@ def _prep_run_diffcsp( ), parameters={ "diffcsp_task_grp": expl_task_grp, + "expl_config": expl_config, }, artifacts={ "models": models, diff --git a/tests/op/test_run_relax.py b/tests/op/test_run_relax.py index 847056b6..9f96fe45 100644 --- a/tests/op/test_run_relax.py +++ b/tests/op/test_run_relax.py @@ -276,6 +276,7 @@ def testRunRelax(self, mocked_run): op_in = OPIO( { "diffcsp_task_grp": task_group, + "expl_config": {}, "task_path": Path("task.000000"), "models": [Path("model_0.pt"), Path("model_1.pt")], }