diff --git a/dpgen2/op/run_dp_train.py b/dpgen2/op/run_dp_train.py index 51514657..b7f260ab 100644 --- a/dpgen2/op/run_dp_train.py +++ b/dpgen2/op/run_dp_train.py @@ -280,7 +280,13 @@ def execute( auto_prob_str = "prob_sys_size" if do_init_model: old_ratio = config["init_model_old_ratio"] - numb_old = len(init_data) + len(iter_data_old_exp) + if config["multitask"]: + head = config["head"] + multi_init_data_idx = config["multi_init_data_idx"] + len_init = len(multi_init_data_idx[head]) + else: + len_init = len(init_data) + numb_old = len_init + len(iter_data_old_exp) numb_new = numb_old + len(iter_data_new_exp) auto_prob_str = f"prob_sys_size; 0:{numb_old}:{old_ratio}; {numb_old}:{numb_new}:{1.-old_ratio:g}" @@ -418,6 +424,7 @@ def write_data_to_input_script( ] if k == head: v["training_data"]["systems"] += [str(ii) for ii in iter_data] + v["training_data"]["auto_prob"] = auto_prob_str return odict data_list = [str(ii) for ii in init_data] + [str(ii) for ii in iter_data] if major_version == "1":