Skip to content

Commit

Permalink
[Provisioner] Fix cache for internal file mounts (#2715)
Browse files Browse the repository at this point in the history
* Fix file mounts

* print out the skylet version for better debugging ability

* Add skypilot version
  • Loading branch information
Michaelvll authored Oct 28, 2023
1 parent acd5ab7 commit 1ec056e
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 29 deletions.
2 changes: 0 additions & 2 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2904,8 +2904,6 @@ def _provision(
provisioner.ClusterName(handle.cluster_name,
handle.cluster_name_on_cloud),
handle.cluster_yaml,
local_wheel_path=local_wheel_path,
wheel_hash=wheel_hash,
provision_record=provision_record,
custom_resource=resources_vars.get('custom_resources'),
log_dir=self.log_dir)
Expand Down
25 changes: 15 additions & 10 deletions sky/provision/instance_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def _hint_worker_log_path(cluster_name: str, cluster_info: common.ClusterInfo,


def _parallel_ssh_with_cache(func, cluster_name: str, stage_name: str,
digest: str, cluster_info: common.ClusterInfo,
digest: Optional[str],
cluster_info: common.ClusterInfo,
ssh_credentials: Dict[str, Any]) -> List[Any]:
with futures.ThreadPoolExecutor(max_workers=32) as pool:
results = []
Expand Down Expand Up @@ -140,7 +141,7 @@ def _initialize_docker(runner: command_runner.SSHCommandRunner,
stage_name='initialize_docker',
# Should not cache docker setup, as it needs to be
# run every time a cluster is restarted.
digest=str(time.time()),
digest=None,
cluster_info=cluster_info,
ssh_credentials=ssh_credentials)
logger.debug(f'All docker users: {docker_users}')
Expand Down Expand Up @@ -372,8 +373,7 @@ def _internal_file_mounts(file_mounts: Dict,
@_log_start_end
def internal_file_mounts(cluster_name: str, common_file_mounts: Dict,
cluster_info: common.ClusterInfo,
ssh_credentials: Dict[str,
str], wheel_hash: str) -> None:
ssh_credentials: Dict[str, str]) -> None:
"""Executes file mounts - rsyncing internal local files"""
_hint_worker_log_path(cluster_name, cluster_info, 'internal_file_mounts')

Expand All @@ -382,9 +382,14 @@ def _setup_node(runner: command_runner.SSHCommandRunner,
del metadata
_internal_file_mounts(common_file_mounts, runner, log_path)

_parallel_ssh_with_cache(_setup_node,
cluster_name,
stage_name='internal_file_mounts',
digest=wheel_hash,
cluster_info=cluster_info,
ssh_credentials=ssh_credentials)
_parallel_ssh_with_cache(
_setup_node,
cluster_name,
stage_name='internal_file_mounts',
# Do not cache the file mounts, as the cloud
# credentials may change, and we should always
# update the remote files. The internal file_mounts
# is minimal and should not take too much time.
digest=None,
cluster_info=cluster_info,
ssh_credentials=ssh_credentials)
8 changes: 6 additions & 2 deletions sky/provision/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import pathlib
import shutil
from typing import Optional

from sky import sky_logging

Expand All @@ -30,7 +31,7 @@ def _get_instance_metadata_dir(cluster_name: str,


def cache_func(cluster_name: str, instance_id: str, stage_name: str,
hash_str: str):
hash_str: Optional[str]):
"""A helper function for caching function execution."""

def decorator(function):
Expand All @@ -51,8 +52,11 @@ def wrapper(*args, **kwargs):

@contextlib.contextmanager
def check_cache_hash_or_update(cluster_name: str, instance_id: str,
stage_name: str, hash_str: str):
stage_name: str, hash_str: Optional[str]):
"""A decorator for 'cache_func'."""
if hash_str is None:
yield True
return
path = get_instance_cache_dir(cluster_name, instance_id) / stage_name
if path.exists():
with open(path) as f:
Expand Down
17 changes: 3 additions & 14 deletions sky/provision/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dataclasses
import json
import os
import pathlib
import shlex
import socket
import subprocess
Expand Down Expand Up @@ -311,7 +310,6 @@ def wait_for_ssh(cluster_info: provision_common.ClusterInfo,

def _post_provision_setup(
cloud_name: str, cluster_name: ClusterName, cluster_yaml: str,
local_wheel_path: pathlib.Path, wheel_hash: str,
provision_record: provision_common.ProvisionRecord,
custom_resource: Optional[str]) -> provision_common.ClusterInfo:
cluster_info = provision.get_cluster_info(cloud_name,
Expand Down Expand Up @@ -388,21 +386,15 @@ def _post_provision_setup(
# (3) all instances need permission to mount storage for all clouds
# It is possible to have a "smaller" permission model, but we leave that
# for later.
file_mounts = {
backend_utils.SKY_REMOTE_PATH + '/' + wheel_hash:
str(local_wheel_path),
**config_from_yaml.get('file_mounts', {})
}
file_mounts = config_from_yaml.get('file_mounts', {})

runtime_preparation_str = ('[bold cyan]Preparing SkyPilot '
'runtime ({step}/3 - {step_name})')
status.update(
runtime_preparation_str.format(step=1, step_name='initializing'))
instance_setup.internal_file_mounts(cluster_name.name_on_cloud,
file_mounts,
cluster_info,
ssh_credentials,
wheel_hash=wheel_hash)
file_mounts, cluster_info,
ssh_credentials)

status.update(
runtime_preparation_str.format(step=2, step_name='dependencies'))
Expand Down Expand Up @@ -464,7 +456,6 @@ def _post_provision_setup(

def post_provision_runtime_setup(
cloud_name: str, cluster_name: ClusterName, cluster_yaml: str,
local_wheel_path: pathlib.Path, wheel_hash: str,
provision_record: provision_common.ProvisionRecord,
custom_resource: Optional[str],
log_dir: str) -> provision_common.ClusterInfo:
Expand All @@ -483,8 +474,6 @@ def post_provision_runtime_setup(
return _post_provision_setup(cloud_name,
cluster_name,
cluster_yaml=cluster_yaml,
local_wheel_path=local_wheel_path,
wheel_hash=wheel_hash,
provision_record=provision_record,
custom_resource=custom_resource)
except Exception: # pylint: disable=broad-except
Expand Down
5 changes: 4 additions & 1 deletion sky/skylet/skylet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import time

import sky
from sky import sky_logging
from sky.skylet import constants
from sky.skylet import events

# Use the explicit logger name so that the logger is under the
# `sky.skylet.skylet` namespace when executed directly, so as
# to inherit the setup from the `sky` logger.
logger = sky_logging.init_logger('sky.skylet.skylet')
logger.info('skylet started')
logger.info(f'Skylet started with version {constants.SKYLET_VERSION}; '
f'SkyPilot v{sky.__version__} (commit: {sky.__commit__})')

EVENTS = [
events.AutostopEvent(),
Expand Down

0 comments on commit 1ec056e

Please sign in to comment.