From edd1be08583dc78e40692d5ecb88f1ce4ba2d01f Mon Sep 17 00:00:00 2001 From: ishachirimar Date: Mon, 18 Mar 2024 14:33:32 -0400 Subject: [PATCH] Introduce parameter for config to materialize_appdef Differential Revision: D54777218 Pull Request resolved: https://github.com/pytorch/torchx/pull/848 --- torchx/specs/builders.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/torchx/specs/builders.py b/torchx/specs/builders.py index fc93c3d27..4f0653a3b 100644 --- a/torchx/specs/builders.py +++ b/torchx/specs/builders.py @@ -18,7 +18,9 @@ def _create_args_parser( - cmpnt_fn: Callable[..., AppDef], cmpnt_defaults: Optional[Dict[str, str]] = None + cmpnt_fn: Callable[..., AppDef], + cmpnt_defaults: Optional[Dict[str, str]] = None, + config: Optional[Dict[str, Any]] = None, ) -> argparse.ArgumentParser: parameters = inspect.signature(cmpnt_fn).parameters function_desc, args_desc = get_fn_docstring(cmpnt_fn) @@ -81,15 +83,26 @@ def __call__( if len(param_name) == 1: arg_names = [f"-{param_name}"] + arg_names if "default" not in args: - args["required"] = True + if (config and param_name not in config) or not config: + args["required"] = True + script_parser.add_argument(*arg_names, **args) return script_parser +def _merge_config_values_with_args( + parsed_args: argparse.Namespace, config: Dict[str, Any] +) -> None: + for key, val in config.items(): + if key in parsed_args: + setattr(parsed_args, key, val) + + def materialize_appdef( cmpnt_fn: Callable[..., AppDef], cmpnt_args: List[str], cmpnt_defaults: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, ) -> AppDef: """ Creates an application by running user defined ``app_fn``. @@ -114,12 +127,15 @@ def materialize_appdef( cmpnt_args: Function args cmpnt_defaults: Additional default values for parameters of ``app_fn`` (overrides the defaults set on the fn declaration) + config: Optional dict containing additional configuration for the component from a passed config file Returns: An application spec """ - script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults) + script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults, config) parsed_args = script_parser.parse_args(cmpnt_args) + if config: + _merge_config_values_with_args(parsed_args, config) function_args = [] var_arg = []