diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 696320be2f9..ab971a22407 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -2871,8 +2871,10 @@ def check_stale_runtime_on_remote(returncode: int, stderr: str, def wait_and_terminate_csync(cluster_name: str) -> None: - """Terminatese all the CSYNC process running in each node after - waiting for the sync process launched by CSYNC complete. + """Terminates all the CSYNC process running in each node. + + Before terminating the CSYNC daemon, it waits until the sync process + launched by CSYNC is completed if there are any. Args: cluster_name: Cluster name (see `sky status`) @@ -2885,13 +2887,14 @@ def wait_and_terminate_csync(cluster_name: str) -> None: return try: ip_list = handle.external_ips() - # When cluster is in INIT mode, attempt to fetch IP fails raising an error + # When cluster is in INIT status, attempt to fetch IP fails raising an error except exceptions.FetchIPError: return if not ip_list: return port_list = handle.external_ssh_ports() - ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml) + ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml, + handle.docker_user) runners = command_runner.SSHCommandRunner.make_runner_list( ip_list, port_list=port_list, **ssh_credentials) csync_terminate_cmd = ('python -m sky.data.skystorage terminate -a ' diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 0ddf5638641..ec8ce356026 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3009,8 +3009,8 @@ def _sync_file_mounts( ) -> None: """Mounts all user files to the remote nodes.""" self._execute_file_mounts(handle, all_file_mounts) - self._execute_storage_mounts(handle, storage_mounts) - self._execute_storage_csync(handle, storage_mounts) + self._execute_storage_mounts(handle, storage_mounts, storage_utils.StorageMode.MOUNT) + self._execute_storage_mounts(handle, storage_mounts, storage_utils.StorageMode.CSYNC) self._set_cluster_storage_mounts_metadata(handle.cluster_name, storage_mounts) @@ -4415,7 +4415,8 @@ def _symlink_node(runner: command_runner.SSHCommandRunner): def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle, storage_mounts: Dict[Path, - storage_lib.Storage]): + storage_lib.Storage], + mount_mode: storage_utils.StorageMode): """Executes storage mounts: installing mounting tools and mounting.""" # Process only MOUNT mode objects here. COPY mode objects have been # converted to regular copy file mounts and thus have been handled @@ -4423,7 +4424,7 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle, storage_mounts = { path: storage_mount for path, storage_mount in storage_mounts.items() - if storage_mount.mode == storage_utils.StorageMode.MOUNT + if storage_mount.mode == mount_mode } if not storage_mounts: @@ -4437,11 +4438,18 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle, f'mounting. No action will be taken.{colorama.Style.RESET_ALL}') return + if mount_mode == storage_utils.StorageMode.MOUNT: + mode_str = 'mount' + action_message = 'Mounting' + else: # CSYNC mdoe + mode_str = 'csync' + action_message = 'Setting up CSYNC' + fore = colorama.Fore style = colorama.Style plural = 's' if len(storage_mounts) > 1 else '' logger.info(f'{fore.CYAN}Processing {len(storage_mounts)} ' - f'storage mount{plural}.{style.RESET_ALL}') + f'storage {mode_str}{plural}.{style.RESET_ALL}') start = time.time() ip_list = handle.external_ips() port_list = handle.external_ssh_ports() @@ -4450,14 +4458,18 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle, handle.cluster_yaml, handle.docker_user) runners = command_runner.SSHCommandRunner.make_runner_list( ip_list, port_list=port_list, **ssh_credentials) - log_path = os.path.join(self.log_dir, 'storage_mounts.log') + log_file_name = f'storage_{mode_str}.log' + log_path = os.path.join(self.log_dir, log_file_name) for dst, storage_obj in storage_mounts.items(): if not os.path.isabs(dst) and not dst.startswith('~/'): dst = f'{SKY_REMOTE_WORKDIR}/{dst}' # Get the first store and use it to mount store = list(storage_obj.stores.values())[0] - mount_cmd = store.mount_command(dst) + if mount_mode == storage_utils.StorageMode.MOUNT: + mount_cmd = store.mount_command(dst) + else: # CSYNC mode + mount_cmd= store.csync_command(dst) src_print = (storage_obj.source if storage_obj.source else storage_obj.name) if isinstance(src_print, list): @@ -4469,7 +4481,7 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle, target=dst, cmd=mount_cmd, run_rsync=False, - action_message='Mounting', + action_message=action_message, log_path=log_path, ) except exceptions.CommandError as e: @@ -4486,93 +4498,18 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle, else: if env_options.Options.SHOW_DEBUG_INFO.get(): raise exceptions.CommandError(e.returncode, - command='to mount', + command=f'to {mode_str}', error_msg=e.error_msg) else: # Strip the command (a big heredoc) from the exception raise exceptions.CommandError( e.returncode, - command='to mount', + command=f'to {mode_str}', error_msg=e.error_msg) from None end = time.time() - logger.debug(f'Storage mount sync took {end - start} seconds.') - - def _execute_storage_csync( - self, handle: CloudVmRayResourceHandle, - storage_mounts: Dict[Path, storage_lib.Storage]) -> None: - """Executes CSYNC on storages. - - This function only runs the CSYNC daemon on the given storage - and the files/dirs to be copied to remote node are handled in - the _execute_file_mounts. - - """ - csync_storage_mounts = { - path: storage_obj - for path, storage_obj in storage_mounts.items() - if storage_obj.mode == storage_utils.StorageMode.CSYNC - } - - if not csync_storage_mounts: - return - - cloud = handle.launched_resources.cloud - if isinstance(cloud, clouds.Local): - logger.warning( - f'{colorama.Fore.YELLOW}Sky On-prem does not support ' - 'storage syncing. No action will be taken.' - f'{colorama.Style.RESET_ALL}') - return - - fore = colorama.Fore - style = colorama.Style - plural = 's' if len(csync_storage_mounts) > 1 else '' - logger.info(f'{fore.CYAN}Processing {len(csync_storage_mounts)} ' - f'storage CSYNC{plural}.{style.RESET_ALL}') - start = time.time() - ip_list = handle.external_ips() - port_list = handle.external_ssh_ports() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=port_list, **ssh_credentials) - log_path = os.path.join(self.log_dir, 'storage_csyncs.log') + logger.debug(f'Setting storage {mode_str} took {end - start} seconds.') - for dst, storage_obj in csync_storage_mounts.items(): - if not os.path.isabs(dst) and not dst.startswith('~/'): - dst = f'{SKY_REMOTE_WORKDIR}/{dst}' - # Get the first store and use it to sync - store = list(storage_obj.stores.values())[0] - csync_cmd = store.csync_command(dst, storage_obj.interval_seconds) - src_print = (storage_obj.source - if storage_obj.source else storage_obj.name) - if isinstance(src_print, list): - src_print = ', '.join(src_print) - try: - backend_utils.parallel_data_transfer_to_nodes( - runners, - source=src_print, - target=dst, - cmd=csync_cmd, - run_rsync=False, - action_message='Setting CSYNC', - log_path=log_path, - ) - except exceptions.CommandError as e: - if env_options.Options.SHOW_DEBUG_INFO.get(): - raise exceptions.CommandError(e.returncode, - command='to CSYNC', - error_msg=e.error_msg) - else: - # Strip the command (a big heredoc) from the exception - raise exceptions.CommandError( - e.returncode, command='to CSYNC', - error_msg=e.error_msg) from None - - end = time.time() - logger.debug(f'Storage Sync setup took {end - start} seconds.') def _set_cluster_storage_mounts_metadata( self, cluster_name: str, @@ -4617,9 +4554,7 @@ def _get_cluster_storage_mounts_metadata( return storage_mounts def _has_csync(self, cluster_name: str) -> bool: - """Chekcs if there is a storage running with CSYNC mode within the - cluster. - """ + """Chekcs if there are CSYNC mode storages within the cluster.""" storage_mounts = self._get_cluster_storage_mounts_metadata(cluster_name) if storage_mounts is not None: for _, storage_obj in storage_mounts.items(): diff --git a/sky/data/skystorage.py b/sky/data/skystorage.py index 06541def710..f92df8eeb57 100644 --- a/sky/data/skystorage.py +++ b/sky/data/skystorage.py @@ -141,17 +141,6 @@ def main(): pass -def update_interval(interval_seconds: int, elapsed_time: int): - """Updates the time interval for the next sync operation. - - Given the originally set interval_seconds and the time elapsed during the - sync operation, this function computes and returns the remaining time to - wait before the next sync operation. - """ - diff = interval_seconds - elapsed_time - return max(0, diff) - - def get_s3_upload_cmd(src_path: str, dst: str, num_threads: int, delete: bool, no_follow_symlinks: bool): """Builds sync command for aws s3""" @@ -242,7 +231,7 @@ def run_sync(src: str, f'number of retries. Check {log_path} for' 'details') from None - #run necessary post-processes + # run necessary post-processes _set_running_csync_sync_pid(csync_pid, -1) if storetype == 's3': # set number of threads back to its default value @@ -279,9 +268,11 @@ def run_sync(src: str, help='') def csync(source: str, storetype: str, destination: str, num_threads: int, interval_seconds: int, delete: bool, no_follow_symlinks: bool): - """Syncs the source to the bucket every INTERVAL seconds. Creates an entry - of pid of the sync process in local database while sync command is runninng - and removes it when completed. + """Runs daemon to sync the source to the bucket every INTERVAL seconds. + + Creates an entry of pid of the sync process in local database while sync + command is runninng and removes it when completed. + Args: source (str): The local path to the directory that you want to sync. storetype (str): The type of cloud storage to sync to. @@ -309,9 +300,11 @@ def csync(source: str, storetype: str, destination: str, num_threads: int, run_sync(full_src, storetype, destination, num_threads, interval_seconds, delete, no_follow_symlinks, csync_pid) end_time = time.time() - # the time took to sync gets reflected to the interval_seconds + # Given the interval_seconds and the time elapsed during the sync + # operation, we compute remaining time to wait before the next + # sync operation. elapsed_time = int(end_time - start_time) - remaining_interval = update_interval(interval_seconds, elapsed_time) + remaining_interval = max(0, interval_seconds-elapsed_time) # sync_pid column is set to 0 when sync is not running time.sleep(remaining_interval) @@ -346,8 +339,8 @@ def _terminate(paths: List[str], all: bool = False) -> None: # pylint: disable= """Terminates all the CSYNC daemon running after checking if all the sync process has completed. """ - # TODO: Currently, this terminates all the CSYNC daemon by default. - # Make an option of --all to terminate all and make the default + # TODO(Doyoung): Currently, this terminates all the CSYNC daemon by + # default. Make an option of --all to terminate all and make the default # behavior to take a source name to terminate only one daemon. # Call the function to terminate the csync processes here if all: diff --git a/sky/data/storage.py b/sky/data/storage.py index 9e6459e34df..c652cbbe7d4 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -365,7 +365,7 @@ class StorageMetadata(object): - (required) Source. - (optional) Sync every interval_seconds for CSYNC mode. - (optional) Storage mode. - - (optional) Set of stores managed by sky added to the Storage object + - (optional) Set of stores managed by sky added to the Storage object. """ def __init__( diff --git a/tests/test_smoke.py b/tests/test_smoke.py index d8760f24d4e..f21836868d9 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -797,9 +797,9 @@ def test_using_file_mounts_with_env_vars(generic_cloud: str): @pytest.mark.aws def test_aws_storage_mounts(): name = _get_cluster_name() - random_name = int(time.time()) - storage_name = f'sky-test-{random_name}' - csync_storage_name = f'sky-test-{random_name + 1}' + timestamp = int(time.time()) + storage_name = f'sky-test-{timestamp}' + csync_storage_name = f'sky-test-{timestamp + 1}' template_str = pathlib.Path( 'tests/test_yamls/test_storage_mounting.yaml.j2').read_text() template = jinja2.Template(template_str)