diff --git a/python/ray/_private/node.py b/python/ray/_private/node.py index b9097eae54ea4..e79866420e03b 100644 --- a/python/ray/_private/node.py +++ b/python/ray/_private/node.py @@ -27,6 +27,7 @@ from ray._private.resource_spec import ResourceSpec from ray._private.services import serialize_config, get_address from ray._private.utils import open_log, try_to_create_directory, try_to_symlink +from ray.ha import RedisBasedLeaderSelector # Logger for this module. It should be configured at the entry point # into the program using Ray. Ray configures it by default automatically @@ -1412,6 +1413,8 @@ def start_head_processes(self): assert self._gcs_address is None assert self._gcs_client is None + self.start_head_ha_mode() + self.start_gcs_server() assert self.get_gcs_client() is not None self._write_cluster_info_to_kv() @@ -1860,3 +1863,27 @@ def _record_stats(self): # so we truncate it to the first 50 characters # to avoid any issues. record_hardware_usage(cpu_model_name[:50]) + + def check_leadership_downgrade(self): + """[ha feature] Check if the role is downgraded from the active.""" + if self._ray_params.enable_head_ha and hasattr(self, "leader_selector"): + if ( + self.leader_selector is not None + and not self.leader_selector.is_leader() + ): + msg = ( + "This head node will be killed " + "as it has changed from active to standby." + ) + logger.error(msg) + return True + return False + + def start_head_ha_mode(self): + if self._ray_params.enable_head_ha: + logger.info("The head high-availability mode is enabled.") + self.leader_selector = RedisBasedLeaderSelector( + self._ray_params, self._redis_address, self.node_ip_address + ) + self.leader_selector.start() + self.leader_selector.node_wait_to_be_active() diff --git a/python/ray/_private/parameter.py b/python/ray/_private/parameter.py index 1185df149f766..3334761cf253a 100644 --- a/python/ray/_private/parameter.py +++ b/python/ray/_private/parameter.py @@ -265,6 +265,9 @@ def __init__( self.cluster_id = cluster_id self.node_id = node_id self.enable_physical_mode = enable_physical_mode + self.enable_head_ha = ( + os.environ.get("RAY_ENABLE_HEAD_HA", "false").lower() == "true" + ) # Set the internal config options for object reconstruction. if enable_object_reconstruction: diff --git a/python/ray/_private/ray_constants.py b/python/ray/_private/ray_constants.py index 4ea316f2c2cd5..c558641664564 100644 --- a/python/ray/_private/ray_constants.py +++ b/python/ray/_private/ray_constants.py @@ -554,3 +554,17 @@ def gcs_actor_scheduling_enabled(): ) RAY_EXPORT_EVENT_MAX_BACKUP_COUNT = env_bool("RAY_EXPORT_EVENT_MAX_BACKUP_COUNT", 20) + +# head high-availability feature +STORAGE_NAMESPACE = ( + "RAY" + os.environ.get("RAY_external_storage_namespace", "default") + "@" +) +HEAD_NODE_LEADER_ELECTION_KEY = STORAGE_NAMESPACE + "head_node_leader_election_key" +HEAD_ROLE_ACTIVE = "active_head" +HEAD_ROLE_STANDBY = "standby_head" +GCS_ADDRESS_KEY = STORAGE_NAMESPACE + "GcsServerAddress" + +# Number of attempts to ping the Redis server. See +# `services.py::wait_for_redis_to_start()` and +# `services.py::create_redis_client()` +START_REDIS_WAIT_RETRIES = env_integer("RAY_START_REDIS_WAIT_RETRIES", 60) diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index ba83e47bd4404..8e623beb156e7 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -804,7 +804,7 @@ def create_redis_client(redis_address, password=None, username=None): redis_ip_address, redis_port = extract_ip_port( canonicalize_bootstrap_address_or_die(redis_address) ) - cli = redis.StrictRedis( + cli = redis.Redis( host=redis_ip_address, port=int(redis_port), username=username, @@ -2280,3 +2280,77 @@ def start_ray_client_server( fate_share=fate_share, ) return process_info + + +def wait_for_redis_to_start(redis_ip_address, redis_port, username=None, password=None): + """Wait for a Redis server to be available. + + This is accomplished by creating a Redis client and sending a random + command to the server until the command gets through. + + Args: + redis_ip_address (str): The IP address of the redis server. + redis_port (int): The port of the redis server. + username (str): The username of the Redis server. + password (str): The password of the Redis server. + + Raises: + Exception: An exception is raised if we could not connect with Redis. + """ + import redis + + redis_client = create_redis_client( + "%s:%s" % (redis_ip_address, redis_port), password=password, username=username + ) + # Wait for the Redis server to start. + num_retries = ray_constants.START_REDIS_WAIT_RETRIES + delay = 0.001 + for i in range(num_retries): + try: + # Run some random command and see if it worked. + logger.debug( + "Waiting for redis server at {}:{} to respond...".format( + redis_ip_address, redis_port + ) + ) + redis_client.ping() + # If the Redis service is delayed getting set up for any reason, we may + # get a redis.ConnectionError: Error 111 connecting to host:port. + # Connection refused. + # Unfortunately, redis.ConnectionError is also the base class of + # redis.AuthenticationError. We *don't* want to obscure a + # redis.AuthenticationError, because that indicates the user provided a + # bad password. Thus a double except clause to ensure a + # redis.AuthenticationError isn't trapped here. + except redis.AuthenticationError as authEx: + raise RuntimeError( + f"Unable to connect to Redis at {redis_ip_address}:{redis_port}." + ) from authEx + except redis.ConnectionError as connEx: + if i >= num_retries - 1: + raise RuntimeError( + f"Unable to connect to Redis at {redis_ip_address}:" + f"{redis_port} after {num_retries} retries. Check that " + f"{redis_ip_address}:{redis_port} is reachable from this " + "machine. If it is not, your firewall may be blocking " + "this port. If the problem is a flaky connection, try " + "setting the environment variable " + "`RAY_START_REDIS_WAIT_RETRIES` to increase the number of" + " attempts to ping the Redis server." + ) from connEx + # Wait a little bit. + time.sleep(delay) + # Make sure the retry interval doesn't increase too large, which + # will affect the delivery time of the Ray cluster. + delay = min(1, delay * 2) + else: + break + else: + raise RuntimeError( + f"Unable to connect to Redis (after {num_retries} retries). " + "If the Redis instance is on a different machine, check that " + "your firewall and relevant Ray ports are configured properly. " + "You can also set the environment variable " + "`RAY_START_REDIS_WAIT_RETRIES` to increase the number of " + "attempts to ping the Redis server." + ) diff --git a/python/ray/ha/__init__.py b/python/ray/ha/__init__.py new file mode 100644 index 0000000000000..0bed245dd6faf --- /dev/null +++ b/python/ray/ha/__init__.py @@ -0,0 +1,13 @@ +from ray.ha.leader_selector import HeadNodeLeaderSelector +from ray.ha.leader_selector import HeadNodeLeaderSelectorConfig +from ray.ha.redis_leader_selector import RedisBasedLeaderSelector +from ray.ha.redis_leader_selector import is_service_available +from ray.ha.redis_leader_selector import waiting_for_server_stopped + +__all__ = [ + "RedisBasedLeaderSelector", + "HeadNodeLeaderSelector", + "HeadNodeLeaderSelectorConfig", + "is_service_available", + "waiting_for_server_stopped", +] diff --git a/python/ray/ha/leader_selector.py b/python/ray/ha/leader_selector.py new file mode 100644 index 0000000000000..b9b86297697af --- /dev/null +++ b/python/ray/ha/leader_selector.py @@ -0,0 +1,93 @@ +import logging +import os +import ray._private.ray_constants as ray_constants + +logger = logging.getLogger(__name__) + + +class HeadNodeLeaderSelectorConfig: + # Will exit after N consecutive failures + max_failure_count = None + # The interval time for selecting the leader or expire leader + check_interval_s = None + # The expiration time of the leader key + key_expire_time_ms = None + # The time to wait after becoming the active allows the gcs client to + # have enough time to find out that the gcs address has changed. + wait_time_after_be_active_s = None + # Maximum time to wait pre gcs stop serving + wait_pre_gcs_stop_max_time_s = None + # Redis/socket connect time out(s) + connect_timeout_s = None + + def __init__(self): + self.max_failure_count = int( + os.environ.get("RAY_HA_CHECK_MAX_FAILURE_COUNT", "5") + ) + # Default 3s + self.check_interval_s = ( + int(os.environ.get("RAY_HA_CHECK_INTERVAL_MS", "3000")) / 1000 + ) + # Default 60s + self.key_expire_time_ms = int( + os.environ.get("RAY_HA_KEY_EXPIRE_TIME_MS", "60000") + ) + # Default 2s + self.wait_time_after_be_active_s = ( + int(os.environ.get("RAY_HA_WAIT_TIME_AFTER_BE_ACTIVE_MS", "2000")) / 1000 + ) + # Default 3 day (1000 * 60 * 60 * 24 * 3). + self.wait_pre_gcs_stop_max_time_s = ( + int(os.environ.get("RAY_HA_WAIT_PRE_GCS_STOP_MAX_TIME_MS", "259200000")) + / 1000 + ) + # Default redis/socket connect time out is 5s. + self.connect_timeout_s = ( + int(os.environ.get("RAY_HA_CONNECT_TIMEOUT_MS", "5000")) / 1000 + ) + + def __repr__(self) -> str: + return ( + f"HeadNodeLeaderSelectorConfig[" + f"max_failure_count={self.max_failure_count}," + f"check_interval_s={self.check_interval_s}," + f"key_expire_time_ms={self.key_expire_time_ms}," + f"wait_time_after_be_active_s={self.wait_time_after_be_active_s}," + f"wait_pre_gcs_stop_max_time_s=" + f"{self.wait_pre_gcs_stop_max_time_s}," + f"connect_timeout_s={self.connect_timeout_s}]" + ) + + +class HeadNodeLeaderSelector: + _role_type = ray_constants.HEAD_ROLE_STANDBY + _config = None + _is_running = False + + def __init__(self): + self._config = HeadNodeLeaderSelectorConfig() + + def start(self): + pass + + def stop(self): + pass + + def set_role_type(self, role_type): + if role_type != self._role_type: + logger.info( + "This head node changed from %s to %s.", self._role_type, role_type + ) + self._role_type = role_type + + def get_role_type(self): + return self._role_type + + def is_leader(self): + return self.get_role_type() == ray_constants.HEAD_ROLE_ACTIVE + + def do_action_after_be_active(self): + pass + + def is_running(self): + return self._is_running diff --git a/python/ray/ha/redis_leader_selector.py b/python/ray/ha/redis_leader_selector.py new file mode 100644 index 0000000000000..794ba47ba5396 --- /dev/null +++ b/python/ray/ha/redis_leader_selector.py @@ -0,0 +1,298 @@ +import threading +import logging +import socket +import time +import redis +import ray._private.services +import ray._private.gcs_utils +from ray.ha import HeadNodeLeaderSelector +import ray._private.ray_constants as ray_constants +import os +import requests +import json + +logger = logging.getLogger(__name__) + + +def is_service_available(ip, port): + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(0.2) + result = sock.connect_ex((ip, int(port))) + sock.close() + return 0 == result + except Exception: + if sock is not None: + sock.close() + return False + + +def waiting_for_server_stopped(address, max_time): + if address is None or len(address) <= 0: + return True + start_time = time.time() + use_time = 0 + ip, port = address.split(":") + while use_time <= max_time: + if not is_service_available(ip, port): + return True, time.time() - start_time + time.sleep(0.5) + use_time = time.time() - start_time + return False, use_time + + +def parse_head_num(data): + head_num = 0 + groups = data.get("data", {}).get("head", []) + for group in groups: + for group_name in group.keys(): + head_num = head_num + group.get(group_name, {}).get("replicasTotal", 0) + + return head_num + + +def get_cluster_head_num(timeout): + cluster_name = os.environ.get("CLUSTER_NAME", "") + ray_operator_address = os.environ.get("RAY_OPERATOR_ADDRESS", "") + k8s_namespace = os.environ.get("NAMESPACE", "") + head_group_name = os.environ.get("RAY_SHAPE_GROUP", "default") + url = f"http://{ray_operator_address}/elasticity/v2/workerNodes/status" + data = { + "name": cluster_name, + "namespace": k8s_namespace, + "containsPodInfo": False, + "headNode": {"group": head_group_name}, + } + try: + response = requests.post(url, data=json.dumps(data), timeout=timeout) + if response.status_code >= 200 and response.status_code < 300: + resp_json = response.json() + if resp_json.get("success", False): + return parse_head_num(resp_json) + else: + raise ConnectionError( + "The response of request has failed, resp:{resp_json}." + ) + else: + raise ConnectionError( + f"The status code of request is error, " + f"status code:{response.status_code}, resp:{response}." + ) + except Exception as err: + logger.info( + f"Failed to get cluster head num, " + f"url:{url}, data:{data}, exception:{str(err)}" + ) + return -1 + + +class RedisBasedLeaderSelector(HeadNodeLeaderSelector): + _real_failure_count = 0 + _timer = None + _leader_name = None + + def __init__(self, ray_params, redis_address, node_ip_address): + super().__init__() + self._redis_address = redis_address + self._redis_username = ray_params.redis_username + self._redis_password = ray_params.redis_password + self._node_ip_address = node_ip_address + gcs_server_port = ( + 0 if ray_params.gcs_server_port is None else ray_params.gcs_server_port + ) + self._init_gcs_address = node_ip_address + ":" + str(gcs_server_port) + suffix = str(int(time.time())) + self.name = (self._init_gcs_address + "-" + suffix).encode("UTF-8") + logger.info("Initialized redis leader selector with %s.", self._config) + + def start(self): + self._is_running = True + redis_ip_address, redis_port = self._redis_address.split(":") + ray._private.services.wait_for_redis_to_start( + redis_ip_address, + redis_port, + username=self._redis_username, + password=self._redis_password, + ) + self._redis_client = redis.Redis( + host=redis_ip_address, + port=int(redis_port), + username=self._redis_username, + password=self._redis_password, + socket_timeout=self._config.connect_timeout_s, + socket_connect_timeout=self._config.connect_timeout_s, + ) + logger.info("Success to start redis leader selecotr. name:%s", self.name) + self.do_election() + + def stop(self): + self.set_role_type(ray_constants.HEAD_ROLE_STANDBY) + self._is_running = False + if self._timer is not None: + self._timer.cancel() + self._timer = None + if self._redis_client is not None: + self._redis_client.close() + self._redis_client = None + + def do_election(self): + if self._is_running: + self.check_leader() + if self._is_running: + self._timer = threading.Timer( + self._config.check_interval_s, self.do_election + ) + self._timer.start() + else: + logger.warning("Timer still running while leader selector stopped.") + + def check_leader(self): + try: + isSetSuccess = self._redis_client.set( + ray_constants.HEAD_NODE_LEADER_ELECTION_KEY, + self.name, + px=self._config.key_expire_time_ms, + nx=True, + ) + if isSetSuccess: + self.set_role_type(ray_constants.HEAD_ROLE_ACTIVE) + self._real_failure_count = 0 + logger.info( + "This head node preempted lock, " + "become the active head node, name:%s.", + self.name, + ) + else: + self.expire_leader_key() + except Exception: + logger.error("%s node failed to check leader.", self.name, exc_info=True) + self._real_failure_count += 1 + if self._real_failure_count > self._config.max_failure_count: + head_num = get_cluster_head_num(self._config.connect_timeout_s) + if self.is_leader() and head_num == 1: + if self._real_failure_count % self._config.max_failure_count == 1: + logger.warning( + "Because current node is leader and " + "cluster head number is only one," + "will not change role to standby." + ) + else: + logger.error( + "This leader selector will stop for " + "exception %d times more then %d. " + "The role will change from %s to standby, " + "head num: %d.", + self._real_failure_count, + self._config.max_failure_count, + self.get_role_type(), + head_num, + ) + self.stop() + + def expire_leader_key(self): + leader_name = self._redis_client.get( + ray_constants.HEAD_NODE_LEADER_ELECTION_KEY + ) + if leader_name == self.name: + isSuccess = self._redis_client.pexpire( + ray_constants.HEAD_NODE_LEADER_ELECTION_KEY, + self._config.key_expire_time_ms, + ) + if not isSuccess: + self._real_failure_count += 1 + logger.error( + "This active head node expire leader " + "key failed, failure count:%d.", + self._real_failure_count, + ) + if self._real_failure_count > self._config.max_failure_count: + logger.error( + "This active head node role " + "downgraded to standby for expire failed" + " %d times more then %d.", + self._real_failure_count, + self._config.max_failure_count, + ) + self.stop() + self._real_failure_count = 0 + self.set_role_type(ray_constants.HEAD_ROLE_ACTIVE) + else: + if self.is_leader(): + logger.error( + "This active head node role downgraded " + "to standby for leader key changed, active:%s.", + str(leader_name), + ) + self.stop() + else: + if self._leader_name != str(leader_name): + self._leader_name = str(leader_name) + logger.info( + "This head node role is standby now. " "active node is %s.", + str(leader_name), + ) + else: + logger.debug( + "This head node role is standby now. " "active node is %s.", + str(leader_name), + ) + self.set_role_type(ray_constants.HEAD_ROLE_STANDBY) + + def do_action_after_be_active(self): + expired_gcs_address = self._redis_client.get(ray_constants.GCS_ADDRESS_KEY) + if expired_gcs_address: + pre_gcs_address = str(expired_gcs_address, encoding="utf-8") + self._redis_client.set( + ray_constants.GCS_ADDRESS_KEY, self._init_gcs_address + ) + logger.info( + "Reset gcs address from %s to invalid gcs address:%s", + pre_gcs_address, + self._init_gcs_address, + ) + time.sleep(self._config.wait_time_after_be_active_s) + max_wait_time = self._config.wait_pre_gcs_stop_max_time_s + is_disconnect, use_time = waiting_for_server_stopped( + pre_gcs_address, max_wait_time + ) + if is_disconnect: + logger.info( + "After waiting for %f s, the previous gcs(%s) has " + "stopped working, going to start startup process.", + use_time, + pre_gcs_address, + ) + return True + else: + logger.error( + "After waiting for %d s, the previous gcs(%s) still" + " working. this maybe casue some wrong.", + max_wait_time, + pre_gcs_address, + ) + return False + return True + + def node_wait_to_be_active(self): + logger.info("This head node is waiting to to be active...") + while True: + if not self.is_running(): + logger.error( + "The leader selector has stopped in waiting," " will restart it." + ) + self.start() + continue + + if self.is_leader(): + logger.info( + "This head node changed from standby to active, " + "start the startup process..." + ) + self.do_action_after_be_active() + break + logger.debug( + "This head node role is %s, waiting to be acitve node.", + self.get_role_type(), + ) + time.sleep(0.2) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index ac7519c1f3e8d..500e09ed40501 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -88,6 +88,34 @@ def _check_ray_version(gcs_client): ) +def handle_process_fo(node, process_name): + logger.info("process {} failover".format(process_name)) + if process_name == ray.ray_constants.PROCESS_TYPE_GCS_SERVER: + node.kill_gcs_server(check_alive=False) + node.start_gcs_server() + elif process_name == ray.ray_constants.PROCESS_TYPE_DASHBOARD: + node.kill_dashboard(check_alive=False) + node.start_dashboard(require_dashboard=True) + elif process_name == ray.ray_constants.PROCESS_TYPE_MONITOR: + node.kill_monitor(check_alive=False) + node.start_monitor() + else: + logger.error("No FO policy defined for {}".format(process_name)) + + +def check_ray_processes(node): + for name, infos in node.all_processes.items(): + for info in infos: + # find exit process + if info.process.poll() is not None: + logger.error( + "Process {} (pid={}) is dead, try to restart.".format( + name, info.process.pid + ) + ) + handle_process_fo(node, name) + + @click.group() @click.option( "--logging-level", @@ -811,6 +839,8 @@ def start( include_log_monitor=include_log_monitor, ) + clean_processes_at_exit = block or (ray_params.enable_head_ha and head) + if ray_constants.RAY_START_HOOK in os.environ: _load_class(os.environ[ray_constants.RAY_START_HOOK])(ray_params, head) @@ -908,7 +938,10 @@ def start( ) node = ray._private.node.Node( - ray_params, head=True, shutdown_at_exit=block, spawn_reaper=block + ray_params, + head=True, + shutdown_at_exit=clean_processes_at_exit, + spawn_reaper=clean_processes_at_exit, ) bootstrap_address = node.address @@ -1089,9 +1122,11 @@ def start( assert ray_params.gcs_address is not None ray._private.utils.write_ray_address(ray_params.gcs_address, temp_dir) - if block: + if clean_processes_at_exit: cli_logger.newline() - with cli_logger.group(cf.bold("--block")): + msg = "--block" if block else "" + msg = msg + "\nenable_head_ha" if head and ray_params.enable_head_ha else "" + with cli_logger.group(cf.bold(msg)): cli_logger.print( "This command will now block forever until terminated by a signal." ) @@ -1104,44 +1139,55 @@ def start( while True: time.sleep(1) - deceased = node.dead_processes() - - # Report unexpected exits of subprocesses with unexpected return codes. - # We are explicitly expecting SIGTERM because this is how `ray stop` sends - # shutdown signal to subprocesses, i.e. log_monitor, raylet... - # NOTE(rickyyx): We are treating 128+15 as an expected return code since - # this is what autoscaler/_private/monitor.py does upon SIGTERM - # handling. - expected_return_codes = [ - 0, - signal.SIGTERM, - -1 * signal.SIGTERM, - 128 + signal.SIGTERM, - ] - unexpected_deceased = [ - (process_type, process) - for process_type, process in deceased - if process.returncode not in expected_return_codes - ] - if len(unexpected_deceased) > 0: - cli_logger.newline() - cli_logger.error("Some Ray subprocesses exited unexpectedly:") - - with cli_logger.indented(): - for process_type, process in unexpected_deceased: - cli_logger.error( - "{}", - cf.bold(str(process_type)), - _tags={"exit code": str(process.returncode)}, - ) - cli_logger.newline() - cli_logger.error("Remaining processes will be killed.") - # explicitly kill all processes since atexit handlers - # will not exit with errors. - node.kill_all_processes(check_alive=False, allow_graceful=False) - os._exit(1) - # not-reachable + # Head HA + if head and ray_params.enable_head_ha: + if node.check_leadership_downgrade(): + raise RuntimeError("leadership downgrade") + else: + check_ray_processes(node) + + if block: + deceased = node.dead_processes() + + # Report unexpected exits of subprocesses with unexpected return codes. + # We are explicitly expecting SIGTERM because this is how `ray stop` sends + # shutdown signal to subprocesses, i.e. log_monitor, raylet... + # NOTE(rickyyx): We are treating 128+15 as an expected return code since + # this is what autoscaler/_private/monitor.py does upon SIGTERM + # handling. + expected_return_codes = [ + 0, + signal.SIGTERM, + -1 * signal.SIGTERM, + 128 + signal.SIGTERM, + ] + unexpected_deceased = [ + (process_type, process) + for process_type, process in deceased + if process.returncode not in expected_return_codes + ] + if len(unexpected_deceased) > 0: + cli_logger.newline() + cli_logger.error("Some Ray subprocesses exited unexpectedly:") + + with cli_logger.indented(): + for process_type, process in unexpected_deceased: + cli_logger.error( + "{}", + cf.bold(str(process_type)), + _tags={"exit code": str(process.returncode)}, + ) + + cli_logger.newline() + cli_logger.error("Remaining processes will be killed.") + # explicitly kill all processes since atexit handlers + # will not exit with errors. + node.kill_all_processes(check_alive=False, allow_graceful=False) + os._exit(1) + + +# not-reachable @cli.command() diff --git a/python/ray/tests/test_head_node_ha/test_ha_leader_selector_config.py b/python/ray/tests/test_head_node_ha/test_ha_leader_selector_config.py new file mode 100644 index 0000000000000..f30a3a3a29a69 --- /dev/null +++ b/python/ray/tests/test_head_node_ha/test_ha_leader_selector_config.py @@ -0,0 +1,87 @@ +import os +import pytest +import ray._private.parameter as parameter +from ray.ha.redis_leader_selector import RedisBasedLeaderSelector + + +def test_leader_selector_init(): + ray_params = parameter.RayParams() + leader_selector = RedisBasedLeaderSelector(ray_params, "0.0.0.0:0", "0.0.0.0") + assert not ray_params.enable_head_ha + assert str(leader_selector.name, "utf-8").startswith("0.0.0.0:0") + assert leader_selector._config.check_interval_s == 3 + assert leader_selector._config.key_expire_time_ms == 60000 + assert leader_selector._config.max_failure_count == 5 + assert leader_selector._config.wait_pre_gcs_stop_max_time_s == 3600 * 24 * 3 + assert leader_selector._config.wait_time_after_be_active_s == 2 + + +data_dict = [ + { + "is_enable": ["true", True], + "max_failure": ["4", 4], + "check_inverval": ["600", 0.6], + "expire_time": ["5000", 5000], + "wait_time": ["3000", 3.0], + "wait_gcs_max_time": ["20000", 20], + }, + { + "is_enable": ["false", False], + "max_failure": ["6", 6], + "check_inverval": ["550", 0.55], + "expire_time": ["4111", 4111], + "wait_time": ["3111", 3.111], + "wait_gcs_max_time": ["22222", 22.222], + }, + { + "is_enable": ["True", True], + "max_failure": ["10", 10], + "check_inverval": ["10000", 10.0], + "expire_time": ["30000", 30000], + "wait_time": ["10000", 10.0], + "wait_gcs_max_time": ["100000", 100.0], + }, +] + + +@pytest.fixture(params=data_dict) +def check_ha_env_config(request): + os.environ["RAY_ENABLE_HEAD_HA"] = request.param["is_enable"][0] + os.environ["RAY_HA_CHECK_MAX_FAILURE_COUNT"] = request.param["max_failure"][0] + os.environ["RAY_HA_CHECK_INTERVAL_MS"] = request.param["check_inverval"][0] + os.environ["RAY_HA_KEY_EXPIRE_TIME_MS"] = request.param["expire_time"][0] + os.environ["RAY_HA_WAIT_TIME_AFTER_BE_ACTIVE_MS"] = request.param["wait_time"][0] + os.environ["RAY_HA_WAIT_PRE_GCS_STOP_MAX_TIME_MS"] = request.param[ + "wait_gcs_max_time" + ][0] + yield request.param + os.environ.pop("RAY_ENABLE_HEAD_HA") + os.environ.pop("RAY_HA_CHECK_MAX_FAILURE_COUNT") + os.environ.pop("RAY_HA_CHECK_INTERVAL_MS") + os.environ.pop("RAY_HA_KEY_EXPIRE_TIME_MS") + os.environ.pop("RAY_HA_WAIT_TIME_AFTER_BE_ACTIVE_MS") + os.environ.pop("RAY_HA_WAIT_PRE_GCS_STOP_MAX_TIME_MS") + + +def test_leader_selector_init_param(check_ha_env_config): + result = check_ha_env_config + ray_params = parameter.RayParams() + assert ray_params.enable_head_ha == result["is_enable"][1] + + leader_selector = RedisBasedLeaderSelector(ray_params, "0.0.0.0:0", "0.0.0.0") + assert str(leader_selector.name, "utf-8").startswith("0.0.0.0:0") + assert leader_selector._config.check_interval_s == result["check_inverval"][1] + assert leader_selector._config.key_expire_time_ms == result["expire_time"][1] + assert leader_selector._config.max_failure_count == result["max_failure"][1] + assert ( + leader_selector._config.wait_pre_gcs_stop_max_time_s + == result["wait_gcs_max_time"][1] + ) + assert leader_selector._config.wait_time_after_be_active_s == result["wait_time"][1] + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_head_node_ha/test_ha_redis_leader_selector.py b/python/ray/tests/test_head_node_ha/test_ha_redis_leader_selector.py new file mode 100644 index 0000000000000..5feac462d8102 --- /dev/null +++ b/python/ray/tests/test_head_node_ha/test_ha_redis_leader_selector.py @@ -0,0 +1,145 @@ +import time +import os +import ray._private.parameter as parameter +from ray.cluster_utils import Cluster +from ray.ha import RedisBasedLeaderSelector +from ray.ha.redis_leader_selector import get_cluster_head_num +from ray._private.test_utils import wait_for_condition +from unittest.mock import patch +import requests + +become_active_wait_time = 65 + + +def test_leader_selector_start(ray_start_cluster_head_with_external_redis): + cluster: Cluster = ray_start_cluster_head_with_external_redis + ray_params = parameter.RayParams() + ray_params.redis_password = cluster.redis_password + leader_selector = RedisBasedLeaderSelector( + ray_params, cluster.redis_address, "0.0.0.0" + ) + leader_selector.start() + time.sleep(1) + leader_selector_2 = RedisBasedLeaderSelector( + ray_params, cluster.redis_address, "1.1.1.1" + ) + leader_selector_2.start() + wait_for_condition(lambda: leader_selector.is_leader(), 5) + wait_for_condition(lambda: not leader_selector_2.is_leader(), 5) + leader_selector.stop() + assert not leader_selector.is_leader() + + # test active/standby switch + wait_for_condition(lambda: leader_selector_2.is_leader(), become_active_wait_time) + leader_selector.start() + wait_for_condition(lambda: not leader_selector.is_leader(), 5) + leader_selector_2.stop() + wait_for_condition(lambda: leader_selector.is_leader(), become_active_wait_time) + leader_selector.stop() + # test stop repeat + leader_selector.stop() + + +def test_leader_selector_keep_leader(ray_start_cluster_head_with_external_redis): + cluster: Cluster = ray_start_cluster_head_with_external_redis + ray_params = parameter.RayParams() + ray_params.redis_password = cluster.redis_password + leader_selector = RedisBasedLeaderSelector( + ray_params, cluster.redis_address, "0.0.0.0" + ) + leader_selector.start() + time.sleep(1) + leader_selector_2 = RedisBasedLeaderSelector( + ray_params, cluster.redis_address, "1.1.1.1" + ) + leader_selector_2.start() + leader_selector_3 = RedisBasedLeaderSelector( + ray_params, cluster.redis_address, "2.2.2.2" + ) + leader_selector_3.start() + leader_selector_4 = RedisBasedLeaderSelector( + ray_params, cluster.redis_address, "2.2.2.2" + ) + leader_selector_4.start() + wait_for_condition(lambda: leader_selector.is_leader(), 5) + wait_for_condition(lambda: not leader_selector_2.is_leader(), 5) + wait_for_condition(lambda: not leader_selector_3.is_leader(), 5) + wait_for_condition(lambda: not leader_selector_4.is_leader(), 5) + time.sleep(10) + assert leader_selector.is_leader() + assert not leader_selector_2.is_leader() + assert not leader_selector_3.is_leader() + assert not leader_selector_4.is_leader() + leader_selector.stop() + leader_selector_2.stop() + leader_selector_3.stop() + leader_selector_4.stop() + + +def test_selector_while_disconnect_redis(external_redis): + redis_cluster = external_redis + ray_params = parameter.RayParams() + ray_params.redis_password = redis_cluster.redis_password + leader_selector = RedisBasedLeaderSelector( + ray_params, redis_cluster.redis_address, "0.0.0.0" + ) + leader_selector.start() + leader_selector_2 = RedisBasedLeaderSelector( + ray_params, redis_cluster.redis_address, "1.1.1.1" + ) + leader_selector_2.start() + wait_for_condition(lambda: leader_selector.is_leader(), 5) + wait_for_condition(lambda: not leader_selector_2.is_leader(), 5) + redis_cluster.shutdown() + wait_for_condition(lambda: not leader_selector.is_leader(), become_active_wait_time) + wait_for_condition( + lambda: not leader_selector_2.is_leader(), become_active_wait_time + ) + + +def test_special_action(ray_start_cluster_head_with_external_redis): + os.environ["RAY_HA_WAIT_PRE_GCS_STOP_MAX_TIME_MS"] = "10000" + cluster: Cluster = ray_start_cluster_head_with_external_redis + ray_params = parameter.RayParams() + ray_params.redis_password = cluster.redis_password + leader_selector = RedisBasedLeaderSelector( + ray_params, cluster.redis_address, "0.0.0.0" + ) + leader_selector.start() + start_time = time.time() + assert not leader_selector.do_action_after_be_active() + assert (time.time() - start_time) > 9 + os.environ.pop("RAY_HA_WAIT_PRE_GCS_STOP_MAX_TIME_MS") + + +def test_get_cluster_head_num(): + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = b'{"success":true,"message":"try to get cluster pod information: ray-test-zhiyu in namespace: arconkube","data":{"head":[{"default":{"replicasUpdated":1,"replicasTotal":1}}]}}\n' # noqa: E501 + + with patch("requests.post", return_value=mock_response): + assert get_cluster_head_num(5) == 1 + + mock_response.status_code = 200 + mock_response._content = b'{"success":true,"message":"try to get cluster pod information: ray-test-zhiyu in namespace: arconkube","data":{"head":[{"default":{"replicasUpdated":2,"replicasTotal":2}}]}}\n' # noqa: E501 + with patch("requests.post", return_value=mock_response): + assert get_cluster_head_num(5) == 2 + + mock_response.status_code = 200 + mock_response._content = ( + b'{"success":false,"message":"invalid json :","data":null}\n' + ) + with patch("requests.post", return_value=mock_response): + assert get_cluster_head_num(5) == -1 + + mock_response.status_code = 404 + mock_response._content = b"OK" + with patch("requests.post", return_value=mock_response): + assert get_cluster_head_num(5) == -1 + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_head_node_ha/test_ha_utils.py b/python/ray/tests/test_head_node_ha/test_ha_utils.py new file mode 100644 index 0000000000000..57ae479cb6e73 --- /dev/null +++ b/python/ray/tests/test_head_node_ha/test_ha_utils.py @@ -0,0 +1,35 @@ +import time +import ray +import ray._private.ray_constants as ray_constants +from ray.ha import is_service_available, waiting_for_server_stopped + + +def test_until_disconnect_with_address(ray_start_cluster_head_with_external_redis): + cluster = ray_start_cluster_head_with_external_redis + redis_client = ray._private.services.create_redis_client( + cluster.redis_address, cluster.redis_password + ) + gcs_address = str(redis_client.get(ray_constants.GCS_ADDRESS_KEY), encoding="utf-8") + ip, port = gcs_address.split(":") + assert is_service_available(ip, port) + max_time = 3 + is_disconnect, user_time = waiting_for_server_stopped(gcs_address, max_time) + assert not is_disconnect + assert abs(user_time - max_time) < 0.5 + cluster.head_node.kill_gcs_server() + time.sleep(2) + is_disconnect, user_time = waiting_for_server_stopped(gcs_address, max_time) + assert is_disconnect + assert user_time < 1 + assert not is_service_available(ip, port) + + +def test_is_service_available_exception(): + assert not is_service_available("aaaaa", 0) + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main(["-v", __file__]))