Skip to content

Commit

Permalink
refactor rsync
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Jun 3, 2024
1 parent 8a672d6 commit a032f4e
Showing 1 changed file with 114 additions and 140 deletions.
254 changes: 114 additions & 140 deletions sky/utils/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import shlex
import time
from typing import Any, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union

from sky import sky_logging
from sky.skylet import constants
Expand All @@ -19,6 +19,10 @@
# The git exclude file to support.
GIT_EXCLUDE = '.git/info/exclude'
# Rsync options
# TODO(zhwu): This will print a per-file progress bar (with -P),
# shooting a lot of messages to the output. --info=progress2 is used
# to get a total progress bar, but it requires rsync>=3.1.0 and Mac
# OS has a default rsync==2.6.9 (16 years old).
RSYNC_DISPLAY_OPTION = '-Pavz'
# Legend
# dir-merge: ignore file can appear in any subdir, applies to that
Expand Down Expand Up @@ -208,6 +212,92 @@ def _get_command_to_run(
command_str = ' '.join(command)
return command_str

def _rsync(
self,
source: str,
target: str,
node_destination: str,
up: bool,
rsh_option: str,
# Advanced options.
log_path: str = os.devnull,
stream_logs: bool = True,
max_retry: int = 1,
prefix_command: Optional[str] = None,
get_remote_home_dir: Callable[[], str] = lambda: '~'):
"""Builds the rsync command."""
# Build command.
rsync_command = []
if prefix_command is not None:
rsync_command.append(prefix_command)
rsync_command += ['rsync', RSYNC_DISPLAY_OPTION]

# --filter
rsync_command.append(RSYNC_FILTER_OPTION)

if up:
# The source is a local path, so we need to resolve it.
# --exclude-from
resolved_source = pathlib.Path(source).expanduser().resolve()
if (resolved_source / GIT_EXCLUDE).exists():
# Ensure file exists; otherwise, rsync will error out.
#
# We shlex.quote() because the path may contain spaces:
# 'my dir/.git/info/exclude'
# Without quoting rsync fails.
rsync_command.append(
RSYNC_EXCLUDE_OPTION.format(
shlex.quote(str(resolved_source / GIT_EXCLUDE))))

rsync_command.append(f'-e "{rsh_option}"')

if up:
resolved_target = target
if target.startswith('~'):
remote_home_dir = get_remote_home_dir()
resolved_target = target.replace('~', remote_home_dir)
full_source_str = str(resolved_source)
if resolved_source.is_dir():
full_source_str = os.path.join(full_source_str, '')
rsync_command.extend([
f'{full_source_str!r}',
f'{node_destination}:{resolved_target!r}',
])
else:
resolved_source = source
if source.startswith('~'):
remote_home_dir = get_remote_home_dir()
resolved_source = source.replace('~', remote_home_dir)
rsync_command.extend([
f'{node_destination}:{resolved_source!r}',
f'{os.path.expanduser(target)!r}',
])
command = ' '.join(rsync_command)
logger.debug(f'Running rsync command: {command}')

backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5)
assert max_retry > 0, f'max_retry {max_retry} must be positive.'
while max_retry >= 0:
returncode, stdout, stderr = log_lib.run_with_log(
command,
log_path=log_path,
stream_logs=stream_logs,
shell=True,
require_outputs=True)
if returncode == 0:
break
max_retry -= 1
time.sleep(backoff.current_backoff())

direction = 'up' if up else 'down'
error_msg = (f'Failed to rsync {direction}: {source} -> {target}. '
'Ensure that the network is stable, then retry.')
subprocess_utils.handle_returncode(returncode,
command,
error_msg,
stderr=stdout + stderr,
stream_logs=stream_logs)

@timeline.event
def run(
self,
Expand Down Expand Up @@ -509,30 +599,6 @@ def rsync(
Raises:
exceptions.CommandError: rsync command failed.
"""
# Build command.
# TODO(zhwu): This will print a per-file progress bar (with -P),
# shooting a lot of messages to the output. --info=progress2 is used
# to get a total progress bar, but it requires rsync>=3.1.0 and Mac
# OS has a default rsync==2.6.9 (16 years old).
rsync_command = ['rsync', RSYNC_DISPLAY_OPTION]

# --filter
rsync_command.append(RSYNC_FILTER_OPTION)

if up:
# The source is a local path, so we need to resolve it.
# --exclude-from
resolved_source = pathlib.Path(source).expanduser().resolve()
if (resolved_source / GIT_EXCLUDE).exists():
# Ensure file exists; otherwise, rsync will error out.
#
# We shlex.quote() because the path may contain spaces:
# 'my dir/.git/info/exclude'
# Without quoting rsync fails.
rsync_command.append(
RSYNC_EXCLUDE_OPTION.format(
shlex.quote(str(resolved_source / GIT_EXCLUDE))))

if self._docker_ssh_proxy_command is not None:
docker_ssh_proxy_command = self._docker_ssh_proxy_command(['ssh'])
else:
Expand All @@ -545,46 +611,15 @@ def rsync(
docker_ssh_proxy_command=docker_ssh_proxy_command,
port=self.port,
disable_control_master=self.disable_control_master))
rsync_command.append(f'-e "ssh {ssh_options}"')
# To support spaces in the path, we need to quote source and target.
# rsync doesn't support '~' in a quoted local path, but it is ok to
# have '~' in a quoted remote path.
if up:
full_source_str = str(resolved_source)
if resolved_source.is_dir():
full_source_str = os.path.join(full_source_str, '')
rsync_command.extend([
f'{full_source_str!r}',
f'{self.ssh_user}@{self.ip}:{target!r}',
])
else:
rsync_command.extend([
f'{self.ssh_user}@{self.ip}:{source!r}',
f'{os.path.expanduser(target)!r}',
])
command = ' '.join(rsync_command)

backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5)
while max_retry >= 0:
returncode, stdout, stderr = log_lib.run_with_log(
command,
log_path=log_path,
stream_logs=stream_logs,
shell=True,
require_outputs=True)
if returncode == 0:
break
max_retry -= 1
time.sleep(backoff.current_backoff())

direction = 'up' if up else 'down'
error_msg = (f'Failed to rsync {direction}: {source} -> {target}. '
'Ensure that the network is stable, then retry.')
subprocess_utils.handle_returncode(returncode,
command,
error_msg,
stderr=stdout + stderr,
stream_logs=stream_logs)
rsh_option = f'ssh {ssh_options}'
self._rsync(source,
target,
node_destination=f'{self.ssh_user}@{self.ip}',
up=up,
rsh_option=rsh_option,
log_path=log_path,
stream_logs=stream_logs,
max_retry=max_retry)


class KubernetesCommandRunner(CommandRunner):
Expand Down Expand Up @@ -757,80 +792,19 @@ def get_remote_home_dir() -> str:
return remote_home_dir

# Build command.
# TODO(zhwu): This will print a per-file progress bar (with -P),
# shooting a lot of messages to the output. --info=progress2 is used
# to get a total progress bar, but it requires rsync>=3.1.0 and Mac
# OS has a default rsync==2.6.9 (16 years old).
helper_path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
'kubernetes', 'rsync_helper.sh')
rsync_command = [
f'chmod +x {helper_path} &&', 'rsync', RSYNC_DISPLAY_OPTION
]

# --filter
rsync_command.append(RSYNC_FILTER_OPTION)

if up:
# The source is a local path, so we need to resolve it.
# --exclude-from
resolved_source = pathlib.Path(source).expanduser().resolve()
if (resolved_source / GIT_EXCLUDE).exists():
# Ensure file exists; otherwise, rsync will error out.
#
# We shlex.quote() because the path may contain spaces:
# 'my dir/.git/info/exclude'
# Without quoting rsync fails.
rsync_command.append(
RSYNC_EXCLUDE_OPTION.format(
shlex.quote(str(resolved_source / GIT_EXCLUDE))))

rsync_command.append(f'-e "{helper_path}"')
# rsync with `kubectl` as the rsh command will cause ~/xx parsed as
# /~/xx, so we need to replace ~ with the remote home directory. We only
# need to do this when ~ is at the beginning of the path.
if up:
resolved_target = target
if target.startswith('~'):
remote_home_dir = get_remote_home_dir()
resolved_target = target.replace('~', remote_home_dir)
full_source_str = str(resolved_source)
if resolved_source.is_dir():
full_source_str = os.path.join(full_source_str, '')
rsync_command.extend([
f'{full_source_str!r}',
f'{self.pod_name}@{self.namespace}:{resolved_target!r}',
])
else:
resolved_source = source
if source.startswith('~'):
remote_home_dir = get_remote_home_dir()
resolved_source = source.replace('~', remote_home_dir)
rsync_command.extend([
f'{self.pod_name}@{self.namespace}:{resolved_source!r}',
f'{os.path.expanduser(target)!r}',
])
command = ' '.join(rsync_command)
logger.debug(f'Running rsync command: {command}')

backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5)
assert max_retry > 0, f'max_retry {max_retry} must be positive.'
while max_retry >= 0:
returncode, stdout, stderr = log_lib.run_with_log(
command,
log_path=log_path,
stream_logs=stream_logs,
shell=True,
require_outputs=True)
if returncode == 0:
break
max_retry -= 1
time.sleep(backoff.current_backoff())

direction = 'up' if up else 'down'
error_msg = (f'Failed to rsync {direction}: {source} -> {target}. '
'Ensure that the network is stable, then retry.')
subprocess_utils.handle_returncode(returncode,
command,
error_msg,
stderr=stdout + stderr,
stream_logs=stream_logs)
self._rsync(
source,
target,
node_destination=f'{self.pod_name}@{self.namespace}',
up=up,
rsh_option=helper_path,
log_path=log_path,
stream_logs=stream_logs,
max_retry=max_retry,
prefix_command=f'chmod +x {helper_path} && ',
# rsync with `kubectl` as the rsh command will cause ~/xx parsed as
# /~/xx, so we need to replace ~ with the remote home directory. We
# only need to do this when ~ is at the beginning of the path.
get_remote_home_dir=get_remote_home_dir)

0 comments on commit a032f4e

Please sign in to comment.