Skip to content

Commit

Permalink
DBNode: MPMD RunSettings Unification (#379)
Browse files Browse the repository at this point in the history
A general refactor of the DBNode class to:
- Remove the duplicate methods with near identical functionality where one was intended to be called if the underlying RunSettings were standard and the other if they were MPMD. This removes undefined behavior when the "wrong" method was called
- Allow MPMD DBNodes to map information "per shard" by scraping output files for a serializable LaunchedShardData class.

[ committed by @MattToast ]
[ reviewed by @ashao ]
  • Loading branch information
MattToast authored Oct 5, 2023
1 parent 9471ab4 commit 7eaab83
Show file tree
Hide file tree
Showing 7 changed files with 431 additions and 267 deletions.
8 changes: 8 additions & 0 deletions smartsim/_core/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def database_cli(self) -> str:
"Specified Redis binary at REDIS_CLI_PATH could not be used"
) from e

@property
def database_file_parse_trials(self) -> int:
return int(os.getenv("SMARTSIM_DB_FILE_PARSE_TRIALS", "10"))

@property
def database_file_parse_interval(self) -> int:
return int(os.getenv("SMARTSIM_DB_FILE_PARSE_INTERVAL", "2"))

@property
def log_level(self) -> str:
return os.environ.get("SMARTSIM_LOG_LEVEL", "info")
Expand Down
130 changes: 99 additions & 31 deletions smartsim/_core/entrypoints/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,28 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import argparse
import json
import os
import psutil
import signal
import textwrap
import typing as t
from subprocess import PIPE, STDOUT
from types import FrameType

import psutil

from smartsim._core.utils.network import current_ip
from smartsim.entity.dbnode import LaunchedShardData
from smartsim.error import SSInternalError
from smartsim.log import get_logger
from subprocess import PIPE, STDOUT
from types import FrameType

logger = get_logger(__name__)

"""
Redis/KeyDB entrypoint script
"""

DBPID = None
DBPID: t.Optional[int] = None

# kill is not catchable
SIGNALS = [signal.SIGINT, signal.SIGQUIT, signal.SIGTERM, signal.SIGABRT]
Expand All @@ -54,31 +58,64 @@ def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None:
cleanup()


def build_bind_args(ip_addresses: t.List[str]) -> t.List[str]:
bind_arg = f"--bind {' '.join(ip_addresses)}"
# pin source address to avoid random selection by Redis
bind_src_arg = f"--bind-source-addr {ip_addresses[0]}"
return [bind_arg, bind_src_arg]
def build_bind_args(source_addr: str, *addrs: str) -> t.Tuple[str, ...]:
return (
"--bind",
source_addr,
*addrs,
# pin source address to avoid random selection by Redis
"--bind-source-addr",
source_addr,
)


def print_summary(cmd: t.List[str], ip_address: str, network_interface: str) -> None:
print("-" * 10, " Running Command ", "-" * 10, "\n", flush=True)
print(f"COMMAND: {' '.join(cmd)}\n", flush=True)
print(f"IPADDRESS: {ip_address}\n", flush=True)
print(f"NETWORK: {network_interface}\n", flush=True)
print("-" * 30, "\n\n", flush=True)
print("-" * 10, " Output ", "-" * 10, "\n\n", flush=True)
def build_cluster_args(shard_data: LaunchedShardData) -> t.Tuple[str, ...]:
if cluster_conf_file := shard_data.cluster_conf_file:
return ("--cluster-enabled", "yes", "--cluster-config-file", cluster_conf_file)
return ()


def main(network_interface: str, command: t.List[str]) -> None:
def print_summary(
cmd: t.List[str], network_interface: str, shard_data: LaunchedShardData
) -> None:
print(
textwrap.dedent(
f"""\
----------- Running Command ----------
COMMAND: {' '.join(cmd)}
IPADDRESS: {shard_data.hostname}
NETWORK: {network_interface}
SMARTSIM_ORC_SHARD_INFO: {json.dumps(shard_data.to_dict())}
--------------------------------------
--------------- Output ---------------
"""
),
flush=True,
)


def main(args: argparse.Namespace) -> int:
global DBPID # pylint: disable=global-statement

try:
ip_addresses = [current_ip(net_if) for net_if in network_interface.split(",")]
cmd = command + build_bind_args(ip_addresses)
src_addr, *bind_addrs = (current_ip(net_if) for net_if in args.ifname.split(","))
shard_data = LaunchedShardData(
name=args.name, hostname=src_addr, port=args.port, cluster=args.cluster
)

print_summary(cmd, ip_addresses[0], network_interface)
cmd = [
args.orc_exe,
args.conf_file,
*args.rai_module,
"--port",
str(args.port),
*build_cluster_args(shard_data),
*build_bind_args(src_addr, *bind_addrs),
]
print_summary(cmd, args.ifname, shard_data)

try:
process = psutil.Popen(cmd, stdout=PIPE, stderr=STDOUT)
DBPID = process.pid

Expand All @@ -87,18 +124,17 @@ def main(network_interface: str, command: t.List[str]) -> None:
except Exception as e:
cleanup()
raise SSInternalError("Database process starter raised an exception") from e
return 0


def cleanup() -> None:
logger.debug("Cleaning up database instance")
try:
logger.debug("Cleaning up database instance")
# attempt to stop the database process
db_proc = psutil.Process(DBPID)
db_proc.terminate()

if DBPID is not None:
psutil.Process(DBPID).terminate()
except psutil.NoSuchProcess:
logger.warning("Couldn't find database process to kill.")

except OSError as e:
logger.warning(f"Failed to clean up database gracefully: {str(e)}")

Expand All @@ -110,15 +146,47 @@ def cleanup() -> None:
prefix_chars="+", description="SmartSim Process Launcher"
)
parser.add_argument(
"+ifname", type=str, help="Network Interface name", default="lo"
"+orc-exe", type=str, help="Path to the orchestrator executable", required=True
)
parser.add_argument(
"+conf-file",
type=str,
help="Path to the orchestrator configuration file",
required=True,
)
parser.add_argument(
"+rai-module",
nargs="+",
type=str,
help=(
"Command for the orcestrator to load the Redis AI module with "
"symbols seperated by whitespace"
),
required=True,
)
parser.add_argument(
"+name", type=str, help="Name to identify the shard", required=True
)
parser.add_argument(
"+port",
type=int,
help="The port on which to launch the shard of the orchestrator",
required=True,
)
parser.add_argument(
"+ifname", type=str, help="Network Interface name", required=True
)
parser.add_argument(
"+cluster",
action="store_true",
help="Specify if this orchestrator shard is part of a cluster",
)
parser.add_argument("+command", nargs="+", help="Command to run")
args = parser.parse_args()
args_ = parser.parse_args()

# make sure to register the cleanup before the start
# the proecss so our signaller will be able to stop
# the process so our signaller will be able to stop
# the database process.
for sig in SIGNALS:
signal.signal(sig, handle_signal)

main(args.ifname, args.command)
raise SystemExit(main(args_))
Loading

0 comments on commit 7eaab83

Please sign in to comment.