diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index b0856ed1909..5a9663b1275 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -333,7 +333,7 @@ def wrap_file_mount(cls, path: str) -> str: def make_safe_symlink_command(cls, *, source: str, target: str) -> str: """Returns a command that safely symlinks 'source' to 'target'. - All intermediate directories of 'source' will be owned by $USER, + All intermediate directories of 'source' will be owned by $(whoami), excluding the root directory (/). 'source' must be an absolute path; both 'source' and 'target' must not @@ -360,17 +360,17 @@ def make_safe_symlink_command(cls, *, source: str, target: str) -> str: target) # Below, use sudo in case the symlink needs sudo access to create. # Prepare to create the symlink: - # 1. make sure its dir(s) exist & are owned by $USER. + # 1. make sure its dir(s) exist & are owned by $(whoami). dir_of_symlink = os.path.dirname(source) commands = [ # mkdir, then loop over '/a/b/c' as /a, /a/b, /a/b/c. For each, - # chown $USER on it so user can use these intermediate dirs + # chown $(whoami) on it so user can use these intermediate dirs # (excluding /). f'sudo mkdir -p {dir_of_symlink}', # p: path so far ('(p=""; ' f'for w in $(echo {dir_of_symlink} | tr "/" " "); do ' - 'p=${p}/${w}; sudo chown $USER $p; done)') + 'p=${p}/${w}; sudo chown $(whoami) $p; done)') ] # 2. remove any existing symlink (ln -f may throw 'cannot # overwrite directory', if the link exists and points to a @@ -386,7 +386,7 @@ def make_safe_symlink_command(cls, *, source: str, target: str) -> str: # Link. f'sudo ln -s {target} {source}', # chown. -h to affect symlinks only. - f'sudo chown -h $USER {source}', + f'sudo chown -h $(whoami) {source}', ] return ' && '.join(commands) @@ -1080,7 +1080,7 @@ def get_ready_nodes_counts(pattern, output): def get_docker_user(ip: str, cluster_config_file: str) -> str: """Find docker container username.""" ssh_credentials = ssh_credential_from_yaml(cluster_config_file) - runner = command_runner.SSHCommandRunner(ip, port=22, **ssh_credentials) + runner = command_runner.SSHCommandRunner(node=(ip, 22), **ssh_credentials) container_name = constants.DEFAULT_DOCKER_CONTAINER_NAME whoami_returncode, whoami_stdout, whoami_stderr = runner.run( f'sudo docker exec {container_name} whoami', @@ -1113,7 +1113,7 @@ def wait_until_ray_cluster_ready( try: head_ip = _query_head_ip_with_retries( cluster_config_file, max_attempts=WAIT_HEAD_NODE_IP_MAX_ATTEMPTS) - except exceptions.FetchIPError as e: + except exceptions.FetchClusterInfoError as e: logger.error(common_utils.format_exception(e)) return False, None # failed @@ -1129,8 +1129,7 @@ def wait_until_ray_cluster_ready( ssh_credentials = ssh_credential_from_yaml(cluster_config_file, docker_user) last_nodes_so_far = 0 start = time.time() - runner = command_runner.SSHCommandRunner(head_ip, - port=22, + runner = command_runner.SSHCommandRunner(node=(head_ip, 22), **ssh_credentials) with rich_utils.safe_status( '[bold cyan]Waiting for workers...') as worker_status: @@ -1236,7 +1235,7 @@ def ssh_credential_from_yaml( def parallel_data_transfer_to_nodes( - runners: List[command_runner.SSHCommandRunner], + runners: List[command_runner.CommandRunner], source: Optional[str], target: str, cmd: Optional[str], @@ -1246,32 +1245,36 @@ def parallel_data_transfer_to_nodes( # Advanced options. log_path: str = os.devnull, stream_logs: bool = False, + source_bashrc: bool = False, ): """Runs a command on all nodes and optionally runs rsync from src->dst. Args: - runners: A list of SSHCommandRunner objects that represent multiple nodes. + runners: A list of CommandRunner objects that represent multiple nodes. source: Optional[str]; Source for rsync on local node target: str; Destination on remote node for rsync cmd: str; Command to be executed on all nodes action_message: str; Message to be printed while the command runs log_path: str; Path to the log file stream_logs: bool; Whether to stream logs to stdout + source_bashrc: bool; Source bashrc before running the command. """ fore = colorama.Fore style = colorama.Style origin_source = source - def _sync_node(runner: 'command_runner.SSHCommandRunner') -> None: + def _sync_node(runner: 'command_runner.CommandRunner') -> None: if cmd is not None: rc, stdout, stderr = runner.run(cmd, log_path=log_path, stream_logs=stream_logs, - require_outputs=True) + require_outputs=True, + source_bashrc=source_bashrc) err_msg = ('Failed to run command before rsync ' f'{origin_source} -> {target}. ' - 'Ensure that the network is stable, then retry.') + 'Ensure that the network is stable, then retry. ' + f'{cmd}') if log_path != os.devnull: err_msg += f' See logs in {log_path}' subprocess_utils.handle_returncode(rc, @@ -1336,7 +1339,7 @@ def _query_head_ip_with_retries(cluster_yaml: str, """Returns the IP of the head node by querying the cloud. Raises: - exceptions.FetchIPError: if we failed to get the head IP. + exceptions.FetchClusterInfoError: if we failed to get the head IP. """ backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5) for i in range(max_attempts): @@ -1365,8 +1368,8 @@ def _query_head_ip_with_retries(cluster_yaml: str, break except subprocess.CalledProcessError as e: if i == max_attempts - 1: - raise exceptions.FetchIPError( - reason=exceptions.FetchIPError.Reason.HEAD) from e + raise exceptions.FetchClusterInfoError( + reason=exceptions.FetchClusterInfoError.Reason.HEAD) from e # Retry if the cluster is not up yet. logger.debug('Retrying to get head ip.') time.sleep(backoff.current_backoff()) @@ -1391,7 +1394,7 @@ def get_node_ips(cluster_yaml: str, IPs. Raises: - exceptions.FetchIPError: if we failed to get the IPs. e.reason is + exceptions.FetchClusterInfoError: if we failed to get the IPs. e.reason is HEAD or WORKER. """ ray_config = common_utils.read_yaml(cluster_yaml) @@ -1412,11 +1415,12 @@ def get_node_ips(cluster_yaml: str, 'Failed to get cluster info for ' f'{ray_config["cluster_name"]} from the new provisioner ' f'with {common_utils.format_exception(e)}.') - raise exceptions.FetchIPError( - exceptions.FetchIPError.Reason.HEAD) from e + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.HEAD) from e if len(metadata.instances) < expected_num_nodes: # Simulate the exception when Ray head node is not up. - raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD) + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.HEAD) return metadata.get_feasible_ips(get_internal_ips) if get_internal_ips: @@ -1446,8 +1450,8 @@ def get_node_ips(cluster_yaml: str, break except subprocess.CalledProcessError as e: if retry_cnt == worker_ip_max_attempts - 1: - raise exceptions.FetchIPError( - exceptions.FetchIPError.Reason.WORKER) from e + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.WORKER) from e # Retry if the ssh is not ready for the workers yet. backoff_time = backoff.current_backoff() logger.debug('Retrying to get worker ip ' @@ -1472,8 +1476,8 @@ def get_node_ips(cluster_yaml: str, f'detected IP(s): {worker_ips[-n:]}.') worker_ips = worker_ips[-n:] else: - raise exceptions.FetchIPError( - exceptions.FetchIPError.Reason.WORKER) + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.WORKER) else: worker_ips = [] return head_ip_list + worker_ips @@ -1760,14 +1764,11 @@ def _update_cluster_status_no_lock( def run_ray_status_to_check_ray_cluster_healthy() -> bool: try: - # TODO(zhwu): This function cannot distinguish transient network - # error in ray's get IPs vs. ray runtime failing. - # NOTE: fetching the IPs is very slow as it calls into # `ray get head-ip/worker-ips`. Using cached IPs is safe because # in the worst case we time out in the `ray status` SSH command # below. - external_ips = handle.cached_external_ips + runners = handle.get_command_runners(force_cached=True) # This happens when user interrupt the `sky launch` process before # the first time resources handle is written back to local database. # This is helpful when user interrupt after the provision is done @@ -1775,27 +1776,13 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: # helps keep the cluster status to INIT after `sky status -r`, so # user will be notified that any auto stop/down might not be # triggered. - if external_ips is None or len(external_ips) == 0: + if not runners: logger.debug(f'Refreshing status ({cluster_name!r}): No cached ' f'IPs found. Handle: {handle}') - raise exceptions.FetchIPError( - reason=exceptions.FetchIPError.Reason.HEAD) - - # Potentially refresh the external SSH ports, in case the existing - # cluster before #2491 was launched without external SSH ports - # cached. - external_ssh_ports = handle.external_ssh_ports() - head_ssh_port = external_ssh_ports[0] - - # Check if ray cluster status is healthy. - ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml, - handle.docker_user, - handle.ssh_user) - - runner = command_runner.SSHCommandRunner(external_ips[0], - **ssh_credentials, - port=head_ssh_port) - rc, output, stderr = runner.run( + raise exceptions.FetchClusterInfoError( + reason=exceptions.FetchClusterInfoError.Reason.HEAD) + head_runner = runners[0] + rc, output, stderr = head_runner.run( instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND, stream_logs=False, require_outputs=True, @@ -1815,7 +1802,7 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: f'Refreshing status ({cluster_name!r}): ray status not showing ' f'all nodes ({ready_head + ready_workers}/' f'{total_nodes}); output: {output}; stderr: {stderr}') - except exceptions.FetchIPError: + except exceptions.FetchClusterInfoError: logger.debug( f'Refreshing status ({cluster_name!r}) failed to get IPs.') except RuntimeError as e: @@ -2356,9 +2343,9 @@ def is_controller_accessible( handle.docker_user, handle.ssh_user) - runner = command_runner.SSHCommandRunner(handle.head_ip, - **ssh_credentials, - port=handle.head_ssh_port) + runner = command_runner.SSHCommandRunner(node=(handle.head_ip, + handle.head_ssh_port), + **ssh_credentials) if not runner.check_connection(): error_msg = controller.value.connection_error_hint else: diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 55b1ec4ab2a..5f1930123b0 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2,6 +2,7 @@ import base64 import copy import enum +import functools import getpass import inspect import json @@ -1478,10 +1479,12 @@ def _retry_zones( if zones and len(zones) == 1: launched_resources = launched_resources.copy(zone=zones[0].name) - prev_cluster_ips, prev_ssh_ports = None, None + prev_cluster_ips, prev_ssh_ports, prev_cluster_info = (None, None, + None) if prev_handle is not None: prev_cluster_ips = prev_handle.stable_internal_external_ips prev_ssh_ports = prev_handle.stable_ssh_ports + prev_cluster_info = prev_handle.cached_cluster_info # Record early, so if anything goes wrong, 'sky status' will show # the cluster name and users can appropriately 'sky down'. It also # means a second 'sky launch -c ' will attempt to reuse. @@ -1500,7 +1503,9 @@ def _retry_zones( # optimize the case where the cluster is restarted, i.e., no # need to query IPs and ports from the cloud provider. stable_internal_external_ips=prev_cluster_ips, - stable_ssh_ports=prev_ssh_ports) + stable_ssh_ports=prev_ssh_ports, + cluster_info=prev_cluster_info, + ) usage_lib.messages.usage.update_final_cluster_status( status_lib.ClusterStatus.INIT) @@ -1581,14 +1586,14 @@ def _retry_zones( # manually or by the cloud provider. # Optimize the case where the cluster's head IPs can be parsed # from the output of 'ray up'. - kwargs = {} if handle.launched_nodes == 1: - kwargs = { - 'internal_ips': [head_internal_ip], - 'external_ips': [head_external_ip] - } - handle.update_cluster_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS, - **kwargs) + handle.update_cluster_ips( + max_attempts=_FETCH_IP_MAX_ATTEMPTS, + internal_ips=[head_internal_ip], + external_ips=[head_external_ip]) + else: + handle.update_cluster_ips( + max_attempts=_FETCH_IP_MAX_ATTEMPTS) handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS) if cluster_exists: # Guard against the case where there's an existing cluster @@ -1994,10 +1999,9 @@ def provision_with_retries( requested_features = self._requested_features.copy() # Skip stop feature for Kubernetes jobs controller. - if isinstance(to_provision.cloud, clouds.Kubernetes - ) and controller_utils.Controllers.from_name( - cluster_name - ) == controller_utils.Controllers.JOBS_CONTROLLER: + if (isinstance(to_provision.cloud, clouds.Kubernetes) and + controller_utils.Controllers.from_name(cluster_name) + == controller_utils.Controllers.JOBS_CONTROLLER): assert (clouds.CloudImplementationFeatures.STOP in requested_features), requested_features requested_features.remove( @@ -2125,8 +2129,8 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle): - (optional) If TPU(s) are managed, a path to a deletion script. """ # Bump if any fields get added/removed/changed, and add backward - # compatibility logic in __setstate__. - _VERSION = 7 + # compaitibility logic in __setstate__. + _VERSION = 8 def __init__( self, @@ -2139,6 +2143,7 @@ def __init__( stable_internal_external_ips: Optional[List[Tuple[str, str]]] = None, stable_ssh_ports: Optional[List[int]] = None, + cluster_info: Optional[provision_common.ClusterInfo] = None, # The following 2 fields are deprecated. SkyPilot new provisioner # API handles the TPU node creation/deletion. # Backward compatibility for TPU nodes created before #2943. @@ -2155,10 +2160,10 @@ def __init__( # internal or external ips, depending on the use_internal_ips flag. self.stable_internal_external_ips = stable_internal_external_ips self.stable_ssh_ports = stable_ssh_ports + self.cached_cluster_info = cluster_info self.launched_nodes = launched_nodes self.launched_resources = launched_resources self.docker_user: Optional[str] = None - self.ssh_user: Optional[str] = None # Deprecated. SkyPilot new provisioner API handles the TPU node # creation/deletion. # Backward compatibility for TPU nodes created before #2943. @@ -2222,13 +2227,8 @@ def update_ssh_ports(self, max_attempts: int = 1) -> None: Use this method to use any cloud-specific port fetching logic. """ del max_attempts # Unused. - if isinstance(self.launched_resources.cloud, clouds.RunPod): - cluster_info = provision_lib.get_cluster_info( - str(self.launched_resources.cloud).lower(), - region=self.launched_resources.region, - cluster_name_on_cloud=self.cluster_name_on_cloud, - provider_config=None) - self.stable_ssh_ports = cluster_info.get_ssh_ports() + if self.cached_cluster_info is not None: + self.stable_ssh_ports = self.cached_cluster_info.get_ssh_ports() return head_ssh_port = 22 @@ -2236,11 +2236,49 @@ def update_ssh_ports(self, max_attempts: int = 1) -> None: [head_ssh_port] + [22] * (self.num_ips_per_node * self.launched_nodes - 1)) + def _update_cluster_info(self): + # When a cluster is on a cloud that does not support the new + # provisioner, we should skip updating cluster_info. + if (self.launched_resources.cloud.PROVISIONER_VERSION >= + clouds.ProvisionerVersion.SKYPILOT): + provider_name = str(self.launched_resources.cloud).lower() + config = {} + if os.path.exists(self.cluster_yaml): + # It is possible that the cluster yaml is not available when + # the handle is unpickled for service replicas from the + # controller with older version. + config = common_utils.read_yaml(self.cluster_yaml) + try: + cluster_info = provision_lib.get_cluster_info( + provider_name, + region=self.launched_resources.region, + cluster_name_on_cloud=self.cluster_name_on_cloud, + provider_config=config.get('provider', None)) + except Exception as e: # pylint: disable=broad-except + # This could happen when the VM is not fully launched, and a + # user is trying to terminate it with `sky down`. + logger.debug('Failed to get cluster info for ' + f'{self.cluster_name} from the new provisioner ' + f'with {common_utils.format_exception(e)}.') + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.HEAD) from e + if cluster_info.num_instances != self.launched_nodes: + logger.debug( + f'Available nodes in the cluster {self.cluster_name} ' + 'do not match the number of nodes requested (' + f'{cluster_info.num_instances} != ' + f'{self.launched_nodes}).') + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.HEAD) + self.cached_cluster_info = cluster_info + def update_cluster_ips( self, max_attempts: int = 1, internal_ips: Optional[List[Optional[str]]] = None, - external_ips: Optional[List[Optional[str]]] = None) -> None: + external_ips: Optional[List[Optional[str]]] = None, + cluster_info: Optional[provision_common.ClusterInfo] = None + ) -> None: """Updates the cluster IPs cached in the handle. We cache the cluster IPs in the handle to avoid having to retrieve @@ -2266,61 +2304,74 @@ def update_cluster_ips( external IPs from the cloud provider. Raises: - exceptions.FetchIPError: if we failed to get the IPs. e.reason is - HEAD or WORKER. + exceptions.FetchClusterInfoError: if we failed to get the cluster + infos. e.reason is HEAD or WORKER. """ - - def is_provided_ips_valid(ips: Optional[List[Optional[str]]]) -> bool: - return (ips is not None and - len(ips) == self.num_ips_per_node * self.launched_nodes and - all(ip is not None for ip in ips)) - - use_internal_ips = self._use_internal_ips() - - # cluster_feasible_ips is the list of IPs of the nodes in the cluster - # which can be used to connect to the cluster. It is a list of external - # IPs if the cluster is assigned public IPs, otherwise it is a list of - # internal IPs. - cluster_feasible_ips: List[str] - if is_provided_ips_valid(external_ips): - logger.debug(f'Using provided external IPs: {external_ips}') - cluster_feasible_ips = typing.cast(List[str], external_ips) + if cluster_info is not None: + self.cached_cluster_info = cluster_info + use_internal_ips = self._use_internal_ips() + cluster_feasible_ips = self.cached_cluster_info.get_feasible_ips( + use_internal_ips) + cluster_internal_ips = self.cached_cluster_info.get_feasible_ips( + force_internal_ips=True) else: - cluster_feasible_ips = backend_utils.get_node_ips( - self.cluster_yaml, - self.launched_nodes, - head_ip_max_attempts=max_attempts, - worker_ip_max_attempts=max_attempts, - get_internal_ips=use_internal_ips) - - if self.cached_external_ips == cluster_feasible_ips: - logger.debug('Skipping the fetching of internal IPs as the cached ' - 'external IPs matches the newly fetched ones.') - # Optimization: If the cached external IPs are the same as the - # retrieved feasible IPs, then we can skip retrieving internal - # IPs since the cached IPs are up-to-date. - return - logger.debug( - 'Cached external IPs do not match with the newly fetched ones: ' - f'cached ({self.cached_external_ips}), new ({cluster_feasible_ips})' - ) + # For clouds that do not support the SkyPilot Provisioner API. + # TODO(zhwu): once all the clouds are migrated to SkyPilot + # Provisioner API, we should remove this else block + def is_provided_ips_valid( + ips: Optional[List[Optional[str]]]) -> bool: + return (ips is not None and len(ips) + == self.num_ips_per_node * self.launched_nodes and + all(ip is not None for ip in ips)) + + use_internal_ips = self._use_internal_ips() + + # cluster_feasible_ips is the list of IPs of the nodes in the + # cluster which can be used to connect to the cluster. It is a list + # of external IPs if the cluster is assigned public IPs, otherwise + # it is a list of internal IPs. + if is_provided_ips_valid(external_ips): + logger.debug(f'Using provided external IPs: {external_ips}') + cluster_feasible_ips = typing.cast(List[str], external_ips) + else: + cluster_feasible_ips = backend_utils.get_node_ips( + self.cluster_yaml, + self.launched_nodes, + head_ip_max_attempts=max_attempts, + worker_ip_max_attempts=max_attempts, + get_internal_ips=use_internal_ips) + + if self.cached_external_ips == cluster_feasible_ips: + logger.debug( + 'Skipping the fetching of internal IPs as the cached ' + 'external IPs matches the newly fetched ones.') + # Optimization: If the cached external IPs are the same as the + # retrieved feasible IPs, then we can skip retrieving internal + # IPs since the cached IPs are up-to-date. + return - if use_internal_ips: - # Optimization: if we know use_internal_ips is True (currently - # only exposed for AWS and GCP), then our provisioner is guaranteed - # to not assign public IPs, thus the first list of IPs returned - # above are already private IPs. So skip the second query. - cluster_internal_ips = list(cluster_feasible_ips) - elif is_provided_ips_valid(internal_ips): - logger.debug(f'Using provided internal IPs: {internal_ips}') - cluster_internal_ips = typing.cast(List[str], internal_ips) - else: - cluster_internal_ips = backend_utils.get_node_ips( - self.cluster_yaml, - self.launched_nodes, - head_ip_max_attempts=max_attempts, - worker_ip_max_attempts=max_attempts, - get_internal_ips=True) + logger.debug( + 'Cached external IPs do not match with the newly fetched ones: ' + f'cached ({self.cached_external_ips}), new ' + f'({cluster_feasible_ips})') + + if use_internal_ips: + # Optimization: if we know use_internal_ips is True (currently + # only exposed for AWS and GCP), then our provisioner is + # guaranteed to not assign public IPs, thus the first list of + # IPs returned above are already private IPs. So skip the second + # query. + cluster_internal_ips = list(cluster_feasible_ips) + elif is_provided_ips_valid(internal_ips): + logger.debug(f'Using provided internal IPs: {internal_ips}') + cluster_internal_ips = typing.cast(List[str], internal_ips) + else: + cluster_internal_ips = backend_utils.get_node_ips( + self.cluster_yaml, + self.launched_nodes, + head_ip_max_attempts=max_attempts, + worker_ip_max_attempts=max_attempts, + get_internal_ips=True) assert len(cluster_feasible_ips) == len(cluster_internal_ips), ( f'Cluster {self.cluster_name!r}:' @@ -2339,6 +2390,39 @@ def is_provided_ips_valid(ips: Optional[List[Optional[str]]]) -> bool: internal_external_ips[1:], key=lambda x: x[1]) self.stable_internal_external_ips = stable_internal_external_ips + @functools.lru_cache() + @timeline.event + def get_command_runners(self, + force_cached: bool = False, + avoid_ssh_control: bool = False + ) -> List[command_runner.CommandRunner]: + """Returns a list of command runners for the cluster.""" + ssh_credentials = backend_utils.ssh_credential_from_yaml( + self.cluster_yaml, self.docker_user, self.ssh_user) + if avoid_ssh_control: + ssh_credentials.pop('ssh_control_name', None) + if (clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR >= + self.launched_resources.cloud.PROVISIONER_VERSION): + ip_list = (self.cached_external_ips + if force_cached else self.external_ips()) + if ip_list is None: + return [] + # Potentially refresh the external SSH ports, in case the existing + # cluster before #2491 was launched without external SSH ports + # cached. + port_list = self.external_ssh_ports() + runners = command_runner.SSHCommandRunner.make_runner_list( + zip(ip_list, port_list), **ssh_credentials) + return runners + if self.cached_cluster_info is None: + assert not force_cached, 'cached_cluster_info is None.' + self._update_cluster_info() + assert self.cached_cluster_info is not None, self + runners = provision_lib.get_command_runners( + self.cached_cluster_info.provider_name, self.cached_cluster_info, + **ssh_credentials) + return runners + @property def cached_internal_ips(self) -> Optional[List[str]]: if self.stable_internal_external_ips is not None: @@ -2404,6 +2488,16 @@ def setup_docker_user(self, cluster_config_file: str): def cluster_yaml(self): return os.path.expanduser(self._cluster_yaml) + @property + def ssh_user(self): + if self.cached_cluster_info is not None: + # Overload ssh_user with the user stored in cluster_info, which is + # useful for kubernetes case, where the ssh_user can depend on the + # container image used. For those clusters launched with ray + # autoscaler, we directly use the ssh_user in yaml config. + return self.cached_cluster_info.ssh_user + return None + @property def head_ip(self): external_ips = self.cached_external_ips @@ -2449,8 +2543,8 @@ def __setstate__(self, state): if version < 6: state['cluster_name_on_cloud'] = state['cluster_name'] - if version < 7: - self.ssh_user = None + if version < 8: + self.cached_cluster_info = None self.__dict__.update(state) @@ -2460,7 +2554,7 @@ def __setstate__(self, state): if version < 3 and head_ip is not None: try: self.update_cluster_ips() - except exceptions.FetchIPError: + except exceptions.FetchClusterInfoError: # This occurs when an old cluster from was autostopped, # so the head IP in the database is not updated. pass @@ -2469,6 +2563,14 @@ def __setstate__(self, state): self._update_cluster_region() + if version < 8: + try: + self._update_cluster_info() + except exceptions.FetchClusterInfoError: + # This occurs when an old cluster from was autostopped, + # so the head IP in the database is not updated. + pass + class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']): """Backend: runs on cloud virtual machines, managed by Ray. @@ -2760,22 +2862,17 @@ def _provision( provision_record=provision_record, custom_resource=resources_vars.get('custom_resources'), log_dir=self.log_dir) - # We must query the IPs from the cloud provider, when the - # provisioning is done, to make sure the cluster IPs are - # up-to-date. + # We use the IPs from the cluster_info to update_cluster_ips, + # when the provisioning is done, to make sure the cluster IPs + # are up-to-date. # The staled IPs may be caused by the node being restarted # manually or by the cloud provider. # Optimize the case where the cluster's IPs can be retrieved # from cluster_info. - internal_ips, external_ips = zip(*cluster_info.ip_tuples()) - if not cluster_info.has_external_ips(): - external_ips = internal_ips + handle.docker_user = cluster_info.docker_user handle.update_cluster_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS, - internal_ips=list(internal_ips), - external_ips=list(external_ips)) + cluster_info=cluster_info) handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS) - handle.docker_user = cluster_info.docker_user - handle.ssh_user = cluster_info.ssh_user # Update launched resources. handle.launched_resources = handle.launched_resources.copy( @@ -2807,11 +2904,7 @@ def _provision( handle.launched_resources.cloud.get_zone_shell_cmd()) # zone is None for Azure if get_zone_cmd is not None: - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, - handle.ssh_user) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=ssh_port_list, **ssh_credentials) + runners = handle.get_command_runners() def _get_zone(runner): retry_count = 0 @@ -2852,8 +2945,11 @@ def _get_zone(runner): logger.debug('Checking if skylet is running on the head node.') with rich_utils.safe_status( '[bold cyan]Preparing SkyPilot runtime'): + # We need to source bashrc for skylet to make sure the autostop + # event can access the path to the cloud CLIs. self.run_on_head(handle, - instance_setup.MAYBE_SKYLET_RESTART_CMD) + instance_setup.MAYBE_SKYLET_RESTART_CMD, + source_bashrc=True) self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, @@ -2952,7 +3048,6 @@ def _sync_workdir(self, handle: CloudVmRayResourceHandle, fore = colorama.Fore style = colorama.Style ip_list = handle.external_ips() - port_list = handle.external_ssh_ports() assert ip_list is not None, 'external_ips is not cached in handle' full_workdir = os.path.abspath(os.path.expanduser(workdir)) @@ -2977,14 +3072,10 @@ def _sync_workdir(self, handle: CloudVmRayResourceHandle, log_path = os.path.join(self.log_dir, 'workdir_sync.log') - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - # TODO(zhwu): refactor this with backend_utils.parallel_cmd_with_rsync - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=port_list, **ssh_credentials) + runners = handle.get_command_runners() - def _sync_workdir_node(runner: command_runner.SSHCommandRunner) -> None: + def _sync_workdir_node(runner: command_runner.CommandRunner) -> None: runner.rsync( source=workdir, target=SKY_REMOTE_WORKDIR, @@ -3031,29 +3122,18 @@ def _setup(self, handle: CloudVmRayResourceHandle, task: task_lib.Task, return setup = task.setup # Sync the setup script up and run it. - ip_list = handle.external_ips() internal_ips = handle.internal_ips() - port_list = handle.external_ssh_ports() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - # Disable connection sharing for setup script to avoid old - # connections being reused, which may cause stale ssh agent - # forwarding. - ssh_credentials.pop('ssh_control_name', None) - remote_setup_file_name = f'/tmp/sky_setup_{self.run_timestamp}' # Need this `-i` option to make sure `source ~/.bashrc` work setup_cmd = f'/bin/bash -i {remote_setup_file_name} 2>&1' + runners = handle.get_command_runners(avoid_ssh_control=True) def _setup_node(node_id: int) -> None: setup_envs = task.envs.copy() setup_envs.update(self._skypilot_predefined_env_vars(handle)) setup_envs['SKYPILOT_SETUP_NODE_IPS'] = '\n'.join(internal_ips) setup_envs['SKYPILOT_SETUP_NODE_RANK'] = str(node_id) - runner = command_runner.SSHCommandRunner(ip_list[node_id], - port=port_list[node_id], - **ssh_credentials) + runner = runners[node_id] setup_script = log_lib.make_task_bash_script(setup, env_vars=setup_envs) with tempfile.NamedTemporaryFile('w', prefix='sky_setup_') as f: @@ -3067,11 +3147,12 @@ def _setup_node(node_id: int) -> None: if detach_setup: return setup_log_path = os.path.join(self.log_dir, - f'setup-{runner.ip}.log') + f'setup-{runner.node_id}.log') returncode = runner.run( setup_cmd, log_path=setup_log_path, process_stream=False, + source_bashrc=True, ) def error_message() -> str: @@ -3101,7 +3182,7 @@ def error_message() -> str: command=setup_cmd, error_msg=error_message) - num_nodes = len(ip_list) + num_nodes = len(runners) plural = 's' if num_nodes > 1 else '' if not detach_setup: logger.info(f'{fore.CYAN}Running setup on {num_nodes} node{plural}.' @@ -3167,12 +3248,8 @@ def _exec_code_on_head( # execute it. # We use 120KB as a threshold to be safe for other arguments that # might be added during ssh. - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - head_ssh_port = handle.head_ssh_port - runner = command_runner.SSHCommandRunner(handle.head_ip, - port=head_ssh_port, - **ssh_credentials) + runners = handle.get_command_runners() + head_runner = runners[0] with tempfile.NamedTemporaryFile('w', prefix='sky_app_') as fp: fp.write(codegen) fp.flush() @@ -3181,10 +3258,10 @@ def _exec_code_on_head( # We choose to sync code + exec, because the alternative of 'ray # submit' may not work as it may use system python (python2) to # execute the script. Happens for AWS. - runner.rsync(source=fp.name, - target=script_path, - up=True, - stream_logs=False) + head_runner.rsync(source=fp.name, + target=script_path, + up=True, + stream_logs=False) job_submit_cmd = f'{mkdir_code} && {code}' if managed_job_dag is not None: @@ -3540,15 +3617,7 @@ def sync_down_logs( logger.info(f'{fore.CYAN}Job {job_id} logs: {log_dir}' f'{style.RESET_ALL}') - ip_list = handle.external_ips() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_port_list = handle.external_ssh_ports() - assert ssh_port_list is not None, 'external_ssh_ports is not cached ' \ - 'in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=ssh_port_list, **ssh_credentials) + runners = handle.get_command_runners() def _rsync_down(args) -> None: """Rsync down logs from remote nodes. @@ -3742,7 +3811,7 @@ def teardown_no_lock(self, # even when the command was executed successfully. self.run_on_head(handle, f'{constants.SKY_RAY_CMD} stop --force') - except exceptions.FetchIPError: + except exceptions.FetchClusterInfoError: # This error is expected if the previous cluster IP is # failed to be found, # i.e., the cluster is already stopped/terminated. @@ -4066,7 +4135,8 @@ def set_autostop(self, # cloud and resources support requested autostop. if idle_minutes_to_autostop is not None: # Skip auto-stop for Kubernetes clusters. - if isinstance(handle.launched_resources.cloud, clouds.Kubernetes): + if (isinstance(handle.launched_resources.cloud, clouds.Kubernetes) + and not down and idle_minutes_to_autostop >= 0): # We should hit this code path only for the jobs controller on # Kubernetes clusters. assert (controller_utils.Controllers.from_name( @@ -4137,6 +4207,7 @@ def run_on_head( require_outputs: bool = False, separate_stderr: bool = False, process_stream: bool = True, + source_bashrc: bool = False, **kwargs, ) -> Union[int, Tuple[int, str, str]]: """Runs 'cmd' on the cluster's head node. @@ -4162,6 +4233,9 @@ def run_on_head( process_stream: Whether to post-process the stdout/stderr of the command, such as replacing or skipping lines on the fly. If enabled, lines are printed only when '\r' or '\n' is found. + source_bashrc: Whether to source bashrc when running on the command + on the VM. If it is a user-related commands, it would always be + good to source bashrc to make sure the env vars are set. Returns: returncode @@ -4169,24 +4243,17 @@ def run_on_head( A tuple of (returncode, stdout, stderr). Raises: - exceptions.FetchIPError: If the head node IP cannot be fetched. + exceptions.FetchClusterInfoError: If the cluster info cannot be + fetched. """ # This will try to fetch the head node IP if it is not cached. - external_ips = handle.external_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS) - head_ip = external_ips[0] - external_ssh_ports = handle.external_ssh_ports( - max_attempts=_FETCH_IP_MAX_ATTEMPTS) - head_ssh_port = external_ssh_ports[0] - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - runner = command_runner.SSHCommandRunner(head_ip, - port=head_ssh_port, - **ssh_credentials) + runners = handle.get_command_runners() + head_runner = runners[0] if under_remote_workdir: cmd = f'cd {SKY_REMOTE_WORKDIR} && {cmd}' - return runner.run( + return head_runner.run( cmd, port_forward=port_forward, log_path=log_path, @@ -4195,6 +4262,7 @@ def run_on_head( ssh_mode=ssh_mode, require_outputs=require_outputs, separate_stderr=separate_stderr, + source_bashrc=source_bashrc, **kwargs, ) @@ -4338,13 +4406,7 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, style = colorama.Style logger.info(f'{fore.CYAN}Processing file mounts.{style.RESET_ALL}') start = time.time() - ip_list = handle.external_ips() - port_list = handle.external_ssh_ports() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=port_list, **ssh_credentials) + runners = handle.get_command_runners() log_path = os.path.join(self.log_dir, 'file_mounts.log') # Check the files and warn @@ -4440,12 +4502,21 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, action_message='Syncing', log_path=log_path, stream_logs=False, + # Need to source bashrc, as the cloud specific CLI or SDK may + # require PATH in bashrc. + source_bashrc=True, ) # (2) Run the commands to create symlinks on all the nodes. symlink_command = ' && '.join(symlink_commands) if symlink_command: - - def _symlink_node(runner: command_runner.SSHCommandRunner): + # ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD sets sudo to empty string for + # root. We need this as we do not source bashrc for the command for + # better performance, and our sudo handling is only in bashrc. + symlink_command = ( + f'{command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD} && ' + f'{symlink_command}') + + def _symlink_node(runner: command_runner.CommandRunner): returncode = runner.run(symlink_command, log_path=log_path) subprocess_utils.handle_returncode( returncode, symlink_command, @@ -4485,13 +4556,7 @@ def _execute_storage_mounts( logger.info(f'{fore.CYAN}Processing {len(storage_mounts)} ' f'storage mount{plural}.{style.RESET_ALL}') start = time.time() - ip_list = handle.external_ips() - port_list = handle.external_ssh_ports() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=port_list, **ssh_credentials) + runners = handle.get_command_runners() log_path = os.path.join(self.log_dir, 'storage_mounts.log') for dst, storage_obj in storage_mounts.items(): @@ -4522,6 +4587,9 @@ def _execute_storage_mounts( run_rsync=False, action_message='Mounting', log_path=log_path, + # Need to source bashrc, as the cloud specific CLI or SDK + # may require PATH in bashrc. + source_bashrc=True, ) except exceptions.CommandError as e: if e.returncode == exceptions.MOUNT_PATH_NON_EMPTY_CODE: diff --git a/sky/data/mounting_utils.py b/sky/data/mounting_utils.py index 2f4e37a1b66..043799e5ab5 100644 --- a/sky/data/mounting_utils.py +++ b/sky/data/mounting_utils.py @@ -4,6 +4,7 @@ from typing import Optional from sky import exceptions +from sky.utils import command_runner # Values used to construct mounting commands _STAT_CACHE_TTL = '5s' @@ -129,6 +130,8 @@ def get_mounting_script( script = textwrap.dedent(f""" #!/usr/bin/env bash set -e + + {command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD} MOUNT_PATH={mount_path} MOUNT_BINARY={mount_binary} diff --git a/sky/exceptions.py b/sky/exceptions.py index e3b33ea3e5e..4fced20ce4e 100644 --- a/sky/exceptions.py +++ b/sky/exceptions.py @@ -190,8 +190,8 @@ class StorageExternalDeletionError(StorageBucketGetError): pass -class FetchIPError(Exception): - """Raised when fetching the IP fails.""" +class FetchClusterInfoError(Exception): + """Raised when fetching the cluster info fails.""" class Reason(enum.Enum): HEAD = 'HEAD' diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index 2f9a5bda44c..8371fb8ad83 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -5,7 +5,7 @@ """ import functools import inspect -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from sky import sky_logging from sky import status_lib @@ -21,6 +21,7 @@ from sky.provision import kubernetes from sky.provision import runpod from sky.provision import vsphere +from sky.utils import command_runner logger = sky_logging.init_logger(__name__) @@ -42,7 +43,7 @@ def _wrapper(*args, **kwargs): assert module is not None, f'Unknown provider: {module_name}' impl = getattr(module, func.__name__, None) - if impl: + if impl is not None: return impl(*args, **kwargs) # If implementation does not exist, fall back to default implementation @@ -175,3 +176,18 @@ def get_cluster_info( provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """Get the metadata of instances in a cluster.""" raise NotImplementedError + + +@_route_to_cloud_impl +def get_command_runners( + provider_name: str, + cluster_info: common.ClusterInfo, + **crednetials: Dict[str, Any], +) -> List[command_runner.CommandRunner]: + """Get a command runner for the given cluster.""" + ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() + return command_runner.SSHCommandRunner.make_runner_list( + node_list=zip(ip_list, port_list), + **crednetials, + ) diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index bdf1650665f..e279b30c74b 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -843,7 +843,6 @@ def get_cluster_info( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" - del provider_config # unused ec2 = _default_ec2_resource(region) filters = [ { @@ -875,4 +874,6 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='aws', + provider_config=provider_config, ) diff --git a/sky/provision/common.py b/sky/provision/common.py index dbcb9e659e6..7c1bcb32652 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -104,6 +104,10 @@ class ClusterInfo: # The unique identifier of the head instance, i.e., the # `instance_info.instance_id` of the head node. head_instance_id: Optional[InstanceId] + # Provider related information. + provider_name: str + provider_config: Optional[Dict[str, Any]] = None + docker_user: Optional[str] = None # Override the ssh_user from the cluster config. ssh_user: Optional[str] = None @@ -151,6 +155,19 @@ def ip_tuples(self) -> List[Tuple[str, Optional[str]]]: other_ips.append(pair) return head_instance_ip + other_ips + def instance_ids(self) -> List[str]: + """Return the instance ids in the same order of ip_tuples.""" + id_list = [] + if self.head_instance_id is not None: + id_list.append(self.head_instance_id + '-0') + for inst_id, instances in self.instances.items(): + start_idx = 0 + if inst_id == self.head_instance_id: + start_idx = 1 + id_list.extend( + [f'{inst_id}-{i}' for i in range(start_idx, len(instances))]) + return id_list + def has_external_ips(self) -> bool: """True if the cluster has external IP.""" ip_tuples = self.ip_tuples() @@ -186,8 +203,10 @@ def get_feasible_ips(self, force_internal_ips: bool = False) -> List[str]: def get_ssh_ports(self) -> List[int]: """Get the SSH port of all the instances.""" head_instance = self.get_head_instance() - assert head_instance is not None, self - head_instance_port = [head_instance.ssh_port] + + head_instance_port = [] + if head_instance is not None: + head_instance_port = [head_instance.ssh_port] worker_instances = self.get_worker_instances() worker_instance_ports = [ diff --git a/sky/provision/cudo/instance.py b/sky/provision/cudo/instance.py index e4a2db722e4..39d4bc6b3d1 100644 --- a/sky/provision/cudo/instance.py +++ b/sky/provision/cudo/instance.py @@ -162,7 +162,7 @@ def get_cluster_info( region: str, cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: - del region, provider_config + del region nodes = _filter_instances(cluster_name_on_cloud, ['runn', 'pend']) instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None @@ -178,10 +178,10 @@ def get_cluster_info( if node_info['name'].endswith('-head'): head_instance_id = node_id - return common.ClusterInfo( - instances=instances, - head_instance_id=head_instance_id, - ) + return common.ClusterInfo(instances=instances, + head_instance_id=head_instance_id, + provider_name='cudo', + provider_config=provider_config) def query_instances( diff --git a/sky/provision/docker_utils.py b/sky/provision/docker_utils.py index 8de7beab2e7..10ae5dafc07 100644 --- a/sky/provision/docker_utils.py +++ b/sky/provision/docker_utils.py @@ -3,16 +3,13 @@ import dataclasses import shlex import time -import typing from typing import Any, Dict, List from sky import sky_logging from sky.skylet import constants +from sky.utils import command_runner from sky.utils import subprocess_utils -if typing.TYPE_CHECKING: - from sky.utils import command_runner - logger = sky_logging.init_logger(__name__) DOCKER_PERMISSION_DENIED_STR = ('permission denied while trying to connect to ' @@ -117,7 +114,7 @@ class DockerInitializer: """Initializer for docker containers on a remote node.""" def __init__(self, docker_config: Dict[str, Any], - runner: 'command_runner.SSHCommandRunner', log_path: str): + runner: 'command_runner.CommandRunner', log_path: str): self.docker_config = docker_config self.container_name = docker_config['container_name'] self.runner = runner @@ -255,7 +252,8 @@ def initialize(self) -> str: # Disable apt-get from asking user input during installation. # see https://askubuntu.com/questions/909277/avoiding-user-interaction-with-tzdata-when-installing-certbot-in-a-docker-contai # pylint: disable=line-too-long self._run( - 'echo \'[ "$(whoami)" == "root" ] && alias sudo=""\' >> ~/.bashrc;' + f'echo \'{command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD}\' ' + '>> ~/.bashrc;' 'echo "export DEBIAN_FRONTEND=noninteractive" >> ~/.bashrc;', run_env='docker') # Install dependencies. diff --git a/sky/provision/fluidstack/instance.py b/sky/provision/fluidstack/instance.py index 2c0d836fadc..b37519a8458 100644 --- a/sky/provision/fluidstack/instance.py +++ b/sky/provision/fluidstack/instance.py @@ -273,7 +273,7 @@ def get_cluster_info( region: str, cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: - del region, provider_config # unused + del region # unused running_instances = _filter_instances(cluster_name_on_cloud, ['running']) instances: Dict[str, List[common.InstanceInfo]] = {} @@ -296,7 +296,9 @@ def get_cluster_info( return common.ClusterInfo(instances=instances, head_instance_id=head_instance_id, - custom_ray_options={'use_external_ip': True}) + custom_ray_options={'use_external_ip': True}, + provider_name='fluidstack', + provider_config=provider_config) def query_instances( diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 35c8ae44dc8..a4996fc4d4b 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -428,6 +428,8 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='gcp', + provider_config=provider_config, ) diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 81e13f54fd0..1e5e6285fef 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -8,6 +8,7 @@ import time from typing import Any, Dict, List, Optional, Tuple +from sky import provision from sky import sky_logging from sky.provision import common from sky.provision import docker_utils @@ -123,24 +124,22 @@ def _parallel_ssh_with_cache(func, max_workers = subprocess_utils.get_parallel_threads() with futures.ThreadPoolExecutor(max_workers=max_workers) as pool: results = [] - for instance_id, metadatas in cluster_info.instances.items(): - for i, metadata in enumerate(metadatas): - cache_id = f'{instance_id}-{i}' - runner = command_runner.SSHCommandRunner( - metadata.get_feasible_ip(), - port=metadata.ssh_port, - **ssh_credentials) - wrapper = metadata_utils.cache_func(cluster_name, cache_id, - stage_name, digest) - if (cluster_info.head_instance_id == instance_id and i == 0): - # Log the head node's output to the provision.log - log_path_abs = str(provision_logging.get_log_path()) - else: - log_dir_abs = metadata_utils.get_instance_log_dir( - cluster_name, cache_id) - log_path_abs = str(log_dir_abs / (stage_name + '.log')) - results.append( - pool.submit(wrapper(func), runner, metadata, log_path_abs)) + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + # instance_ids is guaranteed to be in the same order as runners. + instance_ids = cluster_info.instance_ids() + for i, runner in enumerate(runners): + cache_id = instance_ids[i] + wrapper = metadata_utils.cache_func(cluster_name, cache_id, + stage_name, digest) + if i == 0: + # Log the head node's output to the provision.log + log_path_abs = str(provision_logging.get_log_path()) + else: + log_dir_abs = metadata_utils.get_instance_log_dir( + cluster_name, cache_id) + log_path_abs = str(log_dir_abs / (stage_name + '.log')) + results.append(pool.submit(wrapper(func), runner, log_path_abs)) return [future.result() for future in results] @@ -155,9 +154,7 @@ def initialize_docker(cluster_name: str, docker_config: Dict[str, Any], _hint_worker_log_path(cluster_name, cluster_info, 'initialize_docker') @_auto_retry - def _initialize_docker(runner: command_runner.SSHCommandRunner, - metadata: common.InstanceInfo, log_path: str): - del metadata # Unused. + def _initialize_docker(runner: command_runner.CommandRunner, log_path: str): docker_user = docker_utils.DockerInitializer(docker_config, runner, log_path).initialize() logger.debug(f'Initialized docker user: {docker_user}') @@ -194,14 +191,16 @@ def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str], digest = hasher.hexdigest() @_auto_retry - def _setup_node(runner: command_runner.SSHCommandRunner, - metadata: common.InstanceInfo, log_path: str): - del metadata + def _setup_node(runner: command_runner.CommandRunner, log_path: str): for cmd in setup_commands: - returncode, stdout, stderr = runner.run(cmd, - stream_logs=False, - log_path=log_path, - require_outputs=True) + returncode, stdout, stderr = runner.run( + cmd, + stream_logs=False, + log_path=log_path, + require_outputs=True, + # Installing depencies requires source bashrc to access the PATH + # in bashrc. + source_bashrc=True) retry_cnt = 0 while returncode == 255 and retry_cnt < _MAX_RETRY: # Got network connection issue occur during setup. This could @@ -215,7 +214,8 @@ def _setup_node(runner: command_runner.SSHCommandRunner, returncode, stdout, stderr = runner.run(cmd, stream_logs=False, log_path=log_path, - require_outputs=True) + require_outputs=True, + source_bashrc=True) if not returncode: break @@ -256,11 +256,9 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: """Start Ray on the head node.""" - ip_list = cluster_info.get_feasible_ips() - port_list = cluster_info.get_ssh_ports() - ssh_runner = command_runner.SSHCommandRunner(ip_list[0], - port=port_list[0], - **ssh_credentials) + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + head_runner = runners[0] assert cluster_info.head_instance_id is not None, (cluster_name, cluster_info) @@ -297,10 +295,14 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], _RAY_PRLIMIT + _DUMP_RAY_PORTS + RAY_HEAD_WAIT_INITIALIZED_COMMAND) logger.info(f'Running command on head node: {cmd}') # TODO(zhwu): add the output to log files. - returncode, stdout, stderr = ssh_runner.run(cmd, - stream_logs=False, - log_path=log_path_abs, - require_outputs=True) + returncode, stdout, stderr = head_runner.run( + cmd, + stream_logs=False, + log_path=log_path_abs, + require_outputs=True, + # Source bashrc for starting ray cluster to make sure actors started by + # ray will have the correct PATH. + source_bashrc=True) if returncode: raise RuntimeError('Failed to start ray on the head node ' f'(exit code {returncode}). Error: \n' @@ -318,11 +320,9 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, if cluster_info.num_instances <= 1: return _hint_worker_log_path(cluster_name, cluster_info, 'ray_cluster') - ip_list = cluster_info.get_feasible_ips() - ssh_runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list[1:], - port_list=cluster_info.get_ssh_ports()[1:], - **ssh_credentials) + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + worker_runners = runners[1:] worker_instances = cluster_info.get_worker_instances() cache_ids = [] prev_instance_id = None @@ -374,11 +374,11 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, f'grep "gcs-address={head_ip}:${{RAY_PORT}}" || ' f'{{ {cmd} }}') else: - cmd = 'ray stop; ' + cmd + cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd logger.info(f'Running command on worker nodes: {cmd}') - def _setup_ray_worker(runner_and_id: Tuple[command_runner.SSHCommandRunner, + def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, str]): # for cmd in config_from_yaml['worker_start_ray_commands']: # cmd = cmd.replace('$RAY_HEAD_IP', ip_list[0][0]) @@ -386,13 +386,17 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.SSHCommandRunner, runner, instance_id = runner_and_id log_dir = metadata_utils.get_instance_log_dir(cluster_name, instance_id) log_path_abs = str(log_dir / ('ray_cluster' + '.log')) - return runner.run(cmd, - stream_logs=False, - require_outputs=True, - log_path=log_path_abs) + return runner.run( + cmd, + stream_logs=False, + require_outputs=True, + log_path=log_path_abs, + # Source bashrc for starting ray cluster to make sure actors started + # by ray will have the correct PATH. + source_bashrc=True) results = subprocess_utils.run_in_parallel( - _setup_ray_worker, list(zip(ssh_runners, cache_ids))) + _setup_ray_worker, list(zip(worker_runners, cache_ids))) for returncode, stdout, stderr in results: if returncode: with ux_utils.print_exception_no_traceback(): @@ -410,18 +414,19 @@ def start_skylet_on_head_node(cluster_name: str, ssh_credentials: Dict[str, Any]) -> None: """Start skylet on the head node.""" del cluster_name - ip_list = cluster_info.get_feasible_ips() - port_list = cluster_info.get_ssh_ports() - ssh_runner = command_runner.SSHCommandRunner(ip_list[0], - port=port_list[0], - **ssh_credentials) + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + head_runner = runners[0] assert cluster_info.head_instance_id is not None, cluster_info log_path_abs = str(provision_logging.get_log_path()) logger.info(f'Running command on head node: {MAYBE_SKYLET_RESTART_CMD}') - returncode, stdout, stderr = ssh_runner.run(MAYBE_SKYLET_RESTART_CMD, - stream_logs=False, - require_outputs=True, - log_path=log_path_abs) + # We need to source bashrc for skylet to make sure the autostop event can + # access the path to the cloud CLIs. + returncode, stdout, stderr = head_runner.run(MAYBE_SKYLET_RESTART_CMD, + stream_logs=False, + require_outputs=True, + log_path=log_path_abs, + source_bashrc=True) if returncode: raise RuntimeError('Failed to start skylet on the head node ' f'(exit code {returncode}). Error: ' @@ -431,7 +436,7 @@ def start_skylet_on_head_node(cluster_name: str, @_auto_retry def _internal_file_mounts(file_mounts: Dict, - runner: command_runner.SSHCommandRunner, + runner: command_runner.CommandRunner, log_path: str) -> None: if file_mounts is None or not file_mounts: return @@ -493,9 +498,7 @@ def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str], """Executes file mounts - rsyncing internal local files""" _hint_worker_log_path(cluster_name, cluster_info, 'internal_file_mounts') - def _setup_node(runner: command_runner.SSHCommandRunner, - metadata: common.InstanceInfo, log_path: str): - del metadata + def _setup_node(runner: command_runner.CommandRunner, log_path: str): _internal_file_mounts(common_file_mounts, runner, log_path) _parallel_ssh_with_cache( diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 9129d39e586..9068079701f 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -730,7 +730,9 @@ def get_cluster_info( custom_ray_options={ 'object-store-memory': 500000000, 'num-cpus': cpu_request, - }) + }, + provider_name='kubernetes', + provider_config=provider_config) def query_instances( diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 4c26c0c2199..3b3608947ad 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -715,23 +715,18 @@ def parse_memory_resource(resource_qty_str: str, class KubernetesInstanceType: """Class to represent the "Instance Type" in a Kubernetes. - Since Kubernetes does not have a notion of instances, we generate virtual instance types that represent the resources requested by a pod ("node"). - This name captures the following resource requests: - CPU - Memory - Accelerators - The name format is "{n}CPU--{k}GB" where n is the number of vCPUs and k is the amount of memory in GB. Accelerators can be specified by appending "--{a}{type}" where a is the number of accelerators and type is the accelerator type. - CPU and memory can be specified as floats. Accelerator count must be int. - Examples: - 4CPU--16GB - 0.5CPU--1.5GB @@ -770,7 +765,6 @@ def _parse_instance_type( cls, name: str) -> Tuple[float, float, Optional[int], Optional[str]]: """Parses and returns resources from the given InstanceType name - Returns: cpus | float: Number of CPUs memory | float: Amount of memory in GB @@ -815,7 +809,6 @@ def from_resources(cls, accelerator_count: Union[float, int] = 0, accelerator_type: str = '') -> 'KubernetesInstanceType': """Returns an instance name object from the given resources. - If accelerator_count is not an int, it will be rounded up since GPU requests in Kubernetes must be int. """ diff --git a/sky/provision/paperspace/instance.py b/sky/provision/paperspace/instance.py index 12c581c8314..ce1a4768c24 100644 --- a/sky/provision/paperspace/instance.py +++ b/sky/provision/paperspace/instance.py @@ -251,7 +251,7 @@ def get_cluster_info( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None, ) -> common.ClusterInfo: - del region, provider_config # unused + del region # unused running_instances = _filter_instances(cluster_name_on_cloud, ['ready']) instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None @@ -271,6 +271,8 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='paperspace', + provider_config=provider_config, ) diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 764d197493a..df9a9fcc58a 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -24,7 +24,6 @@ from sky.provision import logging as provision_logging from sky.provision import metadata_utils from sky.skylet import constants -from sky.utils import command_runner from sky.utils import common_utils from sky.utils import rich_utils from sky.utils import ux_utils @@ -444,8 +443,6 @@ def _post_provision_setup( 'status with: sky status -r; and retry provisioning.') # TODO(suquark): Move wheel build here in future PRs. - ip_list = cluster_info.get_feasible_ips() - port_list = cluster_info.get_ssh_ports() # We don't set docker_user here, as we are configuring the VM itself. ssh_credentials = backend_utils.ssh_credential_from_yaml( cluster_yaml, ssh_user=cluster_info.ssh_user) @@ -505,9 +502,9 @@ def _post_provision_setup( cluster_name.name_on_cloud, config_from_yaml['setup_commands'], cluster_info, ssh_credentials) - head_runner = command_runner.SSHCommandRunner(ip_list[0], - port=port_list[0], - **ssh_credentials) + runners = provision.get_command_runners(cloud_name, cluster_info, + **ssh_credentials) + head_runner = runners[0] status.update( runtime_preparation_str.format(step=3, step_name='runtime')) @@ -544,7 +541,7 @@ def _post_provision_setup( # if provision_record.is_instance_just_booted(inst.instance_id): # worker_ips.append(inst.public_ip) - if len(ip_list) > 1: + if cluster_info.num_instances > 1: instance_setup.start_ray_on_worker_nodes( cluster_name.name_on_cloud, no_restart=not full_ray_setup, diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index 3ae99dae8d5..d7cb20b57a6 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -154,7 +154,7 @@ def get_cluster_info( region: str, cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: - del region, provider_config # unused + del region # unused running_instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None @@ -174,6 +174,8 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='runpod', + provider_config=provider_config, ) diff --git a/sky/provision/vsphere/instance.py b/sky/provision/vsphere/instance.py index 69a544210b3..787d8c97f62 100644 --- a/sky/provision/vsphere/instance.py +++ b/sky/provision/vsphere/instance.py @@ -571,8 +571,6 @@ def get_cluster_info( cluster_name: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" - if provider_config: - del provider_config # unused logger.info('New provision of Vsphere: get_cluster_info().') # Init the vsphere client @@ -610,4 +608,6 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='vsphere', + provider_config=provider_config, ) diff --git a/sky/serve/core.py b/sky/serve/core.py index 9680b90de0c..79aa53f7b58 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -474,7 +474,7 @@ def down( code, require_outputs=True, stream_logs=False) - except exceptions.FetchIPError as e: + except exceptions.FetchClusterInfoError as e: raise RuntimeError( 'Failed to fetch controller IP. Please refresh controller status ' f'by `sky status -r {serve_utils.SKY_SERVE_CONTROLLER_NAME}` ' diff --git a/sky/skylet/events.py b/sky/skylet/events.py index 22e86778570..c63b42cc438 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -17,6 +17,7 @@ from sky.jobs import utils as managed_job_utils from sky.serve import serve_utils from sky.skylet import autostop_lib +from sky.skylet import constants from sky.skylet import job_lib from sky.utils import cluster_yaml_utils from sky.utils import common_utils @@ -197,16 +198,19 @@ def _stop_cluster(self, autostop_config): logger.info('Running ray down.') # Stop the workers first to avoid orphan workers. subprocess.run( - ['ray', 'down', '-y', '--workers-only', config_path], + f'{constants.SKY_RAY_CMD} down -y --workers-only ' + f'{config_path}', check=True, + shell=True, # We pass env inherited from os.environ due to calling `ray # `. env=env) logger.info('Running final ray down.') subprocess.run( - ['ray', 'down', '-y', config_path], + f'{constants.SKY_RAY_CMD} down -y {config_path}', check=True, + shell=True, # We pass env inherited from os.environ due to calling `ray # `. env=env) @@ -228,7 +232,7 @@ def _stop_cluster_with_new_provisioner(self, autostop_config, # Stop the ray autoscaler to avoid scaling up, during # stopping/terminating of the cluster. logger.info('Stopping the ray cluster.') - subprocess.run('ray stop', shell=True, check=True) + subprocess.run(f'{constants.SKY_RAY_CMD} stop', shell=True, check=True) operation_fn = provision_lib.stop_instances if autostop_config.down: diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index ceed5a26024..93bbe99b3ce 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -883,7 +883,7 @@ def tail_logs(cls, follow: bool = True) -> str: # pylint: disable=line-too-long code = [ - f'job_id = {job_id} if {job_id} is not None else job_lib.get_latest_job_id()', + f'job_id = {job_id} if {job_id} != None else job_lib.get_latest_job_id()', 'run_timestamp = job_lib.get_run_timestamp(job_id)', f'log_dir = None if run_timestamp is None else os.path.join({constants.SKY_LOGS_DIRECTORY!r}, run_timestamp)', f'log_lib.tail_logs(job_id=job_id, log_dir=log_dir, ' diff --git a/sky/skylet/providers/lambda_cloud/node_provider.py b/sky/skylet/providers/lambda_cloud/node_provider.py index 8a9c5997a0b..bb8d40da62e 100644 --- a/sky/skylet/providers/lambda_cloud/node_provider.py +++ b/sky/skylet/providers/lambda_cloud/node_provider.py @@ -155,9 +155,10 @@ def _get_internal_ip(node: Dict[str, Any]): if node['external_ip'] is None or node['status'] != 'active': node['internal_ip'] = None return - runner = command_runner.SSHCommandRunner(node['external_ip'], - 'ubuntu', - self.ssh_key_path) + runner = command_runner.SSHCommandRunner( + node=(node['external_ip'], 22), + ssh_user='ubuntu', + ssh_private_key=self.ssh_key_path) rc, stdout, stderr = runner.run(_GET_INTERNAL_IP_CMD, require_outputs=True, stream_logs=False) diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index c66b5dfe032..3aa87eda138 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -5,13 +5,14 @@ import pathlib import shlex import time -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Type, Union from sky import sky_logging from sky.skylet import constants from sky.skylet import log_lib from sky.utils import common_utils from sky.utils import subprocess_utils +from sky.utils import timeline logger = sky_logging.init_logger(__name__) @@ -41,6 +42,12 @@ def _ssh_control_path(ssh_control_filename: Optional[str]) -> Optional[str]: return path +# Disable sudo for root user. This is useful when the command is running in a +# docker container, i.e. image_id is a docker image. +ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD = ( + '{ [ "$(whoami)" == "root" ] && function sudo() { "$@"; } || true; }') + + def ssh_options_list( ssh_private_key: Optional[str], ssh_control_name: Optional[str], @@ -134,17 +141,155 @@ class SshMode(enum.Enum): LOGIN = 2 -class SSHCommandRunner: +class CommandRunner: + """Runner for commands to be executed on the cluster.""" + + def __init__(self, node: Tuple[Any, Any], **kwargs): + del kwargs # Unused. + self.node = node + + @property + def node_id(self) -> str: + return '-'.join(str(x) for x in self.node) + + def _get_command_to_run( + self, + cmd: Union[str, List[str]], + process_stream: bool, + separate_stderr: bool, + skip_lines: int, + source_bashrc: bool = False, + ) -> str: + """Returns the command to run.""" + if isinstance(cmd, list): + cmd = ' '.join(cmd) + + # We need this to correctly run the cmd, and get the output. + command = [ + 'bash', + '--login', + '-c', + ] + if source_bashrc: + command += [ + # Need this `-i` option to make sure `source ~/.bashrc` work. + # Sourcing bashrc may take a few seconds causing overheads. + '-i', + shlex.quote( + f'true && source ~/.bashrc && export OMP_NUM_THREADS=1 ' + f'PYTHONWARNINGS=ignore && ({cmd})'), + ] + else: + # Optimization: this reduces the time for connecting to the remote + # cluster by 1 second. + # sourcing ~/.bashrc is not required for internal executions + command += [ + 'true && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore' + f' && ({cmd})' + ] + if not separate_stderr: + command.append('2>&1') + if not process_stream and skip_lines: + command += [ + # A hack to remove the following bash warnings (twice): + # bash: cannot set terminal process group + # bash: no job control in this shell + f'| stdbuf -o0 tail -n +{skip_lines}', + # This is required to make sure the executor of command can get + # correct returncode, since linux pipe is used. + '; exit ${PIPESTATUS[0]}' + ] + + command_str = ' '.join(command) + return command_str + + @timeline.event + def run( + self, + cmd: Union[str, List[str]], + *, + require_outputs: bool = False, + # Advanced options. + log_path: str = os.devnull, + # If False, do not redirect stdout/stderr to optimize performance. + process_stream: bool = True, + stream_logs: bool = True, + ssh_mode: SshMode = SshMode.NON_INTERACTIVE, + separate_stderr: bool = False, + source_bashrc: bool = False, + **kwargs) -> Union[int, Tuple[int, str, str]]: + """Runs the command on the cluster. + + Args: + cmd: The command to run. + require_outputs: Whether to return the stdout/stderr of the command. + log_path: Redirect stdout/stderr to the log_path. + stream_logs: Stream logs to the stdout/stderr. + ssh_mode: The mode to use for ssh. + See SSHMode for more details. + separate_stderr: Whether to separate stderr from stdout. + + Returns: + returncode + or + A tuple of (returncode, stdout, stderr). + """ + raise NotImplementedError + + @timeline.event + def rsync( + self, + source: str, + target: str, + *, + up: bool, + # Advanced options. + log_path: str = os.devnull, + stream_logs: bool = True, + max_retry: int = 1, + ) -> None: + """Uses 'rsync' to sync 'source' to 'target'. + + Args: + source: The source path. + target: The target path. + up: The direction of the sync, True for local to cluster, False + for cluster to local. + log_path: Redirect stdout/stderr to the log_path. + stream_logs: Stream logs to the stdout/stderr. + max_retry: The maximum number of retries for the rsync command. + This value should be non-negative. + + Raises: + exceptions.CommandError: rsync command failed. + """ + raise NotImplementedError + + @classmethod + def make_runner_list( + cls: Type['CommandRunner'], + node_list: Iterable[Any], + **kwargs, + ) -> List['CommandRunner']: + """Helper function for creating runners with the same credentials""" + return [cls(node, **kwargs) for node in node_list] + + def check_connection(self) -> bool: + """Check if the connection to the remote machine is successful.""" + returncode = self.run('true', connect_timeout=5, stream_logs=False) + return returncode == 0 + + +class SSHCommandRunner(CommandRunner): """Runner for SSH commands.""" def __init__( self, - ip: str, + node: Tuple[str, int], ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = '__default__', ssh_proxy_command: Optional[str] = None, - port: int = 22, docker_user: Optional[str] = None, disable_control_master: Optional[bool] = False, ): @@ -156,7 +301,7 @@ def __init__( runner.rsync(source, target, up=True) Args: - ip: The IP address of the remote machine. + node: (ip, port) The IP address and port of the remote machine. ssh_private_key: The path to the private key to use for ssh. ssh_user: The user to use for ssh. ssh_control_name: The files name of the ssh_control to use. This is @@ -174,6 +319,8 @@ def __init__( command will utilize ControlMaster. We currently disable it for k8s instance. """ + super().__init__(node) + ip, port = node self.ssh_private_key = ssh_private_key self.ssh_control_name = ( None if ssh_control_name is None else hashlib.md5( @@ -198,27 +345,6 @@ def __init__( self.port = port self._docker_ssh_proxy_command = None - @staticmethod - def make_runner_list( - ip_list: List[str], - ssh_user: str, - ssh_private_key: str, - ssh_control_name: Optional[str] = None, - ssh_proxy_command: Optional[str] = None, - disable_control_master: Optional[bool] = False, - port_list: Optional[List[int]] = None, - docker_user: Optional[str] = None, - ) -> List['SSHCommandRunner']: - """Helper function for creating runners with the same ssh credentials""" - if not port_list: - port_list = [22] * len(ip_list) - return [ - SSHCommandRunner(ip, ssh_user, ssh_private_key, ssh_control_name, - ssh_proxy_command, port, docker_user, - disable_control_master) - for ip, port in zip(ip_list, port_list) - ] - def _ssh_base_command(self, *, ssh_mode: SshMode, port_forward: Optional[List[int]], connect_timeout: Optional[int]) -> List[str]: @@ -251,6 +377,7 @@ def _ssh_base_command(self, *, ssh_mode: SshMode, f'{self.ssh_user}@{self.ip}' ] + @timeline.event def run( self, cmd: Union[str, List[str]], @@ -265,11 +392,11 @@ def run( ssh_mode: SshMode = SshMode.NON_INTERACTIVE, separate_stderr: bool = False, connect_timeout: Optional[int] = None, + source_bashrc: bool = False, **kwargs) -> Union[int, Tuple[int, str, str]]: """Uses 'ssh' to run 'cmd' on a node with ip. Args: - ip: The IP address of the node. cmd: The command to run. port_forward: A list of ports to forward from the localhost to the remote host. @@ -299,39 +426,20 @@ def run( command = base_ssh_command + cmd proc = subprocess_utils.run(command, shell=False, check=False) return proc.returncode, '', '' - if isinstance(cmd, list): - cmd = ' '.join(cmd) + + command_str = self._get_command_to_run( + cmd, + process_stream, + separate_stderr, + # A hack to remove the following bash warnings (twice): + # bash: cannot set terminal process group + # bash: no job control in this shell + skip_lines=5 if source_bashrc else 0, + source_bashrc=source_bashrc) + command = base_ssh_command + [shlex.quote(command_str)] log_dir = os.path.expanduser(os.path.dirname(log_path)) os.makedirs(log_dir, exist_ok=True) - # We need this to correctly run the cmd, and get the output. - command = [ - 'bash', - '--login', - '-c', - # Need this `-i` option to make sure `source ~/.bashrc` work. - '-i', - ] - - command += [ - shlex.quote(f'true && source ~/.bashrc && export OMP_NUM_THREADS=1 ' - f'PYTHONWARNINGS=ignore && ({cmd})'), - ] - if not separate_stderr: - command.append('2>&1') - if not process_stream and ssh_mode == SshMode.NON_INTERACTIVE: - command += [ - # A hack to remove the following bash warnings (twice): - # bash: cannot set terminal process group - # bash: no job control in this shell - '| stdbuf -o0 tail -n +5', - # This is required to make sure the executor of command can get - # correct returncode, since linux pipe is used. - '; exit ${PIPESTATUS[0]}' - ] - - command_str = ' '.join(command) - command = base_ssh_command + [shlex.quote(command_str)] executable = None if not process_stream: @@ -354,6 +462,7 @@ def run( executable=executable, **kwargs) + @timeline.event def rsync( self, source: str, @@ -456,10 +565,3 @@ def rsync( error_msg, stderr=stdout + stderr, stream_logs=stream_logs) - - def check_connection(self) -> bool: - """Check if the connection to the remote machine is successful.""" - returncode = self.run('true', connect_timeout=5, stream_logs=False) - if returncode: - return False - return True diff --git a/sky/utils/command_runner.pyi b/sky/utils/command_runner.pyi index f1b547927d3..e8f12ef6ebe 100644 --- a/sky/utils/command_runner.pyi +++ b/sky/utils/command_runner.pyi @@ -6,7 +6,7 @@ determine the return type based on the value of require_outputs. """ import enum import typing -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union from typing_extensions import Literal @@ -18,6 +18,7 @@ GIT_EXCLUDE: str RSYNC_DISPLAY_OPTION: str RSYNC_FILTER_OPTION: str RSYNC_EXCLUDE_OPTION: str +ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD: str def ssh_options_list( @@ -39,40 +40,91 @@ class SshMode(enum.Enum): LOGIN: int -class SSHCommandRunner: +class CommandRunner: + node_id: str + + def __init__( + self, + node: Tuple[Any, ...], + **kwargs, + ) -> None: + ... + + @typing.overload + def run(self, + cmd: Union[str, List[str]], + *, + require_outputs: Literal[False] = ..., + log_path: str = ..., + process_stream: bool = ..., + stream_logs: bool = ..., + separate_stderr: bool = ..., + **kwargs) -> int: + ... + + @typing.overload + def run(self, + cmd: Union[str, List[str]], + *, + require_outputs: Literal[True], + log_path: str = ..., + process_stream: bool = ..., + stream_logs: bool = ..., + separate_stderr: bool = ..., + **kwargs) -> Tuple[int, str, str]: + ... + + @typing.overload + def run(self, + cmd: Union[str, List[str]], + *, + require_outputs: bool = ..., + log_path: str = ..., + process_stream: bool = ..., + stream_logs: bool = ..., + separate_stderr: bool = ..., + **kwargs) -> Union[Tuple[int, str, str], int]: + ... + + def rsync(self, + source: str, + target: str, + *, + up: bool, + log_path: str = ..., + stream_logs: bool = ...) -> None: + ... + + @classmethod + def make_runner_list(cls: typing.Type[CommandRunner], + node_list: Iterable[Tuple[Any, ...]], + **kwargs) -> List[CommandRunner]: + ... + + def check_connection(self) -> bool: + ... + + +class SSHCommandRunner(CommandRunner): ip: str + port: int ssh_user: str ssh_private_key: str ssh_control_name: Optional[str] docker_user: str - port: int disable_control_master: Optional[bool] def __init__( self, - ip: str, + node: Tuple[str, int], ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = ..., - port: int = ..., docker_user: Optional[str] = ..., disable_control_master: Optional[bool] = ..., ) -> None: ... - @staticmethod - def make_runner_list( - ip_list: List[str], - ssh_user: str, - ssh_private_key: str, - ssh_control_name: Optional[str] = ..., - ssh_proxy_command: Optional[str] = ..., - port_list: Optional[List[int]] = ..., - docker_user: Optional[str] = ..., - disable_control_master: Optional[bool] = ..., - ) -> List['SSHCommandRunner']: - ... - @typing.overload def run(self, cmd: Union[str, List[str]], @@ -123,6 +175,3 @@ class SSHCommandRunner: log_path: str = ..., stream_logs: bool = ...) -> None: ... - - def check_connection(self) -> bool: - ... diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index b5b9027de65..0dc78d8427c 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -452,9 +452,9 @@ def class_fullname(cls, skip_builtins: bool = True): """Get the full name of a class. Example: - >>> e = sky.exceptions.FetchIPError() + >>> e = sky.exceptions.FetchClusterInfoError() >>> class_fullname(e.__class__) - 'sky.exceptions.FetchIPError' + 'sky.exceptions.FetchClusterInfoError' Args: cls: The class to get the full name. diff --git a/tests/backward_compatibility_tests.sh b/tests/backward_compatibility_tests.sh index 340c9f9826e..2156057953c 100644 --- a/tests/backward_compatibility_tests.sh +++ b/tests/backward_compatibility_tests.sh @@ -73,7 +73,8 @@ s=$(sky launch --cloud ${CLOUD} -d -c ${CLUSTER_NAME} examples/minimal.yaml) sky logs ${CLUSTER_NAME} 2 --status | grep RUNNING || exit 1 # remove color and find the job id echo "$s" | sed -r "s/\x1B\[([0-9]{1,3}(;[0-9]{1,2})?)?[mGK]//g" | grep "Job ID: 4" || exit 1 -sleep 45 +# wait for ready +sky logs ${CLUSTER_NAME} 2 q=$(sky queue ${CLUSTER_NAME}) echo "$q" echo "$q" | grep "SUCCEEDED" | wc -l | grep 4 || exit 1 diff --git a/tests/skyserve/auto_restart.yaml b/tests/skyserve/auto_restart.yaml index 2a3a31051b9..f7dc2a13f07 100644 --- a/tests/skyserve/auto_restart.yaml +++ b/tests/skyserve/auto_restart.yaml @@ -7,7 +7,6 @@ service: resources: ports: 8080 - cloud: gcp cpus: 2+ workdir: examples/serve/http_server diff --git a/tests/test_smoke.py b/tests/test_smoke.py index c0469abb109..7dbc1e2e86d 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1920,8 +1920,9 @@ def test_azure_start_stop(): f'sky start -y {name} -i 1', f'sky exec {name} examples/azure_start_stop.yaml', f'sky logs {name} 3 --status', # Ensure the job succeeded. - 'sleep 200', + 'sleep 260', f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' ], f'sky down -y {name}', timeout=30 * 60, # 30 mins @@ -2933,9 +2934,12 @@ def test_azure_start_stop_two_nodes(): f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml', f'sky logs {name} 1 --status', # Ensure the job succeeded. f'sky stop -y {name}', - f'sky start -y {name}', + f'sky start -y {name} -i 1', f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml', f'sky logs {name} 2 --status', # Ensure the job succeeded. + 'sleep 200', + f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' ], f'sky down -y {name}', timeout=30 * 60, # 30 mins (it takes around ~23 mins)