Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Head HA #499

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/ray/_private/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ray._private.resource_spec import ResourceSpec
from ray._private.services import serialize_config, get_address
from ray._private.utils import open_log, try_to_create_directory, try_to_symlink
from ray.ha import RedisBasedLeaderSelector

# Logger for this module. It should be configured at the entry point
# into the program using Ray. Ray configures it by default automatically
Expand Down Expand Up @@ -1412,6 +1413,8 @@ def start_head_processes(self):
assert self._gcs_address is None
assert self._gcs_client is None

self.start_head_ha_mode()

self.start_gcs_server()
assert self.get_gcs_client() is not None
self._write_cluster_info_to_kv()
Expand Down Expand Up @@ -1860,3 +1863,27 @@ def _record_stats(self):
# so we truncate it to the first 50 characters
# to avoid any issues.
record_hardware_usage(cpu_model_name[:50])

def check_leadership_downgrade(self):
"""[ha feature] Check if the role is downgraded from the active."""
if self._ray_params.enable_head_ha and hasattr(self, "leader_selector"):
if (
self.leader_selector is not None
and not self.leader_selector.is_leader()
):
msg = (
"This head node will be killed "
"as it has changed from active to standby."
)
logger.error(msg)
return True
return False

def start_head_ha_mode(self):
if self._ray_params.enable_head_ha:
logger.info("The head high-availability mode is enabled.")
self.leader_selector = RedisBasedLeaderSelector(
self._ray_params, self._redis_address, self.node_ip_address
)
self.leader_selector.start()
self.leader_selector.node_wait_to_be_active()
3 changes: 3 additions & 0 deletions python/ray/_private/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ def __init__(
self.cluster_id = cluster_id
self.node_id = node_id
self.enable_physical_mode = enable_physical_mode
self.enable_head_ha = (
os.environ.get("RAY_ENABLE_HEAD_HA", "false").lower() == "true"
)

# Set the internal config options for object reconstruction.
if enable_object_reconstruction:
Expand Down
14 changes: 14 additions & 0 deletions python/ray/_private/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,17 @@ def gcs_actor_scheduling_enabled():
)

RAY_EXPORT_EVENT_MAX_BACKUP_COUNT = env_bool("RAY_EXPORT_EVENT_MAX_BACKUP_COUNT", 20)

# head high-availability feature
STORAGE_NAMESPACE = (
"RAY" + os.environ.get("RAY_external_storage_namespace", "default") + "@"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use RAY_

)
HEAD_NODE_LEADER_ELECTION_KEY = STORAGE_NAMESPACE + "head_node_leader_election_key"
HEAD_ROLE_ACTIVE = "active_head"
HEAD_ROLE_STANDBY = "standby_head"
GCS_ADDRESS_KEY = STORAGE_NAMESPACE + "GcsServerAddress"

# Number of attempts to ping the Redis server. See
# `services.py::wait_for_redis_to_start()` and
# `services.py::create_redis_client()`
START_REDIS_WAIT_RETRIES = env_integer("RAY_START_REDIS_WAIT_RETRIES", 60)
76 changes: 75 additions & 1 deletion python/ray/_private/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def create_redis_client(redis_address, password=None, username=None):
redis_ip_address, redis_port = extract_ip_port(
canonicalize_bootstrap_address_or_die(redis_address)
)
cli = redis.StrictRedis(
cli = redis.Redis(
host=redis_ip_address,
port=int(redis_port),
username=username,
Expand Down Expand Up @@ -2280,3 +2280,77 @@ def start_ray_client_server(
fate_share=fate_share,
)
return process_info


def wait_for_redis_to_start(redis_ip_address, redis_port, username=None, password=None):
"""Wait for a Redis server to be available.

This is accomplished by creating a Redis client and sending a random
command to the server until the command gets through.

Args:
redis_ip_address (str): The IP address of the redis server.
redis_port (int): The port of the redis server.
username (str): The username of the Redis server.
password (str): The password of the Redis server.

Raises:
Exception: An exception is raised if we could not connect with Redis.
"""
import redis

redis_client = create_redis_client(
"%s:%s" % (redis_ip_address, redis_port), password=password, username=username
)
# Wait for the Redis server to start.
num_retries = ray_constants.START_REDIS_WAIT_RETRIES
delay = 0.001
for i in range(num_retries):
try:
# Run some random command and see if it worked.
logger.debug(
"Waiting for redis server at {}:{} to respond...".format(
redis_ip_address, redis_port
)
)
redis_client.ping()
# If the Redis service is delayed getting set up for any reason, we may
# get a redis.ConnectionError: Error 111 connecting to host:port.
# Connection refused.
# Unfortunately, redis.ConnectionError is also the base class of
# redis.AuthenticationError. We *don't* want to obscure a
# redis.AuthenticationError, because that indicates the user provided a
# bad password. Thus a double except clause to ensure a
# redis.AuthenticationError isn't trapped here.
except redis.AuthenticationError as authEx:
raise RuntimeError(
f"Unable to connect to Redis at {redis_ip_address}:{redis_port}."
) from authEx
except redis.ConnectionError as connEx:
if i >= num_retries - 1:
raise RuntimeError(
f"Unable to connect to Redis at {redis_ip_address}:"
f"{redis_port} after {num_retries} retries. Check that "
f"{redis_ip_address}:{redis_port} is reachable from this "
"machine. If it is not, your firewall may be blocking "
"this port. If the problem is a flaky connection, try "
"setting the environment variable "
"`RAY_START_REDIS_WAIT_RETRIES` to increase the number of"
" attempts to ping the Redis server."
) from connEx
# Wait a little bit.
time.sleep(delay)
# Make sure the retry interval doesn't increase too large, which
# will affect the delivery time of the Ray cluster.
delay = min(1, delay * 2)
else:
break
else:
raise RuntimeError(
f"Unable to connect to Redis (after {num_retries} retries). "
"If the Redis instance is on a different machine, check that "
"your firewall and relevant Ray ports are configured properly. "
"You can also set the environment variable "
"`RAY_START_REDIS_WAIT_RETRIES` to increase the number of "
"attempts to ping the Redis server."
)
13 changes: 13 additions & 0 deletions python/ray/ha/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ray.ha.leader_selector import HeadNodeLeaderSelector
from ray.ha.leader_selector import HeadNodeLeaderSelectorConfig
from ray.ha.redis_leader_selector import RedisBasedLeaderSelector
from ray.ha.redis_leader_selector import is_service_available
from ray.ha.redis_leader_selector import waiting_for_server_stopped

__all__ = [
"RedisBasedLeaderSelector",
"HeadNodeLeaderSelector",
"HeadNodeLeaderSelectorConfig",
"is_service_available",
"waiting_for_server_stopped",
]
93 changes: 93 additions & 0 deletions python/ray/ha/leader_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
import os
import ray._private.ray_constants as ray_constants

logger = logging.getLogger(__name__)


class HeadNodeLeaderSelectorConfig:
# Will exit after N consecutive failures
max_failure_count = None
# The interval time for selecting the leader or expire leader
check_interval_s = None
# The expiration time of the leader key
key_expire_time_ms = None
# The time to wait after becoming the active allows the gcs client to
# have enough time to find out that the gcs address has changed.
wait_time_after_be_active_s = None
# Maximum time to wait pre gcs stop serving
wait_pre_gcs_stop_max_time_s = None
# Redis/socket connect time out(s)
connect_timeout_s = None

def __init__(self):
self.max_failure_count = int(
os.environ.get("RAY_HA_CHECK_MAX_FAILURE_COUNT", "5")
)
# Default 3s
self.check_interval_s = (
int(os.environ.get("RAY_HA_CHECK_INTERVAL_MS", "3000")) / 1000
)
# Default 60s
self.key_expire_time_ms = int(
os.environ.get("RAY_HA_KEY_EXPIRE_TIME_MS", "60000")
)
# Default 2s
self.wait_time_after_be_active_s = (
int(os.environ.get("RAY_HA_WAIT_TIME_AFTER_BE_ACTIVE_MS", "2000")) / 1000
)
# Default 3 day (1000 * 60 * 60 * 24 * 3).
self.wait_pre_gcs_stop_max_time_s = (
int(os.environ.get("RAY_HA_WAIT_PRE_GCS_STOP_MAX_TIME_MS", "259200000"))
/ 1000
)
# Default redis/socket connect time out is 5s.
self.connect_timeout_s = (
int(os.environ.get("RAY_HA_CONNECT_TIMEOUT_MS", "5000")) / 1000
)

def __repr__(self) -> str:
return (
f"HeadNodeLeaderSelectorConfig["
f"max_failure_count={self.max_failure_count},"
f"check_interval_s={self.check_interval_s},"
f"key_expire_time_ms={self.key_expire_time_ms},"
f"wait_time_after_be_active_s={self.wait_time_after_be_active_s},"
f"wait_pre_gcs_stop_max_time_s="
f"{self.wait_pre_gcs_stop_max_time_s},"
f"connect_timeout_s={self.connect_timeout_s}]"
)


class HeadNodeLeaderSelector:
_role_type = ray_constants.HEAD_ROLE_STANDBY
_config = None
_is_running = False

def __init__(self):
self._config = HeadNodeLeaderSelectorConfig()

def start(self):
pass

def stop(self):
pass

def set_role_type(self, role_type):
if role_type != self._role_type:
logger.info(
"This head node changed from %s to %s.", self._role_type, role_type
)
self._role_type = role_type

def get_role_type(self):
return self._role_type

def is_leader(self):
return self.get_role_type() == ray_constants.HEAD_ROLE_ACTIVE

def do_action_after_be_active(self):
pass

def is_running(self):
return self._is_running
Loading