Skip to content

Commit

Permalink
nit and merging _execute_storage_mounts and _csync
Browse files Browse the repository at this point in the history
  • Loading branch information
landscapepainter committed Sep 17, 2023
1 parent 785e770 commit 2bec928
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 116 deletions.
11 changes: 7 additions & 4 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2871,8 +2871,10 @@ def check_stale_runtime_on_remote(returncode: int, stderr: str,


def wait_and_terminate_csync(cluster_name: str) -> None:
"""Terminatese all the CSYNC process running in each node after
waiting for the sync process launched by CSYNC complete.
"""Terminates all the CSYNC process running in each node.
Before terminating the CSYNC daemon, it waits until the sync process
launched by CSYNC is completed if there are any.
Args:
cluster_name: Cluster name (see `sky status`)
Expand All @@ -2885,13 +2887,14 @@ def wait_and_terminate_csync(cluster_name: str) -> None:
return
try:
ip_list = handle.external_ips()
# When cluster is in INIT mode, attempt to fetch IP fails raising an error
# When cluster is in INIT status, attempt to fetch IP fails raising an error
except exceptions.FetchIPError:
return
if not ip_list:
return
port_list = handle.external_ssh_ports()
ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml)
ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml,
handle.docker_user)
runners = command_runner.SSHCommandRunner.make_runner_list(
ip_list, port_list=port_list, **ssh_credentials)
csync_terminate_cmd = ('python -m sky.data.skystorage terminate -a '
Expand Down
113 changes: 24 additions & 89 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,8 +3009,8 @@ def _sync_file_mounts(
) -> None:
"""Mounts all user files to the remote nodes."""
self._execute_file_mounts(handle, all_file_mounts)
self._execute_storage_mounts(handle, storage_mounts)
self._execute_storage_csync(handle, storage_mounts)
self._execute_storage_mounts(handle, storage_mounts, storage_utils.StorageMode.MOUNT)
self._execute_storage_mounts(handle, storage_mounts, storage_utils.StorageMode.CSYNC)
self._set_cluster_storage_mounts_metadata(handle.cluster_name,
storage_mounts)

Expand Down Expand Up @@ -4415,15 +4415,16 @@ def _symlink_node(runner: command_runner.SSHCommandRunner):

def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle,
storage_mounts: Dict[Path,
storage_lib.Storage]):
storage_lib.Storage],
mount_mode: storage_utils.StorageMode):
"""Executes storage mounts: installing mounting tools and mounting."""
# Process only MOUNT mode objects here. COPY mode objects have been
# converted to regular copy file mounts and thus have been handled
# in the '_execute_file_mounts' method.
storage_mounts = {
path: storage_mount
for path, storage_mount in storage_mounts.items()
if storage_mount.mode == storage_utils.StorageMode.MOUNT
if storage_mount.mode == mount_mode
}

if not storage_mounts:
Expand All @@ -4437,11 +4438,18 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle,
f'mounting. No action will be taken.{colorama.Style.RESET_ALL}')
return

if mount_mode == storage_utils.StorageMode.MOUNT:
mode_str = 'mount'
action_message = 'Mounting'
else: # CSYNC mdoe
mode_str = 'csync'
action_message = 'Setting up CSYNC'

fore = colorama.Fore
style = colorama.Style
plural = 's' if len(storage_mounts) > 1 else ''
logger.info(f'{fore.CYAN}Processing {len(storage_mounts)} '
f'storage mount{plural}.{style.RESET_ALL}')
f'storage {mode_str}{plural}.{style.RESET_ALL}')
start = time.time()
ip_list = handle.external_ips()
port_list = handle.external_ssh_ports()
Expand All @@ -4450,14 +4458,18 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle,
handle.cluster_yaml, handle.docker_user)
runners = command_runner.SSHCommandRunner.make_runner_list(
ip_list, port_list=port_list, **ssh_credentials)
log_path = os.path.join(self.log_dir, 'storage_mounts.log')
log_file_name = f'storage_{mode_str}.log'
log_path = os.path.join(self.log_dir, log_file_name)

for dst, storage_obj in storage_mounts.items():
if not os.path.isabs(dst) and not dst.startswith('~/'):
dst = f'{SKY_REMOTE_WORKDIR}/{dst}'
# Get the first store and use it to mount
store = list(storage_obj.stores.values())[0]
mount_cmd = store.mount_command(dst)
if mount_mode == storage_utils.StorageMode.MOUNT:
mount_cmd = store.mount_command(dst)
else: # CSYNC mode
mount_cmd= store.csync_command(dst)
src_print = (storage_obj.source
if storage_obj.source else storage_obj.name)
if isinstance(src_print, list):
Expand All @@ -4469,7 +4481,7 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle,
target=dst,
cmd=mount_cmd,
run_rsync=False,
action_message='Mounting',
action_message=action_message,
log_path=log_path,
)
except exceptions.CommandError as e:
Expand All @@ -4486,93 +4498,18 @@ def _execute_storage_mounts(self, handle: CloudVmRayResourceHandle,
else:
if env_options.Options.SHOW_DEBUG_INFO.get():
raise exceptions.CommandError(e.returncode,
command='to mount',
command=f'to {mode_str}',
error_msg=e.error_msg)
else:
# Strip the command (a big heredoc) from the exception
raise exceptions.CommandError(
e.returncode,
command='to mount',
command=f'to {mode_str}',
error_msg=e.error_msg) from None

end = time.time()
logger.debug(f'Storage mount sync took {end - start} seconds.')

def _execute_storage_csync(
self, handle: CloudVmRayResourceHandle,
storage_mounts: Dict[Path, storage_lib.Storage]) -> None:
"""Executes CSYNC on storages.
This function only runs the CSYNC daemon on the given storage
and the files/dirs to be copied to remote node are handled in
the _execute_file_mounts.
"""
csync_storage_mounts = {
path: storage_obj
for path, storage_obj in storage_mounts.items()
if storage_obj.mode == storage_utils.StorageMode.CSYNC
}

if not csync_storage_mounts:
return

cloud = handle.launched_resources.cloud
if isinstance(cloud, clouds.Local):
logger.warning(
f'{colorama.Fore.YELLOW}Sky On-prem does not support '
'storage syncing. No action will be taken.'
f'{colorama.Style.RESET_ALL}')
return

fore = colorama.Fore
style = colorama.Style
plural = 's' if len(csync_storage_mounts) > 1 else ''
logger.info(f'{fore.CYAN}Processing {len(csync_storage_mounts)} '
f'storage CSYNC{plural}.{style.RESET_ALL}')
start = time.time()
ip_list = handle.external_ips()
port_list = handle.external_ssh_ports()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
runners = command_runner.SSHCommandRunner.make_runner_list(
ip_list, port_list=port_list, **ssh_credentials)
log_path = os.path.join(self.log_dir, 'storage_csyncs.log')
logger.debug(f'Setting storage {mode_str} took {end - start} seconds.')

for dst, storage_obj in csync_storage_mounts.items():
if not os.path.isabs(dst) and not dst.startswith('~/'):
dst = f'{SKY_REMOTE_WORKDIR}/{dst}'
# Get the first store and use it to sync
store = list(storage_obj.stores.values())[0]
csync_cmd = store.csync_command(dst, storage_obj.interval_seconds)
src_print = (storage_obj.source
if storage_obj.source else storage_obj.name)
if isinstance(src_print, list):
src_print = ', '.join(src_print)
try:
backend_utils.parallel_data_transfer_to_nodes(
runners,
source=src_print,
target=dst,
cmd=csync_cmd,
run_rsync=False,
action_message='Setting CSYNC',
log_path=log_path,
)
except exceptions.CommandError as e:
if env_options.Options.SHOW_DEBUG_INFO.get():
raise exceptions.CommandError(e.returncode,
command='to CSYNC',
error_msg=e.error_msg)
else:
# Strip the command (a big heredoc) from the exception
raise exceptions.CommandError(
e.returncode, command='to CSYNC',
error_msg=e.error_msg) from None

end = time.time()
logger.debug(f'Storage Sync setup took {end - start} seconds.')

def _set_cluster_storage_mounts_metadata(
self, cluster_name: str,
Expand Down Expand Up @@ -4617,9 +4554,7 @@ def _get_cluster_storage_mounts_metadata(
return storage_mounts

def _has_csync(self, cluster_name: str) -> bool:
"""Chekcs if there is a storage running with CSYNC mode within the
cluster.
"""
"""Chekcs if there are CSYNC mode storages within the cluster."""
storage_mounts = self._get_cluster_storage_mounts_metadata(cluster_name)
if storage_mounts is not None:
for _, storage_obj in storage_mounts.items():
Expand Down
31 changes: 12 additions & 19 deletions sky/data/skystorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,6 @@ def main():
pass


def update_interval(interval_seconds: int, elapsed_time: int):
"""Updates the time interval for the next sync operation.
Given the originally set interval_seconds and the time elapsed during the
sync operation, this function computes and returns the remaining time to
wait before the next sync operation.
"""
diff = interval_seconds - elapsed_time
return max(0, diff)


def get_s3_upload_cmd(src_path: str, dst: str, num_threads: int, delete: bool,
no_follow_symlinks: bool):
"""Builds sync command for aws s3"""
Expand Down Expand Up @@ -242,7 +231,7 @@ def run_sync(src: str,
f'number of retries. Check {log_path} for'
'details') from None

#run necessary post-processes
# run necessary post-processes
_set_running_csync_sync_pid(csync_pid, -1)
if storetype == 's3':
# set number of threads back to its default value
Expand Down Expand Up @@ -279,9 +268,11 @@ def run_sync(src: str,
help='')
def csync(source: str, storetype: str, destination: str, num_threads: int,
interval_seconds: int, delete: bool, no_follow_symlinks: bool):
"""Syncs the source to the bucket every INTERVAL seconds. Creates an entry
of pid of the sync process in local database while sync command is runninng
and removes it when completed.
"""Runs daemon to sync the source to the bucket every INTERVAL seconds.
Creates an entry of pid of the sync process in local database while sync
command is runninng and removes it when completed.
Args:
source (str): The local path to the directory that you want to sync.
storetype (str): The type of cloud storage to sync to.
Expand Down Expand Up @@ -309,9 +300,11 @@ def csync(source: str, storetype: str, destination: str, num_threads: int,
run_sync(full_src, storetype, destination, num_threads,
interval_seconds, delete, no_follow_symlinks, csync_pid)
end_time = time.time()
# the time took to sync gets reflected to the interval_seconds
# Given the interval_seconds and the time elapsed during the sync
# operation, we compute remaining time to wait before the next
# sync operation.
elapsed_time = int(end_time - start_time)
remaining_interval = update_interval(interval_seconds, elapsed_time)
remaining_interval = max(0, interval_seconds-elapsed_time)
# sync_pid column is set to 0 when sync is not running
time.sleep(remaining_interval)

Expand Down Expand Up @@ -346,8 +339,8 @@ def _terminate(paths: List[str], all: bool = False) -> None: # pylint: disable=
"""Terminates all the CSYNC daemon running after checking if all the
sync process has completed.
"""
# TODO: Currently, this terminates all the CSYNC daemon by default.
# Make an option of --all to terminate all and make the default
# TODO(Doyoung): Currently, this terminates all the CSYNC daemon by
# default. Make an option of --all to terminate all and make the default
# behavior to take a source name to terminate only one daemon.
# Call the function to terminate the csync processes here
if all:
Expand Down
2 changes: 1 addition & 1 deletion sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class StorageMetadata(object):
- (required) Source.
- (optional) Sync every interval_seconds for CSYNC mode.
- (optional) Storage mode.
- (optional) Set of stores managed by sky added to the Storage object
- (optional) Set of stores managed by sky added to the Storage object.
"""

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,9 @@ def test_using_file_mounts_with_env_vars(generic_cloud: str):
@pytest.mark.aws
def test_aws_storage_mounts():
name = _get_cluster_name()
random_name = int(time.time())
storage_name = f'sky-test-{random_name}'
csync_storage_name = f'sky-test-{random_name + 1}'
timestamp = int(time.time())
storage_name = f'sky-test-{timestamp}'
csync_storage_name = f'sky-test-{timestamp + 1}'
template_str = pathlib.Path(
'tests/test_yamls/test_storage_mounting.yaml.j2').read_text()
template = jinja2.Template(template_str)
Expand Down

0 comments on commit 2bec928

Please sign in to comment.