From b160c0592398485a12cb7139f541a687cfdf2063 Mon Sep 17 00:00:00 2001 From: Matt Drozt Date: Wed, 31 Jan 2024 09:32:20 -0800 Subject: [PATCH] Expose Typehints (#468) Add and ship `py.typed` marker to expose inline type hints. Fix type errors related to SmartRedis. [ committed by @MattToast ] [ reviewed by @al-rigazzi ] --- pyproject.toml | 18 ++++++++++++++++-- setup.cfg | 2 ++ smartsim/_core/utils/redis.py | 7 +++++-- smartsim/database/orchestrator.py | 2 +- smartsim/entity/dbobject.py | 21 +++++++++++---------- smartsim/entity/ensemble.py | 2 +- smartsim/entity/model.py | 2 +- smartsim/ml/data.py | 29 +++++++++++++++++++---------- smartsim/ml/tf/data.py | 11 ++++++++--- smartsim/py.typed | 0 10 files changed, 64 insertions(+), 30 deletions(-) create mode 100644 smartsim/py.typed diff --git a/pyproject.toml b/pyproject.toml index 60c33bee5..8c5e25a29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,18 +91,32 @@ disallow_untyped_defs = true disallow_incomplete_defs = true disallow_untyped_decorators = true +# Probably Unintended Branches/Blocks +# warn_unreachable = true + # Safety/Upgrading Mypy warn_unused_ignores = true warn_redundant_casts = true warn_unused_configs = true show_error_codes = true +# Misc Strictness Settings +strict_concatenate = false +strict_equality = true + +# Additional Error Codes +enable_error_code = [ + # "redundant-expr", + # "possibly-undefined", + # "unused-awaitable", + # "ignore-without-code", + # "mutable-override", +] + [[tool.mypy.overrides]] # Ignore packages that are not used or not typed module = [ "coloredlogs", - "smartredis", - "smartredis.error", "redis.cluster", "keras", "torch", diff --git a/setup.cfg b/setup.cfg index 43178d47a..eeac3fbe3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,5 +68,7 @@ exclude = smartredis [options.package_data] +smartsim = + py.typed smartsim._core.bin = * diff --git a/smartsim/_core/utils/redis.py b/smartsim/_core/utils/redis.py index 6c592d0f3..7d76aa1bd 100644 --- a/smartsim/_core/utils/redis.py +++ b/smartsim/_core/utils/redis.py @@ -177,6 +177,8 @@ def set_ml_model(db_model: DBModel, client: Client) -> None: outputs=db_model.outputs, ) else: + if db_model.model is None: + raise ValueError(f"No model attacted to {db_model.name}") client.set_model( name=db_model.name, model=db_model.model, @@ -203,7 +205,7 @@ def set_script(db_script: DBScript, client: Client) -> None: client.set_script_from_file( name=db_script.name, file=str(db_script.file), device=device ) - else: + elif db_script.script: if isinstance(db_script.script, str): client.set_script( name=db_script.name, script=db_script.script, device=device @@ -212,7 +214,8 @@ def set_script(db_script: DBScript, client: Client) -> None: client.set_function( name=db_script.name, function=db_script.script, device=device ) - + else: + raise ValueError(f"No script or file attached to {db_script.name}") except RedisReplyError as error: # pragma: no cover logger.error("Error while setting model on orchestrator.") raise error diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index 31bc1be6c..e586bee1c 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -573,7 +573,7 @@ def set_max_message_size(self, size: int = 1_073_741_824) -> None: """ self.set_db_conf("proto-max-bulk-len", str(size)) - def set_db_conf(self, key: str, value: t.Union[int, str]) -> None: + def set_db_conf(self, key: str, value: str) -> None: """Set any valid configuration at runtime without the need to restart the database. All configuration parameters that are set are immediately loaded by the database and diff --git a/smartsim/entity/dbobject.py b/smartsim/entity/dbobject.py index bebedb12c..368864b40 100644 --- a/smartsim/entity/dbobject.py +++ b/smartsim/entity/dbobject.py @@ -33,7 +33,10 @@ __all__ = ["DBObject", "DBModel", "DBScript"] -class DBObject: +_DBObjectFuncT = t.TypeVar("_DBObjectFuncT", str, bytes) + + +class DBObject(t.Generic[_DBObjectFuncT]): """Base class for ML objects residing on DB. Should not be instantiated. """ @@ -41,14 +44,14 @@ class DBObject: def __init__( self, name: str, - func: t.Optional[str], + func: t.Optional[_DBObjectFuncT], file_path: t.Optional[str], device: t.Literal["CPU", "GPU"], devices_per_node: int, first_device: int, ) -> None: self.name = name - self.func = func + self.func: t.Optional[_DBObjectFuncT] = func self.file: t.Optional[Path] = ( None # Need to have this explicitly to check on it ) @@ -65,9 +68,7 @@ def devices(self) -> t.List[str]: @property def is_file(self) -> bool: - if self.func: - return False - return True + return not self.func @staticmethod def _check_tensor_args( @@ -153,7 +154,7 @@ def _check_devices( raise ValueError(msg) -class DBScript(DBObject): +class DBScript(DBObject[str]): def __init__( self, name: str, @@ -214,12 +215,12 @@ def __str__(self) -> str: return desc_str -class DBModel(DBObject): +class DBModel(DBObject[bytes]): def __init__( self, name: str, backend: str, - model: t.Optional[str] = None, + model: t.Optional[bytes] = None, model_file: t.Optional[str] = None, device: t.Literal["CPU", "GPU"] = "CPU", devices_per_node: int = 1, @@ -276,7 +277,7 @@ def __init__( self.inputs, self.outputs = self._check_tensor_args(inputs, outputs) @property - def model(self) -> t.Union[str, None]: + def model(self) -> t.Optional[bytes]: return self.func def __str__(self) -> str: diff --git a/smartsim/entity/ensemble.py b/smartsim/entity/ensemble.py index 28ada31de..74bdcfba4 100644 --- a/smartsim/entity/ensemble.py +++ b/smartsim/entity/ensemble.py @@ -357,7 +357,7 @@ def add_ml_model( self, name: str, backend: str, - model: t.Optional[str] = None, + model: t.Optional[bytes] = None, model_path: t.Optional[str] = None, device: t.Literal["CPU", "GPU"] = "CPU", devices_per_node: int = 1, diff --git a/smartsim/entity/model.py b/smartsim/entity/model.py index 01d9173f9..2b380a32b 100644 --- a/smartsim/entity/model.py +++ b/smartsim/entity/model.py @@ -485,7 +485,7 @@ def add_ml_model( self, name: str, backend: str, - model: t.Optional[str] = None, + model: t.Optional[bytes] = None, model_path: t.Optional[str] = None, device: t.Literal["CPU", "GPU"] = "CPU", devices_per_node: int = 1, diff --git a/smartsim/ml/data.py b/smartsim/ml/data.py index 3dfca9f0c..4a28c7bb9 100644 --- a/smartsim/ml/data.py +++ b/smartsim/ml/data.py @@ -35,6 +35,9 @@ from ..error import SSInternalError from ..log import get_logger +if t.TYPE_CHECKING: + import numpy.typing as npt + logger = get_logger(__name__) @@ -118,7 +121,7 @@ def download(self, client: Client) -> None: if "target_name" in field_names: self.target_name = info_ds.get_meta_strings("target_name")[0] if "num_classes" in field_names: - self.num_classes = info_ds.get_meta_scalars("num_classes")[0] + self.num_classes = int(info_ds.get_meta_scalars("num_classes")[0]) def __repr__(self) -> str: strings = ["DataInfo object"] @@ -311,8 +314,8 @@ def __init__( self.address = address self.cluster = cluster self.verbose = verbose - self.samples = None - self.targets = None + self.samples: t.Optional["npt.NDArray[t.Any]"] = None + self.targets: t.Optional["npt.NDArray[t.Any]"] = None self.num_samples = 0 self.indices = np.arange(0) self.shuffle = shuffle @@ -460,14 +463,20 @@ def _add_samples(self, indices: t.List[int]) -> None: if self.samples is not None: for dataset in datasets: self.samples = np.concatenate( - (self.samples, dataset.get_tensor(self.sample_name)) + ( + t.cast("npt.NDArray[t.Any]", self.samples), + dataset.get_tensor(self.sample_name), + ) ) if self.need_targets: self.targets = np.concatenate( - (self.targets, dataset.get_tensor(self.target_name)) + ( + t.cast("npt.NDArray[t.Any]", self.targets), + dataset.get_tensor(self.target_name), + ) ) - self.num_samples = self.samples.shape[0] + self.num_samples = t.cast("npt.NDArray[t.Any]", self.samples).shape[0] self.indices = np.arange(self.num_samples) self.log(f"New dataset size: {self.num_samples}, batches: {len(self)}") @@ -496,8 +505,8 @@ def update_data(self) -> None: np.random.shuffle(self.indices) def _data_generation( - self, indices: np.ndarray # type: ignore[type-arg] - ) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] + self, indices: "npt.NDArray[t.Any]" + ) -> t.Tuple["npt.NDArray[t.Any]", "npt.NDArray[t.Any]"]: # Initialization if self.samples is None: raise ValueError("Samples have not been initialized") @@ -505,10 +514,10 @@ def _data_generation( xval = self.samples[indices] if self.need_targets: - yval = self.targets[indices] + yval = t.cast("npt.NDArray[t.Any]", self.targets)[indices] elif self.autoencoding: yval = xval else: - return xval + return xval # type: ignore[no-any-return] return xval, yval diff --git a/smartsim/ml/tf/data.py b/smartsim/ml/tf/data.py index ae0b9aadd..c5e93261c 100644 --- a/smartsim/ml/tf/data.py +++ b/smartsim/ml/tf/data.py @@ -31,6 +31,9 @@ from smartsim.ml import DataDownloader +if t.TYPE_CHECKING: + import numpy.typing as npt + class _TFDataGenerationCommon(DataDownloader, keras.utils.Sequence): def __getitem__( @@ -60,7 +63,9 @@ def on_epoch_end(self) -> None: if self.shuffle: np.random.shuffle(self.indices) - def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] + def _data_generation( + self, indices: "npt.NDArray[t.Any]" + ) -> t.Tuple["npt.NDArray[t.Any]", "npt.NDArray[t.Any]"]: # Initialization if self.samples is None: raise ValueError("No samples loaded for data generation") @@ -68,13 +73,13 @@ def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarra xval = self.samples[indices] if self.need_targets: - yval = self.targets[indices] + yval = t.cast("npt.NDArray[t.Any]", self.targets)[indices] if self.num_classes is not None: yval = keras.utils.to_categorical(yval, num_classes=self.num_classes) elif self.autoencoding: yval = xval else: - return xval + return xval # type: ignore[no-any-return] return xval, yval diff --git a/smartsim/py.typed b/smartsim/py.typed new file mode 100644 index 000000000..e69de29bb