From a032f4ec06e7bf1a8208baac2699d40bbe7dd35c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 3 Jun 2024 07:16:26 +0000 Subject: [PATCH] refactor rsync --- sky/utils/command_runner.py | 254 ++++++++++++++++-------------------- 1 file changed, 114 insertions(+), 140 deletions(-) diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index cb755c40beb..e9b79b1ced9 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -5,7 +5,7 @@ import pathlib import shlex import time -from typing import Any, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union from sky import sky_logging from sky.skylet import constants @@ -19,6 +19,10 @@ # The git exclude file to support. GIT_EXCLUDE = '.git/info/exclude' # Rsync options +# TODO(zhwu): This will print a per-file progress bar (with -P), +# shooting a lot of messages to the output. --info=progress2 is used +# to get a total progress bar, but it requires rsync>=3.1.0 and Mac +# OS has a default rsync==2.6.9 (16 years old). RSYNC_DISPLAY_OPTION = '-Pavz' # Legend # dir-merge: ignore file can appear in any subdir, applies to that @@ -208,6 +212,92 @@ def _get_command_to_run( command_str = ' '.join(command) return command_str + def _rsync( + self, + source: str, + target: str, + node_destination: str, + up: bool, + rsh_option: str, + # Advanced options. + log_path: str = os.devnull, + stream_logs: bool = True, + max_retry: int = 1, + prefix_command: Optional[str] = None, + get_remote_home_dir: Callable[[], str] = lambda: '~'): + """Builds the rsync command.""" + # Build command. + rsync_command = [] + if prefix_command is not None: + rsync_command.append(prefix_command) + rsync_command += ['rsync', RSYNC_DISPLAY_OPTION] + + # --filter + rsync_command.append(RSYNC_FILTER_OPTION) + + if up: + # The source is a local path, so we need to resolve it. + # --exclude-from + resolved_source = pathlib.Path(source).expanduser().resolve() + if (resolved_source / GIT_EXCLUDE).exists(): + # Ensure file exists; otherwise, rsync will error out. + # + # We shlex.quote() because the path may contain spaces: + # 'my dir/.git/info/exclude' + # Without quoting rsync fails. + rsync_command.append( + RSYNC_EXCLUDE_OPTION.format( + shlex.quote(str(resolved_source / GIT_EXCLUDE)))) + + rsync_command.append(f'-e "{rsh_option}"') + + if up: + resolved_target = target + if target.startswith('~'): + remote_home_dir = get_remote_home_dir() + resolved_target = target.replace('~', remote_home_dir) + full_source_str = str(resolved_source) + if resolved_source.is_dir(): + full_source_str = os.path.join(full_source_str, '') + rsync_command.extend([ + f'{full_source_str!r}', + f'{node_destination}:{resolved_target!r}', + ]) + else: + resolved_source = source + if source.startswith('~'): + remote_home_dir = get_remote_home_dir() + resolved_source = source.replace('~', remote_home_dir) + rsync_command.extend([ + f'{node_destination}:{resolved_source!r}', + f'{os.path.expanduser(target)!r}', + ]) + command = ' '.join(rsync_command) + logger.debug(f'Running rsync command: {command}') + + backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5) + assert max_retry > 0, f'max_retry {max_retry} must be positive.' + while max_retry >= 0: + returncode, stdout, stderr = log_lib.run_with_log( + command, + log_path=log_path, + stream_logs=stream_logs, + shell=True, + require_outputs=True) + if returncode == 0: + break + max_retry -= 1 + time.sleep(backoff.current_backoff()) + + direction = 'up' if up else 'down' + error_msg = (f'Failed to rsync {direction}: {source} -> {target}. ' + 'Ensure that the network is stable, then retry.') + subprocess_utils.handle_returncode(returncode, + command, + error_msg, + stderr=stdout + stderr, + stream_logs=stream_logs) + @timeline.event def run( self, @@ -509,30 +599,6 @@ def rsync( Raises: exceptions.CommandError: rsync command failed. """ - # Build command. - # TODO(zhwu): This will print a per-file progress bar (with -P), - # shooting a lot of messages to the output. --info=progress2 is used - # to get a total progress bar, but it requires rsync>=3.1.0 and Mac - # OS has a default rsync==2.6.9 (16 years old). - rsync_command = ['rsync', RSYNC_DISPLAY_OPTION] - - # --filter - rsync_command.append(RSYNC_FILTER_OPTION) - - if up: - # The source is a local path, so we need to resolve it. - # --exclude-from - resolved_source = pathlib.Path(source).expanduser().resolve() - if (resolved_source / GIT_EXCLUDE).exists(): - # Ensure file exists; otherwise, rsync will error out. - # - # We shlex.quote() because the path may contain spaces: - # 'my dir/.git/info/exclude' - # Without quoting rsync fails. - rsync_command.append( - RSYNC_EXCLUDE_OPTION.format( - shlex.quote(str(resolved_source / GIT_EXCLUDE)))) - if self._docker_ssh_proxy_command is not None: docker_ssh_proxy_command = self._docker_ssh_proxy_command(['ssh']) else: @@ -545,46 +611,15 @@ def rsync( docker_ssh_proxy_command=docker_ssh_proxy_command, port=self.port, disable_control_master=self.disable_control_master)) - rsync_command.append(f'-e "ssh {ssh_options}"') - # To support spaces in the path, we need to quote source and target. - # rsync doesn't support '~' in a quoted local path, but it is ok to - # have '~' in a quoted remote path. - if up: - full_source_str = str(resolved_source) - if resolved_source.is_dir(): - full_source_str = os.path.join(full_source_str, '') - rsync_command.extend([ - f'{full_source_str!r}', - f'{self.ssh_user}@{self.ip}:{target!r}', - ]) - else: - rsync_command.extend([ - f'{self.ssh_user}@{self.ip}:{source!r}', - f'{os.path.expanduser(target)!r}', - ]) - command = ' '.join(rsync_command) - - backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5) - while max_retry >= 0: - returncode, stdout, stderr = log_lib.run_with_log( - command, - log_path=log_path, - stream_logs=stream_logs, - shell=True, - require_outputs=True) - if returncode == 0: - break - max_retry -= 1 - time.sleep(backoff.current_backoff()) - - direction = 'up' if up else 'down' - error_msg = (f'Failed to rsync {direction}: {source} -> {target}. ' - 'Ensure that the network is stable, then retry.') - subprocess_utils.handle_returncode(returncode, - command, - error_msg, - stderr=stdout + stderr, - stream_logs=stream_logs) + rsh_option = f'ssh {ssh_options}' + self._rsync(source, + target, + node_destination=f'{self.ssh_user}@{self.ip}', + up=up, + rsh_option=rsh_option, + log_path=log_path, + stream_logs=stream_logs, + max_retry=max_retry) class KubernetesCommandRunner(CommandRunner): @@ -757,80 +792,19 @@ def get_remote_home_dir() -> str: return remote_home_dir # Build command. - # TODO(zhwu): This will print a per-file progress bar (with -P), - # shooting a lot of messages to the output. --info=progress2 is used - # to get a total progress bar, but it requires rsync>=3.1.0 and Mac - # OS has a default rsync==2.6.9 (16 years old). helper_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'kubernetes', 'rsync_helper.sh') - rsync_command = [ - f'chmod +x {helper_path} &&', 'rsync', RSYNC_DISPLAY_OPTION - ] - - # --filter - rsync_command.append(RSYNC_FILTER_OPTION) - - if up: - # The source is a local path, so we need to resolve it. - # --exclude-from - resolved_source = pathlib.Path(source).expanduser().resolve() - if (resolved_source / GIT_EXCLUDE).exists(): - # Ensure file exists; otherwise, rsync will error out. - # - # We shlex.quote() because the path may contain spaces: - # 'my dir/.git/info/exclude' - # Without quoting rsync fails. - rsync_command.append( - RSYNC_EXCLUDE_OPTION.format( - shlex.quote(str(resolved_source / GIT_EXCLUDE)))) - - rsync_command.append(f'-e "{helper_path}"') - # rsync with `kubectl` as the rsh command will cause ~/xx parsed as - # /~/xx, so we need to replace ~ with the remote home directory. We only - # need to do this when ~ is at the beginning of the path. - if up: - resolved_target = target - if target.startswith('~'): - remote_home_dir = get_remote_home_dir() - resolved_target = target.replace('~', remote_home_dir) - full_source_str = str(resolved_source) - if resolved_source.is_dir(): - full_source_str = os.path.join(full_source_str, '') - rsync_command.extend([ - f'{full_source_str!r}', - f'{self.pod_name}@{self.namespace}:{resolved_target!r}', - ]) - else: - resolved_source = source - if source.startswith('~'): - remote_home_dir = get_remote_home_dir() - resolved_source = source.replace('~', remote_home_dir) - rsync_command.extend([ - f'{self.pod_name}@{self.namespace}:{resolved_source!r}', - f'{os.path.expanduser(target)!r}', - ]) - command = ' '.join(rsync_command) - logger.debug(f'Running rsync command: {command}') - - backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5) - assert max_retry > 0, f'max_retry {max_retry} must be positive.' - while max_retry >= 0: - returncode, stdout, stderr = log_lib.run_with_log( - command, - log_path=log_path, - stream_logs=stream_logs, - shell=True, - require_outputs=True) - if returncode == 0: - break - max_retry -= 1 - time.sleep(backoff.current_backoff()) - - direction = 'up' if up else 'down' - error_msg = (f'Failed to rsync {direction}: {source} -> {target}. ' - 'Ensure that the network is stable, then retry.') - subprocess_utils.handle_returncode(returncode, - command, - error_msg, - stderr=stdout + stderr, - stream_logs=stream_logs) + self._rsync( + source, + target, + node_destination=f'{self.pod_name}@{self.namespace}', + up=up, + rsh_option=helper_path, + log_path=log_path, + stream_logs=stream_logs, + max_retry=max_retry, + prefix_command=f'chmod +x {helper_path} && ', + # rsync with `kubectl` as the rsh command will cause ~/xx parsed as + # /~/xx, so we need to replace ~ with the remote home directory. We + # only need to do this when ~ is at the beginning of the path. + get_remote_home_dir=get_remote_home_dir)