Skip to content

Commit

Permalink
Merge branch 'deepmd-v3' into fix-dp-optim-parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzyphysics committed Mar 28, 2024
2 parents 898aebd + a6f3462 commit 37792e9
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 76 deletions.
5 changes: 3 additions & 2 deletions dpgen2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
train_task_pattern = "task." + train_index_pattern
train_script_name = "input.json"
train_log_name = "train.log"
model_name_pattern = "model.%03d"
model_name_match_pattern = r"model\.[0-9]{3,}"
model_name_pattern = "model.%03d.pb"
pytorch_model_name_pattern = "model.%03d.pth"
model_name_match_pattern = r"model\.[0-9]{3,}(\.pb|\.pth)"
lmp_index_pattern = "%06d"
lmp_task_pattern = "task." + lmp_index_pattern
lmp_conf_name = "conf.lmp"
Expand Down
9 changes: 7 additions & 2 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def workflow_concurrent_learning(
train_config["multitask"] = True
train_config["head"] = head
train_config["multi_init_data_idx"] = multi_init_data_idx
lmp_config["head"] = head
explore_config["head"] = head
else:
init_data_prefix = config["inputs"]["init_data_prefix"]
init_data = config["inputs"]["init_data_sys"]
Expand Down Expand Up @@ -771,7 +771,10 @@ def print_list_steps(

def successful_step_keys(wf):
all_step_keys = []
for step in wf.query_step():
steps = wf.query_step()
# For reused steps whose startedAt are identical, sort them by key
steps.sort(key=lambda x: "%s-%s" % (x.startedAt, x.key))
for step in steps:
if step.key is not None and step.phase == "Succeeded":
all_step_keys.append(step.key)
return all_step_keys
Expand Down Expand Up @@ -905,6 +908,8 @@ def resubmit_concurrent_learning(
reused_folded_keys[k] = [k]
reused_keys = sum(reused_folded_keys.values(), [])
reuse_step = old_wf.query_step(key=reused_keys)
# For reused steps whose startedAt are identical, sort them by key
reuse_step.sort(key=lambda x: "%s-%s" % (x.startedAt, x.key))

wf = submit_concurrent_learning(
wf_config,
Expand Down
12 changes: 7 additions & 5 deletions dpgen2/op/prep_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def execute(
def _set_desc_seed(self, desc):
if desc["type"] == "hybrid":
for desc in desc["list"]:
desc["seed"] = random.randrange(sys.maxsize) % (2**32)
self._set_desc_seed(desc)
elif desc["type"] not in ["dpa1", "dpa2"]:
desc["seed"] = random.randrange(sys.maxsize) % (2**32)

Expand All @@ -119,13 +119,15 @@ def _script_rand_seed(
input_dict,
):
jtmp = input_dict.copy()
if "shared_dict" in jtmp["model"] and "model_dict" in jtmp["model"]:
if "dpa1_dpau_descriptor_1" in jtmp["model"]["shared_dict"]:
self._set_desc_seed(jtmp["model"]["shared_dict"]["dpa1_dpau_descriptor_1"])
if "model_dict" in jtmp["model"]:
for d in jtmp["model"]["model_dict"].values():
if isinstance(d["descriptor"], str):
self._set_desc_seed(jtmp["model"]["shared_dict"][d["descriptor"]])
d["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (2**32)
else:
self._set_desc_seed(jtmp["model"]["descriptor"])
jtmp["model"]["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (2**32)
jtmp["model"]["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (
2**32
)
jtmp["training"]["seed"] = random.randrange(sys.maxsize) % (2**32)
return jtmp
99 changes: 69 additions & 30 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,11 @@ def execute(
finetune_mode = ip["optional_parameter"]["finetune_mode"]
config = ip["config"] if ip["config"] is not None else {}
impl = ip["config"].get("impl", "tensorflow")
if impl == "tensorflow":
dp_command = ["dp"]
elif impl == "pytorch":
assert impl in ["tensorflow", "pytorch"]
if impl == "pytorch":
dp_command = ["dp", "--pt"]
else:
dp_command = ["dp"]
finetune_args = config.get("finetune_args", "")
config = RunDPTrain.normalize_config(config)
task_name = ip["task_name"]
Expand Down Expand Up @@ -167,7 +168,13 @@ def execute(

# update the input dict
train_dict = RunDPTrain.write_data_to_input_script(
train_dict, config, init_data, iter_data_exp, auto_prob_str, major_version, valid_data
train_dict,
config,
init_data,
iter_data_exp,
auto_prob_str,
major_version,
valid_data,
)
train_dict = RunDPTrain.write_other_to_input_script(
train_dict, config, do_init_model, major_version
Expand Down Expand Up @@ -198,32 +205,53 @@ def clean_before_quit():

# train model
if impl == "tensorflow" and os.path.isfile("checkpoint"):
command = dp_command + ["train", "--restart", "model.ckpt", train_script_name]
command = dp_command + [
"train",
"--restart",
"model.ckpt",
train_script_name,
]
elif impl == "pytorch" and len(glob.glob("model.ckpt-[0-9]*.pt")) > 0:
checkpoint = "model.ckpt-%s.pt" % max([int(f[11:-3]) for f in glob.glob("model.ckpt-[0-9]*.pt")])
command = dp_command + ["train", "--restart", checkpoint, train_script_name]
elif (do_init_model or finetune_mode == "train-init") and not config["init_model_with_finetune"]:
if impl == "tensorflow":
checkpoint = "model.ckpt-%s.pt" % max(
[int(f[11:-3]) for f in glob.glob("model.ckpt-[0-9]*.pt")]
)
command = dp_command + [
"train",
"--restart",
checkpoint,
train_script_name,
]
elif (do_init_model or finetune_mode == "train-init") and not config[
"init_model_with_finetune"
]:
if impl == "pytorch":
command = dp_command + [
"train",
"--init-frz-model",
"--init-model",
str(init_model),
train_script_name,
]
elif impl == "pytorch":
else:
command = dp_command + [
"train",
"--init-model",
"--init-frz-model",
str(init_model),
train_script_name,
]
elif finetune_mode == "finetune" or ((do_init_model or finetune_mode == "train-init") and config["init_model_with_finetune"]):
command = dp_command + [
"train",
train_script_name,
"--finetune",
str(init_model),
] + finetune_args.split()
elif finetune_mode == "finetune" or (
(do_init_model or finetune_mode == "train-init")
and config["init_model_with_finetune"]
):
command = (
dp_command
+ [
"train",
train_script_name,
"--finetune",
str(init_model),
]
+ finetune_args.split()
)
else:
command = dp_command + ["train", train_script_name]
ret, out, err = run_command(command)
Expand Down Expand Up @@ -252,7 +280,9 @@ def clean_before_quit():
shutil.copy2("input_v2_compat.json", train_script_name)

# freeze model
if impl == "tensorflow":
if impl == "pytorch":
model_file = "model.ckpt.pt"
else:
ret, out, err = run_command(["dp", "freeze", "-o", "frozen_model.pb"])
if ret != 0:
clean_before_quit()
Expand All @@ -271,8 +301,6 @@ def clean_before_quit():
)
raise FatalError("dp freeze failed")
model_file = "frozen_model.pb"
elif impl == "pytorch":
model_file = "model.ckpt.pt"
fplog.write("#=================== freeze std out ===================\n")
fplog.write(out)
fplog.write("#=================== freeze std err ===================\n")
Expand Down Expand Up @@ -306,10 +334,11 @@ def write_data_to_input_script(
for k, v in odict["training"]["data_dict"].items():
v["training_data"]["systems"] = []
if k in multi_init_data_idx:
v["training_data"]["systems"] += [str(init_data[ii]) for ii in multi_init_data_idx[k]]
v["training_data"]["systems"] += [
str(init_data[ii]) for ii in multi_init_data_idx[k]
]
if k == head:
v["training_data"]["systems"] += [str(ii) for ii in iter_data]
v.pop("validation_data", None)
return odict
data_list = [str(ii) for ii in init_data] + [str(ii) for ii in iter_data]
if major_version == "1":
Expand Down Expand Up @@ -349,9 +378,16 @@ def write_other_to_input_script(
odict["training"]["disp_file"] = "lcurve.out"
if do_init_model:
odict["learning_rate"]["start_lr"] = config["init_model_start_lr"]
odict["loss"]["start_pref_e"] = config["init_model_start_pref_e"]
odict["loss"]["start_pref_f"] = config["init_model_start_pref_f"]
odict["loss"]["start_pref_v"] = config["init_model_start_pref_v"]
if "loss_dict" in odict:
for v in odict["loss_dict"].values():
if isinstance(v, dict):
v["start_pref_e"] = config["init_model_start_pref_e"]
v["start_pref_f"] = config["init_model_start_pref_f"]
v["start_pref_v"] = config["init_model_start_pref_v"]
else:
odict["loss"]["start_pref_e"] = config["init_model_start_pref_e"]
odict["loss"]["start_pref_f"] = config["init_model_start_pref_f"]
odict["loss"]["start_pref_v"] = config["init_model_start_pref_v"]
if major_version == "1":
odict["training"]["stop_batch"] = config["init_model_numb_steps"]
elif major_version == "2":
Expand Down Expand Up @@ -420,7 +456,7 @@ def decide_init_model(

@staticmethod
def training_args():
doc_impl = "The implementation of DP. It can be 'tensorflow' or 'pytorch'. 'tensorflow' for default."
doc_impl = "The implementation/backend of DP. It can be 'tensorflow' or 'pytorch'. 'tensorflow' for default."
doc_init_model_policy = "The policy of init-model training. It can be\n\n\
- 'no': No init-model training. Traing from scratch.\n\n\
- 'yes': Do init-model training.\n\n\
Expand All @@ -440,7 +476,9 @@ def training_args():
doc_finetune_args = "Extra arguments for finetuning"
doc_multitask = "Do multitask training"
doc_head = "Head to use in the multitask training"
doc_multi_init_data_idx = "A dict mapping from task name to list of indices in the init data"
doc_multi_init_data_idx = (
"A dict mapping from task name to list of indices in the init data"
)
doc_init_model_with_finetune = "Use finetune for init model"
return [
Argument(
Expand All @@ -449,6 +487,7 @@ def training_args():
optional=True,
default="tensorflow",
doc=doc_impl,
alias=["backend"],
),
Argument(
"init_model_policy",
Expand Down Expand Up @@ -534,7 +573,7 @@ def training_args():
optional=True,
default=None,
doc=doc_multi_init_data_idx,
)
),
]

@staticmethod
Expand Down
51 changes: 34 additions & 17 deletions dpgen2/op/run_lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
model_name_match_pattern,
model_name_pattern,
plm_output_name,
pytorch_model_name_pattern,
)
from dpgen2.utils import (
BinaryFileInput,
Expand Down Expand Up @@ -143,30 +144,39 @@ def execute(
model_names = []
for idx, mm in enumerate(model_files):
ext = os.path.splitext(mm)[-1]
mname = model_name_pattern % (idx) + ext
model_names.append(mname)
if ext == ".pb":
mname = model_name_pattern % (idx)
Path(mname).symlink_to(mm)
elif ext == ".pt":
# freeze model
mname = pytorch_model_name_pattern % (idx)
freeze_args = "-o %s" % mname
if config.get("head") is not None:
freeze_args += " --head %s" % config["head"]
freeze_cmd = "dp --pt freeze -c %s %s" % (mm, freeze_args)
ret, out, err = run_command(freeze_cmd, shell=True)
if ret != 0:
logging.error("".join((
"freeze failed\n",
"command was",
freeze_cmd,
"out msg",
out,
"\n",
"err msg",
err,
"\n",
)))
logging.error(
"".join(
(
"freeze failed\n",
"command was",
freeze_cmd,
"out msg",
out,
"\n",
"err msg",
err,
"\n",
)
)
)
raise TransientError("freeze failed")
else:
raise RuntimeError(
"Model file with extension '%s' is not supported" % ext
)
model_names.append(mname)

if shuffle_models:
random.shuffle(model_names)
Expand Down Expand Up @@ -249,7 +259,11 @@ def set_models(lmp_input_name: str, model_names: List[str]):
with open(lmp_input_name, encoding="utf8") as f:
lmp_input_lines = f.readlines()

idx = find_only_one_key(lmp_input_lines, ["pair_style", "deepmd"])
idx = find_only_one_key(
lmp_input_lines, ["pair_style", "deepmd"], raise_not_found=False
)
if idx is None:
return
new_line_split = lmp_input_lines[idx].split()
match_first = -1
match_last = -1
Expand All @@ -275,13 +289,13 @@ def set_models(lmp_input_name: str, model_names: List[str]):
f"in line {lmp_input_lines[idx]}"
)
new_line_split[match_first:match_last] = model_names
lmp_input_lines[idx] = " ".join(new_line_split)
lmp_input_lines[idx] = " ".join(new_line_split) + "\n"

with open(lmp_input_name, "w", encoding="utf8") as f:
f.write("".join(lmp_input_lines))


def find_only_one_key(lmp_lines, key):
def find_only_one_key(lmp_lines, key, raise_not_found=True):
found = []
for idx in range(len(lmp_lines)):
words = lmp_lines[idx].split()
Expand All @@ -291,5 +305,8 @@ def find_only_one_key(lmp_lines, key):
if len(found) > 1:
raise RuntimeError("found %d keywords %s" % (len(found), key))
if len(found) == 0:
raise RuntimeError("failed to find keyword %s" % (key))
if raise_not_found:
raise RuntimeError("failed to find keyword %s" % (key))
else:
return None
return found[0]
Loading

0 comments on commit 37792e9

Please sign in to comment.