Skip to content

Commit

Permalink
Check version info in ray start for non-head nodes. (#1264)
Browse files Browse the repository at this point in the history
* Check version info in ray start for non-head nodes.

* Small fix.

* Fix

* Push error to all drivers when worker has version mismatch.

* Linting

* Linting

* Fix

* Unify methods.

* Fix bug.
  • Loading branch information
robertnishihara authored and pcmoritz committed Nov 28, 2017
1 parent 2c0d554 commit c1496b8
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 88 deletions.
9 changes: 5 additions & 4 deletions python/ray/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import ray.signature as signature
import ray.worker
from ray.utils import (binary_to_hex, FunctionProperties, random_string,
release_gpus_in_use, select_local_scheduler, is_cython)
release_gpus_in_use, select_local_scheduler, is_cython,
push_error_to_driver)


def random_actor_id():
Expand Down Expand Up @@ -252,9 +253,9 @@ def temporary_actor_method(*xs):
# traceback and notify the scheduler of the failure.
traceback_str = ray.worker.format_error_message(traceback.format_exc())
# Log the error message.
worker.push_error_to_driver(driver_id, "register_actor_signatures",
traceback_str,
data={"actor_id": actor_id_str})
push_error_to_driver(worker.redis_client, "register_actor_signatures",
traceback_str, driver_id,
data={"actor_id": actor_id_str})
# TODO(rkn): In the future, it might make sense to have the worker exit
# here. However, currently that would lead to hanging if someone calls
# ray.get on a method invoked on the actor.
Expand Down
17 changes: 11 additions & 6 deletions python/ray/scripts/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
from __future__ import print_function

import click
import redis
import subprocess

import ray.services as services


def check_no_existing_redis_clients(node_ip_address, redis_address):
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))
def check_no_existing_redis_clients(node_ip_address, redis_client):
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
Expand Down Expand Up @@ -158,17 +154,26 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
raise Exception("If --head is not passed in, the --no-ui flag is "
"not relevant.")
redis_ip_address, redis_port = redis_address.split(":")

# Wait for the Redis server to be started. And throw an exception if we
# can't connect to it.
services.wait_for_redis_to_start(redis_ip_address, int(redis_port))

# Create a Redis client.
redis_client = services.create_redis_client(redis_address)

# Check that the verion information on this node matches the version
# information that the cluster was started with.
services.check_version_info(redis_client)

# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
print("Using IP address {} for this node.".format(node_ip_address))
# Check that there aren't already Redis clients with the same IP
# address connected with this Redis instance. This raises an exception
# if the Redis server already has clients on this node.
check_no_existing_redis_clients(node_ip_address, redis_address)
check_no_existing_redis_clients(node_ip_address, redis_client)
address_info = services.start_ray_node(
node_ip_address=node_ip_address,
redis_address=redis_address,
Expand Down
15 changes: 15 additions & 0 deletions python/ray/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,21 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
redis_client.rpush(log_file_list_key, log_file.name)


def create_redis_client(redis_address):
"""Create a Redis client.
Args:
The IP address and port of the Redis server.
Returns:
A Redis client.
"""
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine
# as Redis) must have run "CONFIG SET protected-mode no".
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))


def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
"""Wait for a Redis server to be available.
Expand Down
31 changes: 31 additions & 0 deletions python/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,37 @@

import ray.local_scheduler

ERROR_KEY_PREFIX = b"Error:"
DRIVER_ID_LENGTH = 20


def _random_string():
return np.random.bytes(20)


def push_error_to_driver(redis_client, error_type, message, driver_id=None,
data=None):
"""Push an error message to the driver to be printed in the background.
Args:
redis_client: The redis client to use.
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
driver_id: The ID of the driver to push the error message to. If this
is None, then the message will be pushed to all drivers.
data: This should be a dictionary mapping strings to strings. It
will be serialized with json and stored in Redis.
"""
if driver_id is None:
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
redis_client.hmset(error_key, {"type": error_type,
"message": message,
"data": data})
redis_client.rpush("ErrorKeys", error_key)


def is_cython(obj):
"""Check if an object is a Cython function or method"""
Expand Down
75 changes: 38 additions & 37 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,10 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
"object store. This may be fine, or it "
"may be a bug.")
if not warning_sent:
self.push_error_to_driver(self.task_driver_id.id(),
"wait_for_class",
warning_message)
ray.utils.push_error_to_driver(
self.redis_client, "wait_for_class",
warning_message,
driver_id=self.task_driver_id.id())
warning_sent = True

def get_object(self, object_ids):
Expand Down Expand Up @@ -599,24 +600,6 @@ def run_function_on_all_workers(self, function):
# operations into a transaction (or by implementing a custom
# command that does all three things).

def push_error_to_driver(self, driver_id, error_type, message, data=None):
"""Push an error message to the driver to be printed in the background.
Args:
driver_id: The ID of the driver to push the error message to.
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
data: This should be a dictionary mapping strings to strings. It
will be serialized with json and stored in Redis.
"""
error_key = ERROR_KEY_PREFIX + driver_id + b":" + random_string()
data = {} if data is None else data
self.redis_client.hmset(error_key, {"type": error_type,
"message": message,
"data": data})
self.redis_client.rpush("ErrorKeys", error_key)

def _wait_for_function(self, function_id, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
Expand Down Expand Up @@ -651,9 +634,10 @@ def _wait_for_function(self, function_id, driver_id, timeout=10):
"registered. You may have to restart "
"Ray.")
if not warning_sent:
self.push_error_to_driver(driver_id,
"wait_for_function",
warning_message)
ray.utils.push_error_to_driver(self.redis_client,
"wait_for_function",
warning_message,
driver_id=driver_id)
warning_sent = True
time.sleep(0.001)

Expand Down Expand Up @@ -808,10 +792,12 @@ def _handle_process_task_failure(self, function_id, return_object_ids,
range(len(return_object_ids))]
self._store_outputs_in_objstore(return_object_ids, failure_objects)
# Log the error message.
self.push_error_to_driver(self.task_driver_id.id(), "task",
str(failure_object),
data={"function_id": function_id.id(),
"function_name": function_name})
ray.utils.push_error_to_driver(self.redis_client,
"task",
str(failure_object),
driver_id=self.task_driver_id.id(),
data={"function_id": function_id.id(),
"function_name": function_name})

def _wait_for_and_process_task(self, task):
"""Wait for a task to be ready and process the task.
Expand Down Expand Up @@ -1552,10 +1538,12 @@ def f():
# record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
# Log the error message.
worker.push_error_to_driver(driver_id, "register_remote_function",
traceback_str,
data={"function_id": function_id.id(),
"function_name": function_name})
ray.utils.push_error_to_driver(worker.redis_client,
"register_remote_function",
traceback_str,
driver_id=driver_id,
data={"function_id": function_id.id(),
"function_name": function_name})
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
Expand All @@ -1582,8 +1570,11 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
# Log the error message.
name = function.__name__ if ("function" in locals() and
hasattr(function, "__name__")) else ""
worker.push_error_to_driver(driver_id, "function_to_run",
traceback_str, data={"name": name})
ray.utils.push_error_to_driver(worker.redis_client,
"function_to_run",
traceback_str,
driver_id=driver_id,
data={"name": name})


def import_thread(worker, mode):
Expand Down Expand Up @@ -1714,9 +1705,19 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
worker.redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))

# Check that the version information matches the version information that
# the Ray cluster was started with.
ray.services.check_version_info(worker.redis_client)
# For driver's check that the version information matches the version
# information that the Ray cluster was started with.
try:
ray.services.check_version_info(worker.redis_client)
except Exception as e:
if mode in [SCRIPT_MODE, SILENT_MODE]:
raise e
elif mode == WORKER_MODE:
traceback_str = traceback.format_exc()
ray.utils.push_error_to_driver(worker.redis_client,
"version_mismatch",
traceback_str,
driver_id=None)

worker.lock = threading.Lock()

Expand Down
44 changes: 3 additions & 41 deletions python/ray/workers/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import argparse
import binascii
import numpy as np
import redis
import traceback

import ray
Expand All @@ -30,36 +28,6 @@
"mode"))


def random_string():
return np.random.bytes(20)


def create_redis_client(redis_address):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine
# as Redis) must have run "CONFIG SET protected-mode no".
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))


def push_error_to_all_drivers(redis_client, message, error_type):
"""Push an error message to all drivers.
Args:
redis_client: The redis client to use.
message: The error message to push.
error_type: The type of the error.
"""
DRIVER_ID_LENGTH = 20
# We use a driver ID of all zeros to push an error message to all
# drivers.
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = b"Error:" + driver_id + b":" + random_string()
# Create a Redis client.
redis_client.hmset(error_key, {"type": error_type,
"message": message})
redis_client.rpush("ErrorKeys", error_key)


if __name__ == "__main__":
args = parser.parse_args()

Expand All @@ -80,13 +48,6 @@ def push_error_to_all_drivers(redis_client, message, error_type):

ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)

try:
ray.services.check_version_info(ray.worker.global_worker.redis_client)
except Exception as e:
traceback_str = traceback.format_exc()
push_error_to_all_drivers(ray.worker.global_worker.redis_client,
traceback_str, "version_mismatch")

error_explanation = """
This error is unexpected and should not have happened. Somehow a worker
crashed in an unanticipated way causing the main_loop to throw an exception,
Expand All @@ -103,8 +64,9 @@ def push_error_to_all_drivers(redis_client, message, error_type):
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
# Create a Redis client.
redis_client = create_redis_client(args.redis_address)
push_error_to_all_drivers(redis_client, traceback_str, "worker_crash")
redis_client = ray.services.create_redis_client(args.redis_address)
ray.utils.push_error_to_driver(redis_client, "worker_crash",
traceback_str, driver_id=None)
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, then any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to
Expand Down

0 comments on commit c1496b8

Please sign in to comment.