Skip to content

Commit

Permalink
feat: support the deepmd-kit v3 (deepmodeling#207)
Browse files Browse the repository at this point in the history
Signed-off-by: zjgemi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zjgemi and pre-commit-ci[bot] authored Mar 30, 2024
1 parent e32dfa2 commit f2e1d59
Show file tree
Hide file tree
Showing 14 changed files with 1,230 additions and 109 deletions.
3 changes: 2 additions & 1 deletion dpgen2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
train_script_name = "input.json"
train_log_name = "train.log"
model_name_pattern = "model.%03d.pb"
model_name_match_pattern = r"model\.[0-9]{3,}\.pb"
pytorch_model_name_pattern = "model.%03d.pth"
model_name_match_pattern = r"model\.[0-9]{3,}(\.pb|\.pth)"
lmp_index_pattern = "%06d"
lmp_task_pattern = "task." + lmp_index_pattern
lmp_conf_name = "conf.lmp"
Expand Down
45 changes: 44 additions & 1 deletion dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ def input_args():
doc_do_finetune = textwrap.dedent(doc_do_finetune)
doc_init_data_prefix = "The prefix of initial data systems"
doc_init_sys = "The inital data systems"
doc_multitask = "Do multitask training"
doc_head = "Head to use in the multitask training"
doc_multi_init_data = (
"The inital data for multitask, it should be a dict, whose keys are task names and each value is a dict"
"containing fields `prefix` and `sys` for initial data of each task"
)
doc_valid_data_prefix = "The prefix of validation data systems"
doc_valid_sys = "The validation data systems"

return [
Argument("type_map", List[str], optional=False, doc=doc_type_map),
Expand All @@ -288,10 +296,45 @@ def input_args():
Argument(
"init_data_sys",
[List[str], str],
optional=False,
optional=True,
default=None,
doc=doc_init_sys,
),
Argument(
"multitask",
bool,
optional=True,
default=False,
doc=doc_multitask,
),
Argument(
"head",
str,
optional=True,
default=None,
doc=doc_head,
),
Argument(
"multi_init_data",
dict,
optional=True,
default=None,
doc=doc_multi_init_data,
),
Argument(
"valid_data_prefix",
str,
optional=True,
default=None,
doc=doc_valid_data_prefix,
),
Argument(
"valid_data_sys",
[List[str], str],
optional=True,
default=None,
doc=doc_valid_sys,
),
]


Expand Down
60 changes: 51 additions & 9 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def make_concurrent_learning_op(
collect_data_config: dict = default_config,
cl_step_config: dict = default_config,
upload_python_packages: Optional[List[os.PathLike]] = None,
valid_data: Optional[S3Artifact] = None,
):
if train_style in ("dp", "dp-dist"):
prep_run_train_op = PrepRunDPTrain(
Expand All @@ -154,6 +155,7 @@ def make_concurrent_learning_op(
prep_config=prep_train_config,
run_config=run_train_config,
upload_python_packages=upload_python_packages,
valid_data=valid_data,
)
else:
raise RuntimeError(f"unknown train_style {train_style}")
Expand Down Expand Up @@ -387,6 +389,7 @@ def make_finetune_step(
init_models,
init_data,
iter_data,
valid_data=None,
):
finetune_optional_parameter = {
"mixed_type": config["inputs"]["mixed_type"],
Expand All @@ -401,6 +404,7 @@ def make_finetune_step(
run_config=run_train_config,
upload_python_packages=upload_python_packages,
finetune=True,
valid_data=valid_data,
)
finetune_step = Step(
"finetune-step",
Expand Down Expand Up @@ -466,6 +470,15 @@ def workflow_concurrent_learning(
]
upload_python_packages = _upload_python_packages

valid_data = config["inputs"]["valid_data_sys"]
if valid_data is not None:
valid_data_prefix = config["inputs"]["valid_data_prefix"]
valid_data = [valid_data] if isinstance(valid_data, str) else valid_data
assert isinstance(valid_data, list)
if valid_data_prefix is not None:
valid_data = [os.path.join(valid_data_prefix, ii) for ii in valid_data]
valid_data = [expand_sys_str(ii) for ii in valid_data]
valid_data = upload_artifact(valid_data)
concurrent_learning_op = make_concurrent_learning_op(
train_style,
explore_style,
Expand All @@ -480,6 +493,7 @@ def workflow_concurrent_learning(
collect_data_config=collect_data_config,
cl_step_config=cl_step_config,
upload_python_packages=upload_python_packages,
valid_data=valid_data,
)
scheduler = make_naive_exploration_scheduler(config)

Expand All @@ -500,7 +514,7 @@ def workflow_concurrent_learning(
explore_config["teacher_model_path"]
), f"No such file: {explore_config['teacher_model_path']}"
explore_config["teacher_model_path"] = BinaryFileInput(
explore_config["teacher_model_path"], "pb"
explore_config["teacher_model_path"]
)

fp_config = {}
Expand All @@ -517,15 +531,37 @@ def workflow_concurrent_learning(
fp_config["run"]["teacher_model_path"]
), f"No such file: {fp_config['run']['teacher_model_path']}"
fp_config["run"]["teacher_model_path"] = BinaryFileInput(
fp_config["run"]["teacher_model_path"], "pb"
fp_config["run"]["teacher_model_path"]
)

init_data_prefix = config["inputs"]["init_data_prefix"]
init_data = config["inputs"]["init_data_sys"]
if init_data_prefix is not None:
init_data = [os.path.join(init_data_prefix, ii) for ii in init_data]
if isinstance(init_data, str):
init_data = expand_sys_str(init_data)
multitask = config["inputs"]["multitask"]
if multitask:
head = config["inputs"]["head"]
multi_init_data = config["inputs"]["multi_init_data"]
init_data = []
multi_init_data_idx = {}
for k, v in multi_init_data.items():
sys = v["sys"]
sys = [sys] if isinstance(sys, str) else sys
assert isinstance(sys, list)
if v["prefix"] is not None:
sys = [os.path.join(v["prefix"], ii) for ii in sys]
sys = [expand_sys_str(ii) for ii in sys]
istart = len(init_data)
init_data += sys
iend = len(init_data)
multi_init_data_idx[k] = list(range(istart, iend))
train_config["multitask"] = True
train_config["head"] = head
train_config["multi_init_data_idx"] = multi_init_data_idx
explore_config["head"] = head
else:
init_data_prefix = config["inputs"]["init_data_prefix"]
init_data = config["inputs"]["init_data_sys"]
if init_data_prefix is not None:
init_data = [os.path.join(init_data_prefix, ii) for ii in init_data]
if isinstance(init_data, str):
init_data = expand_sys_str(init_data)
init_data = upload_artifact(init_data)
iter_data = upload_artifact([])
if init_models_paths is not None:
Expand All @@ -550,6 +586,7 @@ def workflow_concurrent_learning(
init_models,
init_data,
iter_data,
valid_data=valid_data,
)

init_models = finetune_step.outputs.artifacts["models"]
Expand Down Expand Up @@ -734,7 +771,10 @@ def print_list_steps(

def successful_step_keys(wf):
all_step_keys = []
for step in wf.query_step():
steps = wf.query_step()
# For reused steps whose startedAt are identical, sort them by key
steps.sort(key=lambda x: "%s-%s" % (x.startedAt, x.key))
for step in steps:
if step.key is not None and step.phase == "Succeeded":
all_step_keys.append(step.key)
return all_step_keys
Expand Down Expand Up @@ -868,6 +908,8 @@ def resubmit_concurrent_learning(
reused_folded_keys[k] = [k]
reused_keys = sum(reused_folded_keys.values(), [])
reuse_step = old_wf.query_step(key=reused_keys)
# For reused steps whose startedAt are identical, sort them by key
reuse_step.sort(key=lambda x: "%s-%s" % (x.startedAt, x.key))

wf = submit_concurrent_learning(
wf_config,
Expand Down
5 changes: 2 additions & 3 deletions dpgen2/fp/deepmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@
# global static variables
deepmd_temp_path = "one_frame_temp"

# global static variables
deepmd_teacher_model = "teacher_model.pb"


class DeepmdInputs:
@staticmethod
Expand Down Expand Up @@ -136,6 +133,8 @@ def run_task(
def _get_dp_model(self, teacher_model_path: BinaryFileInput):
from deepmd.infer import DeepPot # type: ignore

ext = os.path.splitext(teacher_model_path.file_name)[-1]
deepmd_teacher_model = "teacher_model" + ext
teacher_model_path.save_as_file(deepmd_teacher_model)
dp = DeepPot(Path(deepmd_teacher_model))

Expand Down
19 changes: 14 additions & 5 deletions dpgen2/op/prep_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,27 @@ def execute(
)
return op

def _set_desc_seed(self, desc):
if desc["type"] == "hybrid":
for desc in desc["list"]:
self._set_desc_seed(desc)
elif desc["type"] not in ["dpa1", "dpa2"]:
desc["seed"] = random.randrange(sys.maxsize) % (2**32)

def _script_rand_seed(
self,
input_dict,
):
jtmp = input_dict.copy()
if jtmp["model"]["descriptor"]["type"] == "hybrid":
for desc in jtmp["model"]["descriptor"]["list"]:
desc["seed"] = random.randrange(sys.maxsize) % (2**32)
if "model_dict" in jtmp["model"]:
for d in jtmp["model"]["model_dict"].values():
if isinstance(d["descriptor"], str):
self._set_desc_seed(jtmp["model"]["shared_dict"][d["descriptor"]])
d["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (2**32)
else:
jtmp["model"]["descriptor"]["seed"] = random.randrange(sys.maxsize) % (
self._set_desc_seed(jtmp["model"]["descriptor"])
jtmp["model"]["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (
2**32
)
jtmp["model"]["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (2**32)
jtmp["training"]["seed"] = random.randrange(sys.maxsize) % (2**32)
return jtmp
Loading

0 comments on commit f2e1d59

Please sign in to comment.