Skip to content

Commit

Permalink
support the resubmission of workflows with finetuning (#165)
Browse files Browse the repository at this point in the history
- skip the init model training in the first iteration (iter0) if we do
finetuning.

---------

Co-authored-by: Han Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 29, 2023
1 parent 39f43ce commit dfbdc0a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 14 deletions.
5 changes: 4 additions & 1 deletion dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,9 @@ def submit_concurrent_learning(
"conf_selector",
selector,
)
wf_config["inputs"]["do_finetune"] = False
# the modify-train-script step will be added as reuse step.
# the following hack is not needed anymore.
# wf_config["inputs"]["do_finetune"] = False
# finetune will not be done again if the old process is reused.

wf = Workflow(name="dpgen")
Expand Down Expand Up @@ -759,6 +761,7 @@ def get_resubmit_keys(
[
"prep-train",
"run-train",
"modify-train-script",
"prep-lmp",
"run-lmp",
"select-confs",
Expand Down
7 changes: 3 additions & 4 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,10 @@ def skip_training(
iter_data,
finetune_mode,
):
# we have init model and no iter data, skip training
if finetune_mode is not None and (
finetune_mode == "train-init" or finetune_mode == "finetune"
):
# do not skip if we do finetuning
if finetune_mode is not None and finetune_mode == "finetune":
return False
# we have init model and no iter data, skip training
if (init_model is not None) and (iter_data is None or len(iter_data) == 0):
with set_directory(work_dir):
with open(train_script_name, "w") as fp:
Expand Down
11 changes: 9 additions & 2 deletions dpgen2/utils/dflow_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def matched_step_key(
if (
re.match(f"iter-[0-9]*--{jj}-[0-9]*", kk)
or re.match(f"iter-[0-9]*--{jj}", kk)
or re.match(f"finetune--{jj}-[0-9]*", kk)
or re.match(f"finetune--{jj}", kk)
or re.match(f"init--{jj}", kk)
):
ret.append(kk)
Expand Down Expand Up @@ -116,11 +118,16 @@ def find_slice_ranges(
status = "not-found"
for idx, ii in enumerate(keys):
if status == "not-found":
if re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii):
if re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii) or re.match(
f"finetune--{sliced_subkey}-[0-9]*", ii
):
status = "found"
tmp_range.append(idx)
elif status == "found":
if not re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii):
if not (
re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii)
or re.match(f"finetune--{sliced_subkey}-[0-9]*", ii)
):
status = "not-found"
tmp_range.append(idx)
found_range.append(tmp_range)
Expand Down
2 changes: 1 addition & 1 deletion tests/fake_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def fake_system(
ss = dpdata.LabeledSystem()
ss.data["atom_names"] = [atom_name]
ss.data["atom_numbs"] = [natoms]
ss.data["atom_types"] = [0 for ii in range(natoms)]
ss.data["atom_types"] = np.array([0 for ii in range(natoms)]).astype(int)
# ss.data['cells'] = np.zeros([nframes, 3, 3])
ss.data["cells"] = np.tile(np.eye(3), [nframes, 1, 1])
ss.data["coords"] = np.zeros([nframes, natoms, 3])
Expand Down
28 changes: 22 additions & 6 deletions tests/utils/test_dflow_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@
# isort: on

dpgen_keys = [
"finetune--prep-train",
"finetune--run-train-0002",
"finetune--run-train-0000",
"finetune--run-train-0001",
"finetune--modify-train-script",
"finetune--prep-run-train",
"init--scheduler",
"init--id",
"iter-000000--prep-train",
Expand Down Expand Up @@ -222,6 +228,12 @@ def test_sort_slice_ops(self):

def test_sort_slice_ops(self):
expected_output = [
"finetune--prep-train",
"finetune--run-train-0000",
"finetune--run-train-0001",
"finetune--run-train-0002",
"finetune--modify-train-script",
"finetune--prep-run-train",
"init--scheduler",
"init--id",
"iter-000000--prep-train",
Expand Down Expand Up @@ -260,16 +272,20 @@ def test_sort_slice_ops(self):

def test_print_keys(self):
expected_output = [
" 0 : init--scheduler",
" 1 : init--id",
" 2 : iter-000000--prep-train",
" 3 -> 5 : iter-000000--run-train-0000 -> iter-000000--run-train-0002",
" 6 : iter-000000--prep-run-train",
" 0 : finetune--prep-train",
" 1 -> 3 : finetune--run-train-0000 -> finetune--run-train-0002",
" 4 : finetune--modify-train-script",
" 5 : finetune--prep-run-train",
" 6 : init--scheduler",
" 7 : init--id",
" 8 : iter-000000--prep-train",
" 9 -> 11 : iter-000000--run-train-0000 -> iter-000000--run-train-0002",
" 12 : iter-000000--prep-run-train",
]
expected_output = "\n".join(expected_output + [""])

ret = print_keys_in_nice_format(
dpgen_keys[:7],
dpgen_keys[:13],
["run-train", "run-lmp", "run-fp"],
idx_fmt_len=8,
)
Expand Down

0 comments on commit dfbdc0a

Please sign in to comment.