From 092163b80f1efbb439398c0c2fb72ac5ec0ff4fd Mon Sep 17 00:00:00 2001 From: amandarichardsonn <30413257+amandarichardsonn@users.noreply.github.com> Date: Mon, 29 Jan 2024 17:24:20 -0600 Subject: [PATCH] Fixed Typehint for RunSettings.colocated_db_settings (#462) This PR adds Python type hinting to RunSettings.colocated_db_settings. [ reviewed by @MattToast ] [ committed by @amandarichardsonn ] --- smartsim/_core/control/controller.py | 2 +- smartsim/_core/launcher/step/step.py | 2 +- smartsim/entity/model.py | 52 ++++++++++++++++++++++------ smartsim/settings/base.py | 19 +++++++++- smartsim/settings/lsfSettings.py | 2 +- 5 files changed, 62 insertions(+), 15 deletions(-) diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 3204600d6..a79f0e347 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -624,7 +624,7 @@ def _prep_entity_client_env(self, entity: Model) -> None: # Set address to local if it's a colocated model if entity.colocated and entity.run_settings.colocated_db_settings is not None: db_name_colo = entity.run_settings.colocated_db_settings["db_identifier"] - + assert isinstance(db_name_colo, str) for key in address_dict: _, db_id = unpack_db_identifier(key, "_") if db_name_colo == db_id: diff --git a/smartsim/_core/launcher/step/step.py b/smartsim/_core/launcher/step/step.py index ebbdd074e..74b0f4fdf 100644 --- a/smartsim/_core/launcher/step/step.py +++ b/smartsim/_core/launcher/step/step.py @@ -93,7 +93,7 @@ def get_colocated_launch_script(self) -> str: ) makedirs(osp.dirname(script_path), exist_ok=True) - db_settings: t.Dict[str, str] = {} + db_settings = {} if isinstance(self.step_settings, RunSettings): db_settings = self.step_settings.colocated_db_settings or {} diff --git a/smartsim/entity/model.py b/smartsim/entity/model.py index 6b97cbf2e..01d9173f9 100644 --- a/smartsim/entity/model.py +++ b/smartsim/entity/model.py @@ -258,11 +258,11 @@ def colocate_db_uds( f"Invalid name for unix socket: {unix_socket}. Must only " "contain alphanumeric characters or . : _ - /" ) - - uds_options = { + uds_options: t.Dict[str, t.Union[int, str]] = { "unix_socket": unix_socket, "socket_permissions": socket_permissions, - "port": 0, # This is hardcoded to 0 as recommended by redis for UDS + # This is hardcoded to 0 as recommended by redis for UDS + "port": 0, } common_options = { @@ -332,9 +332,18 @@ def colocate_db_tcp( def _set_colocated_db_settings( self, - connection_options: t.Dict[str, t.Any], - common_options: t.Dict[str, t.Any], - **kwargs: t.Any, + connection_options: t.Mapping[str, t.Union[int, t.List[str], str]], + common_options: t.Dict[ + str, + t.Union[ + t.Union[t.Iterable[t.Union[int, t.Iterable[int]]], None], + bool, + int, + str, + None, + ], + ], + **kwargs: t.Union[int, None], ) -> None: """ Ingest the connection-specific options (UDS/TCP) and set the final settings @@ -357,21 +366,42 @@ def _set_colocated_db_settings( ) # TODO list which db settings can be extras + custom_pinning_ = t.cast( + t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], + common_options.get("custom_pinning"), + ) + cpus_ = t.cast(int, common_options.get("cpus")) common_options["custom_pinning"] = self._create_pinning_string( - common_options["custom_pinning"], common_options["cpus"] + custom_pinning_, cpus_ ) - colo_db_config = {} + colo_db_config: t.Dict[ + str, + t.Union[ + bool, + int, + str, + None, + t.List[str], + t.Iterable[t.Union[int, t.Iterable[int]]], + t.List[DBModel], + t.List[DBScript], + t.Dict[str, t.Union[int, None]], + t.Dict[str, str], + ], + ] = {} colo_db_config.update(connection_options) colo_db_config.update(common_options) - # redisai arguments for inference settings - colo_db_config["rai_args"] = { + + redis_ai_temp = { "threads_per_queue": kwargs.get("threads_per_queue", None), "inter_op_parallelism": kwargs.get("inter_op_parallelism", None), "intra_op_parallelism": kwargs.get("intra_op_parallelism", None), } + # redisai arguments for inference settings + colo_db_config["rai_args"] = redis_ai_temp colo_db_config["extra_db_args"] = { - k: str(v) for k, v in kwargs.items() if k not in colo_db_config["rai_args"] + k: str(v) for k, v in kwargs.items() if k not in redis_ai_temp } self._check_db_objects_colo() diff --git a/smartsim/settings/base.py b/smartsim/settings/base.py index a6df4eed4..c4eefc780 100644 --- a/smartsim/settings/base.py +++ b/smartsim/settings/base.py @@ -30,6 +30,7 @@ from smartsim.settings.containers import Container from .._core.utils.helpers import expand_exe_path, fmt_dict, is_valid_cmd +from ..entity.dbobject import DBModel, DBScript from ..log import get_logger logger = get_logger(__name__) @@ -96,7 +97,23 @@ def __init__( self.container = container self._run_command = run_command self.in_batch = False - self.colocated_db_settings: t.Optional[t.Dict[str, str]] = None + self.colocated_db_settings: t.Optional[ + t.Dict[ + str, + t.Union[ + bool, + int, + str, + None, + t.List[str], + t.Iterable[t.Union[int, t.Iterable[int]]], + t.List[DBModel], + t.List[DBScript], + t.Dict[str, t.Union[int, None]], + t.Dict[str, str], + ], + ] + ] = None @property def exe_args(self) -> t.Union[str, t.List[str]]: diff --git a/smartsim/settings/lsfSettings.py b/smartsim/settings/lsfSettings.py index 47fe91802..11785f7b7 100644 --- a/smartsim/settings/lsfSettings.py +++ b/smartsim/settings/lsfSettings.py @@ -97,7 +97,7 @@ def set_cpus_per_rs(self, cpus_per_rs: int) -> None: :type cpus_per_rs: int or str """ if self.colocated_db_settings: - db_cpus = int(self.colocated_db_settings.get("db_cpus", 0)) + db_cpus = int(t.cast(int, self.colocated_db_settings.get("db_cpus", 0))) if not db_cpus: raise ValueError("db_cpus must be configured on colocated_db_settings")