diff --git a/src/nuclai/cls.py b/src/nuclai/cls.py index 1f01558..9c163a5 100644 --- a/src/nuclai/cls.py +++ b/src/nuclai/cls.py @@ -65,7 +65,7 @@ def _get_args(mode: str) -> argparse.Namespace: parser.add_argument( "--hidden_dim", - type=int, + type=str, nargs="+", default=None, help="Neurons per hidden layer provided as whitespace-separated int (e.g. 256 128). Default is None.", @@ -269,8 +269,8 @@ def train(): ), f"File {path_checkpoint} does not exist." assert hidden_dim is None or all( - isinstance(x, int) for x in hidden_dim - ), "Hidden dimensions must be a list of integers or None" + isinstance(x, str) for x in hidden_dim + ), "Hidden dimensions must be a list of str or None" assert isinstance( bias, bool @@ -318,6 +318,13 @@ def train(): assert isinstance(devices, list), "Devices must be a list." + # convert hidden_dim to list of integers + if hidden_dim is not None: + if "None" in hidden_dim: + hidden_dim = None + else: + hidden_dim = [int(x) for x in hidden_dim] + # create directories d = date.today() identifier = (