Skip to content

Commit

Permalink
Merge branch 'master' of github.com:skypilot-org/skypilot into intera…
Browse files Browse the repository at this point in the history
…ctive-dev-docs
  • Loading branch information
Michaelvll committed May 9, 2024
2 parents 93ddebf + 7f30ce5 commit f65a2f6
Show file tree
Hide file tree
Showing 28 changed files with 674 additions and 419 deletions.
91 changes: 39 additions & 52 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def wrap_file_mount(cls, path: str) -> str:
def make_safe_symlink_command(cls, *, source: str, target: str) -> str:
"""Returns a command that safely symlinks 'source' to 'target'.
All intermediate directories of 'source' will be owned by $USER,
All intermediate directories of 'source' will be owned by $(whoami),
excluding the root directory (/).
'source' must be an absolute path; both 'source' and 'target' must not
Expand All @@ -360,17 +360,17 @@ def make_safe_symlink_command(cls, *, source: str, target: str) -> str:
target)
# Below, use sudo in case the symlink needs sudo access to create.
# Prepare to create the symlink:
# 1. make sure its dir(s) exist & are owned by $USER.
# 1. make sure its dir(s) exist & are owned by $(whoami).
dir_of_symlink = os.path.dirname(source)
commands = [
# mkdir, then loop over '/a/b/c' as /a, /a/b, /a/b/c. For each,
# chown $USER on it so user can use these intermediate dirs
# chown $(whoami) on it so user can use these intermediate dirs
# (excluding /).
f'sudo mkdir -p {dir_of_symlink}',
# p: path so far
('(p=""; '
f'for w in $(echo {dir_of_symlink} | tr "/" " "); do '
'p=${p}/${w}; sudo chown $USER $p; done)')
'p=${p}/${w}; sudo chown $(whoami) $p; done)')
]
# 2. remove any existing symlink (ln -f may throw 'cannot
# overwrite directory', if the link exists and points to a
Expand All @@ -386,7 +386,7 @@ def make_safe_symlink_command(cls, *, source: str, target: str) -> str:
# Link.
f'sudo ln -s {target} {source}',
# chown. -h to affect symlinks only.
f'sudo chown -h $USER {source}',
f'sudo chown -h $(whoami) {source}',
]
return ' && '.join(commands)

Expand Down Expand Up @@ -1080,7 +1080,7 @@ def get_ready_nodes_counts(pattern, output):
def get_docker_user(ip: str, cluster_config_file: str) -> str:
"""Find docker container username."""
ssh_credentials = ssh_credential_from_yaml(cluster_config_file)
runner = command_runner.SSHCommandRunner(ip, port=22, **ssh_credentials)
runner = command_runner.SSHCommandRunner(node=(ip, 22), **ssh_credentials)
container_name = constants.DEFAULT_DOCKER_CONTAINER_NAME
whoami_returncode, whoami_stdout, whoami_stderr = runner.run(
f'sudo docker exec {container_name} whoami',
Expand Down Expand Up @@ -1113,7 +1113,7 @@ def wait_until_ray_cluster_ready(
try:
head_ip = _query_head_ip_with_retries(
cluster_config_file, max_attempts=WAIT_HEAD_NODE_IP_MAX_ATTEMPTS)
except exceptions.FetchIPError as e:
except exceptions.FetchClusterInfoError as e:
logger.error(common_utils.format_exception(e))
return False, None # failed

Expand All @@ -1129,8 +1129,7 @@ def wait_until_ray_cluster_ready(
ssh_credentials = ssh_credential_from_yaml(cluster_config_file, docker_user)
last_nodes_so_far = 0
start = time.time()
runner = command_runner.SSHCommandRunner(head_ip,
port=22,
runner = command_runner.SSHCommandRunner(node=(head_ip, 22),
**ssh_credentials)
with rich_utils.safe_status(
'[bold cyan]Waiting for workers...') as worker_status:
Expand Down Expand Up @@ -1236,7 +1235,7 @@ def ssh_credential_from_yaml(


def parallel_data_transfer_to_nodes(
runners: List[command_runner.SSHCommandRunner],
runners: List[command_runner.CommandRunner],
source: Optional[str],
target: str,
cmd: Optional[str],
Expand All @@ -1246,32 +1245,36 @@ def parallel_data_transfer_to_nodes(
# Advanced options.
log_path: str = os.devnull,
stream_logs: bool = False,
source_bashrc: bool = False,
):
"""Runs a command on all nodes and optionally runs rsync from src->dst.
Args:
runners: A list of SSHCommandRunner objects that represent multiple nodes.
runners: A list of CommandRunner objects that represent multiple nodes.
source: Optional[str]; Source for rsync on local node
target: str; Destination on remote node for rsync
cmd: str; Command to be executed on all nodes
action_message: str; Message to be printed while the command runs
log_path: str; Path to the log file
stream_logs: bool; Whether to stream logs to stdout
source_bashrc: bool; Source bashrc before running the command.
"""
fore = colorama.Fore
style = colorama.Style

origin_source = source

def _sync_node(runner: 'command_runner.SSHCommandRunner') -> None:
def _sync_node(runner: 'command_runner.CommandRunner') -> None:
if cmd is not None:
rc, stdout, stderr = runner.run(cmd,
log_path=log_path,
stream_logs=stream_logs,
require_outputs=True)
require_outputs=True,
source_bashrc=source_bashrc)
err_msg = ('Failed to run command before rsync '
f'{origin_source} -> {target}. '
'Ensure that the network is stable, then retry.')
'Ensure that the network is stable, then retry. '
f'{cmd}')
if log_path != os.devnull:
err_msg += f' See logs in {log_path}'
subprocess_utils.handle_returncode(rc,
Expand Down Expand Up @@ -1336,7 +1339,7 @@ def _query_head_ip_with_retries(cluster_yaml: str,
"""Returns the IP of the head node by querying the cloud.
Raises:
exceptions.FetchIPError: if we failed to get the head IP.
exceptions.FetchClusterInfoError: if we failed to get the head IP.
"""
backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5)
for i in range(max_attempts):
Expand Down Expand Up @@ -1365,8 +1368,8 @@ def _query_head_ip_with_retries(cluster_yaml: str,
break
except subprocess.CalledProcessError as e:
if i == max_attempts - 1:
raise exceptions.FetchIPError(
reason=exceptions.FetchIPError.Reason.HEAD) from e
raise exceptions.FetchClusterInfoError(
reason=exceptions.FetchClusterInfoError.Reason.HEAD) from e
# Retry if the cluster is not up yet.
logger.debug('Retrying to get head ip.')
time.sleep(backoff.current_backoff())
Expand All @@ -1391,7 +1394,7 @@ def get_node_ips(cluster_yaml: str,
IPs.
Raises:
exceptions.FetchIPError: if we failed to get the IPs. e.reason is
exceptions.FetchClusterInfoError: if we failed to get the IPs. e.reason is
HEAD or WORKER.
"""
ray_config = common_utils.read_yaml(cluster_yaml)
Expand All @@ -1412,11 +1415,12 @@ def get_node_ips(cluster_yaml: str,
'Failed to get cluster info for '
f'{ray_config["cluster_name"]} from the new provisioner '
f'with {common_utils.format_exception(e)}.')
raise exceptions.FetchIPError(
exceptions.FetchIPError.Reason.HEAD) from e
raise exceptions.FetchClusterInfoError(
exceptions.FetchClusterInfoError.Reason.HEAD) from e
if len(metadata.instances) < expected_num_nodes:
# Simulate the exception when Ray head node is not up.
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
raise exceptions.FetchClusterInfoError(
exceptions.FetchClusterInfoError.Reason.HEAD)
return metadata.get_feasible_ips(get_internal_ips)

if get_internal_ips:
Expand Down Expand Up @@ -1446,8 +1450,8 @@ def get_node_ips(cluster_yaml: str,
break
except subprocess.CalledProcessError as e:
if retry_cnt == worker_ip_max_attempts - 1:
raise exceptions.FetchIPError(
exceptions.FetchIPError.Reason.WORKER) from e
raise exceptions.FetchClusterInfoError(
exceptions.FetchClusterInfoError.Reason.WORKER) from e
# Retry if the ssh is not ready for the workers yet.
backoff_time = backoff.current_backoff()
logger.debug('Retrying to get worker ip '
Expand All @@ -1472,8 +1476,8 @@ def get_node_ips(cluster_yaml: str,
f'detected IP(s): {worker_ips[-n:]}.')
worker_ips = worker_ips[-n:]
else:
raise exceptions.FetchIPError(
exceptions.FetchIPError.Reason.WORKER)
raise exceptions.FetchClusterInfoError(
exceptions.FetchClusterInfoError.Reason.WORKER)
else:
worker_ips = []
return head_ip_list + worker_ips
Expand Down Expand Up @@ -1760,42 +1764,25 @@ def _update_cluster_status_no_lock(

def run_ray_status_to_check_ray_cluster_healthy() -> bool:
try:
# TODO(zhwu): This function cannot distinguish transient network
# error in ray's get IPs vs. ray runtime failing.

# NOTE: fetching the IPs is very slow as it calls into
# `ray get head-ip/worker-ips`. Using cached IPs is safe because
# in the worst case we time out in the `ray status` SSH command
# below.
external_ips = handle.cached_external_ips
runners = handle.get_command_runners(force_cached=True)
# This happens when user interrupt the `sky launch` process before
# the first time resources handle is written back to local database.
# This is helpful when user interrupt after the provision is done
# and before the skylet is restarted. After #2304 is merged, this
# helps keep the cluster status to INIT after `sky status -r`, so
# user will be notified that any auto stop/down might not be
# triggered.
if external_ips is None or len(external_ips) == 0:
if not runners:
logger.debug(f'Refreshing status ({cluster_name!r}): No cached '
f'IPs found. Handle: {handle}')
raise exceptions.FetchIPError(
reason=exceptions.FetchIPError.Reason.HEAD)

# Potentially refresh the external SSH ports, in case the existing
# cluster before #2491 was launched without external SSH ports
# cached.
external_ssh_ports = handle.external_ssh_ports()
head_ssh_port = external_ssh_ports[0]

# Check if ray cluster status is healthy.
ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml,
handle.docker_user,
handle.ssh_user)

runner = command_runner.SSHCommandRunner(external_ips[0],
**ssh_credentials,
port=head_ssh_port)
rc, output, stderr = runner.run(
raise exceptions.FetchClusterInfoError(
reason=exceptions.FetchClusterInfoError.Reason.HEAD)
head_runner = runners[0]
rc, output, stderr = head_runner.run(
instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND,
stream_logs=False,
require_outputs=True,
Expand All @@ -1815,7 +1802,7 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
f'Refreshing status ({cluster_name!r}): ray status not showing '
f'all nodes ({ready_head + ready_workers}/'
f'{total_nodes}); output: {output}; stderr: {stderr}')
except exceptions.FetchIPError:
except exceptions.FetchClusterInfoError:
logger.debug(
f'Refreshing status ({cluster_name!r}) failed to get IPs.')
except RuntimeError as e:
Expand Down Expand Up @@ -2356,9 +2343,9 @@ def is_controller_accessible(
handle.docker_user,
handle.ssh_user)

runner = command_runner.SSHCommandRunner(handle.head_ip,
**ssh_credentials,
port=handle.head_ssh_port)
runner = command_runner.SSHCommandRunner(node=(handle.head_ip,
handle.head_ssh_port),
**ssh_credentials)
if not runner.check_connection():
error_msg = controller.value.connection_error_hint
else:
Expand Down
Loading

0 comments on commit f65a2f6

Please sign in to comment.