From 7eaab839aae3c8ee1826f12f663f507e6048c049 Mon Sep 17 00:00:00 2001 From: Matt Drozt Date: Thu, 5 Oct 2023 10:51:56 -0700 Subject: [PATCH] DBNode: MPMD RunSettings Unification (#379) A general refactor of the DBNode class to: - Remove the duplicate methods with near identical functionality where one was intended to be called if the underlying RunSettings were standard and the other if they were MPMD. This removes undefined behavior when the "wrong" method was called - Allow MPMD DBNodes to map information "per shard" by scraping output files for a serializable LaunchedShardData class. [ committed by @MattToast ] [ reviewed by @ashao ] --- smartsim/_core/config/config.py | 8 + smartsim/_core/entrypoints/redis.py | 130 ++++++++++++---- smartsim/database/orchestrator.py | 210 +++++++++++++++---------- smartsim/entity/dbnode.py | 233 +++++++++++++--------------- tests/test_dbnode.py | 70 ++++++++- tests/test_orchestrator.py | 25 +++ tests/test_smartredis.py | 22 +-- 7 files changed, 431 insertions(+), 267 deletions(-) diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py index 3b0905021..a7b1471bf 100644 --- a/smartsim/_core/config/config.py +++ b/smartsim/_core/config/config.py @@ -143,6 +143,14 @@ def database_cli(self) -> str: "Specified Redis binary at REDIS_CLI_PATH could not be used" ) from e + @property + def database_file_parse_trials(self) -> int: + return int(os.getenv("SMARTSIM_DB_FILE_PARSE_TRIALS", "10")) + + @property + def database_file_parse_interval(self) -> int: + return int(os.getenv("SMARTSIM_DB_FILE_PARSE_INTERVAL", "2")) + @property def log_level(self) -> str: return os.environ.get("SMARTSIM_LOG_LEVEL", "info") diff --git a/smartsim/_core/entrypoints/redis.py b/smartsim/_core/entrypoints/redis.py index 782b4c583..7262a5996 100644 --- a/smartsim/_core/entrypoints/redis.py +++ b/smartsim/_core/entrypoints/redis.py @@ -25,16 +25,20 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse +import json import os -import psutil import signal +import textwrap import typing as t +from subprocess import PIPE, STDOUT +from types import FrameType + +import psutil from smartsim._core.utils.network import current_ip +from smartsim.entity.dbnode import LaunchedShardData from smartsim.error import SSInternalError from smartsim.log import get_logger -from subprocess import PIPE, STDOUT -from types import FrameType logger = get_logger(__name__) @@ -42,7 +46,7 @@ Redis/KeyDB entrypoint script """ -DBPID = None +DBPID: t.Optional[int] = None # kill is not catchable SIGNALS = [signal.SIGINT, signal.SIGQUIT, signal.SIGTERM, signal.SIGABRT] @@ -54,31 +58,64 @@ def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: cleanup() -def build_bind_args(ip_addresses: t.List[str]) -> t.List[str]: - bind_arg = f"--bind {' '.join(ip_addresses)}" - # pin source address to avoid random selection by Redis - bind_src_arg = f"--bind-source-addr {ip_addresses[0]}" - return [bind_arg, bind_src_arg] +def build_bind_args(source_addr: str, *addrs: str) -> t.Tuple[str, ...]: + return ( + "--bind", + source_addr, + *addrs, + # pin source address to avoid random selection by Redis + "--bind-source-addr", + source_addr, + ) -def print_summary(cmd: t.List[str], ip_address: str, network_interface: str) -> None: - print("-" * 10, " Running Command ", "-" * 10, "\n", flush=True) - print(f"COMMAND: {' '.join(cmd)}\n", flush=True) - print(f"IPADDRESS: {ip_address}\n", flush=True) - print(f"NETWORK: {network_interface}\n", flush=True) - print("-" * 30, "\n\n", flush=True) - print("-" * 10, " Output ", "-" * 10, "\n\n", flush=True) +def build_cluster_args(shard_data: LaunchedShardData) -> t.Tuple[str, ...]: + if cluster_conf_file := shard_data.cluster_conf_file: + return ("--cluster-enabled", "yes", "--cluster-config-file", cluster_conf_file) + return () -def main(network_interface: str, command: t.List[str]) -> None: +def print_summary( + cmd: t.List[str], network_interface: str, shard_data: LaunchedShardData +) -> None: + print( + textwrap.dedent( + f"""\ + ----------- Running Command ---------- + COMMAND: {' '.join(cmd)} + IPADDRESS: {shard_data.hostname} + NETWORK: {network_interface} + SMARTSIM_ORC_SHARD_INFO: {json.dumps(shard_data.to_dict())} + -------------------------------------- + + --------------- Output --------------- + + """ + ), + flush=True, + ) + + +def main(args: argparse.Namespace) -> int: global DBPID # pylint: disable=global-statement - try: - ip_addresses = [current_ip(net_if) for net_if in network_interface.split(",")] - cmd = command + build_bind_args(ip_addresses) + src_addr, *bind_addrs = (current_ip(net_if) for net_if in args.ifname.split(",")) + shard_data = LaunchedShardData( + name=args.name, hostname=src_addr, port=args.port, cluster=args.cluster + ) - print_summary(cmd, ip_addresses[0], network_interface) + cmd = [ + args.orc_exe, + args.conf_file, + *args.rai_module, + "--port", + str(args.port), + *build_cluster_args(shard_data), + *build_bind_args(src_addr, *bind_addrs), + ] + print_summary(cmd, args.ifname, shard_data) + try: process = psutil.Popen(cmd, stdout=PIPE, stderr=STDOUT) DBPID = process.pid @@ -87,18 +124,17 @@ def main(network_interface: str, command: t.List[str]) -> None: except Exception as e: cleanup() raise SSInternalError("Database process starter raised an exception") from e + return 0 def cleanup() -> None: + logger.debug("Cleaning up database instance") try: - logger.debug("Cleaning up database instance") # attempt to stop the database process - db_proc = psutil.Process(DBPID) - db_proc.terminate() - + if DBPID is not None: + psutil.Process(DBPID).terminate() except psutil.NoSuchProcess: logger.warning("Couldn't find database process to kill.") - except OSError as e: logger.warning(f"Failed to clean up database gracefully: {str(e)}") @@ -110,15 +146,47 @@ def cleanup() -> None: prefix_chars="+", description="SmartSim Process Launcher" ) parser.add_argument( - "+ifname", type=str, help="Network Interface name", default="lo" + "+orc-exe", type=str, help="Path to the orchestrator executable", required=True + ) + parser.add_argument( + "+conf-file", + type=str, + help="Path to the orchestrator configuration file", + required=True, + ) + parser.add_argument( + "+rai-module", + nargs="+", + type=str, + help=( + "Command for the orcestrator to load the Redis AI module with " + "symbols seperated by whitespace" + ), + required=True, + ) + parser.add_argument( + "+name", type=str, help="Name to identify the shard", required=True + ) + parser.add_argument( + "+port", + type=int, + help="The port on which to launch the shard of the orchestrator", + required=True, + ) + parser.add_argument( + "+ifname", type=str, help="Network Interface name", required=True + ) + parser.add_argument( + "+cluster", + action="store_true", + help="Specify if this orchestrator shard is part of a cluster", ) - parser.add_argument("+command", nargs="+", help="Command to run") - args = parser.parse_args() + args_ = parser.parse_args() # make sure to register the cleanup before the start - # the proecss so our signaller will be able to stop + # the process so our signaller will be able to stop # the database process. for sig in SIGNALS: signal.signal(sig, handle_signal) - main(args.ifname, args.command) + raise SystemExit(main(args_)) diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index a5526be06..507ebdbe9 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -24,13 +24,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import itertools -import psutil import sys import typing as t - from os import getcwd, getenv from shlex import split as sh_split +import psutil from smartredis import Client from smartredis.error import RedisReplyError @@ -41,7 +40,6 @@ from ..entity import DBNode, EntityList from ..error import SmartSimError, SSConfigError, SSUnsupportedError from ..log import get_logger -from ..settings.base import BatchSettings, RunSettings from ..settings import ( AprunSettings, BsubBatchSettings, @@ -55,6 +53,7 @@ SbatchSettings, SrunSettings, ) +from ..settings.base import BatchSettings, RunSettings from ..settings.settings import create_batch_settings, create_run_settings from ..wlm import detect_launcher @@ -161,6 +160,10 @@ def __init__( time: t.Optional[str] = None, alloc: t.Optional[str] = None, single_cmd: bool = False, + *, + threads_per_queue: t.Optional[int] = None, + inter_op_threads: t.Optional[int] = None, + intra_op_threads: t.Optional[int] = None, **kwargs: t.Any, ) -> None: """Initialize an Orchestrator reference for local launch @@ -198,12 +201,13 @@ def __init__( interface = [interface] self._interfaces = interface self._check_network_interface() - self.queue_threads = kwargs.get("threads_per_queue", None) - self.inter_threads = kwargs.get("inter_op_threads", None) - self.intra_threads = kwargs.get("intra_op_threads", None) + self.queue_threads = threads_per_queue + self.inter_threads = inter_op_threads + self.intra_threads = intra_op_threads + if self.launcher == "lsf": - gpus_per_shard = kwargs.pop("gpus_per_shard", 0) - cpus_per_shard = kwargs.pop("cpus_per_shard", 4) + gpus_per_shard = int(kwargs.pop("gpus_per_shard", 0)) + cpus_per_shard = int(kwargs.pop("cpus_per_shard", 4)) else: gpus_per_shard = None cpus_per_shard = None @@ -221,6 +225,9 @@ def __init__( single_cmd=single_cmd, gpus_per_shard=gpus_per_shard, cpus_per_shard=cpus_per_shard, + threads_per_queue=threads_per_queue, + inter_op_threads=inter_op_threads, + intra_op_threads=intra_op_threads, **kwargs, ) @@ -234,11 +241,12 @@ def __init__( self._redis_conf # pylint: disable=W0104 CONFIG.database_cli # pylint: disable=W0104 except SSConfigError as e: - msg = "SmartSim not installed with pre-built extensions (Redis)\n" - msg += "Use the `smart` cli tool to install needed extensions\n" - msg += "or set REDIS_PATH and REDIS_CLI_PATH in your environment\n" - msg += "See documentation for more information" - raise SSConfigError(msg) from e + raise SSConfigError( + "SmartSim not installed with pre-built extensions (Redis)\n" + "Use the `smart` cli tool to install needed extensions\n" + "or set REDIS_PATH and REDIS_CLI_PATH in your environment\n" + "See documentation for more information" + ) from e if launcher != "local": self.batch_settings = self._build_batch_settings( @@ -269,7 +277,19 @@ def num_shards(self) -> int: :returns: num_shards :rtype: int """ - return self.db_nodes + return sum(node.num_shards for node in self.entities) + + @property + def db_nodes(self) -> int: + """Read only property for the number of nodes an ``Orchestrator`` is + launched across. Notice that SmartSim currently assumes that each shard + will be launched on its own node. Therefore this property is currently + an alias to the ``num_shards`` attribute. + + :returns: Number of database nodes + :rtype: int + """ + return self.num_shards @property def hosts(self) -> t.List[str]: @@ -306,10 +326,10 @@ def get_address(self) -> t.List[str]: return self._get_address() def _get_address(self) -> t.List[str]: - addresses: t.List[str] = [] - for ip_address, port in itertools.product(self._hosts, self.ports): - addresses.append(":".join((ip_address, str(port)))) - return addresses + return [ + f"{host}:{port}" + for host, port in itertools.product(self._hosts, self.ports) + ] def is_active(self) -> bool: """Check if the database is active @@ -323,20 +343,21 @@ def is_active(self) -> bool: return db_is_active(self._hosts, self.ports, self.num_shards) @property - def _rai_module(self) -> str: + def _rai_module(self) -> t.Tuple[str, ...]: """Get the RedisAI module from third-party installations - :return: path to module or "" if not found - :rtype: str + :return: Tuple of args to pass to the orchestrator exe + to load and configure the RedisAI + :rtype: tuple[str] """ module = ["--loadmodule", CONFIG.redisai] if self.queue_threads: - module.append(f"THREADS_PER_QUEUE {self.queue_threads}") + module.extend(("THREADS_PER_QUEUE", str(self.queue_threads))) if self.inter_threads: - module.append(f"INTER_OP_PARALLELISM {self.inter_threads}") + module.extend(("INTER_OP_PARALLELISM", str(self.inter_threads))) if self.intra_threads: - module.append(f"INTRA_OP_PARALLELISM {self.intra_threads}") - return " ".join(module) + module.extend(("INTRA_OP_PARALLELISM", str(self.intra_threads))) + return tuple(module) @property def _redis_exe(self) -> str: @@ -407,9 +428,11 @@ def set_hosts(self, host_list: t.List[str]) -> None: if self.launcher == "lsf": for db in self.entities: db.set_hosts(host_list) - elif (self.launcher == "pals" - and isinstance(self.entities[0].run_settings, PalsMpiexecSettings) - and self.entities[0].is_mpmd): + elif ( + self.launcher == "pals" + and isinstance(self.entities[0].run_settings, PalsMpiexecSettings) + and self.entities[0].is_mpmd + ): # In this case, --hosts is a global option, we only set it to the # first run command self.entities[0].run_settings.set_hostlist(host_list) @@ -485,9 +508,9 @@ def enable_checkpoints(self, frequency: int) -> None: :param frequency: the given number of seconds before the DB saves :type frequency: int """ - self.set_db_conf("save", str(frequency) + " 1") + self.set_db_conf("save", f"{frequency} 1") - def set_max_memory(self, mem: int) -> None: + def set_max_memory(self, mem: str) -> None: """Sets the max memory configuration. By default there is no memory limit. Setting max memory to zero also results in no memory limit. Once a limit is surpassed, keys will be removed according to the eviction strategy. The @@ -590,10 +613,14 @@ def _build_batch_settings( batch: bool, account: str, time: str, + *, + launcher: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[BatchSettings]: batch_settings = None - launcher = kwargs.pop("launcher") + + if launcher is None: + raise ValueError("Expected param `launcher` of type `str`") # enter this conditional if user has not specified an allocation to run # on or if user specified batch=False (alloc will be found through env) @@ -605,11 +632,16 @@ def _build_batch_settings( return batch_settings def _build_run_settings( - self, exe: str, exe_args: t.List[t.List[str]], **kwargs: t.Any + self, + exe: str, + exe_args: t.List[t.List[str]], + *, + run_args: t.Optional[t.Dict[str, t.Any]] = None, + db_nodes: int = 1, + single_cmd: bool = True, + **kwargs: t.Any, ) -> RunSettings: - run_args = kwargs.pop("run_args", {}) - db_nodes = kwargs.get("db_nodes", 1) - single_cmd = kwargs.get("single_cmd", True) + run_args = {} if run_args is None else run_args mpmd_nodes = single_cmd and db_nodes > 1 if mpmd_nodes: @@ -638,20 +670,29 @@ def _build_run_settings( if self.launcher != "local": run_settings.set_tasks_per_node(1) - # Put it back in case it is needed again - kwargs["run_args"] = run_args - return run_settings @staticmethod def _build_run_settings_lsf( - exe: str, exe_args: t.List[t.List[str]], **kwargs: t.Any + exe: str, + exe_args: t.List[t.List[str]], + *, + run_args: t.Optional[t.Dict[str, t.Any]] = None, + cpus_per_shard: t.Optional[int] = None, + gpus_per_shard: t.Optional[int] = None, + **_kwargs: t.Any # Needed to ensure no API break and do not want to + # introduce that possibility, even if this method is + # protected, without running the test suite. + # TODO: Test against an LSF system before merge! ) -> t.Optional[JsrunSettings]: - run_args = kwargs.pop("run_args", {}) - cpus_per_shard = kwargs.get("cpus_per_shard", None) - gpus_per_shard = kwargs.get("gpus_per_shard", None) + run_args = {} if run_args is None else run_args erf_rs: t.Optional[JsrunSettings] = None + if cpus_per_shard is None: + raise ValueError("Expected an integer number of cpus per shard") + if gpus_per_shard is None: + raise ValueError("Expected an integer number of gpus per shard") + # We always run the DB on cpus 0:cpus_per_shard-1 # and gpus 0:gpus_per_shard-1 for shard_id, args in enumerate(exe_args): @@ -672,9 +713,9 @@ def _build_run_settings_lsf( } if gpus_per_shard > 1: # pragma: no-cover - erf_sets["gpu"] = "{" + f"0-{gpus_per_shard-1}" + "}" + erf_sets["gpu"] = f"{{0-{gpus_per_shard-1}}}" elif gpus_per_shard > 0: - erf_sets["gpu"] = "{" + str(0) + "}" + erf_sets["gpu"] = "{0}" run_settings.set_erf_sets(erf_sets) @@ -684,31 +725,38 @@ def _build_run_settings_lsf( erf_rs.make_mpmd(run_settings) - kwargs["run_args"] = run_args - return erf_rs - def _initialize_entities(self, **kwargs: t.Any) -> None: - self.db_nodes = int(kwargs.get("db_nodes", 1)) - single_cmd = kwargs.get("single_cmd", True) - - if int(self.db_nodes) == 2: + # Old pylint from TF 2.6.x does not understand that this argument list is + # equivalent to `(self, **kwargs)` + # # pylint: disable-next=arguments-differ + def _initialize_entities( + self, + *, + db_nodes: int = 1, + single_cmd: bool = True, + port: int = 6379, + **kwargs: t.Any, + ) -> None: + db_nodes = int(db_nodes) + if db_nodes == 2: raise SSUnsupportedError("Orchestrator does not support clusters of size 2") - if self.launcher == "local" and self.db_nodes > 1: + if self.launcher == "local" and db_nodes > 1: raise ValueError( "Local Orchestrator does not support multiple database shards" ) - mpmd_nodes = (single_cmd and self.db_nodes > 1) or self.launcher == "lsf" + mpmd_nodes = (single_cmd and db_nodes > 1) or self.launcher == "lsf" if mpmd_nodes: - self._initialize_entities_mpmd(**kwargs) + self._initialize_entities_mpmd( + db_nodes=db_nodes, single_cmd=single_cmd, port=port, **kwargs + ) else: - port = kwargs.get("port", 6379) - cluster = not bool(self.db_nodes < 3) + cluster = db_nodes >= 3 - for db_id in range(self.db_nodes): + for db_id in range(db_nodes): db_node_name = "_".join((self.name, str(db_id))) # create the exe_args list for launching multiple databases @@ -720,7 +768,7 @@ def _initialize_entities(self, **kwargs: t.Any) -> None: # if only launching 1 db per command, we don't need a # list of exe args lists run_settings = self._build_run_settings( - sys.executable, [start_script_args], **kwargs + sys.executable, [start_script_args], port=port, **kwargs ) node = DBNode( @@ -734,13 +782,14 @@ def _initialize_entities(self, **kwargs: t.Any) -> None: self.ports = [port] - def _initialize_entities_mpmd(self, **kwargs: t.Any) -> None: - port = kwargs.get("port", 6379) - cluster = not bool(self.db_nodes < 3) + def _initialize_entities_mpmd( + self, *, db_nodes: int = 1, port: int = 6379, **kwargs: t.Any + ) -> None: + cluster = db_nodes >= 3 exe_args_mpmd: t.List[t.List[str]] = [] - for db_id in range(self.db_nodes): + for db_id in range(db_nodes): db_shard_name = "_".join((self.name, str(db_id))) # create the exe_args list for launching multiple databases # per node. also collect port range for dbnode @@ -754,15 +803,14 @@ def _initialize_entities_mpmd(self, **kwargs: t.Any) -> None: if self.launcher == "lsf": run_settings = self._build_run_settings_lsf( - sys.executable, exe_args_mpmd, **kwargs + sys.executable, exe_args_mpmd, db_nodes=db_nodes, port=port, **kwargs ) output_files = [ - "_".join((self.name, str(db_id))) + ".out" - for db_id in range(self.db_nodes) + f"{self.name}_{db_id}.out" for db_id in range(db_nodes) ] else: run_settings = self._build_run_settings( - sys.executable, exe_args_mpmd, **kwargs + sys.executable, exe_args_mpmd, db_nodes=db_nodes, port=port, **kwargs ) output_files = [self.name + ".out"] @@ -770,37 +818,27 @@ def _initialize_entities_mpmd(self, **kwargs: t.Any) -> None: raise ValueError(f"Could not build run settings for {self.launcher}") node = DBNode(self.name, self.path, run_settings, [port], output_files) - node.is_mpmd = True - node.num_shards = self.db_nodes self.entities.append(node) self.ports = [port] - @staticmethod - def _get_cluster_args(name: str, port: int) -> t.List[str]: - """Create the arguments necessary for cluster creation""" - cluster_conf = "".join(("nodes-", name, "-", str(port), ".conf")) - db_args = ["--cluster-enabled yes", "--cluster-config-file", cluster_conf] - return db_args - def _get_start_script_args( self, name: str, port: int, cluster: bool ) -> t.List[str]: - start_script_args = [ + cmd = [ "-m", "smartsim._core.entrypoints.redis", # entrypoint - "+ifname=" + ",".join(self._interfaces), # pass interface to start script - "+command", # command flag for argparser - self._redis_exe, # redis-server - self._redis_conf, # redis.conf file - self._rai_module, # redisai.so - "--port", # redis port - str(port), # port number + f"+orc-exe={self._redis_exe}", # redis-server + f"+conf-file={self._redis_conf}", # redis.conf file + "+rai-module", # load redisai.so + *self._rai_module, + f"+name={name}", # name of node + f"+port={port}", # redis port + f"+ifname={','.join(self._interfaces)}", # pass interface to start script ] if cluster: - start_script_args += self._get_cluster_args(name, port) - - return start_script_args + cmd.append("+cluster") # is the shard part of a cluster + return cmd def _get_db_hosts(self) -> t.List[str]: hosts = [] diff --git a/smartsim/entity/dbnode.py b/smartsim/entity/dbnode.py index bc04df834..b3740301f 100644 --- a/smartsim/entity/dbnode.py +++ b/smartsim/entity/dbnode.py @@ -24,16 +24,20 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import fileinput +import itertools +import json import os import os.path as osp import time import typing as t +from dataclasses import dataclass +from .._core.config import CONFIG from ..error import SmartSimError from ..log import get_logger -from .entity import SmartSimEntity from ..settings.base import RunSettings - +from .entity import SmartSimEntity logger = get_logger(__name__) @@ -56,11 +60,8 @@ def __init__( output_files: t.List[str], ) -> None: """Initialize a database node within an orchestrator.""" - self.ports = ports - self._host: t.Optional[str] = None super().__init__(name, path, run_settings) - self._mpmd = False - self._num_shards: int = 0 + self.ports = ports self._hosts: t.Optional[t.List[str]] = None if not output_files: @@ -73,17 +74,20 @@ def __init__( @property def num_shards(self) -> int: - return self._num_shards - - @num_shards.setter - def num_shards(self, value: int) -> None: - self._num_shards = value + try: + return len(self.run_settings.mpmd) + 1 # type: ignore[attr-defined] + except AttributeError: + return 1 @property def host(self) -> str: - if not self._host: - self._host = self._parse_db_host() - return self._host + try: + (host,) = self.hosts + except ValueError: + raise ValueError( + f"Multiple hosts detected for this DB Node: {', '.join(self.hosts)}" + ) from None + return host @property def hosts(self) -> t.List[str]: @@ -93,14 +97,10 @@ def hosts(self) -> t.List[str]: @property def is_mpmd(self) -> bool: - return self._mpmd - - @is_mpmd.setter - def is_mpmd(self, value: bool) -> None: - self._mpmd = value - - def set_host(self, host: str) -> None: - self._host = str(host) + try: + return bool(self.run_settings.mpmd) # type: ignore[attr-defined] + except AttributeError: + return False def set_hosts(self, hosts: t.List[str]) -> None: self._hosts = [str(host) for host in hosts] @@ -112,42 +112,27 @@ def remove_stale_dbnode_files(self) -> None: """ for port in self.ports: - if not self._mpmd: - conf_file = osp.join(self.path, self._get_cluster_conf_filename(port)) + for conf_file in ( + osp.join(self.path, filename) + for filename in self._get_cluster_conf_filenames(port) + ): if osp.exists(conf_file): os.remove(conf_file) - else: # cov-lsf - conf_files = [ - osp.join(self.path, filename) - for filename in self._get_cluster_conf_filenames(port) - ] - for conf_file in conf_files: - if osp.exists(conf_file): - os.remove(conf_file) for file_ending in [".err", ".out", ".mpmd"]: file_name = osp.join(self.path, self.name + file_ending) if osp.exists(file_name): os.remove(file_name) - if self._mpmd: + + if self.is_mpmd: for file_ending in [".err", ".out"]: - for shard_id in range(self._num_shards): + for shard_id in range(self.num_shards): file_name = osp.join( self.path, self.name + "_" + str(shard_id) + file_ending ) if osp.exists(file_name): os.remove(file_name) - def _get_cluster_conf_filename(self, port: int) -> str: - """Returns the .conf file name for the given port number - - :param port: port number - :type port: int - :return: the dbnode configuration file name - :rtype: str - """ - return "".join(("nodes-", self.name, "-", str(port), ".conf")) - def _get_cluster_conf_filenames(self, port: int) -> t.List[str]: # cov-lsf """Returns the .conf file name for the given port number @@ -158,108 +143,98 @@ def _get_cluster_conf_filenames(self, port: int) -> t.List[str]: # cov-lsf :return: the dbnode configuration file name :rtype: str """ + if self.num_shards == 1: + return [f"nodes-{self.name}-{port}.conf"] return [ - "".join(("nodes-", self.name + f"_{shard_id}", "-", str(port), ".conf")) - for shard_id in range(self._num_shards) + f"nodes-{self.name}_{shard_id}-{port}.conf" + for shard_id in range(self.num_shards) ] @staticmethod - def _parse_ips(filepath: str, num_ips: t.Optional[int] = None) -> t.List[str]: - ips = [] - with open(filepath, "r", encoding="utf-8") as dbnode_file: - lines = dbnode_file.readlines() - for line in lines: - content = line.split() - if "IPADDRESS:" in content: - ips.append(content[-1]) - if num_ips and len(ips) == num_ips: - break - - return ips - - def _parse_db_host(self, filepath: t.Optional[str] = None) -> str: - """Parse the database host/IP from the output file - - If no file is passed as argument, then the first - file in self._output_files is used. - - :param filepath: Path to file to parse - :type filepath: str, optional - :raises SmartSimError: if host/ip could not be found - :return: ip address | hostname - :rtype: str + def _parse_launched_shard_info_from_iterable( + stream: t.Iterable[str], num_shards: t.Optional[int] = None + ) -> "t.List[LaunchedShardData]": + lines = (line.strip() for line in stream) + lines = (line for line in lines if line) + tokenized = (line.split(maxsplit=1) for line in lines) + tokenized = (tokens for tokens in tokenized if len(tokens) > 1) + shard_data_jsons = ( + kwjson for first, kwjson in tokenized if "SMARTSIM_ORC_SHARD_INFO" in first + ) + shard_data_kwargs = (json.loads(kwjson) for kwjson in shard_data_jsons) + shard_data: "t.Iterable[LaunchedShardData]" = ( + LaunchedShardData(**kwargs) for kwargs in shard_data_kwargs + ) + if num_shards: + shard_data = itertools.islice(shard_data, num_shards) + return list(shard_data) + + @classmethod + def _parse_launched_shard_info_from_files( + cls, file_paths: t.List[str], num_shards: t.Optional[int] = None + ) -> "t.List[LaunchedShardData]": + with fileinput.FileInput(file_paths) as ifstream: + return cls._parse_launched_shard_info_from_iterable(ifstream, num_shards) + + def get_launched_shard_info(self) -> "t.List[LaunchedShardData]": + """Parse the launched database shard info from the output files + + :raises SmartSimError: if all shard info could not be found + :return: The found launched shard info + :rtype: list[LaunchedShardData] """ - if not filepath: - filepath = osp.join(self.path, self._output_files[0]) - trials = 5 - ip_address = None - - # try a few times to give the database files time to - # populate on busy systems. - while not ip_address and trials > 0: + ips: "t.List[LaunchedShardData]" = [] + trials = CONFIG.database_file_parse_trials + interval = CONFIG.database_file_parse_interval + output_files = [osp.join(self.path, file) for file in self._output_files] + + while len(ips) < self.num_shards and trials > 0: try: - if ip_addresses := self._parse_ips(filepath, 1): - ip_address = ip_addresses[0] - # suppress error + ips = self._parse_launched_shard_info_from_files( + output_files, self.num_shards + ) except FileNotFoundError: - pass - - logger.debug("Waiting for Redis output files to populate...") - if not ip_address: - time.sleep(1) + ... + if len(ips) < self.num_shards: + logger.debug("Waiting for output files to populate...") + time.sleep(interval) trials -= 1 - if not ip_address: - logger.error(f"IP address lookup strategy failed for file {filepath}.") - raise SmartSimError("Failed to obtain database hostname") - - return ip_address + if len(ips) < self.num_shards: + msg = ( + f"Failed to parse the launched DB shard information from file(s) " + f"{', '.join(output_files)}. Found the information for " + f"{len(ips)} out of {self.num_shards} DB shards." + ) + logger.error(msg) + raise SmartSimError(msg) + return ips def _parse_db_hosts(self) -> t.List[str]: """Parse the database hosts/IPs from the output files - this uses the RedisIP module that is built as a dependency The IP address is preferred, but if hostname is only present then a lookup to /etc/hosts is done through the socket library. - This function must be called only if ``_mpmd==True``. :raises SmartSimError: if host/ip could not be found :return: ip addresses | hostnames :rtype: list[str] """ - ips: t.List[str] = [] - - # Find out if all shards' output streams are piped to separate files - if len(self._output_files) > 1: - for output_file in self._output_files: - filepath = osp.join(self.path, output_file) - _ = self._parse_db_host(filepath) - else: - filepath = osp.join(self.path, self._output_files[0]) - trials = 10 - ips = [] - while len(ips) < self._num_shards and trials > 0: - ips = [] - try: - ip_address = self._parse_ips(filepath, self._num_shards) - ips.extend(ip_address) - - # suppress error - except FileNotFoundError: - pass - - if len(ips) < self._num_shards: - logger.debug("Waiting for RedisIP files to populate...") - # Larger sleep time, as this seems to be needed for - # multihost setups - time.sleep(2) - trials -= 1 - - if len(ips) < self._num_shards: - msg = f"IP address lookup strategy failed for file {filepath}. " - msg += f"Found {len(ips)} out of {self._num_shards} IPs." - logger.error(msg) - raise SmartSimError("Failed to obtain database hostname") - - ips = list(dict.fromkeys(ips)) - return ips + return list({shard.hostname for shard in self.get_launched_shard_info()}) + + +@dataclass(frozen=True) +class LaunchedShardData: + """Data class to write an parse data about a launched database shard""" + + name: str + hostname: str + port: int + cluster: bool + + @property + def cluster_conf_file(self) -> t.Optional[str]: + return f"nodes-{self.name}-{self.port}.conf" if self.cluster else None + + def to_dict(self) -> t.Dict[str, t.Any]: + return dict(self.__dict__) diff --git a/tests/test_dbnode.py b/tests/test_dbnode.py index a4f4f641e..5b849d76e 100644 --- a/tests/test_dbnode.py +++ b/tests/test_dbnode.py @@ -24,11 +24,17 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import io +import json +import random +import string +import textwrap import pytest from smartsim import Experiment from smartsim.database import Orchestrator +from smartsim.entity.dbnode import DBNode, LaunchedShardData from smartsim.error.errors import SmartSimError @@ -49,21 +55,71 @@ def test_hosts(fileutils, wlmutils): orc.set_path(test_dir) exp.start(orc) - thrown = False hosts = [] try: hosts = orc.hosts - except SmartSimError: - thrown = True + assert len(hosts) == orc.db_nodes == 1 finally: # stop the database even if there is an error raised exp.stop(orc) orc.remove_stale_files() - assert not thrown - assert hosts == orc.hosts + + +def _random_shard_info(): + rand_string = lambda: ''.join(random.choices(string.ascii_letters, k=10)) + rand_num = lambda: random.randint(1000, 9999) + flip_coin = lambda: random.choice((True, False)) + + return LaunchedShardData( + name=rand_string(), + hostname=rand_string(), + port=rand_num(), + cluster=flip_coin(), + ) + + +def test_launched_shard_info_can_be_serialized(): + shard_data = _random_shard_info() + shard_data_from_str = LaunchedShardData( + **json.loads(json.dumps(shard_data.to_dict())) + ) + + assert shard_data is not shard_data_from_str + assert shard_data == shard_data_from_str + + +@pytest.mark.parametrize("limit", [None, 1]) +def test_db_node_can_parse_launched_shard_info(limit): + rand_shards = [_random_shard_info() for _ in range(3)] + with io.StringIO( + textwrap.dedent( + """\ + This is some file like str + -------------------------- + + SMARTSIM_ORC_SHARD_INFO: {} + ^^^^^^^^^^^^^^^^^^^^^^^ + We should be able to parse the serialized + launched db info from this file if the line is + prefixed with this tag. + + Here are two more for good measure: + SMARTSIM_ORC_SHARD_INFO: {} + SMARTSIM_ORC_SHARD_INFO: {} + + All other lines should be ignored. + """ + ).format(*(json.dumps(s.to_dict()) for s in rand_shards)) + ) as stream: + parsed_shards = DBNode._parse_launched_shard_info_from_iterable( + stream, limit + ) + if limit is not None: + rand_shards = rand_shards[:limit] + assert rand_shards == parsed_shards def test_set_host(): orc = Orchestrator() - orc.entities[0].set_host("host") - assert orc.entities[0]._host == "host" + orc.entities[0].set_hosts(["host"]) + assert orc.entities[0].host == "host" diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py index a634009d9..e8156a4ee 100644 --- a/tests/test_orchestrator.py +++ b/tests/test_orchestrator.py @@ -226,6 +226,31 @@ def test_slurm_set_batch_arg(wlmutils): assert orc2.batch_settings.batch_args["account"] == "ACCOUNT" +@pytest.mark.parametrize("single_cmd", [ + pytest.param(True, id="Single MPMD `srun`"), + pytest.param(False, id="Multiple `srun`s"), +]) +def test_orc_results_in_correct_number_of_shards(single_cmd): + num_shards = 5 + orc = Orchestrator( + port=12345, + launcher="slurm", + run_command="srun", + db_nodes=num_shards, + batch=False, + single_cmd=single_cmd, + ) + if single_cmd: + assert len(orc.entities) == 1 + node ,= orc.entities + assert len(node.run_settings.mpmd) == num_shards - 1 + else: + assert len(orc.entities) == num_shards + assert all(node.run_settings.mpmd == [] for node in orc.entities) + assert orc.num_shards == orc.db_nodes == sum( + node.num_shards for node in orc.entities) + + ###### Cobalt ###### diff --git a/tests/test_smartredis.py b/tests/test_smartredis.py index c44969ce1..f27ac5de4 100644 --- a/tests/test_smartredis.py +++ b/tests/test_smartredis.py @@ -93,14 +93,11 @@ def test_exchange(fileutils, wlmutils): # get and confirm statuses statuses = exp.get_status(ensemble) - if not all([stat == status.STATUS_COMPLETED for stat in statuses]): + try: + assert all([stat == status.STATUS_COMPLETED for stat in statuses]) + finally: + # stop the orchestrator exp.stop(orc) - assert False # client ensemble failed - - # stop the orchestrator - exp.stop(orc) - - print(exp.summary()) def test_consumer(fileutils, wlmutils): @@ -145,11 +142,8 @@ def test_consumer(fileutils, wlmutils): # get and confirm statuses statuses = exp.get_status(ensemble) - if not all([stat == status.STATUS_COMPLETED for stat in statuses]): + try: + assert all([stat == status.STATUS_COMPLETED for stat in statuses]) + finally: + # stop the orchestrator exp.stop(orc) - assert False # client ensemble failed - - # stop the orchestrator - exp.stop(orc) - - print(exp.summary())