Skip to content

Commit

Permalink
change hidden_dim to str for model_runner compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
dsethz committed Nov 3, 2024
1 parent 02007a9 commit cd91de3
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/nuclai/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _get_args(mode: str) -> argparse.Namespace:

parser.add_argument(
"--hidden_dim",
type=int,
type=str,
nargs="+",
default=None,
help="Neurons per hidden layer provided as whitespace-separated int (e.g. 256 128). Default is None.",
Expand Down Expand Up @@ -269,8 +269,8 @@ def train():
), f"File {path_checkpoint} does not exist."

assert hidden_dim is None or all(
isinstance(x, int) for x in hidden_dim
), "Hidden dimensions must be a list of integers or None"
isinstance(x, str) for x in hidden_dim
), "Hidden dimensions must be a list of str or None"

assert isinstance(
bias, bool
Expand Down Expand Up @@ -318,6 +318,13 @@ def train():

assert isinstance(devices, list), "Devices must be a list."

# convert hidden_dim to list of integers
if hidden_dim is not None:
if "None" in hidden_dim:
hidden_dim = None
else:
hidden_dim = [int(x) for x in hidden_dim]

# create directories
d = date.today()
identifier = (
Expand Down

0 comments on commit cd91de3

Please sign in to comment.