Skip to content

Commit

Permalink
Accept non-string complex values in defaults argument in materialize_…
Browse files Browse the repository at this point in the history
…appdef

Differential Revision: D54507627

Pull Request resolved: #846
  • Loading branch information
ishachirimar committed Mar 7, 2024
1 parent cc9e217 commit 4c2eee5
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 13 deletions.
15 changes: 3 additions & 12 deletions torchx/specs/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@

from torchx.specs.api import BindMount, MountType, VolumeMount
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
from torchx.util.types import (
decode_from_string,
decode_optional,
get_argparse_param_type,
is_bool,
is_primitive,
)
from torchx.util.types import decode, decode_optional, get_argparse_param_type, is_bool

from .api import AppDef, DeviceMount

Expand Down Expand Up @@ -93,7 +87,7 @@ def __call__(
def materialize_appdef(
cmpnt_fn: Callable[..., AppDef],
cmpnt_args: List[str],
cmpnt_defaults: Optional[Dict[str, str]] = None,
cmpnt_defaults: Optional[Dict[str, Any]] = None,
) -> AppDef:
"""
Creates an application by running user defined ``app_fn``.
Expand Down Expand Up @@ -134,10 +128,7 @@ def materialize_appdef(
arg_value = getattr(parsed_args, param_name)
parameter_type = parameter.annotation
parameter_type = decode_optional(parameter_type)
if is_bool(parameter_type):
arg_value = arg_value and arg_value.lower() == "true"
elif not is_primitive(parameter_type):
arg_value = decode_from_string(arg_value, parameter_type)
arg_value = decode(arg_value, parameter_type)
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg = arg_value
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
Expand Down
53 changes: 53 additions & 0 deletions torchx/specs/test/builders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _test_complex_fn(
num_gpus: Optional[Dict[str, int]] = None,
nnodes: int = 4,
first_arg: Optional[str] = None,
nested_arg: Optional[Dict[str, List[str]]] = None,
*roles_args: str,
) -> AppDef:
"""Creates complex application, testing all possible complex types
Expand All @@ -110,6 +111,8 @@ def _test_complex_fn(
gpus = num_gpus[role_name]
if first_arg:
args = [first_arg, *roles_args]
elif nested_arg:
args = nested_arg[role_name]
else:
args = [*roles_args]
role = Role(
Expand Down Expand Up @@ -171,6 +174,9 @@ def assert_apps(self, expected_app: AppDef, actual_app: AppDef) -> None:
def _get_role_args(self) -> List[str]:
return ["--train", "data_source", "random", "--epochs", "128"]

def _get_nested_arg(self) -> Dict[str, List[str]]:
return {"worker": ["1", "2"], "master": ["3", "4"]}

def _get_expected_app_with_default(self) -> AppDef:
role_args = self._get_role_args()
return _test_complex_fn(
Expand All @@ -181,6 +187,7 @@ def _get_expected_app_with_default(self) -> AppDef:
None,
4,
None,
None,
*role_args,
)

Expand All @@ -207,6 +214,7 @@ def _get_expected_app_with_all_args(self) -> AppDef:
{"worker": 1, "master": 4},
8,
"first_arg",
None,
*role_args,
)

Expand All @@ -231,6 +239,45 @@ def _get_app_args(self) -> List[str]:
*role_args,
]

def _get_expected_app_with_nested_objects(self) -> AppDef:
role_args = self._get_role_args()
defaults = self._get_nested_arg()
return _test_complex_fn(
"test_app",
["img1", "img2"],
{"worker": "worker.py", "master": "master.py"},
[1, 2],
{"worker": 1, "master": 4},
8,
"first_arg",
defaults,
*role_args,
)

def _get_app_args_and_defaults_with_nested_objects(
self,
) -> Tuple[List[str], Dict[str, List[str]]]:
role_args = self._get_role_args()
defaults = self._get_nested_arg()
return [
"--app_name",
"test_app",
"--containers",
"img1,img2",
"--roles_scripts",
"worker=worker.py,master=master.py",
"--num_cpus",
"1,2",
"--num_gpus",
"worker=1,master=4",
"--nnodes",
"8",
"--first_arg",
"first_arg",
"--",
*role_args,
], defaults

def test_load_from_fn_empty(self) -> None:
actual_app = materialize_appdef(test_empty_fn, [])
expected_app = get_dummy_application("trainer")
Expand All @@ -257,6 +304,12 @@ def test_load_from_fn_with_default(self) -> None:
actual_app = materialize_appdef(_test_complex_fn, app_args)
self.assert_apps(expected_app, actual_app)

def test_with_nested_object(self) -> None:
expected_app = self._get_expected_app_with_nested_objects()
app_args, defaults = self._get_app_args_and_defaults_with_nested_objects()
actual_app = materialize_appdef(_test_complex_fn, app_args, defaults)
self.assert_apps(expected_app, actual_app)

def test_varargs(self) -> None:
materialize_appdef(
_test_var_args,
Expand Down
28 changes: 27 additions & 1 deletion torchx/util/test/types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import typing_inspect
from torchx.util.types import (
decode,
decode_from_string,
decode_optional,
get_argparse_param_type,
Expand All @@ -22,7 +23,9 @@


def _test_complex_args(
arg1: int, arg2: Optional[List[str]], arg3: Union[float, int]
arg1: int,
arg2: Optional[List[str]],
arg3: Union[float, int],
) -> int:
return 42

Expand All @@ -31,6 +34,10 @@ def _test_dict(arg1: Dict[int, float]) -> int:
return 42


def _test_nested_object(arg1: Dict[str, List[str]]) -> int:
return 42


def _test_list(arg1: List[float]) -> int:
return 42

Expand Down Expand Up @@ -98,6 +105,25 @@ def test_decode_from_string_list(self) -> None:
self.assertEqual(float(42.2), value[1])
self.assertEqual(float(3.9), value[2])

def test_decode(self) -> None:
encoded_value = "1.0,42.2,3.9"

dict_parameters = inspect.signature(_test_nested_object).parameters
list_parameters = inspect.signature(_test_list).parameters

value = {"a": ["1", "2"], "b": ["3", "4"]}
self.assertDictEqual(value, decode(value, dict_parameters["arg1"].annotation))

self.assertEqual(decode("true", bool), True)
self.assertEqual(decode("false", bool), False)

self.assertEqual(decode(None, int), None)

self.assertEqual(
decode_from_string(encoded_value, list_parameters["arg1"].annotation),
decode(encoded_value, list_parameters["arg1"].annotation),
)

def test_decode_from_string_empty(self) -> None:
parameters = inspect.signature(_test_list).parameters
value = decode_from_string("", parameters["arg1"].annotation)
Expand Down
10 changes: 10 additions & 0 deletions torchx/util/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def _decode_string_to_list(
return arg_values


def decode(encoded_value: Any, annotation: Any):
if encoded_value is None:
return None
if is_bool(annotation):
return encoded_value and encoded_value.lower() == "true"
if not is_primitive(annotation) and type(encoded_value) == str:
return decode_from_string(encoded_value, annotation)
return encoded_value


def decode_from_string(
encoded_value: str, annotation: Any
) -> Union[Dict[Any, Any], List[Any], None]:
Expand Down

0 comments on commit 4c2eee5

Please sign in to comment.