Skip to content

Commit

Permalink
Fixed Typehint for RunSettings.colocated_db_settings (#462)
Browse files Browse the repository at this point in the history
This PR adds Python type hinting to RunSettings.colocated_db_settings.

[ reviewed by @MattToast ]
[ committed by @amandarichardsonn ]
  • Loading branch information
amandarichardsonn authored Jan 29, 2024
1 parent 50aa382 commit 092163b
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 15 deletions.
2 changes: 1 addition & 1 deletion smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def _prep_entity_client_env(self, entity: Model) -> None:
# Set address to local if it's a colocated model
if entity.colocated and entity.run_settings.colocated_db_settings is not None:
db_name_colo = entity.run_settings.colocated_db_settings["db_identifier"]

assert isinstance(db_name_colo, str)
for key in address_dict:
_, db_id = unpack_db_identifier(key, "_")
if db_name_colo == db_id:
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/launcher/step/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_colocated_launch_script(self) -> str:
)
makedirs(osp.dirname(script_path), exist_ok=True)

db_settings: t.Dict[str, str] = {}
db_settings = {}
if isinstance(self.step_settings, RunSettings):
db_settings = self.step_settings.colocated_db_settings or {}

Expand Down
52 changes: 41 additions & 11 deletions smartsim/entity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ def colocate_db_uds(
f"Invalid name for unix socket: {unix_socket}. Must only "
"contain alphanumeric characters or . : _ - /"
)

uds_options = {
uds_options: t.Dict[str, t.Union[int, str]] = {
"unix_socket": unix_socket,
"socket_permissions": socket_permissions,
"port": 0, # This is hardcoded to 0 as recommended by redis for UDS
# This is hardcoded to 0 as recommended by redis for UDS
"port": 0,
}

common_options = {
Expand Down Expand Up @@ -332,9 +332,18 @@ def colocate_db_tcp(

def _set_colocated_db_settings(
self,
connection_options: t.Dict[str, t.Any],
common_options: t.Dict[str, t.Any],
**kwargs: t.Any,
connection_options: t.Mapping[str, t.Union[int, t.List[str], str]],
common_options: t.Dict[
str,
t.Union[
t.Union[t.Iterable[t.Union[int, t.Iterable[int]]], None],
bool,
int,
str,
None,
],
],
**kwargs: t.Union[int, None],
) -> None:
"""
Ingest the connection-specific options (UDS/TCP) and set the final settings
Expand All @@ -357,21 +366,42 @@ def _set_colocated_db_settings(
)

# TODO list which db settings can be extras
custom_pinning_ = t.cast(
t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]],
common_options.get("custom_pinning"),
)
cpus_ = t.cast(int, common_options.get("cpus"))
common_options["custom_pinning"] = self._create_pinning_string(
common_options["custom_pinning"], common_options["cpus"]
custom_pinning_, cpus_
)

colo_db_config = {}
colo_db_config: t.Dict[
str,
t.Union[
bool,
int,
str,
None,
t.List[str],
t.Iterable[t.Union[int, t.Iterable[int]]],
t.List[DBModel],
t.List[DBScript],
t.Dict[str, t.Union[int, None]],
t.Dict[str, str],
],
] = {}
colo_db_config.update(connection_options)
colo_db_config.update(common_options)
# redisai arguments for inference settings
colo_db_config["rai_args"] = {

redis_ai_temp = {
"threads_per_queue": kwargs.get("threads_per_queue", None),
"inter_op_parallelism": kwargs.get("inter_op_parallelism", None),
"intra_op_parallelism": kwargs.get("intra_op_parallelism", None),
}
# redisai arguments for inference settings
colo_db_config["rai_args"] = redis_ai_temp
colo_db_config["extra_db_args"] = {
k: str(v) for k, v in kwargs.items() if k not in colo_db_config["rai_args"]
k: str(v) for k, v in kwargs.items() if k not in redis_ai_temp
}

self._check_db_objects_colo()
Expand Down
19 changes: 18 additions & 1 deletion smartsim/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from smartsim.settings.containers import Container

from .._core.utils.helpers import expand_exe_path, fmt_dict, is_valid_cmd
from ..entity.dbobject import DBModel, DBScript
from ..log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -96,7 +97,23 @@ def __init__(
self.container = container
self._run_command = run_command
self.in_batch = False
self.colocated_db_settings: t.Optional[t.Dict[str, str]] = None
self.colocated_db_settings: t.Optional[
t.Dict[
str,
t.Union[
bool,
int,
str,
None,
t.List[str],
t.Iterable[t.Union[int, t.Iterable[int]]],
t.List[DBModel],
t.List[DBScript],
t.Dict[str, t.Union[int, None]],
t.Dict[str, str],
],
]
] = None

@property
def exe_args(self) -> t.Union[str, t.List[str]]:
Expand Down
2 changes: 1 addition & 1 deletion smartsim/settings/lsfSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def set_cpus_per_rs(self, cpus_per_rs: int) -> None:
:type cpus_per_rs: int or str
"""
if self.colocated_db_settings:
db_cpus = int(self.colocated_db_settings.get("db_cpus", 0))
db_cpus = int(t.cast(int, self.colocated_db_settings.get("db_cpus", 0)))
if not db_cpus:
raise ValueError("db_cpus must be configured on colocated_db_settings")

Expand Down

0 comments on commit 092163b

Please sign in to comment.