From ee811e32f8c5ae58f7cdff40392000f69e66d765 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 10 Mar 2024 01:17:29 -0800 Subject: [PATCH] [Clouds] Remove old node providers for GCP and Kubernetes (#3287) * remove unused providers * Remove code that should be removed in 0.5 * lint * remove interactive nodes * remove test for interactive node * remove cpu/gpunode in comments * fix node type * Address comments --- .github/workflows/format.yml | 5 +- format.sh | 3 - sky/__init__.py | 2 - sky/backends/backend_utils.py | 6 +- sky/backends/cloud_vm_ray_backend.py | 15 +- sky/cli.py | 697 +--------- sky/core.py | 11 - sky/global_user_state.py | 4 +- sky/skylet/constants.py | 2 - sky/skylet/providers/gcp/__init__.py | 2 - sky/skylet/providers/gcp/config.py | 1152 ----------------- sky/skylet/providers/gcp/constants.py | 138 -- sky/skylet/providers/gcp/node.py | 905 ------------- sky/skylet/providers/gcp/node_provider.py | 400 ------ sky/skylet/providers/kubernetes/__init__.py | 1 - sky/skylet/providers/kubernetes/config.py | 337 ----- .../providers/kubernetes/node_provider.py | 657 ---------- sky/spot/controller.py | 1 - sky/spot/spot_utils.py | 3 - sky/task.py | 2 +- sky/utils/common_utils.py | 1 - tests/test_cli.py | 33 - 22 files changed, 24 insertions(+), 4353 deletions(-) delete mode 100644 sky/skylet/providers/gcp/__init__.py delete mode 100644 sky/skylet/providers/gcp/config.py delete mode 100644 sky/skylet/providers/gcp/constants.py delete mode 100644 sky/skylet/providers/gcp/node.py delete mode 100644 sky/skylet/providers/gcp/node_provider.py delete mode 100644 sky/skylet/providers/kubernetes/__init__.py delete mode 100644 sky/skylet/providers/kubernetes/config.py delete mode 100644 sky/skylet/providers/kubernetes/node_provider.py diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 9352681fc64..f1259f422f8 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -33,13 +33,11 @@ jobs: - name: Running yapf run: | yapf --diff --recursive ./ --exclude 'sky/skylet/ray_patches/**' \ - --exclude 'sky/skylet/providers/gcp/**' \ --exclude 'sky/skylet/providers/azure/**' \ --exclude 'sky/skylet/providers/ibm/**' - name: Running black run: | - black --diff --check sky/skylet/providers/gcp/ \ - sky/skylet/providers/azure/ \ + black --diff --check sky/skylet/providers/azure/ \ sky/skylet/providers/ibm/ - name: Running isort for black formatted files run: | @@ -48,6 +46,5 @@ jobs: - name: Running isort for yapf formatted files run: | isort --diff --check ./ --sg 'sky/skylet/ray_patches/**' \ - --sg 'sky/skylet/providers/gcp/**' \ --sg 'sky/skylet/providers/azure/**' \ --sg 'sky/skylet/providers/ibm/**' diff --git a/format.sh b/format.sh index 2cb4ecbc545..e3bcfde0f18 100755 --- a/format.sh +++ b/format.sh @@ -48,20 +48,17 @@ YAPF_FLAGS=( YAPF_EXCLUDES=( '--exclude' 'build/**' - '--exclude' 'sky/skylet/providers/gcp/**' '--exclude' 'sky/skylet/providers/azure/**' '--exclude' 'sky/skylet/providers/ibm/**' ) ISORT_YAPF_EXCLUDES=( '--sg' 'build/**' - '--sg' 'sky/skylet/providers/gcp/**' '--sg' 'sky/skylet/providers/azure/**' '--sg' 'sky/skylet/providers/ibm/**' ) BLACK_INCLUDES=( - 'sky/skylet/providers/gcp' 'sky/skylet/providers/azure' 'sky/skylet/providers/ibm' ) diff --git a/sky/__init__.py b/sky/__init__.py index ae4bc99b69d..a2068315131 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -51,7 +51,6 @@ def get_git_commit(): from sky.core import queue from sky.core import spot_cancel from sky.core import spot_queue -from sky.core import spot_status from sky.core import start from sky.core import status from sky.core import stop @@ -135,7 +134,6 @@ def get_git_commit(): 'job_status', # core APIs Spot Job Management 'spot_queue', - 'spot_status', # Deprecated (alias for spot_queue) 'spot_cancel', # core APIs Storage Management 'storage_ls', diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 9d8c8366f8e..cb507d6121c 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -947,11 +947,7 @@ def write_cluster_config( with open(tmp_yaml_path, 'w', encoding='utf-8') as f: f.write(restored_yaml_content) - # Read the cluster name from the tmp yaml file, to take the backward - # compatbility restortion above into account. - # TODO: remove this after 2 minor releases, 0.5.0. - yaml_config = common_utils.read_yaml(tmp_yaml_path) - config_dict['cluster_name_on_cloud'] = yaml_config['cluster_name'] + config_dict['cluster_name_on_cloud'] = cluster_name_on_cloud # Optimization: copy the contents of source files in file_mounts to a # special dir, and upload that as the only file_mount instead. Delay diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 41d5f0977d4..7a1492b2a7b 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -539,10 +539,6 @@ def add_ray_task(self, sky_env_vars_dict_str += [ f'sky_env_vars_dict[{constants.TASK_ID_ENV_VAR!r}]' f' = {job_run_id!r}', - # TODO(zhwu): remove this deprecated env var in later release - # (after 0.5). - f'sky_env_vars_dict[{constants.TASK_ID_ENV_VAR_DEPRECATED!r}]' - f' = {job_run_id!r}' ] sky_env_vars_dict_str = '\n'.join(sky_env_vars_dict_str) @@ -3433,13 +3429,10 @@ def cancel_jobs(self, code = job_lib.JobLibCodeGen.cancel_jobs(jobs, cancel_all) # All error messages should have been redirected to stdout. - returncode, stdout, stderr = self.run_on_head(handle, - code, - stream_logs=False, - require_outputs=True) - # TODO(zongheng): remove after >=0.5.0, 2 minor versions after. - backend_utils.check_stale_runtime_on_remote(returncode, stdout + stderr, - handle.cluster_name) + returncode, stdout, _ = self.run_on_head(handle, + code, + stream_logs=False, + require_outputs=True) subprocess_utils.handle_returncode( returncode, code, f'Failed to cancel jobs on cluster {handle.cluster_name}.', stdout) diff --git a/sky/cli.py b/sky/cli.py index 70993a8e0e7..f0b43c0b1ad 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -66,7 +66,6 @@ from sky.skylet import job_lib from sky.skylet import log_lib from sky.usage import usage_lib -from sky.utils import command_runner from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import dag_utils @@ -89,22 +88,6 @@ A cluster name. If provided, either reuse an existing cluster with that name or provision a new cluster with that name. Otherwise provision a new cluster with an autogenerated name.""" -_INTERACTIVE_NODE_TYPES = ('cpunode', 'gpunode', 'tpunode') -_INTERACTIVE_NODE_DEFAULT_RESOURCES = { - 'cpunode': sky.Resources(cloud=None, - instance_type=None, - accelerators=None, - use_spot=False), - 'gpunode': sky.Resources(cloud=None, - instance_type=None, - accelerators={'K80': 1}, - use_spot=False), - 'tpunode': sky.Resources(cloud=sky.GCP(), - instance_type=None, - accelerators={'tpu-v2-8': 1}, - accelerator_args={'runtime_version': '2.12.0'}, - use_spot=False), -} # The maximum number of in-progress spot jobs to show in the status # command. @@ -144,186 +127,6 @@ def _get_glob_storages(storages: List[str]) -> List[str]: return list(set(glob_storages)) -def _interactive_node_cli_command(cli_func): - """Click command decorator for interactive node commands.""" - assert cli_func.__name__ in _INTERACTIVE_NODE_TYPES, cli_func.__name__ - - cluster_option = click.option('--cluster', - '-c', - default=None, - type=str, - required=False, - help=_CLUSTER_FLAG_HELP) - port_forward_option = click.option( - '--port-forward', - '-p', - multiple=True, - default=[], - type=int, - required=False, - help=('Port to be forwarded. To forward multiple ports, ' - 'use this option multiple times.')) - screen_option = click.option('--screen', - default=False, - is_flag=True, - help='If true, attach using screen.') - tmux_option = click.option('--tmux', - default=False, - is_flag=True, - help='If true, attach using tmux.') - cloud_option = click.option('--cloud', - default=None, - type=str, - help='Cloud provider to use.') - instance_type_option = click.option('--instance-type', - '-t', - default=None, - type=str, - help='Instance type to use.') - cpus = click.option( - '--cpus', - default=None, - type=str, - help=('Number of vCPUs each instance must have ' - '(e.g., ``--cpus=4`` (exactly 4) or ``--cpus=4+`` (at least 4)). ' - 'This is used to automatically select the instance type.')) - memory = click.option( - '--memory', - default=None, - type=str, - required=False, - help=('Amount of memory each instance must have in GB (e.g., ' - '``--memory=16`` (exactly 16GB), ``--memory=16+`` (at least ' - '16GB))')) - gpus = click.option('--gpus', - default=None, - type=str, - help=('Type and number of GPUs to use ' - '(e.g., ``--gpus=V100:8`` or ``--gpus=V100``).')) - tpus = click.option( - '--tpus', - default=None, - type=str, - help=('Type and number of TPUs to use (e.g., ``--tpus=tpu-v3-8:4`` or ' - '``--tpus=tpu-v3-8``).')) - - spot_option = click.option('--use-spot', - default=None, - is_flag=True, - help='If true, use spot instances.') - - tpuvm_option = click.option('--tpu-vm/--no-tpu-vm', - default=True, - is_flag=True, - help='If true, use TPU VMs.') - - disk_size = click.option('--disk-size', - default=None, - type=int, - required=False, - help=('OS disk size in GBs.')) - disk_tier = click.option('--disk-tier', - default=None, - type=click.Choice( - resources_utils.DiskTier.supported_tiers(), - case_sensitive=False), - required=False, - help=resources_utils.DiskTier.cli_help_message()) - ports = click.option( - '--ports', - required=False, - type=str, - multiple=True, - help=('Ports to open on the cluster. ' - 'If specified, overrides the "ports" config in the YAML. '), - ) - no_confirm = click.option('--yes', - '-y', - is_flag=True, - default=False, - required=False, - help='Skip confirmation prompt.') - idle_autostop = click.option('--idle-minutes-to-autostop', - '-i', - default=None, - type=int, - required=False, - help=('Automatically stop the cluster after ' - 'this many minutes of idleness, i.e. ' - 'no running or pending jobs in the ' - 'cluster\'s job queue. Idleness gets ' - 'reset whenever setting-up/running/' - 'pending jobs are found in the job ' - 'queue. If not set, the cluster ' - 'will not be auto-stopped.')) - autodown = click.option('--down', - default=False, - is_flag=True, - required=False, - help=('Autodown the cluster: tear down the ' - 'cluster after all jobs finish ' - '(successfully or abnormally). If ' - '--idle-minutes-to-autostop is also set, ' - 'the cluster will be torn down after the ' - 'specified idle time. Note that if errors ' - 'occur during provisioning/data syncing/' - 'setting up, the cluster will not be torn ' - 'down for debugging purposes.')) - retry_until_up = click.option('--retry-until-up', - '-r', - is_flag=True, - default=False, - required=False, - help=('Whether to retry provisioning ' - 'infinitely until the cluster is up ' - 'if we fail to launch the cluster on ' - 'any possible region/cloud due to ' - 'unavailability errors.')) - region_option = click.option('--region', - default=None, - type=str, - required=False, - help='The region to use.') - zone_option = click.option('--zone', - default=None, - type=str, - required=False, - help='The zone to use.') - - click_decorators = [ - cli.command(cls=_DocumentedCodeCommand), - cluster_option, - no_confirm, - port_forward_option, - idle_autostop, - autodown, - retry_until_up, - - # Resource options - *([cloud_option] if cli_func.__name__ != 'tpunode' else []), - region_option, - zone_option, - instance_type_option, - cpus, - memory, - disk_size, - disk_tier, - ports, - *([gpus] if cli_func.__name__ == 'gpunode' else []), - *([tpus] if cli_func.__name__ == 'tpunode' else []), - spot_option, - *([tpuvm_option] if cli_func.__name__ == 'tpunode' else []), - - # Attach options - screen_option, - tmux_option, - ] - decorator = functools.reduce(lambda res, f: f(res), - reversed(click_decorators), cli_func) - - return decorator - - def _parse_env_var(env_var: str) -> Tuple[str, str]: """Parse env vars into a (KEY, VAL) pair.""" if '=' not in env_var: @@ -725,67 +528,6 @@ def _parse_override_params( return override_params -def _default_interactive_node_name(node_type: str): - """Returns a deterministic name to refer to the same node.""" - # FIXME: this technically can collide in Azure/GCP with another - # same-username user. E.g., sky-gpunode-ubuntu. Not a problem on AWS - # which is the current cloud for interactive nodes. - assert node_type in _INTERACTIVE_NODE_TYPES, node_type - return f'sky-{node_type}-{common_utils.get_cleaned_username()}' - - -def _infer_interactive_node_type(resources: sky.Resources): - """Determine interactive node type from resources.""" - accelerators = resources.accelerators - cloud = resources.cloud - if accelerators: - # We only support homogenous accelerators for now. - assert len(accelerators) == 1, resources - acc, _ = list(accelerators.items())[0] - is_gcp = cloud is not None and cloud.is_same_cloud(sky.GCP()) - if is_gcp and 'tpu' in acc: - return 'tpunode' - return 'gpunode' - return 'cpunode' - - -def _check_resources_match(backend: backends.Backend, - cluster_name: str, - task: 'sky.Task', - node_type: Optional[str] = None) -> None: - """Check matching resources when reusing an existing cluster. - - The only exception is when [cpu|tpu|gpu]node -c cluster_name is used with no - additional arguments, then login succeeds. - - Args: - cluster_name: The name of the cluster. - task: The task requested to be run on the cluster. - node_type: Only used for interactive node. Node type to attach to VM. - """ - handle = global_user_state.get_handle_from_cluster_name(cluster_name) - if handle is None: - return - - if node_type is not None: - assert isinstance(handle, - backends.CloudVmRayResourceHandle), (node_type, - handle) - inferred_node_type = _infer_interactive_node_type( - handle.launched_resources) - if node_type != inferred_node_type: - name_arg = '' - if cluster_name != _default_interactive_node_name( - inferred_node_type): - name_arg = f' -c {cluster_name}' - raise click.UsageError( - f'Failed to attach to interactive node {cluster_name}. ' - f'Please use: {colorama.Style.BRIGHT}' - f'sky {inferred_node_type}{name_arg}{colorama.Style.RESET_ALL}') - return - backend.check_resources_fit_cluster(handle, task) - - def _launch_with_confirm( task: sky.Task, backend: backends.Backend, @@ -799,7 +541,6 @@ def _launch_with_confirm( down: bool = False, # pylint: disable=redefined-outer-name retry_until_up: bool = False, no_setup: bool = False, - node_type: Optional[str] = None, clone_disk_from: Optional[str] = None, ): """Launch a cluster with a Task.""" @@ -815,7 +556,7 @@ def _launch_with_confirm( with sky.Dag() as dag: dag.add(task) - maybe_status, _ = backend_utils.refresh_cluster_status_handle(cluster) + maybe_status, handle = backend_utils.refresh_cluster_status_handle(cluster) if maybe_status is None: # Show the optimize log before the prompt if the cluster does not exist. try: @@ -828,7 +569,8 @@ def _launch_with_confirm( dag = sky.optimize(dag) task = dag.tasks[0] - _check_resources_match(backend, cluster, task, node_type=node_type) + if handle is not None: + backend.check_resources_fit_cluster(handle, task) confirm_shown = False if not no_confirm: @@ -846,161 +588,23 @@ def _launch_with_confirm( confirm_shown = True click.confirm(prompt, default=True, abort=True, show_default=True) - if node_type is not None: - if maybe_status != status_lib.ClusterStatus.UP: - click.secho(f'Setting up interactive node {cluster}...', - fg='yellow') - - # We do not sky.launch if interactive node is already up, so we need - # to update idle timeout and autodown here. - elif idle_minutes_to_autostop is not None: - core.autostop(cluster, idle_minutes_to_autostop, down) - elif down: - core.autostop(cluster, 1, down) - - elif not confirm_shown: + if not confirm_shown: click.secho(f'Running task on cluster {cluster}...', fg='yellow') - if node_type is None or maybe_status != status_lib.ClusterStatus.UP: - # No need to sky.launch again when interactive node is already up. - sky.launch( - dag, - dryrun=dryrun, - stream_logs=True, - cluster_name=cluster, - detach_setup=detach_setup, - detach_run=detach_run, - backend=backend, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, - retry_until_up=retry_until_up, - no_setup=no_setup, - clone_disk_from=clone_disk_from, - ) - - -# TODO: skip installing ray to speed up provisioning. -def _create_and_ssh_into_node( - node_type: str, - resources: sky.Resources, - cluster_name: str, - backend: Optional['backend_lib.Backend'] = None, - port_forward: Optional[List[int]] = None, - session_manager: Optional[str] = None, - user_requested_resources: Optional[bool] = False, - no_confirm: bool = False, - idle_minutes_to_autostop: Optional[int] = None, - down: bool = False, # pylint: disable=redefined-outer-name - retry_until_up: bool = False, -): - """Creates and attaches to an interactive node. - - Args: - node_type: Type of the interactive node: { 'cpunode', 'gpunode' }. - resources: Resources to attach to VM. - cluster_name: a cluster name to identify the interactive node. - backend: the Backend to use (currently only CloudVmRayBackend). - port_forward: List of ports to forward. - session_manager: Attach session manager: { 'screen', 'tmux' }. - user_requested_resources: If true, user requested resources explicitly. - no_confirm: If true, skips confirmation prompt presented to user. - idle_minutes_to_autostop: Automatically stop the cluster after - specified minutes of idleness. Idleness gets - reset whenever setting-up/running/pending - jobs are found in the job queue. - down: If true, autodown the cluster after all jobs finish. If - idle_minutes_to_autostop is also set, the cluster will be torn - down after the specified idle time. - retry_until_up: Whether to retry provisioning infinitely until the - cluster is up if we fail to launch due to - unavailability errors. - """ - assert node_type in _INTERACTIVE_NODE_TYPES, node_type - assert session_manager in (None, 'screen', 'tmux'), session_manager - - backend = backend if backend is not None else backends.CloudVmRayBackend() - if not isinstance(backend, backends.CloudVmRayBackend): - raise click.UsageError('Interactive nodes are only supported for ' - f'{backends.CloudVmRayBackend.__name__} ' - f'backend. Got {type(backend).__name__}.') - - maybe_status, handle = backend_utils.refresh_cluster_status_handle( - cluster_name) - if maybe_status is not None: - if user_requested_resources: - if not resources.less_demanding_than(handle.launched_resources): - name_arg = '' - if cluster_name != _default_interactive_node_name(node_type): - name_arg = f' -c {cluster_name}' - raise click.UsageError( - f'Relaunching interactive node {cluster_name!r} with ' - 'mismatched resources.\n ' - f'Requested resources: {resources}\n ' - f'Launched resources: {handle.launched_resources}\n' - 'To login to existing cluster, use ' - f'{colorama.Style.BRIGHT}sky {node_type}{name_arg}' - f'{colorama.Style.RESET_ALL}. To launch a new cluster, ' - f'use {colorama.Style.BRIGHT}sky {node_type} -c NEW_NAME ' - f'{colorama.Style.RESET_ALL}') - else: - # Use existing interactive node if it exists and no user - # resources were specified. - resources = handle.launched_resources - - # TODO: Add conda environment replication - # should be setup = - # 'conda env export | grep -v "^prefix: " > environment.yml' - # && conda env create -f environment.yml - task = sky.Task( - node_type, - workdir=None, - setup=None, - ) - task.set_resources(resources) - - _launch_with_confirm( - task, - backend, - cluster_name, - dryrun=False, - detach_run=True, - no_confirm=no_confirm, + sky.launch( + dag, + dryrun=dryrun, + stream_logs=True, + cluster_name=cluster, + detach_setup=detach_setup, + detach_run=detach_run, + backend=backend, idle_minutes_to_autostop=idle_minutes_to_autostop, down=down, retry_until_up=retry_until_up, - node_type=node_type, + no_setup=no_setup, + clone_disk_from=clone_disk_from, ) - handle = global_user_state.get_handle_from_cluster_name(cluster_name) - assert isinstance(handle, backends.CloudVmRayResourceHandle), handle - - # Use ssh rather than 'ray attach' to suppress ray messages, speed up - # connection, and for allowing adding 'cd workdir' in the future. - # Disable check, since the returncode could be non-zero if the user Ctrl-D. - commands = [] - if session_manager == 'screen': - commands += ['screen', '-D', '-R'] - elif session_manager == 'tmux': - commands += ['tmux', 'attach', '||', 'tmux', 'new'] - backend.run_on_head(handle, - commands, - port_forward=port_forward, - ssh_mode=command_runner.SshMode.LOGIN) - cluster_name = handle.cluster_name - - click.echo('To attach to it again: ', nl=False) - if cluster_name == _default_interactive_node_name(node_type): - option = '' - else: - option = f' -c {cluster_name}' - click.secho(f'sky {node_type}{option}', bold=True) - click.echo('To stop the node:\t', nl=False) - click.secho(f'sky stop {cluster_name}', bold=True) - click.echo('To tear down the node:\t', nl=False) - click.secho(f'sky down {cluster_name}', bold=True) - click.echo('To upload a folder:\t', nl=False) - click.secho(f'rsync -rP /local/path {cluster_name}:/remote/path', bold=True) - click.echo('To download a folder:\t', nl=False) - click.secho(f'rsync -rP {cluster_name}:/remote/path /local/path', bold=True) def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]: @@ -2746,7 +2350,7 @@ def start( # # INIT - ok to restart: # 1. It can be a failed-to-provision cluster, so it isn't up - # (Ex: gpunode --gpus=A100:8). Running `sky start` enables + # (Ex: launch --gpus=A100:8). Running `sky start` enables # retrying the provisioning - without setup steps being # completed. (Arguably the original command that failed should # be used instead; but using start isn't harmful - after it @@ -3197,264 +2801,6 @@ def _down_or_stop(name: str): progress.refresh() -@_interactive_node_cli_command -@usage_lib.entrypoint -# pylint: disable=redefined-outer-name -def gpunode(cluster: str, yes: bool, port_forward: Optional[List[int]], - cloud: Optional[str], region: Optional[str], zone: Optional[str], - instance_type: Optional[str], cpus: Optional[str], - memory: Optional[str], gpus: Optional[str], - use_spot: Optional[bool], screen: Optional[bool], - tmux: Optional[bool], disk_size: Optional[int], - disk_tier: Optional[str], ports: Tuple[str], - idle_minutes_to_autostop: Optional[int], down: bool, - retry_until_up: bool): - """Launch or attach to an interactive GPU node. - - Examples: - - .. code-block:: bash - - # Launch a default gpunode. - sky gpunode - \b - # Do work, then log out. The node is kept running. Attach back to the - # same node and do more work. - sky gpunode - \b - # Create many interactive nodes by assigning names via --cluster (-c). - sky gpunode -c node0 - sky gpunode -c node1 - \b - # Port forward. - sky gpunode --port-forward 8080 --port-forward 4650 -c cluster_name - sky gpunode -p 8080 -p 4650 -c cluster_name - \b - # Sync current working directory to ~/workdir on the node. - rsync -r . cluster_name:~/workdir - - """ - # TODO: Factor out the shared logic below for [gpu|cpu|tpu]node. - if screen and tmux: - raise click.UsageError('Cannot use both screen and tmux.') - - session_manager = None - if screen or tmux: - session_manager = 'tmux' if tmux else 'screen' - name = cluster - if name is None: - name = _default_interactive_node_name('gpunode') - - user_requested_resources = not (cloud is None and region is None and - zone is None and instance_type is None and - cpus is None and memory is None and - gpus is None and use_spot is None) - default_resources = _INTERACTIVE_NODE_DEFAULT_RESOURCES['gpunode'] - cloud_provider = clouds.CLOUD_REGISTRY.from_str(cloud) - if gpus is None and instance_type is None: - # Use this request if both gpus and instance_type are not specified. - gpus = default_resources.accelerators - instance_type = default_resources.instance_type - if use_spot is None: - use_spot = default_resources.use_spot - resources = sky.Resources(cloud=cloud_provider, - region=region, - zone=zone, - instance_type=instance_type, - cpus=cpus, - memory=memory, - accelerators=gpus, - use_spot=use_spot, - disk_size=disk_size, - disk_tier=disk_tier, - ports=ports) - - _create_and_ssh_into_node( - 'gpunode', - resources, - cluster_name=name, - port_forward=port_forward, - session_manager=session_manager, - user_requested_resources=user_requested_resources, - no_confirm=yes, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, - retry_until_up=retry_until_up, - ) - - -@_interactive_node_cli_command -@usage_lib.entrypoint -# pylint: disable=redefined-outer-name -def cpunode(cluster: str, yes: bool, port_forward: Optional[List[int]], - cloud: Optional[str], region: Optional[str], zone: Optional[str], - instance_type: Optional[str], cpus: Optional[str], - memory: Optional[str], use_spot: Optional[bool], - screen: Optional[bool], tmux: Optional[bool], - disk_size: Optional[int], disk_tier: Optional[str], - ports: Tuple[str], idle_minutes_to_autostop: Optional[int], - down: bool, retry_until_up: bool): - """Launch or attach to an interactive CPU node. - - Examples: - - .. code-block:: bash - - # Launch a default cpunode. - sky cpunode - \b - # Do work, then log out. The node is kept running. Attach back to the - # same node and do more work. - sky cpunode - \b - # Create many interactive nodes by assigning names via --cluster (-c). - sky cpunode -c node0 - sky cpunode -c node1 - \b - # Port forward. - sky cpunode --port-forward 8080 --port-forward 4650 -c cluster_name - sky cpunode -p 8080 -p 4650 -c cluster_name - \b - # Sync current working directory to ~/workdir on the node. - rsync -r . cluster_name:~/workdir - - """ - if screen and tmux: - raise click.UsageError('Cannot use both screen and tmux.') - - session_manager = None - if screen or tmux: - session_manager = 'tmux' if tmux else 'screen' - name = cluster - if name is None: - name = _default_interactive_node_name('cpunode') - - user_requested_resources = not (cloud is None and region is None and - zone is None and instance_type is None and - cpus is None and memory is None and - use_spot is None) - default_resources = _INTERACTIVE_NODE_DEFAULT_RESOURCES['cpunode'] - cloud_provider = clouds.CLOUD_REGISTRY.from_str(cloud) - if instance_type is None: - instance_type = default_resources.instance_type - if use_spot is None: - use_spot = default_resources.use_spot - resources = sky.Resources(cloud=cloud_provider, - region=region, - zone=zone, - instance_type=instance_type, - cpus=cpus, - memory=memory, - use_spot=use_spot, - disk_size=disk_size, - disk_tier=disk_tier, - ports=ports) - - _create_and_ssh_into_node( - 'cpunode', - resources, - cluster_name=name, - port_forward=port_forward, - session_manager=session_manager, - user_requested_resources=user_requested_resources, - no_confirm=yes, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, - retry_until_up=retry_until_up, - ) - - -@_interactive_node_cli_command -@usage_lib.entrypoint -# pylint: disable=redefined-outer-name -def tpunode(cluster: str, yes: bool, port_forward: Optional[List[int]], - region: Optional[str], zone: Optional[str], - instance_type: Optional[str], cpus: Optional[str], - memory: Optional[str], tpus: Optional[str], - use_spot: Optional[bool], tpu_vm: Optional[bool], - screen: Optional[bool], tmux: Optional[bool], - disk_size: Optional[int], disk_tier: Optional[str], - ports: Tuple[str], idle_minutes_to_autostop: Optional[int], - down: bool, retry_until_up: bool): - """Launch or attach to an interactive TPU node. - - Examples: - - .. code-block:: bash - - # Launch a default tpunode. - sky tpunode - \b - # Do work, then log out. The node is kept running. Attach back to the - # same node and do more work. - sky tpunode - \b - # Create many interactive nodes by assigning names via --cluster (-c). - sky tpunode -c node0 - sky tpunode -c node1 - \b - # Port forward. - sky tpunode --port-forward 8080 --port-forward 4650 -c cluster_name - sky tpunode -p 8080 -p 4650 -c cluster_name - \b - # Sync current working directory to ~/workdir on the node. - rsync -r . cluster_name:~/workdir - - """ - if screen and tmux: - raise click.UsageError('Cannot use both screen and tmux.') - - session_manager = None - if screen or tmux: - session_manager = 'tmux' if tmux else 'screen' - name = cluster - if name is None: - name = _default_interactive_node_name('tpunode') - - user_requested_resources = not (region is None and zone is None and - instance_type is None and cpus is None and - memory is None and tpus is None and - use_spot is None) - default_resources = _INTERACTIVE_NODE_DEFAULT_RESOURCES['tpunode'] - accelerator_args = default_resources.accelerator_args - if tpu_vm: - accelerator_args['tpu_vm'] = True - accelerator_args['runtime_version'] = 'tpu-vm-base' - else: - accelerator_args['tpu_vm'] = False - if instance_type is None: - instance_type = default_resources.instance_type - if tpus is None: - tpus = default_resources.accelerators - if use_spot is None: - use_spot = default_resources.use_spot - resources = sky.Resources(cloud=sky.GCP(), - region=region, - zone=zone, - instance_type=instance_type, - cpus=cpus, - memory=memory, - accelerators=tpus, - accelerator_args=accelerator_args, - use_spot=use_spot, - disk_size=disk_size, - disk_tier=disk_tier, - ports=ports) - - _create_and_ssh_into_node( - 'tpunode', - resources, - cluster_name=name, - port_forward=port_forward, - session_manager=session_manager, - user_requested_resources=user_requested_resources, - no_confirm=yes, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, - retry_until_up=retry_until_up, - ) - - @cli.command() @click.option('--verbose', '-v', @@ -5565,19 +4911,6 @@ def local_down(): f'{colorama.Fore.GREEN}Local cluster removed.{style.RESET_ALL}') -# TODO(skypilot): remove the below in v0.5. -_add_command_alias_to_group(spot, spot_queue, 'status', hidden=True) -_deprecate_and_hide_command(group=None, - command_to_deprecate=cpunode, - alternative_command='sky launch') -_deprecate_and_hide_command(group=None, - command_to_deprecate=gpunode, - alternative_command='sky launch --gpus ') -_deprecate_and_hide_command(group=None, - command_to_deprecate=tpunode, - alternative_command='sky launch --gpus ') - - def main(): return cli() diff --git a/sky/core.py b/sky/core.py index 688644744c5..8927dcce6fc 100644 --- a/sky/core.py +++ b/sky/core.py @@ -1,6 +1,5 @@ """SDK functions for cluster/job management.""" import getpass -import sys import typing from typing import Any, Dict, List, Optional, Union @@ -768,16 +767,6 @@ def job_status(cluster_name: str, # ======================= -@usage_lib.entrypoint -def spot_status(refresh: bool) -> List[Dict[str, Any]]: - """[Deprecated] (alias of spot_queue) Get statuses of managed spot jobs.""" - sky_logging.print( - f'{colorama.Fore.YELLOW}WARNING: `spot_status()` is deprecated. ' - f'Instead, use: spot_queue(){colorama.Style.RESET_ALL}', - file=sys.stderr) - return spot_queue(refresh=refresh) - - @usage_lib.entrypoint def spot_queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]: diff --git a/sky/global_user_state.py b/sky/global_user_state.py index a8f8c2fdf79..2522ef639d5 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -328,8 +328,8 @@ def remove_cluster(cluster_name: str, terminate: bool) -> None: handle = get_handle_from_cluster_name(cluster_name) if handle is None: return - # Must invalidate IP list: otherwise 'sky cpunode' - # on a stopped cpunode will directly try to ssh, which leads to timeout. + # Must invalidate IP list to avoid directly trying to ssh into a + # stopped VM, which leads to timeout. if hasattr(handle, 'stable_internal_external_ips'): handle.stable_internal_external_ips = None _DB.cursor.execute( diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 19d5e953728..0c15adc8dd0 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -26,8 +26,6 @@ # The name for the environment variable that stores the unique ID of the # current task. This will stay the same across multiple recoveries of the # same spot task. -# TODO(zhwu): Remove SKYPILOT_JOB_ID after 0.5.0. -TASK_ID_ENV_VAR_DEPRECATED = 'SKYPILOT_JOB_ID' TASK_ID_ENV_VAR = 'SKYPILOT_TASK_ID' # This environment variable stores a '\n'-separated list of task IDs that # are within the same spot job (DAG). This can be used by the user to diff --git a/sky/skylet/providers/gcp/__init__.py b/sky/skylet/providers/gcp/__init__.py deleted file mode 100644 index bab89456b61..00000000000 --- a/sky/skylet/providers/gcp/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""GCP node provider""" -from sky.skylet.providers.gcp.node_provider import GCPNodeProvider diff --git a/sky/skylet/providers/gcp/config.py b/sky/skylet/providers/gcp/config.py deleted file mode 100644 index 704cb34f157..00000000000 --- a/sky/skylet/providers/gcp/config.py +++ /dev/null @@ -1,1152 +0,0 @@ -import copy -import json -import logging -import os -import time -from functools import partial -import typing -from typing import Dict, List, Set, Tuple - -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from google.oauth2 import service_account -from google.oauth2.credentials import Credentials as OAuthCredentials -from googleapiclient import discovery, errors - -if typing.TYPE_CHECKING: - import google - -from sky.skylet.providers.gcp.node import ( - MAX_POLLS, - POLL_INTERVAL, - GCPNodeType, - GCPCompute, -) -from sky.skylet.providers.gcp.constants import ( - SKYPILOT_VPC_NAME, - VPC_TEMPLATE, - FIREWALL_RULES_TEMPLATE, - FIREWALL_RULES_REQUIRED, - VM_MINIMAL_PERMISSIONS, - TPU_MINIMAL_PERMISSIONS, -) -from sky.utils import common_utils -from ray.autoscaler._private.util import check_legacy_fields - -logger = logging.getLogger(__name__) - -VERSION = "v1" -TPU_VERSION = "v2alpha" # change once v2 is stable - -RAY = "ray-autoscaler" -DEFAULT_SERVICE_ACCOUNT_ID = RAY + "-sa-" + VERSION -SERVICE_ACCOUNT_EMAIL_TEMPLATE = "{account_id}@{project_id}.iam.gserviceaccount.com" -DEFAULT_SERVICE_ACCOUNT_CONFIG = { - "displayName": "Ray Autoscaler Service Account ({})".format(VERSION), -} - -SKYPILOT = "skypilot" -SKYPILOT_SERVICE_ACCOUNT_ID = SKYPILOT + "-" + VERSION -SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( - "{account_id}@{project_id}.iam.gserviceaccount.com" -) -SKYPILOT_SERVICE_ACCOUNT_CONFIG = { - "displayName": "SkyPilot Service Account ({})".format(VERSION), -} - -# Those roles will be always added. -# NOTE: `serviceAccountUser` allows the head node to create workers with -# a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp. -DEFAULT_SERVICE_ACCOUNT_ROLES = [ - "roles/storage.objectAdmin", - "roles/compute.admin", - "roles/iam.serviceAccountUser", - "roles/iam.roleViewer", -] -# Those roles will only be added if there are TPU nodes defined in config. -TPU_SERVICE_ACCOUNT_ROLES = ["roles/tpu.admin"] - -# If there are TPU nodes in config, this field will be set -# to True in config["provider"]. -HAS_TPU_PROVIDER_FIELD = "_has_tpus" - -# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes -# with ServiceAccounts. - - -def _skypilot_log_error_and_exit_for_failover(error: str) -> None: - """Logs an message then raises a specific RuntimeError to trigger failover. - - Mainly used for handling VPC/subnet errors before nodes are launched. - """ - # NOTE: keep. The backend looks for this to know no nodes are launched. - prefix = "SKYPILOT_ERROR_NO_NODES_LAUNCHED: " - raise RuntimeError(prefix + error) - - -def get_node_type(node: dict) -> GCPNodeType: - """Returns node type based on the keys in ``node``. - - This is a very simple check. If we have a ``machineType`` key, - this is a Compute instance. If we don't have a ``machineType`` key, - but we have ``acceleratorType``, this is a TPU. Otherwise, it's - invalid and an exception is raised. - - This works for both node configs and API returned nodes. - """ - - if "machineType" not in node and "acceleratorType" not in node: - raise ValueError( - "Invalid node. For a Compute instance, 'machineType' is " - "required. " - "For a TPU instance, 'acceleratorType' and no 'machineType' " - "is required. " - f"Got {list(node)}" - ) - - if "machineType" not in node and "acceleratorType" in node: - return GCPNodeType.TPU - return GCPNodeType.COMPUTE - - -def wait_for_crm_operation(operation, crm): - """Poll for cloud resource manager operation until finished.""" - logger.info( - "wait_for_crm_operation: " - "Waiting for operation {} to finish...".format(operation) - ) - - for _ in range(MAX_POLLS): - result = crm.operations().get(name=operation["name"]).execute() - if "error" in result: - raise Exception(result["error"]) - - if "done" in result and result["done"]: - logger.info("wait_for_crm_operation: Operation done.") - break - - time.sleep(POLL_INTERVAL) - - return result - - -def wait_for_compute_global_operation(project_name, operation, compute): - """Poll for global compute operation until finished.""" - logger.info( - "wait_for_compute_global_operation: " - "Waiting for operation {} to finish...".format(operation["name"]) - ) - - for _ in range(MAX_POLLS): - result = ( - compute.globalOperations() - .get( - project=project_name, - operation=operation["name"], - ) - .execute() - ) - if "error" in result: - raise Exception(result["error"]) - - if result["status"] == "DONE": - logger.info("wait_for_compute_global_operation: Operation done.") - break - - time.sleep(POLL_INTERVAL) - - return result - - -def key_pair_name(i, region, project_id, ssh_user): - """Returns the ith default gcp_key_pair_name.""" - key_name = "{}_gcp_{}_{}_{}_{}".format(SKYPILOT, region, project_id, ssh_user, i) - return key_name - - -def key_pair_paths(key_name): - """Returns public and private key paths for a given key_name.""" - public_key_path = os.path.expanduser("~/.ssh/{}.pub".format(key_name)) - private_key_path = os.path.expanduser("~/.ssh/{}.pem".format(key_name)) - return public_key_path, private_key_path - - -def generate_rsa_key_pair(): - """Create public and private ssh-keys.""" - - key = rsa.generate_private_key( - backend=default_backend(), public_exponent=65537, key_size=2048 - ) - - public_key = ( - key.public_key() - .public_bytes( - serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH - ) - .decode("utf-8") - ) - - pem = key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ).decode("utf-8") - - return public_key, pem - - -def _has_tpus_in_node_configs(config: dict) -> bool: - """Check if any nodes in config are TPUs.""" - node_configs = [ - node_type["node_config"] - for node_type in config["available_node_types"].values() - ] - return any(get_node_type(node) == GCPNodeType.TPU for node in node_configs) - - -def _is_head_node_a_tpu(config: dict) -> bool: - """Check if the head node is a TPU.""" - node_configs = { - node_id: node_type["node_config"] - for node_id, node_type in config["available_node_types"].items() - } - return get_node_type(node_configs[config["head_node_type"]]) == GCPNodeType.TPU - - -def _create_crm(gcp_credentials=None): - return discovery.build( - "cloudresourcemanager", "v1", credentials=gcp_credentials, cache_discovery=False - ) - - -def _create_iam(gcp_credentials=None): - return discovery.build( - "iam", "v1", credentials=gcp_credentials, cache_discovery=False - ) - - -def _create_compute(gcp_credentials=None): - return discovery.build( - "compute", "v1", credentials=gcp_credentials, cache_discovery=False - ) - - -def _create_tpu(gcp_credentials=None): - return discovery.build( - "tpu", - TPU_VERSION, - credentials=gcp_credentials, - cache_discovery=False, - discoveryServiceUrl="https://tpu.googleapis.com/$discovery/rest", - ) - - -def construct_clients_from_provider_config(provider_config): - """ - Attempt to fetch and parse the JSON GCP credentials from the provider - config yaml file. - - tpu resource (the last element of the tuple) will be None if - `_has_tpus` in provider config is not set or False. - """ - gcp_credentials = provider_config.get("gcp_credentials") - if gcp_credentials is None: - logger.debug( - "gcp_credentials not found in cluster yaml file. " - "Falling back to GOOGLE_APPLICATION_CREDENTIALS " - "environment variable." - ) - tpu_resource = ( - _create_tpu() - if provider_config.get(HAS_TPU_PROVIDER_FIELD, False) - else None - ) - # If gcp_credentials is None, then discovery.build will search for - # credentials in the local environment. - return _create_crm(), _create_iam(), _create_compute(), tpu_resource - - assert ( - "type" in gcp_credentials - ), "gcp_credentials cluster yaml field missing 'type' field." - assert ( - "credentials" in gcp_credentials - ), "gcp_credentials cluster yaml field missing 'credentials' field." - - cred_type = gcp_credentials["type"] - credentials_field = gcp_credentials["credentials"] - - if cred_type == "service_account": - # If parsing the gcp_credentials failed, then the user likely made a - # mistake in copying the credentials into the config yaml. - try: - service_account_info = json.loads(credentials_field) - except json.decoder.JSONDecodeError: - raise RuntimeError( - "gcp_credentials found in cluster yaml file but " - "formatted improperly." - ) - credentials = service_account.Credentials.from_service_account_info( - service_account_info - ) - elif cred_type == "credentials_token": - # Otherwise the credentials type must be credentials_token. - credentials = OAuthCredentials(credentials_field) - - tpu_resource = ( - _create_tpu(credentials) - if provider_config.get(HAS_TPU_PROVIDER_FIELD, False) - else None - ) - - return ( - _create_crm(credentials), - _create_iam(credentials), - _create_compute(credentials), - tpu_resource, - ) - - -def bootstrap_gcp(config): - config = copy.deepcopy(config) - check_legacy_fields(config) - # Used internally to store head IAM role. - config["head_node"] = {} - - # Check if we have any TPUs defined, and if so, - # insert that information into the provider config - if _has_tpus_in_node_configs(config): - config["provider"][HAS_TPU_PROVIDER_FIELD] = True - - crm, iam, compute, tpu = construct_clients_from_provider_config(config["provider"]) - - config = _configure_project(config, crm) - config = _configure_iam_role(config, crm, iam) - config = _configure_key_pair(config, compute) - config = _configure_subnet(config, compute) - - return config - - -def _configure_project(config, crm): - """Setup a Google Cloud Platform Project. - - Google Compute Platform organizes all the resources, such as storage - buckets, users, and instances under projects. This is different from - aws ec2 where everything is global. - """ - config = copy.deepcopy(config) - - project_id = config["provider"].get("project_id") - assert config["provider"]["project_id"] is not None, ( - "'project_id' must be set in the 'provider' section of the autoscaler" - " config. Notice that the project id must be globally unique." - ) - project = _get_project(project_id, crm) - - if project is None: - # Project not found, try creating it - _create_project(project_id, crm) - project = _get_project(project_id, crm) - - assert project is not None, "Failed to create project" - assert ( - project["lifecycleState"] == "ACTIVE" - ), "Project status needs to be ACTIVE, got {}".format(project["lifecycleState"]) - - config["provider"]["project_id"] = project["projectId"] - - return config - - -def _is_permission_satisfied( - service_account, crm, iam, required_permissions, required_roles -): - """Check if either of the roles or permissions are satisfied.""" - if service_account is None: - return False, None - - project_id = service_account["projectId"] - email = service_account["email"] - - member_id = "serviceAccount:" + email - - required_permissions = set(required_permissions) - policy = crm.projects().getIamPolicy(resource=project_id, body={}).execute() - original_policy = copy.deepcopy(policy) - already_configured = True - - logger.info(f"_configure_iam_role: Checking permissions for {email}...") - - # Check the roles first, as checking the permission requires more API calls and - # permissions. - for role in required_roles: - role_exists = False - for binding in policy["bindings"]: - if binding["role"] == role: - if member_id not in binding["members"]: - logger.info( - f"_configure_iam_role: role {role} is not attached to {member_id}..." - ) - binding["members"].append(member_id) - already_configured = False - role_exists = True - - if not role_exists: - logger.info(f"_configure_iam_role: role {role} does not exist.") - already_configured = False - policy["bindings"].append( - { - "members": [member_id], - "role": role, - } - ) - - if already_configured: - # In some managed environments, an admin needs to grant the - # roles, so only call setIamPolicy if needed. - return True, policy - - for binding in original_policy["bindings"]: - if member_id in binding["members"]: - role = binding["role"] - try: - role_definition = iam.projects().roles().get(name=role).execute() - except TypeError as e: - if "does not match the pattern" in str(e): - logger.info( - f"_configure_iam_role: fail to check permission for built-in role {role}. skipped." - ) - permissions = [] - else: - raise - else: - permissions = role_definition["includedPermissions"] - required_permissions -= set(permissions) - if not required_permissions: - break - if not required_permissions: - # All required permissions are already granted. - return True, policy - logger.info(f"_configure_iam_role: missing permisisons {required_permissions}") - - return False, policy - - -def _configure_iam_role(config, crm, iam): - """Setup a gcp service account with IAM roles. - - Creates a gcp service acconut and binds IAM roles which allow it to control - control storage/compute services. Specifically, the head node needs to have - an IAM role that allows it to create further gce instances and store items - in google cloud storage. - - TODO: Allow the name/id of the service account to be configured - """ - config = copy.deepcopy(config) - - email = SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( - account_id=SKYPILOT_SERVICE_ACCOUNT_ID, - project_id=config["provider"]["project_id"], - ) - service_account = _get_service_account(email, config, iam) - - permissions = VM_MINIMAL_PERMISSIONS - roles = DEFAULT_SERVICE_ACCOUNT_ROLES - if config["provider"].get(HAS_TPU_PROVIDER_FIELD, False): - roles = DEFAULT_SERVICE_ACCOUNT_ROLES + TPU_SERVICE_ACCOUNT_ROLES - permissions = VM_MINIMAL_PERMISSIONS + TPU_MINIMAL_PERMISSIONS - - satisfied, policy = _is_permission_satisfied( - service_account, crm, iam, permissions, roles - ) - - if not satisfied: - # SkyPilot: Fallback to the old ray service account name for - # backwards compatibility. Users using GCP before #2112 have - # the old service account setup setup in their GCP project, - # and the user may not have the permissions to create the - # new service account. This is to ensure that the old service - # account is still usable. - email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( - account_id=DEFAULT_SERVICE_ACCOUNT_ID, - project_id=config["provider"]["project_id"], - ) - logger.info(f"_configure_iam_role: Fallback to service account {email}") - - ray_service_account = _get_service_account(email, config, iam) - ray_satisfied, _ = _is_permission_satisfied( - ray_service_account, crm, iam, permissions, roles - ) - logger.info( - "_configure_iam_role: " - f"Fallback to service account {email} succeeded? {ray_satisfied}" - ) - - if ray_satisfied: - service_account = ray_service_account - satisfied = ray_satisfied - elif service_account is None: - logger.info( - "_configure_iam_role: " - "Creating new service account {}".format(SKYPILOT_SERVICE_ACCOUNT_ID) - ) - # SkyPilot: a GCP user without the permission to create a service - # account will fail here. - service_account = _create_service_account( - SKYPILOT_SERVICE_ACCOUNT_ID, - SKYPILOT_SERVICE_ACCOUNT_CONFIG, - config, - iam, - ) - satisfied, policy = _is_permission_satisfied( - service_account, crm, iam, permissions, roles - ) - - assert service_account is not None, "Failed to create service account" - - if not satisfied: - logger.info( - "_configure_iam_role: " f"Adding roles to service account {email}..." - ) - _add_iam_policy_binding(service_account, policy, crm, iam) - - account_dict = { - "email": service_account["email"], - # NOTE: The amount of access is determined by the scope + IAM - # role of the service account. Even if the cloud-platform scope - # gives (scope) access to the whole cloud-platform, the service - # account is limited by the IAM rights specified below. - "scopes": ["https://www.googleapis.com/auth/cloud-platform"], - } - if _is_head_node_a_tpu(config): - # SKY: The API for TPU VM is slightly different from normal compute instances. - # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node - account_dict["scope"] = account_dict["scopes"] - account_dict.pop("scopes") - config["head_node"]["serviceAccount"] = account_dict - else: - config["head_node"]["serviceAccounts"] = [account_dict] - - return config - - -def _configure_key_pair(config, compute): - """Configure SSH access, using an existing key pair if possible. - - Creates a project-wide ssh key that can be used to access all the instances - unless explicitly prohibited by instance config. - - The ssh-keys created by ray are of format: - - [USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME] - - where: - - [USERNAME] is the user for the SSH key, specified in the config. - [KEY_VALUE] is the public SSH key value. - """ - config = copy.deepcopy(config) - - if "ssh_private_key" in config["auth"]: - return config - - ssh_user = config["auth"]["ssh_user"] - - project = compute.projects().get(project=config["provider"]["project_id"]).execute() - - # Key pairs associated with project meta data. The key pairs are general, - # and not just ssh keys. - ssh_keys_str = next( - ( - item - for item in project["commonInstanceMetadata"].get("items", []) - if item["key"] == "ssh-keys" - ), - {}, - ).get("value", "") - - ssh_keys = ssh_keys_str.split("\n") if ssh_keys_str else [] - - # Try a few times to get or create a good key pair. - key_found = False - for i in range(10): - key_name = key_pair_name( - i, config["provider"]["region"], config["provider"]["project_id"], ssh_user - ) - public_key_path, private_key_path = key_pair_paths(key_name) - - for ssh_key in ssh_keys: - key_parts = ssh_key.split(" ") - if len(key_parts) != 3: - continue - - if key_parts[2] == ssh_user and os.path.exists(private_key_path): - # Found a key - key_found = True - break - - # Writing the new ssh key to the filesystem fails if the ~/.ssh - # directory doesn't already exist. - os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True) - - # Create a key since it doesn't exist locally or in GCP - if not key_found and not os.path.exists(private_key_path): - logger.info( - "_configure_key_pair: Creating new key pair {}".format(key_name) - ) - public_key, private_key = generate_rsa_key_pair() - - _create_project_ssh_key_pair(project, public_key, ssh_user, compute) - - # Create the directory if it doesn't exists - private_key_dir = os.path.dirname(private_key_path) - os.makedirs(private_key_dir, exist_ok=True) - - # We need to make sure to _create_ the file with the right - # permissions. In order to do that we need to change the default - # os.open behavior to include the mode we want. - with open( - private_key_path, - "w", - opener=partial(os.open, mode=0o600), - ) as f: - f.write(private_key) - - with open(public_key_path, "w") as f: - f.write(public_key) - - key_found = True - - break - - if key_found: - break - - assert key_found, "SSH keypair for user {} not found for {}".format( - ssh_user, private_key_path - ) - assert os.path.exists( - private_key_path - ), "Private key file {} not found for user {}".format(private_key_path, ssh_user) - - logger.info( - "_configure_key_pair: " - "Private key not specified in config, using" - "{}".format(private_key_path) - ) - - config["auth"]["ssh_private_key"] = private_key_path - - return config - - -def _check_firewall_rules(vpc_name, config, compute): - """Check if the firewall rules in the VPC are sufficient.""" - required_rules = FIREWALL_RULES_REQUIRED.copy() - - operation = compute.networks().getEffectiveFirewalls( - project=config["provider"]["project_id"], network=vpc_name - ) - response = operation.execute() - if len(response) == 0: - return False - effective_rules = response["firewalls"] - - def _merge_and_refine_rule(rules): - """Returns the reformatted rules from the firewall rules - - The function translates firewall rules fetched from the cloud provider - to a format for simple comparison. - - Example of firewall rules from the cloud: - [ - { - ... - "direction": "INGRESS", - "allowed": [ - {"IPProtocol": "tcp", "ports": ['80', '443']}, - {"IPProtocol": "udp", "ports": ['53']}, - ], - "sourceRanges": ["10.128.0.0/9"], - }, - { - ... - "direction": "INGRESS", - "allowed": [{ - "IPProtocol": "tcp", - "ports": ["22"], - }], - "sourceRanges": ["0.0.0.0/0"], - }, - ] - - Returns: - source2rules: Dict[(direction, sourceRanges) -> Dict(protocol -> Set[ports])] - Example { - ("INGRESS", "10.128.0.0/9"): {"tcp": {80, 443}, "udp": {53}}, - ("INGRESS", "0.0.0.0/0"): {"tcp": {22}}, - } - """ - source2rules: Dict[Tuple[str, str], Dict[str, Set[int]]] = {} - source2allowed_list: Dict[Tuple[str, str], List[Dict[str, str]]] = {} - for rule in rules: - # Rules applied to specific VM (targetTags) may not work for the - # current VM, so should be skipped. - # Filter by targetTags == ['cluster_name'] - # See https://developers.google.com/resources/api-libraries/documentation/compute/alpha/python/latest/compute_alpha.networks.html#getEffectiveFirewalls # pylint: disable=line-too-long - tags = rule.get("targetTags", None) - if tags is not None: - if len(tags) != 1: - continue - if tags[0] != config["cluster_name"]: - continue - direction = rule.get("direction", "") - sources = rule.get("sourceRanges", []) - allowed = rule.get("allowed", []) - for source in sources: - key = (direction, source) - source2allowed_list[key] = source2allowed_list.get(key, []) + allowed - for direction_source, allowed_list in source2allowed_list.items(): - source2rules[direction_source] = {} - for allowed in allowed_list: - # Example of port_list: ["20", "50-60"] - # If list is empty, it means all ports - port_list = allowed.get("ports", []) - port_set = set() - if port_list == []: - port_set.update(set(range(1, 65536))) - else: - for port_range in port_list: - parse_ports = port_range.split("-") - if len(parse_ports) == 1: - port_set.add(int(parse_ports[0])) - else: - assert ( - len(parse_ports) == 2 - ), f"Failed to parse the port range: {port_range}" - port_set.update( - set(range(int(parse_ports[0]), int(parse_ports[1]) + 1)) - ) - if allowed["IPProtocol"] not in source2rules[direction_source]: - source2rules[direction_source][allowed["IPProtocol"]] = set() - source2rules[direction_source][allowed["IPProtocol"]].update(port_set) - return source2rules - - effective_rules = _merge_and_refine_rule(effective_rules) - required_rules = _merge_and_refine_rule(required_rules) - - for direction_source, allowed_req in required_rules.items(): - if direction_source not in effective_rules: - return False - allowed_eff = effective_rules[direction_source] - # Special case: "all" means allowing all traffic - if "all" in allowed_eff: - continue - # Check if the required ports are a subset of the effective ports - for protocol, ports_req in allowed_req.items(): - ports_eff = allowed_eff.get(protocol, set()) - if not ports_req.issubset(ports_eff): - return False - return True - - -def _create_rules(config, compute, rules, VPC_NAME, PROJ_ID): - opertaions = [] - for rule in rules: - # Query firewall rule by its name (unique in a project). - # If the rule already exists, delete it first. - rule_name = rule["name"].format(VPC_NAME=VPC_NAME) - rule_list = _list_firewall_rules(config, compute, filter=f"(name={rule_name})") - if len(rule_list) > 0: - _delete_firewall_rule(config, compute, rule_name) - - body = rule.copy() - body["name"] = body["name"].format(VPC_NAME=VPC_NAME) - body["network"] = body["network"].format(PROJ_ID=PROJ_ID, VPC_NAME=VPC_NAME) - body["selfLink"] = body["selfLink"].format(PROJ_ID=PROJ_ID, VPC_NAME=VPC_NAME) - op = _create_firewall_rule_submit(config, compute, body) - opertaions.append(op) - for op in opertaions: - wait_for_compute_global_operation(config["provider"]["project_id"], op, compute) - - -def _network_interface_to_vpc_name(network_interface: Dict[str, str]) -> str: - """Returns the VPC name of a network interface.""" - return network_interface["network"].split("/")[-1] - - -def get_usable_vpc_and_subnet( - config, -) -> Tuple[str, "google.cloud.compute_v1.types.compute.Subnetwork"]: - """Return a usable VPC and the subnet in it. - - If config['provider']['vpc_name'] is set, return the VPC with the name - (errors out if not found). When this field is set, no firewall rules - checking or overrides will take place; it is the user's responsibility to - properly set up the VPC. - - If not found, create a new one with sufficient firewall rules. - - Returns: - vpc_name: The name of the VPC network. - subnet_name: The name of the subnet in the VPC network for the specific - region. - - Raises: - RuntimeError: if the user has specified a VPC name but the VPC is not found. - """ - _, _, compute, _ = construct_clients_from_provider_config(config["provider"]) - - # For existing cluster, it is ok to return a VPC and subnet not used by - # the cluster, as AWS will ignore them. - # There is a corner case where the multi-node cluster was partially - # launched, launching the cluster again can cause the nodes located on - # different VPCs, if VPCs in the project have changed. It should be fine to - # not handle this special case as we don't want to sacrifice the performance - # for every launch just for this rare case. - - specific_vpc_to_use = config["provider"].get("vpc_name", None) - if specific_vpc_to_use is not None: - vpcnets_all = _list_vpcnets( - config, compute, filter=f"name={specific_vpc_to_use}" - ) - # On GCP, VPC names are unique, so it'd be 0 or 1 VPC found. - assert ( - len(vpcnets_all) <= 1 - ), f"{len(vpcnets_all)} VPCs found with the same name {specific_vpc_to_use}" - if len(vpcnets_all) == 1: - # Skip checking any firewall rules if the user has specified a VPC. - logger.info(f"Using user-specified VPC {specific_vpc_to_use!r}.") - subnets = _list_subnets(config, compute, network=specific_vpc_to_use) - if not subnets: - _skypilot_log_error_and_exit_for_failover( - f"No subnet for region {config['provider']['region']} found for specified VPC {specific_vpc_to_use!r}. " - f"Check the subnets of VPC {specific_vpc_to_use!r} at https://console.cloud.google.com/networking/networks" - ) - return specific_vpc_to_use, subnets[0] - else: - # VPC with this name not found. Error out and let SkyPilot failover. - _skypilot_log_error_and_exit_for_failover( - f"No VPC with name {specific_vpc_to_use!r} is found. " - "To fix: specify a correct VPC name." - ) - # Should not reach here. - - subnets_all = _list_subnets(config, compute) - - # Check if VPC for subnet has sufficient firewall rules. - insufficient_vpcs = set() - for subnet in subnets_all: - vpc_name = _network_interface_to_vpc_name(subnet) - if vpc_name in insufficient_vpcs: - continue - if _check_firewall_rules(vpc_name, config, compute): - logger.info(f"get_usable_vpc: Found a usable VPC network {vpc_name!r}.") - return vpc_name, subnet - else: - insufficient_vpcs.add(vpc_name) - - # No usable VPC found. Try to create one. - proj_id = config["provider"]["project_id"] - logger.info(f"Creating a default VPC network, {SKYPILOT_VPC_NAME}...") - - # Create a SkyPilot VPC network if it doesn't exist - vpc_list = _list_vpcnets(config, compute, filter=f"name={SKYPILOT_VPC_NAME}") - if len(vpc_list) == 0: - body = VPC_TEMPLATE.copy() - body["name"] = body["name"].format(VPC_NAME=SKYPILOT_VPC_NAME) - body["selfLink"] = body["selfLink"].format( - PROJ_ID=proj_id, VPC_NAME=SKYPILOT_VPC_NAME - ) - _create_vpcnet(config, compute, body) - - _create_rules(config, compute, FIREWALL_RULES_TEMPLATE, SKYPILOT_VPC_NAME, proj_id) - - usable_vpc_name = SKYPILOT_VPC_NAME - subnets = _list_subnets(config, compute, network=usable_vpc_name) - if not subnets: - _skypilot_log_error_and_exit_for_failover( - f"No subnet for region {config['provider']['region']} found for generated VPC {usable_vpc_name!r}. " - "This is probably due to the region being disabled in the account/project_id." - ) - usable_subnet = subnets[0] - logger.info(f"A VPC network {SKYPILOT_VPC_NAME} created.") - - return usable_vpc_name, usable_subnet - - -def _configure_subnet(config, compute): - """Pick a reasonable subnet if not specified by the config.""" - config = copy.deepcopy(config) - - node_configs = [ - node_type["node_config"] - for node_type in config["available_node_types"].values() - ] - # Rationale: avoid subnet lookup if the network is already - # completely manually configured - - # networkInterfaces is compute, networkConfig is TPU - if all( - "networkInterfaces" in node_config or "networkConfig" in node_config - for node_config in node_configs - ): - return config - - # SkyPilot: make sure there's a usable VPC - _, default_subnet = get_usable_vpc_and_subnet(config) - - default_interfaces = [ - { - "subnetwork": default_subnet["selfLink"], - "accessConfigs": [ - { - "name": "External NAT", - "type": "ONE_TO_ONE_NAT", - } - ], - } - ] - if config["provider"].get("use_internal_ips", False): - # Removing this key means the VM will not be assigned an external IP. - default_interfaces[0].pop("accessConfigs") - - for node_config in node_configs: - # The not applicable key will be removed during node creation - - # compute - if "networkInterfaces" not in node_config: - node_config["networkInterfaces"] = copy.deepcopy(default_interfaces) - # TPU - if "networkConfig" not in node_config: - node_config["networkConfig"] = copy.deepcopy(default_interfaces)[0] - # TPU doesn't have accessConfigs - node_config["networkConfig"].pop("accessConfigs", None) - if config["provider"].get("use_internal_ips", False): - node_config["networkConfig"]["enableExternalIps"] = False - - return config - - -def _create_firewall_rule_submit(config, compute, body): - operation = ( - compute.firewalls() - .insert(project=config["provider"]["project_id"], body=body) - .execute() - ) - return operation - - -def _delete_firewall_rule(config, compute, name): - operation = ( - compute.firewalls() - .delete(project=config["provider"]["project_id"], firewall=name) - .execute() - ) - response = wait_for_compute_global_operation( - config["provider"]["project_id"], operation, compute - ) - return response - - -def _list_firewall_rules(config, compute, filter=None): - response = ( - compute.firewalls() - .list( - project=config["provider"]["project_id"], - filter=filter, - ) - .execute() - ) - return response["items"] if "items" in response else [] - - -def _create_vpcnet(config, compute, body): - operation = ( - compute.networks() - .insert(project=config["provider"]["project_id"], body=body) - .execute() - ) - response = wait_for_compute_global_operation( - config["provider"]["project_id"], operation, compute - ) - return response - - -def _list_vpcnets(config, compute, filter=None): - response = ( - compute.networks() - .list( - project=config["provider"]["project_id"], - filter=filter, - ) - .execute() - ) - - return ( - list(sorted(response["items"], key=lambda x: x["name"])) - if "items" in response - else [] - ) - - -def _list_subnets( - config, compute, network=None -) -> List["google.cloud.compute_v1.types.compute.Subnetwork"]: - response = ( - compute.subnetworks() - .list( - project=config["provider"]["project_id"], - region=config["provider"]["region"], - ) - .execute() - ) - - items = response["items"] if "items" in response else [] - if network is None: - return items - - # Filter by network (VPC) name. - # - # Note we do not directly use the filter (network=<...>) arg of the list() - # call above, because it'd involve constructing a long URL of the following - # format and passing it as the filter value: - # 'https://www.googleapis.com/compute/v1/projects//global/networks/' - matched_items = [] - for item in items: - if network == _network_interface_to_vpc_name(item): - matched_items.append(item) - return matched_items - - -def _get_subnet(config, subnet_id, compute): - subnet = ( - compute.subnetworks() - .get( - project=config["provider"]["project_id"], - region=config["provider"]["region"], - subnetwork=subnet_id, - ) - .execute() - ) - - return subnet - - -def _get_project(project_id, crm): - try: - project = crm.projects().get(projectId=project_id).execute() - except errors.HttpError as e: - if e.resp.status != 403: - raise - project = None - - return project - - -def _create_project(project_id, crm): - operation = ( - crm.projects() - .create(body={"projectId": project_id, "name": project_id}) - .execute() - ) - - result = wait_for_crm_operation(operation, crm) - - return result - - -def _get_service_account(account, config, iam): - project_id = config["provider"]["project_id"] - full_name = "projects/{project_id}/serviceAccounts/{account}".format( - project_id=project_id, account=account - ) - try: - service_account = iam.projects().serviceAccounts().get(name=full_name).execute() - except errors.HttpError as e: - if e.resp.status not in [403, 404]: - # SkyPilot: added 403, which means the service account doesn't exist, - # or not accessible by the current account, which is fine, as we do the - # fallback in the caller. - raise - service_account = None - - return service_account - - -def _create_service_account(account_id, account_config, config, iam): - project_id = config["provider"]["project_id"] - - service_account = ( - iam.projects() - .serviceAccounts() - .create( - name="projects/{project_id}".format(project_id=project_id), - body={ - "accountId": account_id, - "serviceAccount": account_config, - }, - ) - .execute() - ) - - return service_account - - -def _add_iam_policy_binding(service_account, policy, crm, iam): - """Add new IAM roles for the service account.""" - project_id = service_account["projectId"] - - result = ( - crm.projects() - .setIamPolicy( - resource=project_id, - body={ - "policy": policy, - }, - ) - .execute() - ) - - return result - - -def _create_project_ssh_key_pair(project, public_key, ssh_user, compute): - """Inserts an ssh-key into project commonInstanceMetadata""" - - key_parts = public_key.split(" ") - - # Sanity checks to make sure that the generated key matches expectation - assert len(key_parts) == 2, key_parts - assert key_parts[0] == "ssh-rsa", key_parts - - new_ssh_meta = "{ssh_user}:ssh-rsa {key_value} {ssh_user}".format( - ssh_user=ssh_user, key_value=key_parts[1] - ) - - common_instance_info = project["commonInstanceMetadata"] - items = common_instance_info.get("items", []) - - ssh_keys_i = next( - (i for i, item in enumerate(items) if item["key"] == "ssh-keys"), None - ) - - if ssh_keys_i is None: - items.append({"key": "ssh-keys", "value": new_ssh_meta}) - else: - ssh_keys = items[ssh_keys_i] - ssh_keys["value"] += "\n" + new_ssh_meta - items[ssh_keys_i] = ssh_keys - - common_instance_info["items"] = items - - operation = ( - compute.projects() - .setCommonInstanceMetadata(project=project["name"], body=common_instance_info) - .execute() - ) - - response = wait_for_compute_global_operation(project["name"], operation, compute) - - return response diff --git a/sky/skylet/providers/gcp/constants.py b/sky/skylet/providers/gcp/constants.py deleted file mode 100644 index 8f6343c49f6..00000000000 --- a/sky/skylet/providers/gcp/constants.py +++ /dev/null @@ -1,138 +0,0 @@ -from sky import skypilot_config - -SKYPILOT_VPC_NAME = "skypilot-vpc" - -# Below parameters are from the default VPC on GCP. -# https://cloud.google.com/vpc/docs/firewalls#more_rules_default_vpc -VPC_TEMPLATE = { - "name": "{VPC_NAME}", - "selfLink": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", - "autoCreateSubnetworks": True, - "mtu": 1460, - "routingConfig": {"routingMode": "GLOBAL"}, -} - -# Required firewall rules for SkyPilot to work. -FIREWALL_RULES_REQUIRED = [ - # Allow internal connections between GCP VMs for Ray multi-node cluster. - { - "direction": "INGRESS", - "allowed": [ - {"IPProtocol": "tcp", "ports": ["0-65535"]}, - {"IPProtocol": "udp", "ports": ["0-65535"]}, - ], - "sourceRanges": ["10.128.0.0/9"], - }, - # Allow ssh connection from anywhere. - { - "direction": "INGRESS", - "allowed": [ - { - "IPProtocol": "tcp", - "ports": ["22"], - } - ], - # Some users have reported that this conflicts with their network - # security policy. A custom VPC can be specified in ~/.sky/config.yaml - # allowing for restriction of source ranges bypassing this requirement. - "sourceRanges": ["0.0.0.0/0"], - }, -] - -# Template when creating firewall rules for a new VPC. -FIREWALL_RULES_TEMPLATE = [ - { - "name": "{VPC_NAME}-allow-custom", - "description": "Allows connection from any source to any instance on the network using custom protocols.", - "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", - "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-custom", - "direction": "INGRESS", - "priority": 65534, - "allowed": [ - {"IPProtocol": "tcp", "ports": ["0-65535"]}, - {"IPProtocol": "udp", "ports": ["0-65535"]}, - {"IPProtocol": "icmp"}, - ], - "sourceRanges": ["10.128.0.0/9"], - }, - { - "name": "{VPC_NAME}-allow-ssh", - "description": "Allows TCP connections from any source to any instance on the network using port 22.", - "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", - "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-ssh", - "direction": "INGRESS", - "priority": 65534, - "allowed": [ - { - "IPProtocol": "tcp", - "ports": ["22"], - } - ], - # TODO(skypilot): some users reported that this should be relaxed (e.g., - # allowlisting only certain IPs to have ssh access). - "sourceRanges": ["0.0.0.0/0"], - }, - { - "name": "{VPC_NAME}-allow-icmp", - "description": "Allows ICMP connections from any source to any instance on the network.", - "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", - "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-icmp", - "direction": "INGRESS", - "priority": 65534, - "allowed": [ - { - "IPProtocol": "icmp", - } - ], - "sourceRanges": ["0.0.0.0/0"], - }, -] - -# A list of permissions required to run SkyPilot on GCP. -# Keep this in sync with https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/gcp.html # pylint: disable=line-too-long -VM_MINIMAL_PERMISSIONS = [ - "compute.disks.create", - "compute.disks.list", - "compute.firewalls.create", - "compute.firewalls.delete", - "compute.firewalls.get", - "compute.instances.create", - "compute.instances.delete", - "compute.instances.get", - "compute.instances.list", - "compute.instances.setLabels", - "compute.instances.setServiceAccount", - "compute.instances.start", - "compute.instances.stop", - "compute.networks.get", - "compute.networks.list", - "compute.networks.getEffectiveFirewalls", - "compute.globalOperations.get", - "compute.subnetworks.use", - "compute.subnetworks.list", - "compute.subnetworks.useExternalIp", - "compute.projects.get", - "compute.zoneOperations.get", - "iam.roles.get", - "iam.serviceAccounts.actAs", - "iam.serviceAccounts.get", - "serviceusage.services.enable", - "serviceusage.services.list", - "serviceusage.services.use", - "resourcemanager.projects.get", - "resourcemanager.projects.getIamPolicy", -] -# If specifying custom VPC, permissions to modify network are not necessary -# unless opening ports (e.g., via `resources.ports`). -if skypilot_config.get_nested(("gcp", "vpc_name"), ""): - remove = ("compute.firewalls.create", "compute.firewalls.delete") - VM_MINIMAL_PERMISSIONS = [p for p in VM_MINIMAL_PERMISSIONS if p not in remove] - -TPU_MINIMAL_PERMISSIONS = [ - "tpu.nodes.create", - "tpu.nodes.delete", - "tpu.nodes.list", - "tpu.nodes.get", - "tpu.nodes.update", - "tpu.operations.get", -] diff --git a/sky/skylet/providers/gcp/node.py b/sky/skylet/providers/gcp/node.py deleted file mode 100644 index 301206ab922..00000000000 --- a/sky/skylet/providers/gcp/node.py +++ /dev/null @@ -1,905 +0,0 @@ -"""Abstractions around GCP resources and nodes. - -The logic has been abstracted away here to allow for different GCP resources -(API endpoints), which can differ widely, making it impossible to use -the same logic for everything. - -Classes inheriting from ``GCPResource`` represent different GCP resources - -API endpoints that allow for nodes to be created, removed, listed and -otherwise managed. Those classes contain methods abstracting GCP REST API -calls. -Each resource has a corresponding node type, represented by a -class inheriting from ``GCPNode``. Those classes are essentially dicts -with some extra methods. The instances of those classes will be created -from API responses. - -The ``GCPNodeType`` enum is a lightweight way to classify nodes. - -Currently, Compute and TPU resources & nodes are supported. - -In order to add support for new resources, create classes inheriting from -``GCPResource`` and ``GCPNode``, update the ``GCPNodeType`` enum, -update the ``_generate_node_name`` method and finally update the -node provider. -""" - -import abc -import logging -import re -import time - -from collections import UserDict -from copy import deepcopy -from enum import Enum -from functools import wraps -from typing import Any, Dict, List, Optional, Tuple, Union -from uuid import uuid4 - -from googleapiclient.discovery import Resource -from googleapiclient.errors import HttpError - -from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME - -logger = logging.getLogger(__name__) - -INSTANCE_NAME_MAX_LEN = 64 -INSTANCE_NAME_UUID_LEN = 8 -MAX_POLLS = 12 -# TPUs take a long while to respond, so we increase the MAX_POLLS -# considerably - this probably could be smaller -# TPU deletion uses MAX_POLLS -MAX_POLLS_TPU = MAX_POLLS * 8 -# Stopping instances can take several minutes, so we increase the timeout -MAX_POLLS_STOP = MAX_POLLS * 8 -POLL_INTERVAL = 5 - - -def _retry_on_exception( - exception: Union[Exception, Tuple[Exception]], - regex: Optional[str] = None, - max_retries: int = MAX_POLLS, - retry_interval_s: int = POLL_INTERVAL, -): - """Retry a function call n-times for as long as it throws an exception.""" - - def dec(func): - @wraps(func) - def wrapper(*args, **kwargs): - def try_catch_exc(): - try: - value = func(*args, **kwargs) - return value - except Exception as e: - if not isinstance(e, exception) or ( - regex and not re.search(regex, str(e)) - ): - raise e - return e - - for _ in range(max_retries): - ret = try_catch_exc() - if not isinstance(ret, Exception): - break - time.sleep(retry_interval_s) - if isinstance(ret, Exception): - raise ret - return ret - - return wrapper - - return dec - - -def _generate_node_name(labels: dict, node_suffix: str) -> str: - """Generate node name from labels and suffix. - - This is required so that the correct resource can be selected - when the only information autoscaler has is the name of the node. - - The suffix is expected to be one of 'compute' or 'tpu' - (as in ``GCPNodeType``). - """ - name_label = labels[TAG_RAY_NODE_NAME] - assert len(name_label) <= (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1), ( - name_label, - len(name_label), - ) - return f"{name_label}-{uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}" - - -class GCPNodeType(Enum): - """Enum for GCP node types (compute & tpu)""" - - COMPUTE = "compute" - TPU = "tpu" - - @staticmethod - def from_gcp_node(node: "GCPNode"): - """Return GCPNodeType based on ``node``'s class""" - if isinstance(node, GCPTPUNode): - return GCPNodeType.TPU - if isinstance(node, GCPComputeNode): - return GCPNodeType.COMPUTE - raise TypeError(f"Wrong GCPNode type {type(node)}.") - - @staticmethod - def name_to_type(name: str): - """Provided a node name, determine the type. - - This expects the name to be in format '[NAME]-[UUID]-[TYPE]', - where [TYPE] is either 'compute' or 'tpu'. - """ - return GCPNodeType(name.split("-")[-1]) - - -class GCPNode(UserDict, metaclass=abc.ABCMeta): - """Abstraction around compute and tpu nodes""" - - NON_TERMINATED_STATUSES = None - RUNNING_STATUSES = None - STOPPED_STATUSES = None - STOPPING_STATUSES = None - STATUS_FIELD = None - - def __init__(self, base_dict: dict, resource: "GCPResource", **kwargs) -> None: - super().__init__(base_dict, **kwargs) - self.resource = resource - assert isinstance(self.resource, GCPResource) - - def is_running(self) -> bool: - return self.get(self.STATUS_FIELD) in self.RUNNING_STATUSES - - def is_terminated(self) -> bool: - return self.get(self.STATUS_FIELD) not in self.NON_TERMINATED_STATUSES - - def is_stopped(self) -> bool: - return self.get(self.STATUS_FIELD) in self.STOPPED_STATUSES - - def is_stopping(self) -> bool: - return self.get(self.STATUS_FIELD) in self.STOPPING_STATUSES - - @property - def id(self) -> str: - return self["name"] - - @abc.abstractmethod - def get_labels(self) -> dict: - return - - @abc.abstractmethod - def get_external_ip(self) -> str: - return - - @abc.abstractmethod - def get_internal_ip(self) -> str: - return - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}: {self.get('name')}>" - - -class GCPComputeNode(GCPNode): - """Abstraction around compute nodes""" - - # https://cloud.google.com/compute/docs/instances/instance-life-cycle - NON_TERMINATED_STATUSES = {"PROVISIONING", "STAGING", "RUNNING"} - RUNNING_STATUSES = {"RUNNING"} - STOPPED_STATUSES = {"TERMINATED"} - STOPPING_STATUSES = {"STOPPING"} - STATUS_FIELD = "status" - - def get_labels(self) -> dict: - return self.get("labels", {}) - - def get_external_ip(self) -> str: - return ( - self.get("networkInterfaces", [{}])[0] - .get("accessConfigs", [{}])[0] - .get("natIP", None) - ) - - def get_internal_ip(self) -> str: - return self.get("networkInterfaces", [{}])[0].get("networkIP") - - -class GCPTPUNode(GCPNode): - """Abstraction around tpu nodes""" - - # https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#State - - NON_TERMINATED_STATUSES = {"CREATING", "STARTING", "RESTARTING", "READY"} - RUNNING_STATUSES = {"READY"} - STOPPED_STATUSES = {"STOPPED"} - STOPPING_STATUSES = {"STOPPING"} - STATUS_FIELD = "state" - - # SKY: get status of TPU VM for status filtering - def get_status(self) -> str: - return self.get(self.STATUS_FIELD) - - def get_labels(self) -> dict: - return self.get("labels", {}) - - def get_external_ip(self) -> str: - return ( - self.get("networkEndpoints", [{}])[0] - .get("accessConfig", {}) - .get("externalIp", None) - ) - - def get_internal_ip(self) -> str: - return self.get("networkEndpoints", [{}])[0].get("ipAddress", None) - - -class GCPResource(metaclass=abc.ABCMeta): - """Abstraction around compute and TPU resources""" - - def __init__( - self, - resource: Resource, - project_id: str, - availability_zone: str, - cluster_name: str, - ) -> None: - self.resource = resource - self.project_id = project_id - self.availability_zone = availability_zone - self.cluster_name = cluster_name - - @abc.abstractmethod - def wait_for_operation( - self, - operation: dict, - max_polls: int = MAX_POLLS, - poll_interval: int = POLL_INTERVAL, - ) -> dict: - """Waits a preset amount of time for operation to complete.""" - return None - - @abc.abstractmethod - def list_instances(self, label_filters: Optional[dict] = None) -> List["GCPNode"]: - """Returns a filtered list of all instances. - - The filter removes all terminated instances and, if ``label_filters`` - are provided, all instances which labels are not matching the - ones provided. - """ - return - - @abc.abstractmethod - def get_instance(self, node_id: str) -> "GCPNode": - """Returns a single instance.""" - return - - @abc.abstractmethod - def set_labels( - self, node: GCPNode, labels: dict, wait_for_operation: bool = True - ) -> dict: - """Sets labels on an instance and returns result. - - Completely replaces the labels dictionary.""" - return - - @abc.abstractmethod - def create_instance( - self, base_config: dict, labels: dict, wait_for_operation: bool = True - ) -> Tuple[dict, str]: - """Creates a single instance and returns result. - - Returns a tuple of (result, node_name). - """ - return - - @abc.abstractmethod - def resize_disk( - self, base_config: dict, instance_name: str, wait_for_operation: bool = True - ) -> dict: - """Resize a Google Cloud disk based on the provided configuration. - - Returns the response of resize operation. - """ - return - - def create_instances( - self, - base_config: dict, - labels: dict, - count: int, - wait_for_operation: bool = True, - ) -> List[Tuple[dict, str]]: - """Creates multiple instances and returns result. - - Returns a list of tuples of (result, node_name). - """ - operations = [ - self.create_instance(base_config, labels, wait_for_operation=False) - for i in range(count) - ] - - if wait_for_operation: - results = [ - (self.wait_for_operation(operation), node_name) - for operation, node_name in operations - ] - else: - results = operations - - return results - - @abc.abstractmethod - def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - """Start an instance and return result.""" - - @abc.abstractmethod - def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - """Stop an instance and return result.""" - - @abc.abstractmethod - def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - """Deletes an instance and returns result.""" - return - - -class GCPCompute(GCPResource): - """Abstraction around GCP compute resource""" - - def wait_for_operation( - self, - operation: dict, - max_polls: int = MAX_POLLS, - poll_interval: int = POLL_INTERVAL, - ) -> dict: - """Poll for compute zone operation until finished.""" - logger.info( - "wait_for_compute_zone_operation: " - f"Waiting for operation {operation['name']} to finish..." - ) - - for _ in range(max_polls): - result = ( - self.resource.zoneOperations() - .get( - project=self.project_id, - operation=operation["name"], - zone=self.availability_zone, - ) - .execute() - ) - if "error" in result: - raise Exception(result["error"]) - - if result["status"] == "DONE": - logger.info( - "wait_for_compute_zone_operation: " - f"Operation {operation['name']} finished." - ) - break - - time.sleep(poll_interval) - - return result - - def list_instances( - self, - label_filters: Optional[dict] = None, - ) -> List[GCPComputeNode]: - non_terminated_status = list(GCPComputeNode.NON_TERMINATED_STATUSES) - return self._list_instances(label_filters, non_terminated_status) - - def _list_instances( - self, label_filters: Optional[dict], status_filter: Optional[List[str]] - ) -> List[GCPComputeNode]: - label_filters = label_filters or {} - - if label_filters: - label_filter_expr = ( - "(" - + " AND ".join( - [ - "(labels.{key} = {value})".format(key=key, value=value) - for key, value in label_filters.items() - ] - ) - + ")" - ) - else: - label_filter_expr = "" - - if status_filter: - instance_state_filter_expr = ( - "(" - + " OR ".join( - [ - "(status = {status})".format(status=status) - for status in status_filter - ] - ) - + ")" - ) - else: - instance_state_filter_expr = "" - - cluster_name_filter_expr = "(labels.{key} = {value})".format( - key=TAG_RAY_CLUSTER_NAME, value=self.cluster_name - ) - - not_empty_filters = [ - f - for f in [ - label_filter_expr, - instance_state_filter_expr, - cluster_name_filter_expr, - ] - if f - ] - - filter_expr = " AND ".join(not_empty_filters) - - response = ( - self.resource.instances() - .list( - project=self.project_id, - zone=self.availability_zone, - filter=filter_expr, - ) - .execute() - ) - - instances = response.get("items", []) - return [GCPComputeNode(i, self) for i in instances] - - def get_instance(self, node_id: str) -> GCPComputeNode: - instance = ( - self.resource.instances() - .get( - project=self.project_id, - zone=self.availability_zone, - instance=node_id, - ) - .execute() - ) - - return GCPComputeNode(instance, self) - - def set_labels( - self, node: GCPComputeNode, labels: dict, wait_for_operation: bool = True - ) -> dict: - body = { - "labels": dict(node["labels"], **labels), - "labelFingerprint": node["labelFingerprint"], - } - node_id = node["name"] - operation = ( - self.resource.instances() - .setLabels( - project=self.project_id, - zone=self.availability_zone, - instance=node_id, - body=body, - ) - .execute() - ) - - if wait_for_operation: - result = self.wait_for_operation(operation) - else: - result = operation - - return result - - def _convert_resources_to_urls( - self, configuration_dict: Dict[str, Any] - ) -> Dict[str, Any]: - """Ensures that resources are in their full URL form. - - GCP expects machineType and accleratorType to be a full URL (e.g. - `zones/us-west1/machineTypes/n1-standard-2`) instead of just the - type (`n1-standard-2`) - - Args: - configuration_dict: Dict of options that will be passed to GCP - Returns: - Input dictionary, but with possibly expanding `machineType` and - `acceleratorType`. - """ - configuration_dict = deepcopy(configuration_dict) - existing_machine_type = configuration_dict["machineType"] - if not re.search(".*/machineTypes/.*", existing_machine_type): - configuration_dict[ - "machineType" - ] = "zones/{zone}/machineTypes/{machine_type}".format( - zone=self.availability_zone, - machine_type=configuration_dict["machineType"], - ) - - for accelerator in configuration_dict.get("guestAccelerators", []): - gpu_type = accelerator["acceleratorType"] - if not re.search(".*/acceleratorTypes/.*", gpu_type): - accelerator[ - "acceleratorType" - ] = "projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}".format( # noqa: E501 - project=self.project_id, - zone=self.availability_zone, - accelerator=gpu_type, - ) - - return configuration_dict - - def create_instance( - self, base_config: dict, labels: dict, wait_for_operation: bool = True - ) -> Tuple[dict, str]: - config = self._convert_resources_to_urls(base_config) - # removing TPU-specific default key set in config.py - config.pop("networkConfig", None) - name = _generate_node_name(labels, GCPNodeType.COMPUTE.value) - - labels = dict(config.get("labels", {}), **labels) - - config.update( - { - "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}), - "name": name, - } - ) - - # Allow Google Compute Engine instance templates. - # - # Config example: - # - # ... - # node_config: - # sourceInstanceTemplate: global/instanceTemplates/worker-16 - # machineType: e2-standard-16 - # ... - # - # node_config parameters override matching template parameters, if any. - # - # https://cloud.google.com/compute/docs/instance-templates - # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert - source_instance_template = config.pop("sourceInstanceTemplate", None) - - operation = ( - self.resource.instances() - .insert( - project=self.project_id, - zone=self.availability_zone, - sourceInstanceTemplate=source_instance_template, - body=config, - ) - .execute() - ) - - if wait_for_operation: - result = self.wait_for_operation(operation) - else: - result = operation - - return result, name - - def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - operation = ( - self.resource.instances() - .start( - project=self.project_id, - zone=self.availability_zone, - instance=node_id, - ) - .execute() - ) - - if wait_for_operation: - result = self.wait_for_operation(operation) - else: - result = operation - - return result - - def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - operation = ( - self.resource.instances() - .stop( - project=self.project_id, - zone=self.availability_zone, - instance=node_id, - ) - .execute() - ) - - if wait_for_operation: - result = self.wait_for_operation(operation) - else: - result = operation - - return result - - def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - operation = ( - self.resource.instances() - .delete( - project=self.project_id, - zone=self.availability_zone, - instance=node_id, - ) - .execute() - ) - - if wait_for_operation: - result = self.wait_for_operation(operation) - else: - result = operation - - return result - - def resize_disk( - self, base_config: dict, instance_name: str, wait_for_operation: bool = True - ) -> dict: - """Resize a Google Cloud disk based on the provided configuration.""" - - # Extract the specified disk size from the configuration - new_size_gb = base_config["disks"][0]["initializeParams"]["diskSizeGb"] - - # Fetch the instance details to get the disk name and current disk size - response = ( - self.resource.instances() - .get( - project=self.project_id, - zone=self.availability_zone, - instance=instance_name, - ) - .execute() - ) - disk_name = response["disks"][0]["source"].split("/")[-1] - - try: - # Execute the resize request and return the response - operation = ( - self.resource.disks() - .resize( - project=self.project_id, - zone=self.availability_zone, - disk=disk_name, - body={ - "sizeGb": str(new_size_gb), - }, - ) - .execute() - ) - except HttpError as e: - # Catch HttpError when provided with invalid value for new disk size. - # Allowing users to create instances with the same size as the image - logger.warning(f"googleapiclient.errors.HttpError: {e.reason}") - return {} - - if wait_for_operation: - result = self.wait_for_operation(operation) - else: - result = operation - - return result - - -class GCPTPU(GCPResource): - """Abstraction around GCP TPU resource""" - - # node names already contain the path, but this is required for `parent` - # arguments - @property - def path(self): - return f"projects/{self.project_id}/locations/{self.availability_zone}" - - def wait_for_operation( - self, - operation: dict, - max_polls: int = MAX_POLLS_TPU, - poll_interval: int = POLL_INTERVAL, - ) -> dict: - """Poll for TPU operation until finished.""" - logger.info( - "wait_for_tpu_operation: " - f"Waiting for operation {operation['name']} to finish..." - ) - - for _ in range(max_polls): - result = ( - self.resource.projects() - .locations() - .operations() - .get(name=f"{operation['name']}") - .execute() - ) - if "error" in result: - raise Exception(result["error"]) - - if "response" in result: - logger.info( - "wait_for_tpu_operation: " - f"Operation {operation['name']} finished." - ) - break - - time.sleep(poll_interval) - - return result - - def list_instances(self, label_filters: Optional[dict] = None) -> List[GCPTPUNode]: - non_terminated_status = list(GCPTPUNode.NON_TERMINATED_STATUSES) - return self._list_instances(label_filters, non_terminated_status) - - def _list_instances( - self, label_filters: Optional[dict], status_filter: Optional[List[str]] - ) -> List[GCPTPUNode]: - try: - response = ( - self.resource.projects() - .locations() - .nodes() - .list(parent=self.path) - .execute() - ) - except HttpError as e: - # SKY: Catch HttpError when accessing unauthorized region. - # Return empty list instead of raising exception to not break - # ray down. - logger.warning(f"googleapiclient.errors.HttpError: {e.reason}") - return [] - - instances = response.get("nodes", []) - instances = [GCPTPUNode(i, self) for i in instances] - - # filter_expr cannot be passed directly to API - # so we need to filter the results ourselves - - # same logic as in GCPCompute.list_instances - label_filters = label_filters or {} - label_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name - - def filter_instance(instance: GCPTPUNode) -> bool: - labels = instance.get_labels() - if label_filters: - for key, value in label_filters.items(): - if key not in labels: - return False - if value != labels[key]: - return False - - if status_filter: - if instance.get_status() not in status_filter: - return False - - return True - - instances = list(filter(filter_instance, instances)) - - return instances - - def get_instance(self, node_id: str) -> GCPTPUNode: - instance = ( - self.resource.projects().locations().nodes().get(name=node_id).execute() - ) - - return GCPTPUNode(instance, self) - - # this sometimes fails without a clear reason, so we retry it - # MAX_POLLS times - @_retry_on_exception(HttpError, "unable to queue the operation") - def set_labels( - self, node: GCPTPUNode, labels: dict, wait_for_operation: bool = True - ) -> dict: - body = { - "labels": dict(node["labels"], **labels), - } - update_mask = "labels" - - operation = ( - self.resource.projects() - .locations() - .nodes() - .patch( - name=node["name"], - updateMask=update_mask, - body=body, - ) - .execute() - ) - - if wait_for_operation: - result = self.wait_for_operation(operation) - else: - result = operation - - return result - - def create_instance( - self, base_config: dict, labels: dict, wait_for_operation: bool = True - ) -> Tuple[dict, str]: - config = base_config.copy() - # removing Compute-specific default key set in config.py - config.pop("networkInterfaces", None) - name = _generate_node_name(labels, GCPNodeType.TPU.value) - - labels = dict(config.get("labels", {}), **labels) - - config.update( - { - "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}), - } - ) - - if "networkConfig" not in config: - config["networkConfig"] = {} - if "enableExternalIps" not in config["networkConfig"]: - # this is required for SSH to work, per google documentation - # https://cloud.google.com/tpu/docs/users-guide-tpu-vm#create-curl - config["networkConfig"]["enableExternalIps"] = True - - try: - operation = ( - self.resource.projects() - .locations() - .nodes() - .create( - parent=self.path, - body=config, - nodeId=name, - ) - .execute() - ) - except HttpError as e: - # SKY: Catch HttpError when accessing unauthorized region. - logger.error(f"googleapiclient.errors.HttpError: {e.reason}") - raise e - - if wait_for_operation: - result = self.wait_for_operation(operation) - else: - result = operation - - return result, name - - def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - operation = ( - self.resource.projects().locations().nodes().start(name=node_id).execute() - ) - - # No need to increase MAX_POLLS for deletion - if wait_for_operation: - result = self.wait_for_operation(operation, max_polls=MAX_POLLS) - else: - result = operation - - return result - - def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - operation = ( - self.resource.projects().locations().nodes().stop(name=node_id).execute() - ) - - # No need to increase MAX_POLLS for deletion - if wait_for_operation: - result = self.wait_for_operation(operation, max_polls=MAX_POLLS) - else: - result = operation - - return result - - def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: - operation = ( - self.resource.projects().locations().nodes().delete(name=node_id).execute() - ) - - # No need to increase MAX_POLLS for deletion - if wait_for_operation: - result = self.wait_for_operation(operation, max_polls=MAX_POLLS) - else: - result = operation - - return result - - def resize_disk( - self, base_config: dict, instance_name: str, wait_for_operation: bool = True - ) -> dict: - """ - TODO: Implement the feature to attach persistent disks for TPU VMs. - The boot disk of TPU VMs is not resizable, and users need to add a - persistent disk to expand disk capacity. Related issue: #2387 - """ diff --git a/sky/skylet/providers/gcp/node_provider.py b/sky/skylet/providers/gcp/node_provider.py deleted file mode 100644 index 1f0be55a03a..00000000000 --- a/sky/skylet/providers/gcp/node_provider.py +++ /dev/null @@ -1,400 +0,0 @@ -import logging -import time -import copy -from functools import wraps -from threading import RLock -from typing import Dict, List - -import googleapiclient - -from sky.skylet.providers.gcp.config import ( - bootstrap_gcp, - construct_clients_from_provider_config, - get_node_type, -) -from sky.skylet.providers.command_runner import SkyDockerCommandRunner -from sky.provision import docker_utils - -from ray.autoscaler._private.command_runner import SSHCommandRunner -from ray.autoscaler.tags import ( - TAG_RAY_LAUNCH_CONFIG, - TAG_RAY_NODE_KIND, - TAG_RAY_USER_NODE_TYPE, -) -from ray.autoscaler._private.cli_logger import cf, cli_logger - - -# The logic has been abstracted away here to allow for different GCP resources -# (API endpoints), which can differ widely, making it impossible to use -# the same logic for everything. -from sky.skylet.providers.gcp.node import ( # noqa - GCPCompute, - GCPNode, - GCPNodeType, - GCPResource, - GCPTPU, - # Added by SkyPilot - INSTANCE_NAME_MAX_LEN, - INSTANCE_NAME_UUID_LEN, - MAX_POLLS_STOP, - POLL_INTERVAL, -) -from ray.autoscaler.node_provider import NodeProvider - -logger = logging.getLogger(__name__) - - -def _retry(method, max_tries=5, backoff_s=1): - """Retry decorator for methods of GCPNodeProvider. - - Upon catching BrokenPipeError, API clients are rebuilt and - decorated methods are retried. - - Work-around for https://github.com/ray-project/ray/issues/16072. - Based on https://github.com/kubeflow/pipelines/pull/5250/files. - """ - - @wraps(method) - def method_with_retries(self, *args, **kwargs): - try_count = 0 - while try_count < max_tries: - try: - return method(self, *args, **kwargs) - except BrokenPipeError: - logger.warning("Caught a BrokenPipeError. Retrying.") - try_count += 1 - if try_count < max_tries: - self._construct_clients() - time.sleep(backoff_s) - else: - raise - - return method_with_retries - - -class GCPNodeProvider(NodeProvider): - def __init__(self, provider_config: dict, cluster_name: str): - NodeProvider.__init__(self, provider_config, cluster_name) - self.lock = RLock() - self._construct_clients() - - # Cache of node objects from the last nodes() call. This avoids - # excessive DescribeInstances requests. - self.cached_nodes: Dict[str, GCPNode] = {} - self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True) - - def _construct_clients(self): - _, _, compute, tpu = construct_clients_from_provider_config( - self.provider_config - ) - - # Dict of different resources provided by GCP. - # At this moment - Compute and TPUs - self.resources: Dict[GCPNodeType, GCPResource] = {} - - # Compute is always required - self.resources[GCPNodeType.COMPUTE] = GCPCompute( - compute, - self.provider_config["project_id"], - self.provider_config["availability_zone"], - self.cluster_name, - ) - - # if there are no TPU nodes defined in config, tpu will be None. - if tpu is not None: - self.resources[GCPNodeType.TPU] = GCPTPU( - tpu, - self.provider_config["project_id"], - self.provider_config["availability_zone"], - self.cluster_name, - ) - - def _get_resource_depending_on_node_name(self, node_name: str) -> GCPResource: - """Return the resource responsible for the node, based on node_name. - - This expects the name to be in format '[NAME]-[UUID]-[TYPE]', - where [TYPE] is either 'compute' or 'tpu' (see ``GCPNodeType``). - """ - return self.resources[GCPNodeType.name_to_type(node_name)] - - @_retry - def non_terminated_nodes(self, tag_filters: dict): - with self.lock: - instances = [] - - for resource in self.resources.values(): - node_instances = resource.list_instances(tag_filters) - instances += node_instances - - # Note: All the operations use "name" as the unique instance id - self.cached_nodes = {i["name"]: i for i in instances} - return [i["name"] for i in instances] - - def is_running(self, node_id: str): - with self.lock: - node = self._get_cached_node(node_id) - return node.is_running() - - def is_terminated(self, node_id: str): - with self.lock: - node = self._get_cached_node(node_id) - return node.is_terminated() - - def node_tags(self, node_id: str): - with self.lock: - node = self._get_cached_node(node_id) - return node.get_labels() - - @_retry - def set_node_tags(self, node_id: str, tags: dict): - with self.lock: - labels = tags - node = self._get_node(node_id) - - resource = self._get_resource_depending_on_node_name(node_id) - - result = resource.set_labels(node=node, labels=labels) - - return result - - def external_ip(self, node_id: str): - with self.lock: - node = self._get_cached_node(node_id) - - ip = node.get_external_ip() - if ip is None: - node = self._get_node(node_id) - ip = node.get_external_ip() - - return ip - - def internal_ip(self, node_id: str): - with self.lock: - node = self._get_cached_node(node_id) - - ip = node.get_internal_ip() - if ip is None: - node = self._get_node(node_id) - ip = node.get_internal_ip() - - return ip - - @_retry - def create_node(self, base_config: dict, tags: dict, count: int) -> Dict[str, dict]: - """Creates instances. - - Returns dict mapping instance id to each create operation result for the created - instances. - """ - with self.lock: - result_dict = {} - labels = tags # gcp uses "labels" instead of aws "tags" - labels = dict(sorted(copy.deepcopy(labels).items())) - - node_type = get_node_type(base_config) - resource = self.resources[node_type] - - # Try to reuse previously stopped nodes with compatible configs - if self.cache_stopped_nodes: - filters = { - TAG_RAY_NODE_KIND: labels[TAG_RAY_NODE_KIND], - # SkyPilot: removed TAG_RAY_LAUNCH_CONFIG to allow reusing nodes - # with different launch configs. - # Reference: https://github.com/skypilot-org/skypilot/pull/1671 - } - # This tag may not always be present. - if TAG_RAY_USER_NODE_TYPE in labels: - filters[TAG_RAY_USER_NODE_TYPE] = labels[TAG_RAY_USER_NODE_TYPE] - filters_with_launch_config = copy.copy(filters) - filters_with_launch_config[TAG_RAY_LAUNCH_CONFIG] = labels[ - TAG_RAY_LAUNCH_CONFIG - ] - - # SKY: "TERMINATED" for compute VM, "STOPPED" for TPU VM - # "STOPPING" means the VM is being stopped, which needs - # to be included to avoid creating a new VM. - if isinstance(resource, GCPCompute): - STOPPED_STATUS = ["TERMINATED", "STOPPING"] - else: - STOPPED_STATUS = ["STOPPED", "STOPPING"] - - # SkyPilot: We try to use the instances with the same matching launch_config first. If - # there is not enough instances with matching launch_config, we then use all the - # instances with the same matching launch_config plus some instances with wrong - # launch_config. - def get_order_key(node): - import datetime - - timestamp = node.get("lastStartTimestamp") - if timestamp is not None: - return datetime.datetime.strptime( - timestamp, "%Y-%m-%dT%H:%M:%S.%f%z" - ) - return node.id - - nodes_matching_launch_config = resource._list_instances( - filters_with_launch_config, STOPPED_STATUS - ) - nodes_matching_launch_config.sort( - key=lambda n: get_order_key(n), reverse=True - ) - if len(nodes_matching_launch_config) >= count: - reuse_nodes = nodes_matching_launch_config[:count] - else: - nodes_all = resource._list_instances(filters, STOPPED_STATUS) - nodes_matching_launch_config_ids = set( - n.id for n in nodes_matching_launch_config - ) - nodes_non_matching_launch_config = [ - n - for n in nodes_all - if n.id not in nodes_matching_launch_config_ids - ] - # This is for backward compatibility, where the uesr already has leaked - # stopped nodes with the different launch config before update to #1671, - # and the total number of the leaked nodes is greater than the number of - # nodes to be created. With this, we will make sure we will reuse the - # most recently used nodes. - # This can be removed in the future when we are sure all the users - # have updated to #1671. - nodes_non_matching_launch_config.sort( - key=lambda n: get_order_key(n), reverse=True - ) - reuse_nodes = ( - nodes_matching_launch_config + nodes_non_matching_launch_config - ) - # The total number of reusable nodes can be less than the number of nodes to be created. - # This `[:count]` is fine, as it will get all the reusable nodes, even if there are - # less nodes. - reuse_nodes = reuse_nodes[:count] - - reuse_node_ids = [n.id for n in reuse_nodes] - if reuse_nodes: - # TODO(suquark): Some instances could still be stopping. - # We may wait until these instances stop. - cli_logger.print( - # TODO: handle plural vs singular? - f"Reusing nodes {cli_logger.render_list(reuse_node_ids)}. " - "To disable reuse, set `cache_stopped_nodes: False` " - "under `provider` in the cluster configuration." - ) - for node_id in reuse_node_ids: - result = resource.start_instance(node_id) - result_dict[node_id] = {node_id: result} - for node_id in reuse_node_ids: - self.set_node_tags(node_id, tags) - count -= len(reuse_node_ids) - if count: - results = resource.create_instances(base_config, labels, count) - if "sourceMachineImage" in base_config: - for _, instance_id in results: - resource.resize_disk(base_config, instance_id) - result_dict.update( - {instance_id: result for result, instance_id in results} - ) - return result_dict - - @_retry - def terminate_node(self, node_id: str): - with self.lock: - result = None - resource = self._get_resource_depending_on_node_name(node_id) - try: - if self.cache_stopped_nodes: - cli_logger.print( - f"Stopping instance {node_id} " - + cf.dimmed( - "(to terminate instead, " - "set `cache_stopped_nodes: False` " - "under `provider` in the cluster configuration)" - ), - ) - result = resource.stop_instance(node_id=node_id) - - # Check if the instance is actually stopped. - # GCP does not fully stop an instance even after - # the stop operation is finished. - for _ in range(MAX_POLLS_STOP): - instance = resource.get_instance(node_id=node_id) - if instance.is_stopped(): - logger.info(f"Instance {node_id} is stopped.") - break - elif instance.is_stopping(): - time.sleep(POLL_INTERVAL) - else: - raise RuntimeError( - f"Unexpected instance status." " Details: {instance}" - ) - - if instance.is_stopping(): - raise RuntimeError( - f"Maximum number of polls: " - f"{MAX_POLLS_STOP} reached. " - f"Instance {node_id} is still in " - "STOPPING status." - ) - else: - result = resource.delete_instance( - node_id=node_id, - ) - except googleapiclient.errors.HttpError as http_error: - if http_error.resp.status == 404: - logger.warning( - f"Tried to delete the node with id {node_id} " - "but it was already gone." - ) - else: - raise http_error from None - - return result - - @_retry - def _get_node(self, node_id: str) -> GCPNode: - self.non_terminated_nodes({}) # Side effect: updates cache - - with self.lock: - if node_id in self.cached_nodes: - return self.cached_nodes[node_id] - - resource = self._get_resource_depending_on_node_name(node_id) - instance = resource.get_instance(node_id=node_id) - - return instance - - def _get_cached_node(self, node_id: str) -> GCPNode: - if node_id in self.cached_nodes: - return self.cached_nodes[node_id] - - return self._get_node(node_id) - - @staticmethod - def bootstrap_config(cluster_config): - return bootstrap_gcp(cluster_config) - - def get_command_runner( - self, - log_prefix, - node_id, - auth_config, - cluster_name, - process_runner, - use_internal_ip, - docker_config=None, - ): - common_args = { - "log_prefix": log_prefix, - "node_id": node_id, - "provider": self, - "auth_config": auth_config, - "cluster_name": cluster_name, - "process_runner": process_runner, - "use_internal_ip": use_internal_ip, - } - if docker_config and docker_config["container_name"] != "": - if "docker_login_config" in self.provider_config: - docker_config["docker_login_config"] = docker_utils.DockerLoginConfig( - **self.provider_config["docker_login_config"] - ) - return SkyDockerCommandRunner(docker_config, **common_args) - else: - return SSHCommandRunner(**common_args) diff --git a/sky/skylet/providers/kubernetes/__init__.py b/sky/skylet/providers/kubernetes/__init__.py deleted file mode 100644 index 0bb7afaea81..00000000000 --- a/sky/skylet/providers/kubernetes/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from sky.skylet.providers.kubernetes.node_provider import KubernetesNodeProvider diff --git a/sky/skylet/providers/kubernetes/config.py b/sky/skylet/providers/kubernetes/config.py deleted file mode 100644 index d4f021cb29e..00000000000 --- a/sky/skylet/providers/kubernetes/config.py +++ /dev/null @@ -1,337 +0,0 @@ -import copy -import logging -import math -from typing import Any, Dict, Union - -from sky.adaptors import kubernetes -from sky.provision.kubernetes import utils as kubernetes_utils - -logger = logging.getLogger(__name__) - -log_prefix = 'KubernetesNodeProvider: ' - -# Timeout for deleting a Kubernetes resource (in seconds). -DELETION_TIMEOUT = 90 - - -class InvalidNamespaceError(ValueError): - - def __init__(self, field_name: str, namespace: str): - self.message = ( - f'Namespace of {field_name} config does not match provided ' - f'namespace "{namespace}". Either set it to {namespace} or remove the ' - 'field') - - def __str__(self) -> str: - return self.message - - -def using_existing_msg(resource_type: str, name: str) -> str: - return f'using existing {resource_type} "{name}"' - - -def updating_existing_msg(resource_type: str, name: str) -> str: - return f'updating existing {resource_type} "{name}"' - - -def not_found_msg(resource_type: str, name: str) -> str: - return f'{resource_type} "{name}" not found, attempting to create it' - - -def not_checking_msg(resource_type: str, name: str) -> str: - return f'not checking if {resource_type} "{name}" exists' - - -def created_msg(resource_type: str, name: str) -> str: - return f'successfully created {resource_type} "{name}"' - - -def not_provided_msg(resource_type: str) -> str: - return f'no {resource_type} config provided, must already exist' - - -def bootstrap_kubernetes(config: Dict[str, Any]) -> Dict[str, Any]: - namespace = kubernetes_utils.get_current_kube_config_context_namespace() - - _configure_services(namespace, config['provider']) - - config = _configure_ssh_jump(namespace, config) - - if not config['provider'].get('_operator'): - # These steps are unecessary when using the Operator. - _configure_autoscaler_service_account(namespace, config['provider']) - _configure_autoscaler_role(namespace, config['provider']) - _configure_autoscaler_role_binding(namespace, config['provider']) - - return config - - -def fillout_resources_kubernetes(config: Dict[str, Any]) -> Dict[str, Any]: - """Fills CPU and GPU resources in the ray cluster config. - - For each node type and each of CPU/GPU, looks at container's resources - and limits, takes min of the two. - """ - if 'available_node_types' not in config: - return config - node_types = copy.deepcopy(config['available_node_types']) - head_node_type = config['head_node_type'] - for node_type in node_types: - - node_config = node_types[node_type]['node_config'] - # The next line is for compatibility with configs which define pod specs - # cf. KubernetesNodeProvider.create_node(). - pod = node_config.get('pod', node_config) - container_data = pod['spec']['containers'][0] - - autodetected_resources = get_autodetected_resources(container_data) - if node_types == head_node_type: - # we only autodetect worker type node memory resource - autodetected_resources.pop('memory') - if 'resources' not in config['available_node_types'][node_type]: - config['available_node_types'][node_type]['resources'] = {} - autodetected_resources.update( - config['available_node_types'][node_type]['resources']) - config['available_node_types'][node_type][ - 'resources'] = autodetected_resources - logger.debug(f'Updating the resources of node type {node_type} ' - f'to include {autodetected_resources}.') - return config - - -def get_autodetected_resources( - container_data: Dict[str, Any]) -> Dict[str, Any]: - container_resources = container_data.get('resources', None) - if container_resources is None: - return {'CPU': 0, 'GPU': 0} - - node_type_resources = { - resource_name.upper(): get_resource(container_resources, resource_name) - for resource_name in ['cpu', 'gpu'] - } - - memory_limits = get_resource(container_resources, 'memory') - node_type_resources['memory'] = memory_limits - - return node_type_resources - - -def get_resource(container_resources: Dict[str, Any], - resource_name: str) -> int: - request = _get_resource(container_resources, - resource_name, - field_name='requests') - limit = _get_resource(container_resources, - resource_name, - field_name='limits') - # Use request if limit is not set, else use limit. - # float('inf') means there's no limit set - res_count = request if limit == float('inf') else limit - # Convert to int since Ray autoscaler expects int. - # We also round up the resource count to the nearest integer to provide the - # user at least the amount of resource they requested. - rounded_count = math.ceil(res_count) - if resource_name == 'cpu': - # For CPU, we set minimum count to 1 because if CPU count is set to 0, - # (e.g. when the user sets --cpu 0.5), ray will not be able to schedule - # any tasks. - return max(1, rounded_count) - else: - # For GPU and memory, return the rounded count. - return rounded_count - - -def _get_resource(container_resources: Dict[str, Any], resource_name: str, - field_name: str) -> Union[int, float]: - """Returns the resource quantity. - - The amount of resource is rounded up to nearest integer. - Returns float("inf") if the resource is not present. - - Args: - container_resources: Container's resource field. - resource_name: One of 'cpu', 'gpu' or 'memory'. - field_name: One of 'requests' or 'limits'. - - Returns: - Union[int, float]: Detected resource quantity. - """ - if field_name not in container_resources: - # No limit/resource field. - return float('inf') - resources = container_resources[field_name] - # Look for keys containing the resource_name. For example, - # the key 'nvidia.com/gpu' contains the key 'gpu'. - matching_keys = [key for key in resources if resource_name in key.lower()] - if len(matching_keys) == 0: - return float('inf') - if len(matching_keys) > 1: - # Should have only one match -- mostly relevant for gpu. - raise ValueError(f'Multiple {resource_name} types not supported.') - # E.g. 'nvidia.com/gpu' or 'cpu'. - resource_key = matching_keys.pop() - resource_quantity = resources[resource_key] - if resource_name == 'memory': - return kubernetes_utils.parse_memory_resource(resource_quantity) - else: - return kubernetes_utils.parse_cpu_or_gpu_resource(resource_quantity) - - -def _configure_autoscaler_service_account( - namespace: str, provider_config: Dict[str, Any]) -> None: - account_field = 'autoscaler_service_account' - if account_field not in provider_config: - logger.info(log_prefix + not_provided_msg(account_field)) - return - - account = provider_config[account_field] - if 'namespace' not in account['metadata']: - account['metadata']['namespace'] = namespace - elif account['metadata']['namespace'] != namespace: - raise InvalidNamespaceError(account_field, namespace) - - name = account['metadata']['name'] - field_selector = f'metadata.name={name}' - accounts = (kubernetes.core_api().list_namespaced_service_account( - namespace, field_selector=field_selector).items) - if len(accounts) > 0: - assert len(accounts) == 1 - logger.info(log_prefix + using_existing_msg(account_field, name)) - return - - logger.info(log_prefix + not_found_msg(account_field, name)) - kubernetes.core_api().create_namespaced_service_account(namespace, account) - logger.info(log_prefix + created_msg(account_field, name)) - - -def _configure_autoscaler_role(namespace: str, - provider_config: Dict[str, Any]) -> None: - role_field = 'autoscaler_role' - if role_field not in provider_config: - logger.info(log_prefix + not_provided_msg(role_field)) - return - - role = provider_config[role_field] - if 'namespace' not in role['metadata']: - role['metadata']['namespace'] = namespace - elif role['metadata']['namespace'] != namespace: - raise InvalidNamespaceError(role_field, namespace) - - name = role['metadata']['name'] - field_selector = f'metadata.name={name}' - accounts = (kubernetes.auth_api().list_namespaced_role( - namespace, field_selector=field_selector).items) - if len(accounts) > 0: - assert len(accounts) == 1 - logger.info(log_prefix + using_existing_msg(role_field, name)) - return - - logger.info(log_prefix + not_found_msg(role_field, name)) - kubernetes.auth_api().create_namespaced_role(namespace, role) - logger.info(log_prefix + created_msg(role_field, name)) - - -def _configure_autoscaler_role_binding(namespace: str, - provider_config: Dict[str, Any]) -> None: - binding_field = 'autoscaler_role_binding' - if binding_field not in provider_config: - logger.info(log_prefix + not_provided_msg(binding_field)) - return - - binding = provider_config[binding_field] - if 'namespace' not in binding['metadata']: - binding['metadata']['namespace'] = namespace - elif binding['metadata']['namespace'] != namespace: - raise InvalidNamespaceError(binding_field, namespace) - for subject in binding['subjects']: - if 'namespace' not in subject: - subject['namespace'] = namespace - elif subject['namespace'] != namespace: - subject_name = subject['name'] - raise InvalidNamespaceError( - binding_field + f' subject {subject_name}', namespace) - - name = binding['metadata']['name'] - field_selector = f'metadata.name={name}' - accounts = (kubernetes.auth_api().list_namespaced_role_binding( - namespace, field_selector=field_selector).items) - if len(accounts) > 0: - assert len(accounts) == 1 - logger.info(log_prefix + using_existing_msg(binding_field, name)) - return - - logger.info(log_prefix + not_found_msg(binding_field, name)) - kubernetes.auth_api().create_namespaced_role_binding(namespace, binding) - logger.info(log_prefix + created_msg(binding_field, name)) - - -def _configure_ssh_jump(namespace, config): - """Creates a SSH jump pod to connect to the cluster. - - Also updates config['auth']['ssh_proxy_command'] to use the newly created - jump pod. - """ - pod_cfg = config['available_node_types']['ray_head_default']['node_config'] - - ssh_jump_name = pod_cfg['metadata']['labels']['skypilot-ssh-jump'] - ssh_jump_image = config['provider']['ssh_jump_image'] - - volumes = pod_cfg['spec']['volumes'] - # find 'secret-volume' and get the secret name - secret_volume = next(filter(lambda x: x['name'] == 'secret-volume', - volumes)) - ssh_key_secret_name = secret_volume['secret']['secretName'] - - # TODO(romilb): We currently split SSH jump pod and svc creation. Service - # is first created in authentication.py::setup_kubernetes_authentication - # and then SSH jump pod creation happens here. This is because we need to - # set the ssh_proxy_command in the ray YAML before we pass it to the - # autoscaler. If in the future if we can write the ssh_proxy_command to the - # cluster yaml through this method, then we should move the service - # creation here. - - # TODO(romilb): We should add a check here to make sure the service is up - # and available before we create the SSH jump pod. If for any reason the - # service is missing, we should raise an error. - - kubernetes_utils.setup_ssh_jump_pod(ssh_jump_name, ssh_jump_image, - ssh_key_secret_name, namespace) - return config - - -def _configure_services(namespace: str, provider_config: Dict[str, - Any]) -> None: - service_field = 'services' - if service_field not in provider_config: - logger.info(log_prefix + not_provided_msg(service_field)) - return - - services = provider_config[service_field] - for service in services: - if 'namespace' not in service['metadata']: - service['metadata']['namespace'] = namespace - elif service['metadata']['namespace'] != namespace: - raise InvalidNamespaceError(service_field, namespace) - - name = service['metadata']['name'] - field_selector = f'metadata.name={name}' - services = (kubernetes.core_api().list_namespaced_service( - namespace, field_selector=field_selector).items) - if len(services) > 0: - assert len(services) == 1 - existing_service = services[0] - if service == existing_service: - logger.info(log_prefix + using_existing_msg('service', name)) - return - else: - logger.info(log_prefix + updating_existing_msg('service', name)) - kubernetes.core_api().patch_namespaced_service( - name, namespace, service) - else: - logger.info(log_prefix + not_found_msg('service', name)) - kubernetes.core_api().create_namespaced_service(namespace, service) - logger.info(log_prefix + created_msg('service', name)) - - -class KubernetesError(Exception): - pass diff --git a/sky/skylet/providers/kubernetes/node_provider.py b/sky/skylet/providers/kubernetes/node_provider.py deleted file mode 100644 index 3648b2bec4d..00000000000 --- a/sky/skylet/providers/kubernetes/node_provider.py +++ /dev/null @@ -1,657 +0,0 @@ -import copy -import logging -import os -import re -import time -from typing import Dict -from uuid import uuid4 - -from ray.autoscaler._private.command_runner import SSHCommandRunner -from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import NODE_KIND_HEAD -from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME -from ray.autoscaler.tags import TAG_RAY_NODE_KIND - -from sky import exceptions -from sky.adaptors import kubernetes -from sky.backends import backend_utils -from sky.provision.kubernetes import utils as kubernetes_utils -from sky.skylet import constants -from sky.skylet.providers.kubernetes import config -from sky.utils import cluster_yaml_utils -from sky.utils import common_utils - -logger = logging.getLogger(__name__) - -MAX_TAG_RETRIES = 3 -DELAY_BEFORE_TAG_RETRY = 0.5 -UPTIME_SSH_TIMEOUT = 10 - -RAY_COMPONENT_LABEL = 'cluster.ray.io/component' - - -# Monkey patch SSHCommandRunner to allow specifying SSH port -def set_port(self, port): - self.ssh_options.arg_dict['Port'] = port - - -SSHCommandRunner.set_port = set_port - -# Monkey patch SSHCommandRunner to use a larger timeout when running uptime to -# check cluster liveness. This is needed because the default timeout of 5s is -# too short when the cluster is accessed from different geographical -# locations over VPN. -# -# Ray autoscaler sets the timeout on a per-call basis (as an arg to -# SSHCommandRunner.run). The 5s timeout is hardcoded in -# NodeUpdater.wait_ready() in updater.py is hard to modify without -# duplicating a large chunk of ray autoscaler code. Instead, we -# monkey patch the run method to check if the command being run is 'uptime', -# and if so change the timeout to 10s. -# -# Fortunately, Ray uses a timeout of 120s for running commands after the -# cluster is ready, so we do not need to modify that. - - -def run_override_timeout(*args, **kwargs): - # If command is `uptime`, change timeout to 10s - command = args[1] - if command == 'uptime': - kwargs['timeout'] = UPTIME_SSH_TIMEOUT - return SSHCommandRunner._run(*args, **kwargs) - - -SSHCommandRunner._run = SSHCommandRunner.run -SSHCommandRunner.run = run_override_timeout - - -def head_service_selector(cluster_name: str) -> Dict[str, str]: - """Selector for Operator-configured head service.""" - return {RAY_COMPONENT_LABEL: f'{cluster_name}-head'} - - -def to_label_selector(tags): - label_selector = '' - for k, v in tags.items(): - if label_selector != '': - label_selector += ',' - label_selector += '{}={}'.format(k, v) - return label_selector - - -def run_command_on_pods(node_name, node_namespace, command): - cmd_output = kubernetes.stream()( - kubernetes.core_api().connect_get_namespaced_pod_exec, - node_name, - node_namespace, - command=command, - stderr=True, - stdin=False, - stdout=True, - tty=False, - _request_timeout=kubernetes.API_TIMEOUT) - return cmd_output - - -class KubernetesNodeProvider(NodeProvider): - - def __init__(self, provider_config, cluster_name): - NodeProvider.__init__(self, provider_config, cluster_name) - self.cluster_name = cluster_name - - # Kubernetes namespace to user - self.namespace = provider_config.get( - 'namespace', - kubernetes_utils.get_current_kube_config_context_namespace()) - - # Timeout for resource provisioning. If it takes longer than this - # timeout, the resource provisioning will be considered failed. - # This is useful for failover. May need to be adjusted for different - # kubernetes setups. - self.timeout = provider_config['timeout'] - - def non_terminated_nodes(self, tag_filters): - # Match pods that are in the 'Pending' or 'Running' phase. - # Unfortunately there is no OR operator in field selectors, so we - # have to match on NOT any of the other phases. - field_selector = ','.join([ - 'status.phase!=Failed', - 'status.phase!=Unknown', - 'status.phase!=Succeeded', - 'status.phase!=Terminating', - ]) - - tag_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name - label_selector = to_label_selector(tag_filters) - pod_list = kubernetes.core_api().list_namespaced_pod( - self.namespace, - field_selector=field_selector, - label_selector=label_selector) - - # Don't return pods marked for deletion, - # i.e. pods with non-null metadata.DeletionTimestamp. - return [ - pod.metadata.name - for pod in pod_list.items - if pod.metadata.deletion_timestamp is None - ] - - def is_running(self, node_id): - pod = kubernetes.core_api().read_namespaced_pod(node_id, self.namespace) - return pod.status.phase == 'Running' - - def is_terminated(self, node_id): - pod = kubernetes.core_api().read_namespaced_pod(node_id, self.namespace) - return pod.status.phase not in ['Running', 'Pending'] - - def node_tags(self, node_id): - pod = kubernetes.core_api().read_namespaced_pod(node_id, self.namespace) - return pod.metadata.labels - - def external_ip(self, node_id): - return kubernetes_utils.get_external_ip() - - def external_port(self, node_id): - # Extract the NodePort of the head node's SSH service - # Node id is str e.g., example-cluster-head-v89lb - - # TODO(romilb): Implement caching here for performance. - # TODO(romilb): Multi-node would need more handling here. - cluster_name = node_id.split('-head')[0] - return kubernetes_utils.get_head_ssh_port(cluster_name, self.namespace) - - def internal_ip(self, node_id): - pod = kubernetes.core_api().read_namespaced_pod(node_id, self.namespace) - return pod.status.pod_ip - - def get_node_id(self, ip_address, use_internal_ip=True) -> str: - - def find_node_id(): - if use_internal_ip: - return self._internal_ip_cache.get(ip_address) - else: - return self._external_ip_cache.get(ip_address) - - if not find_node_id(): - all_nodes = self.non_terminated_nodes({}) - ip_func = self.internal_ip if use_internal_ip else self.external_ip - ip_cache = (self._internal_ip_cache - if use_internal_ip else self._external_ip_cache) - for node_id in all_nodes: - ip_cache[ip_func(node_id)] = node_id - - if not find_node_id(): - if use_internal_ip: - known_msg = f'Worker internal IPs: {list(self._internal_ip_cache)}' - else: - known_msg = f'Worker external IP: {list(self._external_ip_cache)}' - raise ValueError(f'ip {ip_address} not found. ' + known_msg) - - return find_node_id() - - def set_node_tags(self, node_ids, tags): - for _ in range(MAX_TAG_RETRIES - 1): - try: - self._set_node_tags(node_ids, tags) - return - except kubernetes.api_exception() as e: - if e.status == 409: - logger.info(config.log_prefix + - 'Caught a 409 error while setting' - ' node tags. Retrying...') - time.sleep(DELAY_BEFORE_TAG_RETRY) - continue - else: - raise - # One more try - self._set_node_tags(node_ids, tags) - - def _recover_cluster_yaml_path(self, cluster_name_with_hash: str) -> str: - # 'cluster_name_with_hash' combines the cluster name and hash value, - # separated by a hyphen. By using 'slice_length', we remove the hash - # (and its preceding hyphen) to retrieve the original cluster name. - slice_length = -(common_utils.USER_HASH_LENGTH_IN_CLUSTER_NAME + 1) - cluster_name = cluster_name_with_hash[:slice_length] - cluster_yaml_path = (os.path.join( - os.path.expanduser(backend_utils.SKY_USER_FILE_PATH), - f'{cluster_name}.yml')) - # Check if cluster_yaml_path exists. If not, we are running on - # the master node in a multi-node setup, in which case we must use the - # default ~/.sky/sky_ray.yml path. - if not os.path.exists(cluster_yaml_path): - cluster_yaml_path = os.path.expanduser( - cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH) - return cluster_yaml_path - - def _set_node_tags(self, node_id, tags): - pod = kubernetes.core_api().read_namespaced_pod(node_id, self.namespace) - pod.metadata.labels.update(tags) - kubernetes.core_api().patch_namespaced_pod(node_id, self.namespace, pod) - - def _raise_pod_scheduling_errors(self, new_nodes): - """Raise pod scheduling failure reason. - - When a pod fails to schedule in Kubernetes, the reasons for the failure - are recorded as events. This function retrieves those events and raises - descriptive errors for better debugging and user feedback. - """ - for new_node in new_nodes: - pod = kubernetes.core_api().read_namespaced_pod( - new_node.metadata.name, self.namespace) - pod_status = pod.status.phase - # When there are multiple pods involved while launching instance, - # there may be a single pod causing issue while others are - # successfully scheduled. In this case, we make sure to not surface - # the error message from the pod that is already scheduled. - if pod_status != 'Pending': - continue - pod_name = pod._metadata._name - events = kubernetes.core_api().list_namespaced_event( - self.namespace, - field_selector=(f'involvedObject.name={pod_name},' - 'involvedObject.kind=Pod')) - # Events created in the past hours are kept by - # Kubernetes python client and we want to surface - # the latest event message - events_desc_by_time = sorted( - events.items, - key=lambda e: e.metadata.creation_timestamp, - reverse=True) - - event_message = None - for event in events_desc_by_time: - if event.reason == 'FailedScheduling': - event_message = event.message - break - timeout_err_msg = ('Timed out while waiting for nodes to start. ' - 'Cluster may be out of resources or ' - 'may be too slow to autoscale.') - lack_resource_msg = ( - 'Insufficient {resource} capacity on the cluster. ' - 'Other SkyPilot tasks or pods may be using resources. ' - 'Check resource usage by running `kubectl describe nodes`.') - if event_message is not None: - if pod_status == 'Pending': - if 'Insufficient cpu' in event_message: - raise config.KubernetesError( - lack_resource_msg.format(resource='CPU')) - if 'Insufficient memory' in event_message: - raise config.KubernetesError( - lack_resource_msg.format(resource='memory')) - gpu_lf_keys = [ - lf.get_label_key() - for lf in kubernetes_utils.LABEL_FORMATTER_REGISTRY - ] - if pod.spec.node_selector: - for label_key in pod.spec.node_selector.keys(): - if label_key in gpu_lf_keys: - # TODO(romilb): We may have additional node - # affinity selectors in the future - in that - # case we will need to update this logic. - if ('Insufficient nvidia.com/gpu' - in event_message or - 'didn\'t match Pod\'s node affinity/selector' - in event_message): - raise config.KubernetesError( - f'{lack_resource_msg.format(resource="GPU")} ' - f'Verify if {pod.spec.node_selector[label_key]}' - ' is available in the cluster.') - raise config.KubernetesError(f'{timeout_err_msg} ' - f'Pod status: {pod_status}' - f'Details: \'{event_message}\' ') - raise config.KubernetesError(f'{timeout_err_msg}') - - def _wait_for_pods_to_schedule(self, new_nodes_with_jump_pod): - """Wait for all pods to be scheduled. - - Wait for all pods including jump pod to be scheduled, and if it - exceeds the timeout, raise an exception. If pod's container - is ContainerCreating, then we can assume that resources have been - allocated and we can exit. - """ - start_time = time.time() - while time.time() - start_time < self.timeout: - all_pods_scheduled = True - for node in new_nodes_with_jump_pod: - # Iterate over each pod to check their status - pod = kubernetes.core_api().read_namespaced_pod( - node.metadata.name, self.namespace) - if pod.status.phase == 'Pending': - # If container_statuses is None, then the pod hasn't - # been scheduled yet. - if pod.status.container_statuses is None: - all_pods_scheduled = False - break - - if all_pods_scheduled: - return - time.sleep(1) - - # Handle pod scheduling errors - try: - self._raise_pod_scheduling_errors(new_nodes_with_jump_pod) - except config.KubernetesError: - raise - except Exception as e: - raise config.KubernetesError( - 'An error occurred while trying to fetch the reason ' - 'for pod scheduling failure. ' - f'Error: {common_utils.format_exception(e)}') from None - - def _wait_for_pods_to_run(self, new_nodes_with_jump_pod): - """Wait for pods and their containers to be ready. - - Pods may be pulling images or may be in the process of container - creation. - """ - while True: - all_pods_running = True - # Iterate over each pod to check their status - for node in new_nodes_with_jump_pod: - pod = kubernetes.core_api().read_namespaced_pod( - node.metadata.name, self.namespace) - - # Continue if pod and all the containers within the - # pod are succesfully created and running. - if pod.status.phase == 'Running' and all([ - container.state.running - for container in pod.status.container_statuses - ]): - continue - - all_pods_running = False - if pod.status.phase == 'Pending': - # Iterate over each container in pod to check their status - for container_status in pod.status.container_statuses: - # If the container wasn't in 'ContainerCreating' - # state, then we know pod wasn't scheduled or - # had some other error, such as image pull error. - # See list of possible reasons for waiting here: - # https://stackoverflow.com/a/57886025 - waiting = container_status.state.waiting - if waiting is not None and waiting.reason != 'ContainerCreating': - raise config.KubernetesError( - 'Failed to create container while launching ' - 'the node. Error details: ' - f'{container_status.state.waiting.message}.') - # Reaching this point means that one of the pods had an issue, - # so break out of the loop - break - - if all_pods_running: - break - time.sleep(1) - - def _check_user_privilege(self, new_nodes): - # Checks if the default user has sufficient privilege to set up - # the kubernetes instance pod. - check_k8s_user_sudo_cmd = [ - '/bin/sh', - '-c', - ( - 'if [ $(id -u) -eq 0 ]; then' - # If user is root, create an alias for sudo used in skypilot setup - ' echo \'alias sudo=""\' >> ~/.bashrc; ' - 'else ' - ' if command -v sudo >/dev/null 2>&1; then ' - ' timeout 2 sudo -l >/dev/null 2>&1 || ' - f' ( echo {exceptions.INSUFFICIENT_PRIVILEGES_CODE!r}; ); ' - ' else ' - f' ( echo {exceptions.INSUFFICIENT_PRIVILEGES_CODE!r}; ); ' - ' fi; ' - 'fi') - ] - - for new_node in new_nodes: - privilege_check = run_command_on_pods(new_node.metadata.name, - self.namespace, - check_k8s_user_sudo_cmd) - if privilege_check == str(exceptions.INSUFFICIENT_PRIVILEGES_CODE): - raise config.KubernetesError( - 'Insufficient system privileges detected. ' - 'Ensure the default user has root access or ' - '"sudo" is installed and the user is added to the sudoers ' - 'from the image.') - - def _setup_ssh_in_pods(self, new_nodes): - # Setting up ssh for the pod instance. This is already setup for - # the jump pod so it does not need to be run for it. - set_k8s_ssh_cmd = [ - '/bin/sh', '-c', - ('prefix_cmd() { if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; }; ' - 'export DEBIAN_FRONTEND=noninteractive;' - '$(prefix_cmd) apt-get update;' - '$(prefix_cmd) apt install openssh-server rsync -y; ' - '$(prefix_cmd) mkdir -p /var/run/sshd; ' - '$(prefix_cmd) sed -i "s/PermitRootLogin prohibit-password/PermitRootLogin yes/" /etc/ssh/sshd_config; ' - '$(prefix_cmd) sed "s@session\\s*required\\s*pam_loginuid.so@session optional pam_loginuid.so@g" -i /etc/pam.d/sshd; ' - 'cd /etc/ssh/ && $(prefix_cmd) ssh-keygen -A; ' - '$(prefix_cmd) mkdir -p ~/.ssh; ' - '$(prefix_cmd) cat /etc/secret-volume/ssh-publickey* > ~/.ssh/authorized_keys; ' - '$(prefix_cmd) service ssh restart') - ] - - # TODO(romilb): We need logging and surface errors here. - for new_node in new_nodes: - run_command_on_pods(new_node.metadata.name, self.namespace, - set_k8s_ssh_cmd) - - def _set_env_vars_in_pods(self, new_nodes): - """Setting environment variables in pods. - - Once all containers are ready, we can exec into them and set env vars. - Kubernetes automatically populates containers with critical - environment variables, such as those for discovering services running - in the cluster and CUDA/nvidia environment variables. We need to - make sure these env vars are available in every task and ssh session. - This is needed for GPU support and service discovery. - See https://github.com/skypilot-org/skypilot/issues/2287 for - more details. - - To do so, we capture env vars from the pod's runtime and write them to - /etc/profile.d/, making them available for all users in future - shell sessions. - """ - set_k8s_env_var_cmd = [ - '/bin/sh', '-c', - ('prefix_cmd() { if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; } && ' - 'printenv | awk -F "=" \'{print "export " $1 "=\\047" $2 "\\047"}\' > ~/k8s_env_var.sh && ' - 'mv ~/k8s_env_var.sh /etc/profile.d/k8s_env_var.sh || ' - '$(prefix_cmd) mv ~/k8s_env_var.sh /etc/profile.d/k8s_env_var.sh') - ] - - for new_node in new_nodes: - run_command_on_pods(new_node.metadata.name, self.namespace, - set_k8s_env_var_cmd) - - def _update_ssh_user_config(self, new_nodes, cluster_name_with_hash): - get_k8s_ssh_user_cmd = ['/bin/sh', '-c', ('echo $(whoami)')] - for new_node in new_nodes: - ssh_user = run_command_on_pods(new_node.metadata.name, - self.namespace, get_k8s_ssh_user_cmd) - - cluster_yaml_path = self._recover_cluster_yaml_path( - cluster_name_with_hash) - with open(cluster_yaml_path, 'r') as f: - content = f.read() - - # Replacing the default ssh user name with the actual user name. - # This updates user name specified in user's custom image if it's used. - content = re.sub(r'ssh_user: \w+', f'ssh_user: {ssh_user}', content) - - with open(cluster_yaml_path, 'w') as f: - f.write(content) - - def create_node(self, node_config, tags, count): - conf = copy.deepcopy(node_config) - pod_spec = conf.get('pod', conf) - service_spec = conf.get('service') - node_uuid = str(uuid4()) - tags[TAG_RAY_CLUSTER_NAME] = self.cluster_name - tags['ray-node-uuid'] = node_uuid - pod_spec['metadata']['namespace'] = self.namespace - if 'labels' in pod_spec['metadata']: - pod_spec['metadata']['labels'].update(tags) - else: - pod_spec['metadata']['labels'] = tags - - # Allow Operator-configured service to access the head node. - if tags[TAG_RAY_NODE_KIND] == NODE_KIND_HEAD: - head_selector = head_service_selector(self.cluster_name) - pod_spec['metadata']['labels'].update(head_selector) - - logger.info(config.log_prefix + - 'calling create_namespaced_pod (count={}).'.format(count)) - new_nodes = [] - for _ in range(count): - pod = kubernetes.core_api().create_namespaced_pod( - self.namespace, pod_spec) - new_nodes.append(pod) - - new_svcs = [] - if service_spec is not None: - logger.info(config.log_prefix + 'calling create_namespaced_service ' - '(count={}).'.format(count)) - - for new_node in new_nodes: - metadata = service_spec.get('metadata', {}) - metadata['name'] = new_node.metadata.name - service_spec['metadata'] = metadata - service_spec['spec']['selector'] = {'ray-node-uuid': node_uuid} - svc = kubernetes.core_api().create_namespaced_service( - self.namespace, service_spec) - new_svcs.append(svc) - - # Adding the jump pod to the new_nodes list as well so it can be - # checked if it's scheduled and running along with other pod instances. - ssh_jump_pod_name = conf['metadata']['labels']['skypilot-ssh-jump'] - new_nodes_with_jump_pod = new_nodes[:] - jump_pod = kubernetes.core_api().read_namespaced_pod( - ssh_jump_pod_name, self.namespace) - new_nodes_with_jump_pod.append(jump_pod) - node_names = [node.metadata.name for node in new_nodes_with_jump_pod] - - # Wait until the pods are scheduled and surface cause for error - # if there is one - logger.info(config.log_prefix + - f'Waiting for pods to schedule. Pods: {node_names}') - self._wait_for_pods_to_schedule(new_nodes_with_jump_pod) - # Wait until the pods and their containers are up and running, and - # fail early if there is an error - logger.info(config.log_prefix + - f'Waiting for pods to run. Pods: {node_names}') - self._wait_for_pods_to_run(new_nodes_with_jump_pod) - logger.info(config.log_prefix + - f'Checking if user in image has sufficient privileges.') - self._check_user_privilege(new_nodes) - logger.info(config.log_prefix + f'Setting up SSH in pod.') - self._setup_ssh_in_pods(new_nodes) - logger.info(config.log_prefix + - f'Setting up environment variables in pod.') - self._set_env_vars_in_pods(new_nodes) - cluster_name_with_hash = conf['metadata']['labels']['skypilot-cluster'] - logger.info(config.log_prefix + f'Fetching and updating ssh username.') - self._update_ssh_user_config(new_nodes, cluster_name_with_hash) - - def terminate_node(self, node_id): - logger.info(config.log_prefix + 'calling delete_namespaced_pod') - try: - kubernetes_utils.clean_zombie_ssh_jump_pod(self.namespace, node_id) - except Exception as e: - logger.warning(config.log_prefix + - f'Error occurred when analyzing SSH Jump pod: {e}') - try: - kubernetes.core_api().delete_namespaced_service( - node_id, - self.namespace, - _request_timeout=config.DELETION_TIMEOUT) - kubernetes.core_api().delete_namespaced_service( - f'{node_id}-ssh', - self.namespace, - _request_timeout=config.DELETION_TIMEOUT) - except kubernetes.api_exception(): - pass - # Note - delete pod after all other resources are deleted. - # This is to ensure there are no leftover resources if this down is run - # from within the pod, e.g., for autodown. - try: - kubernetes.core_api().delete_namespaced_pod( - node_id, - self.namespace, - _request_timeout=config.DELETION_TIMEOUT) - except kubernetes.api_exception() as e: - if e.status == 404: - logger.warning(config.log_prefix + - f'Tried to delete pod {node_id},' - ' but the pod was not found (404).') - else: - raise - - def terminate_nodes(self, node_ids): - # TODO(romilb): terminate_nodes should be include optimizations for - # deletion of multiple nodes. Currently, it deletes one node at a time. - # We should look in to using deletecollection here for batch deletion. - for node_id in node_ids: - self.terminate_node(node_id) - - def get_command_runner(self, - log_prefix, - node_id, - auth_config, - cluster_name_with_hash, - process_runner, - use_internal_ip, - docker_config=None): - """Returns the CommandRunner class used to perform SSH commands. - - Args: - log_prefix(str): stores "NodeUpdater: {}: ".format(). Used - to print progress in the CommandRunner. - node_id(str): the node ID. - auth_config(dict): the authentication configs from the autoscaler - yaml file. - cluster_name_with_hash(str): the name of the cluster and hash value, - separated by a hyphen. - process_runner(module): the module to use to run the commands - in the CommandRunner. E.g., subprocess. - use_internal_ip(bool): whether the node_id belongs to an internal ip - or external ip. - docker_config(dict): If set, the docker information of the docker - container that commands should be run on. - """ - # For custom images, the username might differ across images. - # The 'ssh_user' is updated inplace in the YAML at the end of the - # 'create_node()' process in _update_ssh_user_config. - # Since the node provider is initialized with stale auth information, - # we need to reload the updated user from YAML. - cluster_yaml_path = self._recover_cluster_yaml_path( - cluster_name_with_hash) - ssh_credentials = backend_utils.ssh_credential_from_yaml( - cluster_yaml_path) - auth_config['ssh_user'] = ssh_credentials['ssh_user'] - - common_args = { - 'log_prefix': log_prefix, - 'node_id': node_id, - 'provider': self, - 'auth_config': auth_config, - 'cluster_name': cluster_name_with_hash, - 'process_runner': process_runner, - 'use_internal_ip': use_internal_ip, - } - command_runner = SSHCommandRunner(**common_args) - if use_internal_ip: - port = 22 - else: - port = self.external_port(node_id) - command_runner.set_port(port) - return command_runner - - @staticmethod - def bootstrap_config(cluster_config): - return config.bootstrap_kubernetes(cluster_config) - - @staticmethod - def fillout_available_node_types_resources(cluster_config): - """Fills out missing "resources" field for available_node_types.""" - return config.fillout_resources_kubernetes(cluster_config) diff --git a/sky/spot/controller.py b/sky/spot/controller.py index 46f1f8b0fc2..fd66fa597ff 100644 --- a/sky/spot/controller.py +++ b/sky/spot/controller.py @@ -70,7 +70,6 @@ def __init__(self, job_id: int, dag_yaml: str, for i, task in enumerate(self._dag.tasks): task_envs = task.envs or {} - task_envs[constants.TASK_ID_ENV_VAR_DEPRECATED] = job_id_env_vars[i] task_envs[constants.TASK_ID_ENV_VAR] = job_id_env_vars[i] task_envs[constants.TASK_ID_LIST_ENV_VAR] = '\n'.join( job_id_env_vars) diff --git a/sky/spot/spot_utils.py b/sky/spot/spot_utils.py index 815cc52c246..43bcd31020f 100644 --- a/sky/spot/spot_utils.py +++ b/sky/spot/spot_utils.py @@ -178,9 +178,6 @@ def callback_func(state: str): bash_command=event_callback, log_path=log_path, env_vars=dict( - SKYPILOT_JOB_ID=str( - task.envs.get(constants.TASK_ID_ENV_VAR_DEPRECATED, - 'N.A.')), SKYPILOT_TASK_ID=str( task.envs.get(constants.TASK_ID_ENV_VAR, 'N.A.')), SKYPILOT_TASK_IDS=str( diff --git a/sky/task.py b/sky/task.py index 0f9cfe01053..acca4ff3df5 100644 --- a/sky/task.py +++ b/sky/task.py @@ -262,7 +262,7 @@ def __init__( self.outputs: Optional[str] = None self.estimated_inputs_size_gigabytes: Optional[float] = None self.estimated_outputs_size_gigabytes: Optional[float] = None - # Default to CPUNode + # Default to CPU VM self.resources: Union[List[sky.Resources], Set[sky.Resources]] = {sky.Resources()} self._service: Optional[service_spec.SkyServiceSpec] = None diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 28fcfaf0c59..48a8ef2bde8 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -243,7 +243,6 @@ def get_pretty_entry_point() -> str: Example return values: $ sky launch app.yaml # 'sky launch app.yaml' - $ sky gpunode # 'sky gpunode' $ python examples/app.py # 'app.py' """ argv = sys.argv diff --git a/tests/test_cli.py b/tests/test_cli.py index 7ba55109a91..96203bb8186 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,39 +12,6 @@ ] -def test_infer_gpunode_type(): - resources = [ - sky.Resources(cloud=sky.AWS(), instance_type='p3.2xlarge'), - sky.Resources(cloud=sky.GCP(), accelerators='K80'), - sky.Resources(accelerators={'V100': 8}), - sky.Resources(cloud=sky.Azure(), accelerators='A100'), - ] - for spec in resources: - assert cli._infer_interactive_node_type(spec) == 'gpunode', spec - - -def test_infer_cpunode_type(): - resources = [ - sky.Resources(cloud=sky.AWS(), instance_type='m5.2xlarge'), - sky.Resources(cloud=sky.GCP()), - sky.Resources(), - ] - for spec in resources: - assert cli._infer_interactive_node_type(spec) == 'cpunode', spec - - -def test_infer_tpunode_type(): - resources = [ - sky.Resources(cloud=sky.GCP(), accelerators='tpu-v3-8'), - sky.Resources(cloud=sky.GCP(), accelerators='tpu-v2-32'), - sky.Resources(cloud=sky.GCP(), - accelerators={'tpu-v2-128': 1}, - accelerator_args={'tpu_name': 'tpu'}), - ] - for spec in resources: - assert cli._infer_interactive_node_type(spec) == 'tpunode', spec - - def test_accelerator_mismatch(enable_all_clouds): """Test the specified accelerator does not match the instance_type."""