Skip to content

Commit

Permalink
Add Type Checking to Params on Model (#676)
Browse files Browse the repository at this point in the history
Allow specifying Model and Ensemble parameters 
with number-like types. The constructors for 
parameters on Model and Ensemble now validate 
that the input is number-like and convert them to 
strings.

[ committed by @juliaputko ]
[ reviewed by @ashao]
  • Loading branch information
juliaputko authored Sep 5, 2024
1 parent c2ab99b commit 72be515
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 2 deletions.
7 changes: 7 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ To be released at some future point in time

Description

- Allow specifying Model and Ensemble parameters with
number-like types (e.g. numpy types)
- Pin watchdog to 4.x
- Update codecov to 4.5.0
- Remove build of Redis from setup.py
Expand All @@ -31,6 +33,11 @@ Description

Detailed Notes

- The serializer would fail if a parameter for a Model or Ensemble
was specified as a numpy dtype. The constructors for these
methods now validate that the input is number-like and convert
them to strings
([SmartSim-PR676](https://github.com/CrayLabs/SmartSim/pull/676))
- Pin watchdog to 4.x because v5 introduces new types and requires
updates to the type-checking
([SmartSim-PR690](https://github.com/CrayLabs/SmartSim/pull/690))
Expand Down
22 changes: 21 additions & 1 deletion smartsim/entity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from __future__ import annotations

import itertools
import numbers
import re
import sys
import typing as t
Expand All @@ -46,6 +47,25 @@
logger = get_logger(__name__)


def _parse_model_parameters(params_dict: t.Dict[str, t.Any]) -> t.Dict[str, str]:
"""Convert the values in a params dict to strings
:raises TypeError: if params are of the wrong type
:return: param dictionary with values and keys cast as strings
"""
param_names: t.List[str] = []
parameters: t.List[str] = []
for name, val in params_dict.items():
param_names.append(name)
if isinstance(val, (str, numbers.Number)):
parameters.append(str(val))
else:
raise TypeError(
"Incorrect type for model parameters\n"
+ "Must be numeric value or string."
)
return dict(zip(param_names, parameters))


class Model(SmartSimEntity):
def __init__(
self,
Expand All @@ -70,7 +90,7 @@ def __init__(
model as a batch job
"""
super().__init__(name, str(path), run_settings)
self.params = params
self.params = _parse_model_parameters(params)
self.params_as_args = params_as_args
self.incoming_entities: t.List[SmartSimEntity] = []
self._key_prefixing_enabled = False
Expand Down
15 changes: 15 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@

from uuid import uuid4

import numpy as np
import pytest

from smartsim import Experiment
from smartsim._core.control.manifest import LaunchedManifestBuilder
from smartsim._core.launcher.step import SbatchStep, SrunStep
from smartsim.entity import Ensemble, Model
from smartsim.entity.model import _parse_model_parameters
from smartsim.error import EntityExistsError, SSUnsupportedError
from smartsim.settings import RunSettings, SbatchSettings, SrunSettings
from smartsim.settings.mpiSettings import _BaseMPISettings
Expand Down Expand Up @@ -176,3 +178,16 @@ def test_models_batch_settings_are_ignored_in_ensemble(
step_cmd = step.step_cmds[0]
assert any("srun" in tok for tok in step_cmd) # call the model using run settings
assert not any("sbatch" in tok for tok in step_cmd) # no sbatch in sbatch


@pytest.mark.parametrize("dtype", [int, np.float32, str])
def test_good_model_params(dtype):
print(dtype(0.6))
params = {"foo": dtype(0.6)}
assert all(isinstance(val, str) for val in _parse_model_parameters(params).values())


@pytest.mark.parametrize("bad_val", [["eggs"], {"n": 5}, object])
def test_bad_model_params(bad_val):
with pytest.raises(TypeError):
_parse_model_parameters({"foo": bad_val})
2 changes: 1 addition & 1 deletion tests/test_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_model_preview_properties(test_dir, wlmutils):
assert hw_rs == hello_world_model.run_settings.exe_args[0]
assert None == hello_world_model.batch_settings
assert "port" in list(hello_world_model.params.items())[0]
assert hw_port in list(hello_world_model.params.items())[0]
assert str(hw_port) in list(hello_world_model.params.items())[0]
assert "password" in list(hello_world_model.params.items())[1]
assert hw_password in list(hello_world_model.params.items())[1]

Expand Down

0 comments on commit 72be515

Please sign in to comment.