diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py index 3b0905021..a7b1471bf 100644 --- a/smartsim/_core/config/config.py +++ b/smartsim/_core/config/config.py @@ -143,6 +143,14 @@ def database_cli(self) -> str: "Specified Redis binary at REDIS_CLI_PATH could not be used" ) from e + @property + def database_file_parse_trials(self) -> int: + return int(os.getenv("SMARTSIM_DB_FILE_PARSE_TRIALS", "10")) + + @property + def database_file_parse_interval(self) -> int: + return int(os.getenv("SMARTSIM_DB_FILE_PARSE_INTERVAL", "2")) + @property def log_level(self) -> str: return os.environ.get("SMARTSIM_LOG_LEVEL", "info") diff --git a/smartsim/_core/entrypoints/redis.py b/smartsim/_core/entrypoints/redis.py index 782b4c583..7262a5996 100644 --- a/smartsim/_core/entrypoints/redis.py +++ b/smartsim/_core/entrypoints/redis.py @@ -25,16 +25,20 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse +import json import os -import psutil import signal +import textwrap import typing as t +from subprocess import PIPE, STDOUT +from types import FrameType + +import psutil from smartsim._core.utils.network import current_ip +from smartsim.entity.dbnode import LaunchedShardData from smartsim.error import SSInternalError from smartsim.log import get_logger -from subprocess import PIPE, STDOUT -from types import FrameType logger = get_logger(__name__) @@ -42,7 +46,7 @@ Redis/KeyDB entrypoint script """ -DBPID = None +DBPID: t.Optional[int] = None # kill is not catchable SIGNALS = [signal.SIGINT, signal.SIGQUIT, signal.SIGTERM, signal.SIGABRT] @@ -54,31 +58,64 @@ def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: cleanup() -def build_bind_args(ip_addresses: t.List[str]) -> t.List[str]: - bind_arg = f"--bind {' '.join(ip_addresses)}" - # pin source address to avoid random selection by Redis - bind_src_arg = f"--bind-source-addr {ip_addresses[0]}" - return [bind_arg, bind_src_arg] +def build_bind_args(source_addr: str, *addrs: str) -> t.Tuple[str, ...]: + return ( + "--bind", + source_addr, + *addrs, + # pin source address to avoid random selection by Redis + "--bind-source-addr", + source_addr, + ) -def print_summary(cmd: t.List[str], ip_address: str, network_interface: str) -> None: - print("-" * 10, " Running Command ", "-" * 10, "\n", flush=True) - print(f"COMMAND: {' '.join(cmd)}\n", flush=True) - print(f"IPADDRESS: {ip_address}\n", flush=True) - print(f"NETWORK: {network_interface}\n", flush=True) - print("-" * 30, "\n\n", flush=True) - print("-" * 10, " Output ", "-" * 10, "\n\n", flush=True) +def build_cluster_args(shard_data: LaunchedShardData) -> t.Tuple[str, ...]: + if cluster_conf_file := shard_data.cluster_conf_file: + return ("--cluster-enabled", "yes", "--cluster-config-file", cluster_conf_file) + return () -def main(network_interface: str, command: t.List[str]) -> None: +def print_summary( + cmd: t.List[str], network_interface: str, shard_data: LaunchedShardData +) -> None: + print( + textwrap.dedent( + f"""\ + ----------- Running Command ---------- + COMMAND: {' '.join(cmd)} + IPADDRESS: {shard_data.hostname} + NETWORK: {network_interface} + SMARTSIM_ORC_SHARD_INFO: {json.dumps(shard_data.to_dict())} + -------------------------------------- + + --------------- Output --------------- + + """ + ), + flush=True, + ) + + +def main(args: argparse.Namespace) -> int: global DBPID # pylint: disable=global-statement - try: - ip_addresses = [current_ip(net_if) for net_if in network_interface.split(",")] - cmd = command + build_bind_args(ip_addresses) + src_addr, *bind_addrs = (current_ip(net_if) for net_if in args.ifname.split(",")) + shard_data = LaunchedShardData( + name=args.name, hostname=src_addr, port=args.port, cluster=args.cluster + ) - print_summary(cmd, ip_addresses[0], network_interface) + cmd = [ + args.orc_exe, + args.conf_file, + *args.rai_module, + "--port", + str(args.port), + *build_cluster_args(shard_data), + *build_bind_args(src_addr, *bind_addrs), + ] + print_summary(cmd, args.ifname, shard_data) + try: process = psutil.Popen(cmd, stdout=PIPE, stderr=STDOUT) DBPID = process.pid @@ -87,18 +124,17 @@ def main(network_interface: str, command: t.List[str]) -> None: except Exception as e: cleanup() raise SSInternalError("Database process starter raised an exception") from e + return 0 def cleanup() -> None: + logger.debug("Cleaning up database instance") try: - logger.debug("Cleaning up database instance") # attempt to stop the database process - db_proc = psutil.Process(DBPID) - db_proc.terminate() - + if DBPID is not None: + psutil.Process(DBPID).terminate() except psutil.NoSuchProcess: logger.warning("Couldn't find database process to kill.") - except OSError as e: logger.warning(f"Failed to clean up database gracefully: {str(e)}") @@ -110,15 +146,47 @@ def cleanup() -> None: prefix_chars="+", description="SmartSim Process Launcher" ) parser.add_argument( - "+ifname", type=str, help="Network Interface name", default="lo" + "+orc-exe", type=str, help="Path to the orchestrator executable", required=True + ) + parser.add_argument( + "+conf-file", + type=str, + help="Path to the orchestrator configuration file", + required=True, + ) + parser.add_argument( + "+rai-module", + nargs="+", + type=str, + help=( + "Command for the orcestrator to load the Redis AI module with " + "symbols seperated by whitespace" + ), + required=True, + ) + parser.add_argument( + "+name", type=str, help="Name to identify the shard", required=True + ) + parser.add_argument( + "+port", + type=int, + help="The port on which to launch the shard of the orchestrator", + required=True, + ) + parser.add_argument( + "+ifname", type=str, help="Network Interface name", required=True + ) + parser.add_argument( + "+cluster", + action="store_true", + help="Specify if this orchestrator shard is part of a cluster", ) - parser.add_argument("+command", nargs="+", help="Command to run") - args = parser.parse_args() + args_ = parser.parse_args() # make sure to register the cleanup before the start - # the proecss so our signaller will be able to stop + # the process so our signaller will be able to stop # the database process. for sig in SIGNALS: signal.signal(sig, handle_signal) - main(args.ifname, args.command) + raise SystemExit(main(args_)) diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index a5526be06..507ebdbe9 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -24,13 +24,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import itertools -import psutil import sys import typing as t - from os import getcwd, getenv from shlex import split as sh_split +import psutil from smartredis import Client from smartredis.error import RedisReplyError @@ -41,7 +40,6 @@ from ..entity import DBNode, EntityList from ..error import SmartSimError, SSConfigError, SSUnsupportedError from ..log import get_logger -from ..settings.base import BatchSettings, RunSettings from ..settings import ( AprunSettings, BsubBatchSettings, @@ -55,6 +53,7 @@ SbatchSettings, SrunSettings, ) +from ..settings.base import BatchSettings, RunSettings from ..settings.settings import create_batch_settings, create_run_settings from ..wlm import detect_launcher @@ -161,6 +160,10 @@ def __init__( time: t.Optional[str] = None, alloc: t.Optional[str] = None, single_cmd: bool = False, + *, + threads_per_queue: t.Optional[int] = None, + inter_op_threads: t.Optional[int] = None, + intra_op_threads: t.Optional[int] = None, **kwargs: t.Any, ) -> None: """Initialize an Orchestrator reference for local launch @@ -198,12 +201,13 @@ def __init__( interface = [interface] self._interfaces = interface self._check_network_interface() - self.queue_threads = kwargs.get("threads_per_queue", None) - self.inter_threads = kwargs.get("inter_op_threads", None) - self.intra_threads = kwargs.get("intra_op_threads", None) + self.queue_threads = threads_per_queue + self.inter_threads = inter_op_threads + self.intra_threads = intra_op_threads + if self.launcher == "lsf": - gpus_per_shard = kwargs.pop("gpus_per_shard", 0) - cpus_per_shard = kwargs.pop("cpus_per_shard", 4) + gpus_per_shard = int(kwargs.pop("gpus_per_shard", 0)) + cpus_per_shard = int(kwargs.pop("cpus_per_shard", 4)) else: gpus_per_shard = None cpus_per_shard = None @@ -221,6 +225,9 @@ def __init__( single_cmd=single_cmd, gpus_per_shard=gpus_per_shard, cpus_per_shard=cpus_per_shard, + threads_per_queue=threads_per_queue, + inter_op_threads=inter_op_threads, + intra_op_threads=intra_op_threads, **kwargs, ) @@ -234,11 +241,12 @@ def __init__( self._redis_conf # pylint: disable=W0104 CONFIG.database_cli # pylint: disable=W0104 except SSConfigError as e: - msg = "SmartSim not installed with pre-built extensions (Redis)\n" - msg += "Use the `smart` cli tool to install needed extensions\n" - msg += "or set REDIS_PATH and REDIS_CLI_PATH in your environment\n" - msg += "See documentation for more information" - raise SSConfigError(msg) from e + raise SSConfigError( + "SmartSim not installed with pre-built extensions (Redis)\n" + "Use the `smart` cli tool to install needed extensions\n" + "or set REDIS_PATH and REDIS_CLI_PATH in your environment\n" + "See documentation for more information" + ) from e if launcher != "local": self.batch_settings = self._build_batch_settings( @@ -269,7 +277,19 @@ def num_shards(self) -> int: :returns: num_shards :rtype: int """ - return self.db_nodes + return sum(node.num_shards for node in self.entities) + + @property + def db_nodes(self) -> int: + """Read only property for the number of nodes an ``Orchestrator`` is + launched across. Notice that SmartSim currently assumes that each shard + will be launched on its own node. Therefore this property is currently + an alias to the ``num_shards`` attribute. + + :returns: Number of database nodes + :rtype: int + """ + return self.num_shards @property def hosts(self) -> t.List[str]: @@ -306,10 +326,10 @@ def get_address(self) -> t.List[str]: return self._get_address() def _get_address(self) -> t.List[str]: - addresses: t.List[str] = [] - for ip_address, port in itertools.product(self._hosts, self.ports): - addresses.append(":".join((ip_address, str(port)))) - return addresses + return [ + f"{host}:{port}" + for host, port in itertools.product(self._hosts, self.ports) + ] def is_active(self) -> bool: """Check if the database is active @@ -323,20 +343,21 @@ def is_active(self) -> bool: return db_is_active(self._hosts, self.ports, self.num_shards) @property - def _rai_module(self) -> str: + def _rai_module(self) -> t.Tuple[str, ...]: """Get the RedisAI module from third-party installations - :return: path to module or "" if not found - :rtype: str + :return: Tuple of args to pass to the orchestrator exe + to load and configure the RedisAI + :rtype: tuple[str] """ module = ["--loadmodule", CONFIG.redisai] if self.queue_threads: - module.append(f"THREADS_PER_QUEUE {self.queue_threads}") + module.extend(("THREADS_PER_QUEUE", str(self.queue_threads))) if self.inter_threads: - module.append(f"INTER_OP_PARALLELISM {self.inter_threads}") + module.extend(("INTER_OP_PARALLELISM", str(self.inter_threads))) if self.intra_threads: - module.append(f"INTRA_OP_PARALLELISM {self.intra_threads}") - return " ".join(module) + module.extend(("INTRA_OP_PARALLELISM", str(self.intra_threads))) + return tuple(module) @property def _redis_exe(self) -> str: @@ -407,9 +428,11 @@ def set_hosts(self, host_list: t.List[str]) -> None: if self.launcher == "lsf": for db in self.entities: db.set_hosts(host_list) - elif (self.launcher == "pals" - and isinstance(self.entities[0].run_settings, PalsMpiexecSettings) - and self.entities[0].is_mpmd): + elif ( + self.launcher == "pals" + and isinstance(self.entities[0].run_settings, PalsMpiexecSettings) + and self.entities[0].is_mpmd + ): # In this case, --hosts is a global option, we only set it to the # first run command self.entities[0].run_settings.set_hostlist(host_list) @@ -485,9 +508,9 @@ def enable_checkpoints(self, frequency: int) -> None: :param frequency: the given number of seconds before the DB saves :type frequency: int """ - self.set_db_conf("save", str(frequency) + " 1") + self.set_db_conf("save", f"{frequency} 1") - def set_max_memory(self, mem: int) -> None: + def set_max_memory(self, mem: str) -> None: """Sets the max memory configuration. By default there is no memory limit. Setting max memory to zero also results in no memory limit. Once a limit is surpassed, keys will be removed according to the eviction strategy. The @@ -590,10 +613,14 @@ def _build_batch_settings( batch: bool, account: str, time: str, + *, + launcher: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[BatchSettings]: batch_settings = None - launcher = kwargs.pop("launcher") + + if launcher is None: + raise ValueError("Expected param `launcher` of type `str`") # enter this conditional if user has not specified an allocation to run # on or if user specified batch=False (alloc will be found through env) @@ -605,11 +632,16 @@ def _build_batch_settings( return batch_settings def _build_run_settings( - self, exe: str, exe_args: t.List[t.List[str]], **kwargs: t.Any + self, + exe: str, + exe_args: t.List[t.List[str]], + *, + run_args: t.Optional[t.Dict[str, t.Any]] = None, + db_nodes: int = 1, + single_cmd: bool = True, + **kwargs: t.Any, ) -> RunSettings: - run_args = kwargs.pop("run_args", {}) - db_nodes = kwargs.get("db_nodes", 1) - single_cmd = kwargs.get("single_cmd", True) + run_args = {} if run_args is None else run_args mpmd_nodes = single_cmd and db_nodes > 1 if mpmd_nodes: @@ -638,20 +670,29 @@ def _build_run_settings( if self.launcher != "local": run_settings.set_tasks_per_node(1) - # Put it back in case it is needed again - kwargs["run_args"] = run_args - return run_settings @staticmethod def _build_run_settings_lsf( - exe: str, exe_args: t.List[t.List[str]], **kwargs: t.Any + exe: str, + exe_args: t.List[t.List[str]], + *, + run_args: t.Optional[t.Dict[str, t.Any]] = None, + cpus_per_shard: t.Optional[int] = None, + gpus_per_shard: t.Optional[int] = None, + **_kwargs: t.Any # Needed to ensure no API break and do not want to + # introduce that possibility, even if this method is + # protected, without running the test suite. + # TODO: Test against an LSF system before merge! ) -> t.Optional[JsrunSettings]: - run_args = kwargs.pop("run_args", {}) - cpus_per_shard = kwargs.get("cpus_per_shard", None) - gpus_per_shard = kwargs.get("gpus_per_shard", None) + run_args = {} if run_args is None else run_args erf_rs: t.Optional[JsrunSettings] = None + if cpus_per_shard is None: + raise ValueError("Expected an integer number of cpus per shard") + if gpus_per_shard is None: + raise ValueError("Expected an integer number of gpus per shard") + # We always run the DB on cpus 0:cpus_per_shard-1 # and gpus 0:gpus_per_shard-1 for shard_id, args in enumerate(exe_args): @@ -672,9 +713,9 @@ def _build_run_settings_lsf( } if gpus_per_shard > 1: # pragma: no-cover - erf_sets["gpu"] = "{" + f"0-{gpus_per_shard-1}" + "}" + erf_sets["gpu"] = f"{{0-{gpus_per_shard-1}}}" elif gpus_per_shard > 0: - erf_sets["gpu"] = "{" + str(0) + "}" + erf_sets["gpu"] = "{0}" run_settings.set_erf_sets(erf_sets) @@ -684,31 +725,38 @@ def _build_run_settings_lsf( erf_rs.make_mpmd(run_settings) - kwargs["run_args"] = run_args - return erf_rs - def _initialize_entities(self, **kwargs: t.Any) -> None: - self.db_nodes = int(kwargs.get("db_nodes", 1)) - single_cmd = kwargs.get("single_cmd", True) - - if int(self.db_nodes) == 2: + # Old pylint from TF 2.6.x does not understand that this argument list is + # equivalent to `(self, **kwargs)` + # # pylint: disable-next=arguments-differ + def _initialize_entities( + self, + *, + db_nodes: int = 1, + single_cmd: bool = True, + port: int = 6379, + **kwargs: t.Any, + ) -> None: + db_nodes = int(db_nodes) + if db_nodes == 2: raise SSUnsupportedError("Orchestrator does not support clusters of size 2") - if self.launcher == "local" and self.db_nodes > 1: + if self.launcher == "local" and db_nodes > 1: raise ValueError( "Local Orchestrator does not support multiple database shards" ) - mpmd_nodes = (single_cmd and self.db_nodes > 1) or self.launcher == "lsf" + mpmd_nodes = (single_cmd and db_nodes > 1) or self.launcher == "lsf" if mpmd_nodes: - self._initialize_entities_mpmd(**kwargs) + self._initialize_entities_mpmd( + db_nodes=db_nodes, single_cmd=single_cmd, port=port, **kwargs + ) else: - port = kwargs.get("port", 6379) - cluster = not bool(self.db_nodes < 3) + cluster = db_nodes >= 3 - for db_id in range(self.db_nodes): + for db_id in range(db_nodes): db_node_name = "_".join((self.name, str(db_id))) # create the exe_args list for launching multiple databases @@ -720,7 +768,7 @@ def _initialize_entities(self, **kwargs: t.Any) -> None: # if only launching 1 db per command, we don't need a # list of exe args lists run_settings = self._build_run_settings( - sys.executable, [start_script_args], **kwargs + sys.executable, [start_script_args], port=port, **kwargs ) node = DBNode( @@ -734,13 +782,14 @@ def _initialize_entities(self, **kwargs: t.Any) -> None: self.ports = [port] - def _initialize_entities_mpmd(self, **kwargs: t.Any) -> None: - port = kwargs.get("port", 6379) - cluster = not bool(self.db_nodes < 3) + def _initialize_entities_mpmd( + self, *, db_nodes: int = 1, port: int = 6379, **kwargs: t.Any + ) -> None: + cluster = db_nodes >= 3 exe_args_mpmd: t.List[t.List[str]] = [] - for db_id in range(self.db_nodes): + for db_id in range(db_nodes): db_shard_name = "_".join((self.name, str(db_id))) # create the exe_args list for launching multiple databases # per node. also collect port range for dbnode @@ -754,15 +803,14 @@ def _initialize_entities_mpmd(self, **kwargs: t.Any) -> None: if self.launcher == "lsf": run_settings = self._build_run_settings_lsf( - sys.executable, exe_args_mpmd, **kwargs + sys.executable, exe_args_mpmd, db_nodes=db_nodes, port=port, **kwargs ) output_files = [ - "_".join((self.name, str(db_id))) + ".out" - for db_id in range(self.db_nodes) + f"{self.name}_{db_id}.out" for db_id in range(db_nodes) ] else: run_settings = self._build_run_settings( - sys.executable, exe_args_mpmd, **kwargs + sys.executable, exe_args_mpmd, db_nodes=db_nodes, port=port, **kwargs ) output_files = [self.name + ".out"] @@ -770,37 +818,27 @@ def _initialize_entities_mpmd(self, **kwargs: t.Any) -> None: raise ValueError(f"Could not build run settings for {self.launcher}") node = DBNode(self.name, self.path, run_settings, [port], output_files) - node.is_mpmd = True - node.num_shards = self.db_nodes self.entities.append(node) self.ports = [port] - @staticmethod - def _get_cluster_args(name: str, port: int) -> t.List[str]: - """Create the arguments necessary for cluster creation""" - cluster_conf = "".join(("nodes-", name, "-", str(port), ".conf")) - db_args = ["--cluster-enabled yes", "--cluster-config-file", cluster_conf] - return db_args - def _get_start_script_args( self, name: str, port: int, cluster: bool ) -> t.List[str]: - start_script_args = [ + cmd = [ "-m", "smartsim._core.entrypoints.redis", # entrypoint - "+ifname=" + ",".join(self._interfaces), # pass interface to start script - "+command", # command flag for argparser - self._redis_exe, # redis-server - self._redis_conf, # redis.conf file - self._rai_module, # redisai.so - "--port", # redis port - str(port), # port number + f"+orc-exe={self._redis_exe}", # redis-server + f"+conf-file={self._redis_conf}", # redis.conf file + "+rai-module", # load redisai.so + *self._rai_module, + f"+name={name}", # name of node + f"+port={port}", # redis port + f"+ifname={','.join(self._interfaces)}", # pass interface to start script ] if cluster: - start_script_args += self._get_cluster_args(name, port) - - return start_script_args + cmd.append("+cluster") # is the shard part of a cluster + return cmd def _get_db_hosts(self) -> t.List[str]: hosts = [] diff --git a/smartsim/entity/dbnode.py b/smartsim/entity/dbnode.py index bc04df834..b3740301f 100644 --- a/smartsim/entity/dbnode.py +++ b/smartsim/entity/dbnode.py @@ -24,16 +24,20 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import fileinput +import itertools +import json import os import os.path as osp import time import typing as t +from dataclasses import dataclass +from .._core.config import CONFIG from ..error import SmartSimError from ..log import get_logger -from .entity import SmartSimEntity from ..settings.base import RunSettings - +from .entity import SmartSimEntity logger = get_logger(__name__) @@ -56,11 +60,8 @@ def __init__( output_files: t.List[str], ) -> None: """Initialize a database node within an orchestrator.""" - self.ports = ports - self._host: t.Optional[str] = None super().__init__(name, path, run_settings) - self._mpmd = False - self._num_shards: int = 0 + self.ports = ports self._hosts: t.Optional[t.List[str]] = None if not output_files: @@ -73,17 +74,20 @@ def __init__( @property def num_shards(self) -> int: - return self._num_shards - - @num_shards.setter - def num_shards(self, value: int) -> None: - self._num_shards = value + try: + return len(self.run_settings.mpmd) + 1 # type: ignore[attr-defined] + except AttributeError: + return 1 @property def host(self) -> str: - if not self._host: - self._host = self._parse_db_host() - return self._host + try: + (host,) = self.hosts + except ValueError: + raise ValueError( + f"Multiple hosts detected for this DB Node: {', '.join(self.hosts)}" + ) from None + return host @property def hosts(self) -> t.List[str]: @@ -93,14 +97,10 @@ def hosts(self) -> t.List[str]: @property def is_mpmd(self) -> bool: - return self._mpmd - - @is_mpmd.setter - def is_mpmd(self, value: bool) -> None: - self._mpmd = value - - def set_host(self, host: str) -> None: - self._host = str(host) + try: + return bool(self.run_settings.mpmd) # type: ignore[attr-defined] + except AttributeError: + return False def set_hosts(self, hosts: t.List[str]) -> None: self._hosts = [str(host) for host in hosts] @@ -112,42 +112,27 @@ def remove_stale_dbnode_files(self) -> None: """ for port in self.ports: - if not self._mpmd: - conf_file = osp.join(self.path, self._get_cluster_conf_filename(port)) + for conf_file in ( + osp.join(self.path, filename) + for filename in self._get_cluster_conf_filenames(port) + ): if osp.exists(conf_file): os.remove(conf_file) - else: # cov-lsf - conf_files = [ - osp.join(self.path, filename) - for filename in self._get_cluster_conf_filenames(port) - ] - for conf_file in conf_files: - if osp.exists(conf_file): - os.remove(conf_file) for file_ending in [".err", ".out", ".mpmd"]: file_name = osp.join(self.path, self.name + file_ending) if osp.exists(file_name): os.remove(file_name) - if self._mpmd: + + if self.is_mpmd: for file_ending in [".err", ".out"]: - for shard_id in range(self._num_shards): + for shard_id in range(self.num_shards): file_name = osp.join( self.path, self.name + "_" + str(shard_id) + file_ending ) if osp.exists(file_name): os.remove(file_name) - def _get_cluster_conf_filename(self, port: int) -> str: - """Returns the .conf file name for the given port number - - :param port: port number - :type port: int - :return: the dbnode configuration file name - :rtype: str - """ - return "".join(("nodes-", self.name, "-", str(port), ".conf")) - def _get_cluster_conf_filenames(self, port: int) -> t.List[str]: # cov-lsf """Returns the .conf file name for the given port number @@ -158,108 +143,98 @@ def _get_cluster_conf_filenames(self, port: int) -> t.List[str]: # cov-lsf :return: the dbnode configuration file name :rtype: str """ + if self.num_shards == 1: + return [f"nodes-{self.name}-{port}.conf"] return [ - "".join(("nodes-", self.name + f"_{shard_id}", "-", str(port), ".conf")) - for shard_id in range(self._num_shards) + f"nodes-{self.name}_{shard_id}-{port}.conf" + for shard_id in range(self.num_shards) ] @staticmethod - def _parse_ips(filepath: str, num_ips: t.Optional[int] = None) -> t.List[str]: - ips = [] - with open(filepath, "r", encoding="utf-8") as dbnode_file: - lines = dbnode_file.readlines() - for line in lines: - content = line.split() - if "IPADDRESS:" in content: - ips.append(content[-1]) - if num_ips and len(ips) == num_ips: - break - - return ips - - def _parse_db_host(self, filepath: t.Optional[str] = None) -> str: - """Parse the database host/IP from the output file - - If no file is passed as argument, then the first - file in self._output_files is used. - - :param filepath: Path to file to parse - :type filepath: str, optional - :raises SmartSimError: if host/ip could not be found - :return: ip address | hostname - :rtype: str + def _parse_launched_shard_info_from_iterable( + stream: t.Iterable[str], num_shards: t.Optional[int] = None + ) -> "t.List[LaunchedShardData]": + lines = (line.strip() for line in stream) + lines = (line for line in lines if line) + tokenized = (line.split(maxsplit=1) for line in lines) + tokenized = (tokens for tokens in tokenized if len(tokens) > 1) + shard_data_jsons = ( + kwjson for first, kwjson in tokenized if "SMARTSIM_ORC_SHARD_INFO" in first + ) + shard_data_kwargs = (json.loads(kwjson) for kwjson in shard_data_jsons) + shard_data: "t.Iterable[LaunchedShardData]" = ( + LaunchedShardData(**kwargs) for kwargs in shard_data_kwargs + ) + if num_shards: + shard_data = itertools.islice(shard_data, num_shards) + return list(shard_data) + + @classmethod + def _parse_launched_shard_info_from_files( + cls, file_paths: t.List[str], num_shards: t.Optional[int] = None + ) -> "t.List[LaunchedShardData]": + with fileinput.FileInput(file_paths) as ifstream: + return cls._parse_launched_shard_info_from_iterable(ifstream, num_shards) + + def get_launched_shard_info(self) -> "t.List[LaunchedShardData]": + """Parse the launched database shard info from the output files + + :raises SmartSimError: if all shard info could not be found + :return: The found launched shard info + :rtype: list[LaunchedShardData] """ - if not filepath: - filepath = osp.join(self.path, self._output_files[0]) - trials = 5 - ip_address = None - - # try a few times to give the database files time to - # populate on busy systems. - while not ip_address and trials > 0: + ips: "t.List[LaunchedShardData]" = [] + trials = CONFIG.database_file_parse_trials + interval = CONFIG.database_file_parse_interval + output_files = [osp.join(self.path, file) for file in self._output_files] + + while len(ips) < self.num_shards and trials > 0: try: - if ip_addresses := self._parse_ips(filepath, 1): - ip_address = ip_addresses[0] - # suppress error + ips = self._parse_launched_shard_info_from_files( + output_files, self.num_shards + ) except FileNotFoundError: - pass - - logger.debug("Waiting for Redis output files to populate...") - if not ip_address: - time.sleep(1) + ... + if len(ips) < self.num_shards: + logger.debug("Waiting for output files to populate...") + time.sleep(interval) trials -= 1 - if not ip_address: - logger.error(f"IP address lookup strategy failed for file {filepath}.") - raise SmartSimError("Failed to obtain database hostname") - - return ip_address + if len(ips) < self.num_shards: + msg = ( + f"Failed to parse the launched DB shard information from file(s) " + f"{', '.join(output_files)}. Found the information for " + f"{len(ips)} out of {self.num_shards} DB shards." + ) + logger.error(msg) + raise SmartSimError(msg) + return ips def _parse_db_hosts(self) -> t.List[str]: """Parse the database hosts/IPs from the output files - this uses the RedisIP module that is built as a dependency The IP address is preferred, but if hostname is only present then a lookup to /etc/hosts is done through the socket library. - This function must be called only if ``_mpmd==True``. :raises SmartSimError: if host/ip could not be found :return: ip addresses | hostnames :rtype: list[str] """ - ips: t.List[str] = [] - - # Find out if all shards' output streams are piped to separate files - if len(self._output_files) > 1: - for output_file in self._output_files: - filepath = osp.join(self.path, output_file) - _ = self._parse_db_host(filepath) - else: - filepath = osp.join(self.path, self._output_files[0]) - trials = 10 - ips = [] - while len(ips) < self._num_shards and trials > 0: - ips = [] - try: - ip_address = self._parse_ips(filepath, self._num_shards) - ips.extend(ip_address) - - # suppress error - except FileNotFoundError: - pass - - if len(ips) < self._num_shards: - logger.debug("Waiting for RedisIP files to populate...") - # Larger sleep time, as this seems to be needed for - # multihost setups - time.sleep(2) - trials -= 1 - - if len(ips) < self._num_shards: - msg = f"IP address lookup strategy failed for file {filepath}. " - msg += f"Found {len(ips)} out of {self._num_shards} IPs." - logger.error(msg) - raise SmartSimError("Failed to obtain database hostname") - - ips = list(dict.fromkeys(ips)) - return ips + return list({shard.hostname for shard in self.get_launched_shard_info()}) + + +@dataclass(frozen=True) +class LaunchedShardData: + """Data class to write an parse data about a launched database shard""" + + name: str + hostname: str + port: int + cluster: bool + + @property + def cluster_conf_file(self) -> t.Optional[str]: + return f"nodes-{self.name}-{self.port}.conf" if self.cluster else None + + def to_dict(self) -> t.Dict[str, t.Any]: + return dict(self.__dict__) diff --git a/tests/test_dbnode.py b/tests/test_dbnode.py index a4f4f641e..5b849d76e 100644 --- a/tests/test_dbnode.py +++ b/tests/test_dbnode.py @@ -24,11 +24,17 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import io +import json +import random +import string +import textwrap import pytest from smartsim import Experiment from smartsim.database import Orchestrator +from smartsim.entity.dbnode import DBNode, LaunchedShardData from smartsim.error.errors import SmartSimError @@ -49,21 +55,71 @@ def test_hosts(fileutils, wlmutils): orc.set_path(test_dir) exp.start(orc) - thrown = False hosts = [] try: hosts = orc.hosts - except SmartSimError: - thrown = True + assert len(hosts) == orc.db_nodes == 1 finally: # stop the database even if there is an error raised exp.stop(orc) orc.remove_stale_files() - assert not thrown - assert hosts == orc.hosts + + +def _random_shard_info(): + rand_string = lambda: ''.join(random.choices(string.ascii_letters, k=10)) + rand_num = lambda: random.randint(1000, 9999) + flip_coin = lambda: random.choice((True, False)) + + return LaunchedShardData( + name=rand_string(), + hostname=rand_string(), + port=rand_num(), + cluster=flip_coin(), + ) + + +def test_launched_shard_info_can_be_serialized(): + shard_data = _random_shard_info() + shard_data_from_str = LaunchedShardData( + **json.loads(json.dumps(shard_data.to_dict())) + ) + + assert shard_data is not shard_data_from_str + assert shard_data == shard_data_from_str + + +@pytest.mark.parametrize("limit", [None, 1]) +def test_db_node_can_parse_launched_shard_info(limit): + rand_shards = [_random_shard_info() for _ in range(3)] + with io.StringIO( + textwrap.dedent( + """\ + This is some file like str + -------------------------- + + SMARTSIM_ORC_SHARD_INFO: {} + ^^^^^^^^^^^^^^^^^^^^^^^ + We should be able to parse the serialized + launched db info from this file if the line is + prefixed with this tag. + + Here are two more for good measure: + SMARTSIM_ORC_SHARD_INFO: {} + SMARTSIM_ORC_SHARD_INFO: {} + + All other lines should be ignored. + """ + ).format(*(json.dumps(s.to_dict()) for s in rand_shards)) + ) as stream: + parsed_shards = DBNode._parse_launched_shard_info_from_iterable( + stream, limit + ) + if limit is not None: + rand_shards = rand_shards[:limit] + assert rand_shards == parsed_shards def test_set_host(): orc = Orchestrator() - orc.entities[0].set_host("host") - assert orc.entities[0]._host == "host" + orc.entities[0].set_hosts(["host"]) + assert orc.entities[0].host == "host" diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py index a634009d9..e8156a4ee 100644 --- a/tests/test_orchestrator.py +++ b/tests/test_orchestrator.py @@ -226,6 +226,31 @@ def test_slurm_set_batch_arg(wlmutils): assert orc2.batch_settings.batch_args["account"] == "ACCOUNT" +@pytest.mark.parametrize("single_cmd", [ + pytest.param(True, id="Single MPMD `srun`"), + pytest.param(False, id="Multiple `srun`s"), +]) +def test_orc_results_in_correct_number_of_shards(single_cmd): + num_shards = 5 + orc = Orchestrator( + port=12345, + launcher="slurm", + run_command="srun", + db_nodes=num_shards, + batch=False, + single_cmd=single_cmd, + ) + if single_cmd: + assert len(orc.entities) == 1 + node ,= orc.entities + assert len(node.run_settings.mpmd) == num_shards - 1 + else: + assert len(orc.entities) == num_shards + assert all(node.run_settings.mpmd == [] for node in orc.entities) + assert orc.num_shards == orc.db_nodes == sum( + node.num_shards for node in orc.entities) + + ###### Cobalt ###### diff --git a/tests/test_smartredis.py b/tests/test_smartredis.py index c44969ce1..f27ac5de4 100644 --- a/tests/test_smartredis.py +++ b/tests/test_smartredis.py @@ -93,14 +93,11 @@ def test_exchange(fileutils, wlmutils): # get and confirm statuses statuses = exp.get_status(ensemble) - if not all([stat == status.STATUS_COMPLETED for stat in statuses]): + try: + assert all([stat == status.STATUS_COMPLETED for stat in statuses]) + finally: + # stop the orchestrator exp.stop(orc) - assert False # client ensemble failed - - # stop the orchestrator - exp.stop(orc) - - print(exp.summary()) def test_consumer(fileutils, wlmutils): @@ -145,11 +142,8 @@ def test_consumer(fileutils, wlmutils): # get and confirm statuses statuses = exp.get_status(ensemble) - if not all([stat == status.STATUS_COMPLETED for stat in statuses]): + try: + assert all([stat == status.STATUS_COMPLETED for stat in statuses]) + finally: + # stop the orchestrator exp.stop(orc) - assert False # client ensemble failed - - # stop the orchestrator - exp.stop(orc) - - print(exp.summary())