Skip to content

Commit

Permalink
support specifying URI of init_data, multi_init_data, valid_data, ini…
Browse files Browse the repository at this point in the history
…t_model (#227)

…t_models

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced new arguments for handling model and data URIs to enhance
data management capabilities.
- Added functions for fetching and uploading artifacts, improving
workflow efficiency.

- **Refactor**
- Refactored data handling logic to support concurrent learning
workflows and better manage input data structures.

- **Documentation**
- Updated documentation for new arguments and functions to guide users
on their usage.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: zjgemi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zjgemi and pre-commit-ci[bot] authored Jun 4, 2024
1 parent d3f52d8 commit ff2aed7
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 52 deletions.
40 changes: 40 additions & 0 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def dp_dist_train_args():
doc_config = "Configuration of training"
doc_template_script = "File names of the template training script. It can be a `List[str]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `str`, the models share the same template training script. "
dock_student_model_path = "The path of student model"
doc_student_model_uri = "The URI of student model"

return [
Argument(
Expand All @@ -59,6 +60,13 @@ def dp_dist_train_args():
"template_script", [List[str], str], optional=False, doc=doc_template_script
),
Argument("student_model_path", str, optional=True, doc=dock_student_model_path),
Argument(
"student_model_uri",
str,
optional=True,
default=None,
doc=doc_student_model_uri,
),
]


Expand All @@ -67,6 +75,7 @@ def dp_train_args():
doc_config = "Configuration of training"
doc_template_script = "File names of the template training script. It can be a `List[str]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `str`, the models share the same template training script. "
doc_init_models_paths = "the paths to initial models"
doc_init_models_uri = "The URI of initial models"

return [
Argument(
Expand All @@ -89,6 +98,13 @@ def dp_train_args():
doc=doc_init_models_paths,
alias=["training_iter0_model_path"],
),
Argument(
"init_models_uri",
str,
optional=True,
default=None,
doc=doc_init_models_uri,
),
]


Expand Down Expand Up @@ -354,14 +370,17 @@ def input_args():
doc_do_finetune = textwrap.dedent(doc_do_finetune)
doc_init_data_prefix = "The prefix of initial data systems"
doc_init_sys = "The inital data systems"
doc_init_data_uri = "The URI of initial data"
doc_multitask = "Do multitask training"
doc_head = "Head to use in the multitask training"
doc_multi_init_data = (
"The inital data for multitask, it should be a dict, whose keys are task names and each value is a dict"
"containing fields `prefix` and `sys` for initial data of each task"
)
doc_multi_init_data_uri = "The URI of initial data for multitask"
doc_valid_data_prefix = "The prefix of validation data systems"
doc_valid_sys = "The validation data systems"
doc_valid_data_uri = "The URI of validation data"

return [
Argument("type_map", List[str], optional=False, doc=doc_type_map),
Expand All @@ -384,6 +403,13 @@ def input_args():
default=None,
doc=doc_init_sys,
),
Argument(
"init_data_uri",
str,
optional=True,
default=None,
doc=doc_init_data_uri,
),
Argument(
"multitask",
bool,
Expand All @@ -405,6 +431,13 @@ def input_args():
default=None,
doc=doc_multi_init_data,
),
Argument(
"multi_init_data_uri",
str,
optional=True,
default=None,
doc=doc_multi_init_data_uri,
),
Argument(
"valid_data_prefix",
str,
Expand All @@ -419,6 +452,13 @@ def input_args():
default=None,
doc=doc_valid_sys,
),
Argument(
"valid_data_uri",
str,
optional=True,
default=None,
doc=doc_valid_data_uri,
),
]


Expand Down
70 changes: 39 additions & 31 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,13 @@
BinaryFileInput,
bohrium_config_from_dict,
dump_object_to_file,
get_artifact_from_uri,
get_subkey,
load_object_from_file,
matched_step_key,
print_keys_in_nice_format,
sort_slice_ops,
upload_artifact_and_print_uri,
workflow_config_from_dict,
)
from dpgen2.utils.step_config import normalize as normalize_step_dict
Expand Down Expand Up @@ -457,6 +459,15 @@ def make_finetune_step(
return finetune_step


def get_systems_from_data(data, data_prefix=None):
data = [data] if isinstance(data, str) else data
assert isinstance(data, list)
if data_prefix is not None:
data = [os.path.join(data_prefix, ii) for ii in data]
data = sum([expand_sys_str(ii) for ii in data], [])
return data


def workflow_concurrent_learning(
config: Dict,
) -> Tuple[Step, Optional[Step]]:
Expand Down Expand Up @@ -505,14 +516,12 @@ def workflow_concurrent_learning(
upload_python_packages = _upload_python_packages

valid_data = config["inputs"]["valid_data_sys"]
if valid_data is not None:
if config["inputs"]["valid_data_uri"] is not None:
valid_data = get_artifact_from_uri(config["inputs"]["valid_data_uri"])
elif valid_data is not None:
valid_data_prefix = config["inputs"]["valid_data_prefix"]
valid_data = [valid_data] if isinstance(valid_data, str) else valid_data
assert isinstance(valid_data, list)
if valid_data_prefix is not None:
valid_data = [os.path.join(valid_data_prefix, ii) for ii in valid_data]
valid_data = [expand_sys_str(ii) for ii in valid_data]
valid_data = upload_artifact(valid_data)
valid_data = get_systems_from_data(valid_data, valid_data_prefix)
valid_data = upload_artifact_and_print_uri(valid_data, "valid_data")
concurrent_learning_op = make_concurrent_learning_op(
train_style,
explore_style,
Expand Down Expand Up @@ -570,35 +579,34 @@ def workflow_concurrent_learning(
multitask = config["inputs"]["multitask"]
if multitask:
head = config["inputs"]["head"]
multi_init_data = config["inputs"]["multi_init_data"]
init_data = []
multi_init_data_idx = {}
for k, v in multi_init_data.items():
sys = v["sys"]
sys = [sys] if isinstance(sys, str) else sys
assert isinstance(sys, list)
if v["prefix"] is not None:
sys = [os.path.join(v["prefix"], ii) for ii in sys]
sys = [expand_sys_str(ii) for ii in sys]
istart = len(init_data)
init_data += sys
iend = len(init_data)
multi_init_data_idx[k] = list(range(istart, iend))
if config["inputs"]["multi_init_data_uri"] is not None:
init_data = get_artifact_from_uri(config["inputs"]["multi_init_data_uri"])
else:
multi_init_data = config["inputs"]["multi_init_data"]
init_data = {}
for k, v in multi_init_data.items():
sys = v["sys"]
sys = get_systems_from_data(sys, v.get("prefix", None))
init_data[k] = sys
init_data = upload_artifact_and_print_uri(init_data, "multi_init_data")
train_config["multitask"] = True
train_config["head"] = head
train_config["multi_init_data_idx"] = multi_init_data_idx
explore_config["head"] = head
else:
init_data_prefix = config["inputs"]["init_data_prefix"]
init_data = config["inputs"]["init_data_sys"]
if init_data_prefix is not None:
init_data = [os.path.join(init_data_prefix, ii) for ii in init_data]
if isinstance(init_data, str):
init_data = expand_sys_str(init_data)
init_data = upload_artifact(init_data)
if config["inputs"]["init_data_uri"] is not None:
init_data = get_artifact_from_uri(config["inputs"]["init_data_uri"])
else:
init_data_prefix = config["inputs"]["init_data_prefix"]
init_data = config["inputs"]["init_data_sys"]
init_data = get_systems_from_data(init_data, init_data_prefix)
init_data = upload_artifact_and_print_uri(init_data, "init_data")
iter_data = upload_artifact([])
if init_models_paths is not None:
init_models = upload_artifact(init_models_paths)
if train_style == "dp" and config["train"]["init_models_uri"] is not None:
init_models = get_artifact_from_uri(config["train"]["init_models_uri"])
elif train_style == "dp-dist" and config["train"]["student_model_uri"] is not None:
init_models = get_artifact_from_uri(config["train"]["student_model_uri"])
elif init_models_paths is not None:
init_models = upload_artifact_and_print_uri(init_models_paths, "init_models")
else:
init_models = None

Expand Down
35 changes: 14 additions & 21 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Tuple,
Union,
)

import dpdata
Expand All @@ -26,6 +27,7 @@
Artifact,
BigParameter,
FatalError,
NestedDict,
OPIOSign,
Parameter,
TransientError,
Expand Down Expand Up @@ -187,7 +189,7 @@ def get_input_sign(cls):
),
"task_path": Artifact(Path),
"init_model": Artifact(Path, optional=True),
"init_data": Artifact(List[Path]),
"init_data": Artifact(NestedDict[Path]),
"iter_data": Artifact(List[Path]),
"valid_data": Artifact(List[Path], optional=True),
}
Expand Down Expand Up @@ -220,7 +222,7 @@ def execute(
- `task_name`: (`str`) The name of training task.
- `task_path`: (`Artifact(Path)`) The path that contains all input files prepareed by `PrepDPTrain`.
- `init_model`: (`Artifact(Path)`) A frozen model to initialize the training.
- `init_data`: (`Artifact(List[Path])`) Initial training data.
- `init_data`: (`Artifact(NestedDict[Path])`) Initial training data.
- `iter_data`: (`Artifact(List[Path])`) Training data generated in the DPGEN iterations.
Returns
Expand Down Expand Up @@ -282,8 +284,7 @@ def execute(
old_ratio = config["init_model_old_ratio"]
if config["multitask"]:
head = config["head"]
multi_init_data_idx = config["multi_init_data_idx"]
len_init = len(multi_init_data_idx[head])
len_init = len(init_data[head])
else:
len_init = len(init_data)
numb_old = len_init + len(iter_data_old_exp)
Expand Down Expand Up @@ -406,7 +407,7 @@ def clean_before_quit():
def write_data_to_input_script(
idict: dict,
config,
init_data: List[Path],
init_data: Union[List[Path], Dict[str, List[Path]]],
iter_data: List[Path],
auto_prob_str: str = "prob_sys_size",
major_version: str = "1",
Expand All @@ -415,13 +416,10 @@ def write_data_to_input_script(
odict = idict.copy()
if config["multitask"]:
head = config["head"]
multi_init_data_idx = config["multi_init_data_idx"]
for k, v in odict["training"]["data_dict"].items():
v["training_data"]["systems"] = []
if k in multi_init_data_idx:
v["training_data"]["systems"] += [
str(init_data[ii]) for ii in multi_init_data_idx[k]
]
if k in init_data:
v["training_data"]["systems"] += [str(ii) for ii in init_data[k]]
if k == head:
v["training_data"]["systems"] += [str(ii) for ii in iter_data]
v["training_data"]["auto_prob"] = auto_prob_str
Expand Down Expand Up @@ -531,7 +529,12 @@ def decide_init_model(
do_init_model = True
elif "old_data_larger_than" in config["init_model_policy"]:
old_data_size_level = int(config["init_model_policy"].split(":")[-1])
init_data_size = _get_data_size_of_all_systems(init_data)
if isinstance(init_data, dict):
init_data_size = _get_data_size_of_all_systems(
sum(init_data.values(), [])
)
else:
init_data_size = _get_data_size_of_all_systems(init_data)
iter_data_old_size = _get_data_size_of_all_mult_sys(
iter_data[:-1], mixed_type=mixed_type
)
Expand Down Expand Up @@ -562,9 +565,6 @@ def training_args():
doc_finetune_args = "Extra arguments for finetuning"
doc_multitask = "Do multitask training"
doc_head = "Head to use in the multitask training"
doc_multi_init_data_idx = (
"A dict mapping from task name to list of indices in the init data"
)
doc_init_model_with_finetune = "Use finetune for init model"
return [
Argument(
Expand Down Expand Up @@ -653,13 +653,6 @@ def training_args():
default=None,
doc=doc_head,
),
Argument(
"multi_init_data_idx",
dict,
optional=True,
default=None,
doc=doc_multi_init_data_idx,
),
]

@staticmethod
Expand Down
4 changes: 4 additions & 0 deletions dpgen2/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .artifact_uri import (
get_artifact_from_uri,
upload_artifact_and_print_uri,
)
from .binary_file_input import (
BinaryFileInput,
)
Expand Down
23 changes: 23 additions & 0 deletions dpgen2/utils/artifact_uri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dflow import (
S3Artifact,
s3_config,
upload_artifact,
)


def get_artifact_from_uri(uri):
if uri.startswith("s3://"):
return S3Artifact(uri[5:])
elif uri.startswith("oss://"):
return S3Artifact(uri[6:])
else:
raise ValueError("Unrecognized scheme of URI: %s" % uri)


def upload_artifact_and_print_uri(files, name):
art = upload_artifact(files)
if s3_config["repo_type"] == "s3" and hasattr(art, "key"):
print("%s has been uploaded to s3://%s" % (name, art.key))
elif s3_config["repo_type"] == "oss" and hasattr(art, "key"):
print("%s has been uploaded to oss://%s" % (name, art.key))
return art

0 comments on commit ff2aed7

Please sign in to comment.