Skip to content

Commit

Permalink
Send shutdown command to DB (#355)
Browse files Browse the repository at this point in the history
Adds explicit shutdown of DB shards. Previously, DBs
were terminated by simply terminating their processes,
but that does not work in certain settings.

[ committed by @al-rigazzi ]
[ reviewed by @MattToast ]
  • Loading branch information
al-rigazzi authored Sep 18, 2023
1 parent a332823 commit bc7b232
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 26 deletions.
27 changes: 22 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
AprunSettings,
JsrunSettings,
MpirunSettings,
PalsMpiexecSettings,
RunSettings,
)
from smartsim._core.config import CONFIG
Expand Down Expand Up @@ -83,15 +84,15 @@ def print_test_configuration() -> None:
print("TEST_ALLOC_SPEC_SHEET_PATH:", test_alloc_specs_path)
print("TEST_DIR:", test_dir)
print("Test output will be located in TEST_DIR if there is a failure")
print("TEST_PORT", test_port)
print("TEST_PORT + 1", test_port + 1)
print("TEST_PORTS", ", ".join(str(port) for port in range(test_port, test_port+3)))


def pytest_configure() -> None:
pytest.test_launcher = test_launcher
pytest.wlm_options = ["slurm", "pbs", "cobalt", "lsf", "pals"]
account = get_account()
pytest.test_account = account
pytest.test_device = test_device


def pytest_sessionstart(
Expand Down Expand Up @@ -144,6 +145,12 @@ def get_hostlist() -> t.Optional[t.List[str]]:
return _parse_hostlist_file(os.environ["COBALT_NODEFILE"])
except FileNotFoundError:
return None
elif "PBS_NODEFILE" in os.environ and test_launcher=="pals":
# with PALS, we need a hostfile even if `aprun` is available
try:
return _parse_hostlist_file(os.environ["PBS_NODEFILE"])
except FileNotFoundError:
return None
elif "PBS_NODEFILE" in os.environ and not shutil.which("aprun"):
try:
return _parse_hostlist_file(os.environ["PBS_NODEFILE"])
Expand Down Expand Up @@ -320,7 +327,7 @@ def get_run_settings(

@staticmethod
def get_orchestrator(nodes: int = 1, batch: bool = False) -> Orchestrator:
if test_launcher in ["pbs", "cobalt", "pals"]:
if test_launcher in ["pbs", "cobalt"]:
if not shutil.which("aprun"):
hostlist = get_hostlist()
else:
Expand All @@ -333,6 +340,16 @@ def get_orchestrator(nodes: int = 1, batch: bool = False) -> Orchestrator:
launcher=test_launcher,
hosts=hostlist,
)
if test_launcher == "pals":
hostlist = get_hostlist()
return Orchestrator(
db_nodes=nodes,
port=test_port,
batch=batch,
interface=test_nic,
launcher=test_launcher,
hosts=hostlist,
)
if test_launcher == "slurm":
return Orchestrator(
db_nodes=nodes,
Expand Down Expand Up @@ -579,8 +596,8 @@ def make_test_dir(
:type caller_function: str, optional
:param caller_fspath: absolute path to file containing caller, defaults to None
:type caller_fspath: str or Path, optional
:param level: indicate depth in the call stack relative to test method.
:type level: int, optional
:param level: indicate depth in the call stack relative to test method.
:type level: int, optional
:param sub_dir: a relative path to create in the test directory
:type sub_dir: str or Path, optional
Expand Down
36 changes: 25 additions & 11 deletions smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,26 @@
from smartredis import Client

from ..._core.launcher.step import Step
from ..._core.utils.redis import db_is_active, set_ml_model, set_script
from ..._core.utils.redis import db_is_active, set_ml_model, set_script, shutdown_db
from ...database import Orchestrator
from ...entity import EntityList, SmartSimEntity, Model, Ensemble
from ...entity import Ensemble, EntityList, Model, SmartSimEntity
from ...error import LauncherError, SmartSimError, SSInternalError, SSUnsupportedError
from ...log import get_logger
from ...status import STATUS_RUNNING, TERMINAL_STATUSES
from ...settings.base import BatchSettings
from ...status import STATUS_CANCELLED, STATUS_RUNNING, TERMINAL_STATUSES
from ..config import CONFIG
from ..launcher import (
SlurmLauncher,
PBSLauncher,
LocalLauncher,
CobaltLauncher,
LocalLauncher,
LSFLauncher,
PBSLauncher,
SlurmLauncher,
)
from ..launcher.launcher import Launcher
from ..utils import check_cluster_status, create_cluster
from .job import Job
from .jobmanager import JobManager
from .manifest import Manifest
from .job import Job
from ...settings.base import BatchSettings


logger = get_logger(__name__)

Expand Down Expand Up @@ -189,6 +188,21 @@ def stop_entity(self, entity: t.Union[SmartSimEntity, EntityList]) -> None:
)
self._jobs.move_to_completed(job)

def stop_db(self, db: Orchestrator) -> None:
"""Stop an orchestrator
:param db: orchestrator to be stopped
:type db: Orchestrator
"""
if db.batch:
self.stop_entity(db)
else:
shutdown_db(db.hosts, db.ports)
with JM_LOCK:
for entity in db:
job = self._jobs[entity.name]
job.set_status(STATUS_CANCELLED, "", 0, output=None, error=None)
self._jobs.move_to_completed(job)

def stop_entity_list(self, entity_list: EntityList) -> None:
"""Stop an instance of an entity list
Expand Down Expand Up @@ -550,15 +564,15 @@ def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None:
# TODO remove in favor of by node status check
time.sleep(CONFIG.jm_interval)
elif any(stat in TERMINAL_STATUSES for stat in statuses):
self.stop_entity_list(orchestrator)
self.stop_db(orchestrator)
msg = "Orchestrator failed during startup"
msg += f" See {orchestrator.path} for details"
raise SmartSimError(msg)
else:
logger.debug("Waiting for orchestrator instances to spin up...")
except KeyboardInterrupt:
logger.info("Orchestrator launch cancelled - requesting to stop")
self.stop_entity_list(orchestrator)
self.stop_db(orchestrator)

# re-raise keyboard interrupt so the job manager will display
# any running and un-killed jobs as this method is only called
Expand Down
31 changes: 31 additions & 0 deletions smartsim/_core/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import itertools
import logging
import redis
import time
Expand Down Expand Up @@ -214,3 +215,33 @@ def set_script(db_script: DBScript, client: Client) -> None:
except RedisReplyError as error: # pragma: no cover
logger.error("Error while setting model on orchestrator.")
raise error


def shutdown_db(hosts: t.List[str], ports: t.List[int]) -> None: # cov-wlm
"""Send shutdown signal to cluster instances.
Should only be used in the case where cluster deallocation
needs to occur manually. Usually, the SmartSim task manager
will take care of this automatically.
:param hosts: List of hostnames to connect to
:type hosts: List[str]
:param ports: List of ports for each hostname
:type ports: List[int]
:raises SmartSimError: if cluster creation fails
"""
for host_ip, port in itertools.product(
(get_ip_from_host(host) for host in hosts), ports
):
# call cluster command
redis_cli = CONFIG.database_cli
cmd = [redis_cli, "-h", host_ip, "-p", str(port), "shutdown"]
returncode, out, err = execute_cmd(
cmd, proc_input="yes", shell=False, timeout=10
)

if returncode != 0:
logger.error(out)
logger.error(err)
else:
logger.debug(out)
6 changes: 6 additions & 0 deletions smartsim/database/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,12 @@ def set_hosts(self, host_list: t.List[str]) -> None:
if self.launcher == "lsf":
for db in self.dbnodes:
db.set_hosts(host_list)
elif (self.launcher == "pals"
and isinstance(self.dbnodes[0].run_settings, PalsMpiexecSettings)
and self.dbnodes[0].is_mpmd):
# In this case, --hosts is a global option, we only set it to the
# first run command
self.dbnodes[0].run_settings.set_hostlist(host_list)
else:
for host, db in zip(host_list, self.dbnodes):
if isinstance(db.run_settings, AprunSettings):
Expand Down
7 changes: 5 additions & 2 deletions smartsim/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,15 @@ def stop(self, *args: t.Any) -> None:
:raises TypeError: if wrong type
:raises SmartSimError: if stop request fails
"""
stop_manifest = Manifest(*args)
try:
stop_manifest = Manifest(*args)
for entity in stop_manifest.models:
self._control.stop_entity(entity)
for entity_list in stop_manifest.all_entity_lists:
for entity_list in stop_manifest.ensembles:
self._control.stop_entity_list(entity_list)
db = stop_manifest.db
if db:
self._control.stop_db(db)
except SmartSimError as e:
logger.error(e)
raise
Expand Down
17 changes: 17 additions & 0 deletions smartsim/settings/palsSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,20 @@ def format_env_vars(self) -> t.List[str]:
formatted += ["--envlist", ",".join(export_vars)]

return formatted

def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None:
"""Set the hostlist for the PALS ``mpiexec`` command
This sets ``--hosts``
:param host_list: list of host names
:type host_list: str | list[str]
:raises TypeError: if not str or list of str
"""
if isinstance(host_list, str):
host_list = [host_list.strip()]
if not isinstance(host_list, list):
raise TypeError("host_list argument must be a list of strings")
if not all(isinstance(host, str) for host in host_list):
raise TypeError("host_list argument must be list of strings")
self.run_args["hosts"] = ",".join(host_list)
24 changes: 16 additions & 8 deletions tests/backends/test_dbmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@


import sys
import time

import pytest

import smartsim
from smartsim import Experiment, status
from smartsim._core.utils import installed_redisai_backends
from smartsim.entity import Ensemble
from smartsim.error.errors import SSUnsupportedError
from smartsim.log import get_logger

Expand All @@ -45,8 +44,10 @@

# Check TensorFlow is available for tests
try:
import tensorflow.keras as keras
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Input

except ImportError:
should_run_tf = False
else:
Expand All @@ -59,7 +60,14 @@ def __init__(self):
def call(self, x):
y = self.conv(x)
return y

if pytest.test_device == "GPU":
try:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.set_logical_device_configuration(
physical_devices[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=5_000)])
except:
logger.warning("Could not set TF max memory limit for GPU")

should_run_tf &= "tensorflow" in installed_redisai_backends()

Expand Down Expand Up @@ -342,7 +350,7 @@ def test_db_model_ensemble(fileutils, wlmutils, mlutils):
smartsim_ensemble.add_model(smartsim_model)

# Add the second ML model to the newly added entity. This is
# because the test script run both ML models for all entities.
# because the test script runs both ML models for all entities.
smartsim_model.add_ml_model(
"cnn2",
"TF",
Expand Down Expand Up @@ -520,7 +528,7 @@ def test_colocated_db_model_ensemble(fileutils, wlmutils, mlutils):
colo_settings.set_tasks_per_node(1)

# Create ensemble of two identical models
colo_ensemble = exp.create_ensemble(
colo_ensemble: Ensemble = exp.create_ensemble(
"colocated_ens", run_settings=colo_settings, replicas=2
)
colo_ensemble.set_path(test_dir)
Expand Down Expand Up @@ -795,7 +803,7 @@ def test_colocated_db_model_errors(fileutils, wlmutils, mlutils):
@pytest.mark.skipif(not should_run_tf, reason="Test needs TensorFlow to run")
def test_inconsistent_params_db_model():
"""Test error when devices_per_node parameter>1 when devices is set to CPU in DBModel"""

# Create and save ML model to filesystem
model, inputs, outputs = create_tf_cnn()
with pytest.raises(SSUnsupportedError) as ex:
Expand All @@ -810,6 +818,6 @@ def test_inconsistent_params_db_model():
outputs=outputs,
)
assert (
ex.value.args[0]
ex.value.args[0]
== "Cannot set devices_per_node>1 if CPU is specified under devices"
)
2 changes: 2 additions & 0 deletions tests/test_configs/cov/local_cov.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ omit =
*mpirun*
*alps*
*lsf*
*pals*
*redis_starter.py*
*/_cli/*
*/_install/*
Expand Down Expand Up @@ -47,3 +48,4 @@ exclude_lines=
launcher == "pbs"
launcher == "cobalt"
launcher == "lsf"
launcher == "pals"

0 comments on commit bc7b232

Please sign in to comment.