diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..db40b03b5fa --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,74 @@ +# Ensure this configuration aligns with format.sh and requirements.txt +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/psf/black + rev: 22.10.0 # Match the version from requirements + hooks: + - id: black + name: black (IBM specific) + files: "^sky/skylet/providers/ibm/.*" # Match only files in the IBM directory + +- repo: https://github.com/pycqa/isort + rev: 5.12.0 # Match the version from requirements + hooks: + # First isort command + - id: isort + name: isort (general) + args: + - "--sg=build/**" # Matches "${ISORT_YAPF_EXCLUDES[@]}" + - "--sg=sky/skylet/providers/ibm/**" + files: "^(sky|tests|examples|llm|docs)/.*" # Only match these directories + # Second isort command + - id: isort + name: isort (IBM specific) + args: + - "--profile=black" + - "-l=88" + - "-m=3" + files: "^sky/skylet/providers/ibm/.*" # Only match IBM-specific directory + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.991 # Match the version from requirements + hooks: + - id: mypy + args: + # From tests/mypy_files.txt + - "sky" + - "--exclude" + - "sky/benchmark|sky/callbacks|sky/skylet/providers/azure|sky/resources.py|sky/backends/monkey_patches" + pass_filenames: false + additional_dependencies: + - types-PyYAML + - types-requests<2.31 # Match the condition in requirements.txt + - types-setuptools + - types-cachetools + - types-pyvmomi + +- repo: https://github.com/google/yapf + rev: v0.32.0 # Match the version from requirements + hooks: + - id: yapf + name: yapf + exclude: (build/.*|sky/skylet/providers/ibm/.*) # Matches exclusions from the script + args: ['--recursive', '--parallel'] # Only necessary flags + additional_dependencies: [toml==0.10.2] + +- repo: https://github.com/pylint-dev/pylint + rev: v2.14.5 # Match the version from requirements + hooks: + - id: pylint + additional_dependencies: + - pylint-quotes==0.2.3 # Match the version from requirements + name: pylint + args: + - --rcfile=.pylintrc # Use your custom pylint configuration + - --load-plugins=pylint_quotes # Load the pylint-quotes plugin + files: ^sky/ # Only include files from the 'sky/' directory + exclude: ^sky/skylet/providers/ibm/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fc115331fa2..85ca90b2c4a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -78,6 +78,7 @@ It has some convenience features which you might find helpful (see [Dockerfile]( - If relevant, add tests for your changes. For changes that touch the core system, run the [smoke tests](#testing) and ensure they pass. - Follow the [Google style guide](https://google.github.io/styleguide/pyguide.html). - Ensure code is properly formatted by running [`format.sh`](https://github.com/skypilot-org/skypilot/blob/master/format.sh). + - [Optional] You can also install pre-commit hooks by running `pre-commit install` to automatically format your code on commit. - Push your changes to your fork and open a pull request in the SkyPilot repository. - In the PR description, write a `Tested:` section to describe relevant tests performed. diff --git a/Dockerfile_k8s b/Dockerfile_k8s index 45625871078..f031dff3668 100644 --- a/Dockerfile_k8s +++ b/Dockerfile_k8s @@ -7,7 +7,7 @@ ARG DEBIAN_FRONTEND=noninteractive # Initialize conda for root user, install ssh and other local dependencies RUN apt update -y && \ - apt install git gcc rsync sudo patch openssh-server pciutils nano fuse socat netcat curl -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils nano fuse socat netcat-openbsd curl -y && \ rm -rf /var/lib/apt/lists/* && \ apt remove -y python3 && \ conda init diff --git a/Dockerfile_k8s_gpu b/Dockerfile_k8s_gpu index 09570d102df..6277e7f8d12 100644 --- a/Dockerfile_k8s_gpu +++ b/Dockerfile_k8s_gpu @@ -7,7 +7,7 @@ ARG DEBIAN_FRONTEND=noninteractive # We remove cuda lists to avoid conflicts with the cuda version installed by ray RUN rm -rf /etc/apt/sources.list.d/cuda* && \ apt update -y && \ - apt install git gcc rsync sudo patch openssh-server pciutils nano fuse unzip socat netcat curl -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils nano fuse unzip socat netcat-openbsd curl -y && \ rm -rf /var/lib/apt/lists/* # Setup SSH and generate hostkeys @@ -36,6 +36,7 @@ SHELL ["/bin/bash", "-c"] # Install conda and other dependencies # Keep the conda and Ray versions below in sync with the ones in skylet.constants +# Keep this section in sync with the custom image optimization recommendations in our docs (kubernetes-getting-started.rst) RUN curl https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -o Miniconda3-Linux-x86_64.sh && \ bash Miniconda3-Linux-x86_64.sh -b && \ eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true && conda activate base && \ diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst index 69303a582e2..deb2307b67b 100644 --- a/docs/source/getting-started/installation.rst +++ b/docs/source/getting-started/installation.rst @@ -267,6 +267,14 @@ The :code:`~/.oci/config` file should contain the following fields: # Note that we should avoid using full home path for the key_file configuration, e.g. use ~/.oci instead of /home/username/.oci key_file=~/.oci/oci_api_key.pem +By default, the provisioned nodes will be in the root `compartment `__. To specify the `compartment `_ other than root, create/edit the file :code:`~/.sky/config.yaml`, put the compartment's OCID there, as the following: + +.. code-block:: text + + oci: + default: + compartment_ocid: ocid1.compartment.oc1..aaaaaaaa...... + Lambda Cloud ~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/reference/comparison.rst b/docs/source/reference/comparison.rst index e9bffabba68..23985e5081b 100644 --- a/docs/source/reference/comparison.rst +++ b/docs/source/reference/comparison.rst @@ -46,7 +46,7 @@ SkyPilot provides faster iteration for interactive development. For example, a c * :strong:`With SkyPilot, a single command (`:literal:`sky launch`:strong:`) takes care of everything.` Behind the scenes, SkyPilot provisions pods, installs all required dependencies, executes the job, returns logs, and provides SSH and VSCode access to debug. -.. figure:: https://blog.skypilot.co/ai-on-kubernetes/images/k8s_vs_skypilot_iterative_v2.png +.. figure:: https://i.imgur.com/xfCfz4N.png :align: center :width: 95% :alt: Iterative Development with Kubernetes vs SkyPilot diff --git a/docs/source/reference/kubernetes/kubernetes-getting-started.rst b/docs/source/reference/kubernetes/kubernetes-getting-started.rst index 0e19eb6e266..e4bbb2c8915 100644 --- a/docs/source/reference/kubernetes/kubernetes-getting-started.rst +++ b/docs/source/reference/kubernetes/kubernetes-getting-started.rst @@ -324,3 +324,32 @@ FAQs type: Directory For more details refer to :ref:`config-yaml`. + +* **I am using a custom image. How can I speed up the pod startup time?** + + You can pre-install SkyPilot dependencies in your custom image to speed up the pod startup time. Simply add these lines at the end of your Dockerfile: + + .. code-block:: dockerfile + + FROM + + # Install system dependencies + RUN apt update -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils fuse unzip socat netcat-openbsd curl -y && \ + rm -rf /var/lib/apt/lists/* + + # Install conda and other python dependencies + RUN curl https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -o Miniconda3-Linux-x86_64.sh && \ + bash Miniconda3-Linux-x86_64.sh -b && \ + eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true && conda activate base && \ + grep "# >>> conda initialize >>>" ~/.bashrc || { conda init && source ~/.bashrc; } && \ + rm Miniconda3-Linux-x86_64.sh && \ + export PIP_DISABLE_PIP_VERSION_CHECK=1 && \ + python3 -m venv ~/skypilot-runtime && \ + PYTHON_EXEC=$(echo ~/skypilot-runtime)/bin/python && \ + $PYTHON_EXEC -m pip install 'skypilot-nightly[remote,kubernetes]' 'ray[default]==2.9.3' 'pycryptodome==3.12.0' && \ + $PYTHON_EXEC -m pip uninstall skypilot-nightly -y && \ + curl -LO "https://dl.k8s.io/release/v1.28.11/bin/linux/amd64/kubectl" && \ + sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl && \ + echo 'export PATH="$PATH:$HOME/.local/bin"' >> ~/.bashrc + diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 7addcffbe3c..ff95162ac63 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -139,7 +139,7 @@ def update_current_kubernetes_clusters_from_registry(): def get_allowed_contexts(): """Mock implementation of getting allowed kubernetes contexts.""" from sky.provision.kubernetes import utils - contexts = utils.get_all_kube_config_context_names() + contexts = utils.get_all_kube_context_names() return contexts[:2] diff --git a/examples/oci/serve-qwen-7b.yaml b/examples/oci/serve-qwen-7b.yaml index 799e5a7d891..004e912b088 100644 --- a/examples/oci/serve-qwen-7b.yaml +++ b/examples/oci/serve-qwen-7b.yaml @@ -13,8 +13,8 @@ resources: setup: | conda create -n vllm python=3.12 -y conda activate vllm - pip install vllm - pip install vllm-flash-attn + pip install vllm==0.6.3.post1 + pip install vllm-flash-attn==2.6.2 run: | conda activate vllm diff --git a/sky/adaptors/kubernetes.py b/sky/adaptors/kubernetes.py index ea8fb194efa..001d397ac9e 100644 --- a/sky/adaptors/kubernetes.py +++ b/sky/adaptors/kubernetes.py @@ -19,6 +19,13 @@ # Timeout to use for API calls API_TIMEOUT = 5 +DEFAULT_IN_CLUSTER_REGION = 'in-cluster' +# The name for the environment variable that stores the in-cluster context name +# for Kubernetes clusters. This is used to associate a name with the current +# context when running with in-cluster auth. If not set, the context name is +# set to DEFAULT_IN_CLUSTER_REGION. +IN_CLUSTER_CONTEXT_NAME_ENV_VAR = 'SKYPILOT_IN_CLUSTER_CONTEXT_NAME' + def _decorate_methods(obj: Any, decorator: Callable, decoration_type: str): for attr_name in dir(obj): @@ -57,16 +64,8 @@ def wrapped(*args, **kwargs): def _load_config(context: Optional[str] = None): urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - try: - # Load in-cluster config if running in a pod - # Kubernetes set environment variables for service discovery do not - # show up in SkyPilot tasks. For now, we work around by using - # DNS name instead of environment variables. - # See issue: https://github.com/skypilot-org/skypilot/issues/2287 - os.environ['KUBERNETES_SERVICE_HOST'] = 'kubernetes.default.svc' - os.environ['KUBERNETES_SERVICE_PORT'] = '443' - kubernetes.config.load_incluster_config() - except kubernetes.config.config_exception.ConfigException: + + def _load_config_from_kubeconfig(context: Optional[str] = None): try: kubernetes.config.load_kube_config(context=context) except kubernetes.config.config_exception.ConfigException as e: @@ -90,6 +89,21 @@ def _load_config(context: Optional[str] = None): with ux_utils.print_exception_no_traceback(): raise ValueError(err_str) from None + if context == in_cluster_context_name() or context is None: + try: + # Load in-cluster config if running in a pod and context is None. + # Kubernetes set environment variables for service discovery do not + # show up in SkyPilot tasks. For now, we work around by using + # DNS name instead of environment variables. + # See issue: https://github.com/skypilot-org/skypilot/issues/2287 + os.environ['KUBERNETES_SERVICE_HOST'] = 'kubernetes.default.svc' + os.environ['KUBERNETES_SERVICE_PORT'] = '443' + kubernetes.config.load_incluster_config() + except kubernetes.config.config_exception.ConfigException: + _load_config_from_kubeconfig() + else: + _load_config_from_kubeconfig(context) + @_api_logging_decorator('urllib3', logging.ERROR) @functools.lru_cache() @@ -154,3 +168,13 @@ def max_retry_error(): def stream(): return kubernetes.stream.stream + + +def in_cluster_context_name() -> Optional[str]: + """Returns the name of the in-cluster context from the environment. + + If the environment variable is not set, returns the default in-cluster + context name. + """ + return (os.environ.get(IN_CLUSTER_CONTEXT_NAME_ENV_VAR) or + DEFAULT_IN_CLUSTER_REGION) diff --git a/sky/authentication.py b/sky/authentication.py index 41a7d02dfb7..2eb65bd9f6f 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -380,8 +380,8 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name context = config['provider'].get( 'context', kubernetes_utils.get_current_kube_config_context_name()) - if context == kubernetes_utils.IN_CLUSTER_REGION: - # If the context is set to IN_CLUSTER_REGION, we are running in a pod + if context == kubernetes.in_cluster_context_name(): + # If the context is an in-cluster context name, we are running in a pod # with in-cluster configuration. We need to set the context to None # to use the mounted service account. context = None diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 4ddda8f1f5c..c8abd84ccb4 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -683,7 +683,7 @@ def write_cluster_config( resources_utils.ClusterName( cluster_name, cluster_name_on_cloud, - ), region, zones, dryrun) + ), region, zones, num_nodes, dryrun) config_dict = {} specific_reservations = set( @@ -730,7 +730,12 @@ def write_cluster_config( f'{skypilot_config.loaded_config_path!r} for {cloud}, but it ' 'is not supported by this cloud. Remove the config or set: ' '`remote_identity: LOCAL_CREDENTIALS`.') - excluded_clouds.add(cloud) + if isinstance(cloud, clouds.Kubernetes): + if skypilot_config.get_nested( + ('kubernetes', 'allowed_contexts'), None) is None: + excluded_clouds.add(cloud) + else: + excluded_clouds.add(cloud) for cloud_str, cloud_obj in cloud_registry.CLOUD_REGISTRY.items(): remote_identity_config = skypilot_config.get_nested( @@ -844,7 +849,11 @@ def write_cluster_config( '{sky_wheel_hash}', wheel_hash).replace('{cloud}', str(cloud).lower())), - + 'skypilot_wheel_installation_commands': + constants.SKYPILOT_WHEEL_INSTALLATION_COMMANDS.replace( + '{sky_wheel_hash}', + wheel_hash).replace('{cloud}', + str(cloud).lower()), # Port of Ray (GCS server). # Ray's default port 6379 is conflicted with Redis. 'ray_port': constants.SKY_REMOTE_RAY_PORT, @@ -1190,18 +1199,18 @@ def ssh_credential_from_yaml( def parallel_data_transfer_to_nodes( - runners: List[command_runner.CommandRunner], - source: Optional[str], - target: str, - cmd: Optional[str], - run_rsync: bool, - *, - action_message: str, - # Advanced options. - log_path: str = os.devnull, - stream_logs: bool = False, - source_bashrc: bool = False, -): + runners: List[command_runner.CommandRunner], + source: Optional[str], + target: str, + cmd: Optional[str], + run_rsync: bool, + *, + action_message: str, + # Advanced options. + log_path: str = os.devnull, + stream_logs: bool = False, + source_bashrc: bool = False, + num_threads: Optional[int] = None): """Runs a command on all nodes and optionally runs rsync from src->dst. Args: @@ -1213,6 +1222,7 @@ def parallel_data_transfer_to_nodes( log_path: str; Path to the log file stream_logs: bool; Whether to stream logs to stdout source_bashrc: bool; Source bashrc before running the command. + num_threads: Optional[int]; Number of threads to use. """ style = colorama.Style @@ -1253,7 +1263,7 @@ def _sync_node(runner: 'command_runner.CommandRunner') -> None: message = (f' {style.DIM}{action_message} (to {num_nodes} node{plural})' f': {origin_source} -> {target}{style.RESET_ALL}') logger.info(message) - subprocess_utils.run_in_parallel(_sync_node, runners) + subprocess_utils.run_in_parallel(_sync_node, runners, num_threads) def check_local_gpus() -> bool: diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index d00560ece23..5682cf24586 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -269,6 +269,13 @@ def add_prologue(self, job_id: int) -> None: import time from typing import Dict, List, Optional, Tuple, Union + # Set the environment variables to avoid deduplicating logs and + # scheduler events. This should be set in driver code, since we are + # not using `ray job submit` anymore, and the environment variables + # from the ray cluster is not inherited. + os.environ['RAY_DEDUP_LOGS'] = '0' + os.environ['RAY_SCHEDULER_EVENTS'] = '0' + import ray import ray.util as ray_util @@ -1528,7 +1535,7 @@ def _retry_zones( to_provision, resources_utils.ClusterName( cluster_name, handle.cluster_name_on_cloud), - region, zones)) + region, zones, num_nodes)) config_dict['provision_record'] = provision_record config_dict['resources_vars'] = resources_vars config_dict['handle'] = handle @@ -3086,9 +3093,12 @@ def _sync_workdir_node(runner: command_runner.CommandRunner) -> None: f'{workdir} -> {SKY_REMOTE_WORKDIR}{style.RESET_ALL}') os.makedirs(os.path.expanduser(self.log_dir), exist_ok=True) os.system(f'touch {log_path}') + num_threads = subprocess_utils.get_parallel_threads( + str(handle.launched_resources.cloud)) with rich_utils.safe_status( ux_utils.spinner_message('Syncing workdir', log_path)): - subprocess_utils.run_in_parallel(_sync_workdir_node, runners) + subprocess_utils.run_in_parallel(_sync_workdir_node, runners, + num_threads) logger.info(ux_utils.finishing_message('Workdir synced.', log_path)) def _sync_file_mounts( @@ -4416,6 +4426,8 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, start = time.time() runners = handle.get_command_runners() log_path = os.path.join(self.log_dir, 'file_mounts.log') + num_threads = subprocess_utils.get_max_workers_for_file_mounts( + file_mounts, str(handle.launched_resources.cloud)) # Check the files and warn for dst, src in file_mounts.items(): @@ -4477,6 +4489,7 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, action_message='Syncing', log_path=log_path, stream_logs=False, + num_threads=num_threads, ) continue @@ -4513,6 +4526,7 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, # Need to source bashrc, as the cloud specific CLI or SDK may # require PATH in bashrc. source_bashrc=True, + num_threads=num_threads, ) # (2) Run the commands to create symlinks on all the nodes. symlink_command = ' && '.join(symlink_commands) @@ -4531,7 +4545,8 @@ def _symlink_node(runner: command_runner.CommandRunner): 'Failed to create symlinks. The target destination ' f'may already exist. Log: {log_path}') - subprocess_utils.run_in_parallel(_symlink_node, runners) + subprocess_utils.run_in_parallel(_symlink_node, runners, + num_threads) end = time.time() logger.debug(f'File mount sync took {end - start} seconds.') logger.info(ux_utils.finishing_message('Files synced.', log_path)) @@ -4560,6 +4575,8 @@ def _execute_storage_mounts( return start = time.time() runners = handle.get_command_runners() + num_threads = subprocess_utils.get_parallel_threads( + str(handle.launched_resources.cloud)) log_path = os.path.join(self.log_dir, 'storage_mounts.log') plural = 's' if len(storage_mounts) > 1 else '' @@ -4598,6 +4615,7 @@ def _execute_storage_mounts( # Need to source bashrc, as the cloud specific CLI or SDK # may require PATH in bashrc. source_bashrc=True, + num_threads=num_threads, ) except exceptions.CommandError as e: if e.returncode == exceptions.MOUNT_PATH_NON_EMPTY_CODE: diff --git a/sky/cli.py b/sky/cli.py index c49b692add1..94474b30b6c 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -486,7 +486,7 @@ def _parse_override_params( image_id: Optional[str] = None, disk_size: Optional[int] = None, disk_tier: Optional[str] = None, - ports: Optional[Tuple[str]] = None) -> Dict[str, Any]: + ports: Optional[Tuple[str, ...]] = None) -> Dict[str, Any]: """Parses the override parameters into a dictionary.""" override_params: Dict[str, Any] = {} if cloud is not None: @@ -539,7 +539,14 @@ def _parse_override_params( else: override_params['disk_tier'] = disk_tier if ports: - override_params['ports'] = ports + if any(p.lower() == 'none' for p in ports): + if len(ports) > 1: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Cannot specify both "none" and other ' + 'ports.') + override_params['ports'] = None + else: + override_params['ports'] = ports return override_params @@ -730,7 +737,7 @@ def _make_task_or_dag_from_entrypoint_with_overrides( image_id: Optional[str] = None, disk_size: Optional[int] = None, disk_tier: Optional[str] = None, - ports: Optional[Tuple[str]] = None, + ports: Optional[Tuple[str, ...]] = None, env: Optional[List[Tuple[str, str]]] = None, field_to_ignore: Optional[List[str]] = None, # job launch specific @@ -1084,7 +1091,7 @@ def launch( env: List[Tuple[str, str]], disk_size: Optional[int], disk_tier: Optional[str], - ports: Tuple[str], + ports: Tuple[str, ...], idle_minutes_to_autostop: Optional[int], down: bool, # pylint: disable=redefined-outer-name retry_until_up: bool, diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 22e1039f121..c42d67f8ba4 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -401,6 +401,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Any]: del dryrun # unused assert zones is not None, (region, zones) diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 9d399869666..eb76d2b5e48 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -302,6 +302,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Any]: assert zones is None, ('Azure does not support zones', zones) diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index d8e6deac4f4..455baeaf5d9 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -283,6 +283,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'Region', zones: Optional[List['Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to cloud-specific resource variables. diff --git a/sky/clouds/cudo.py b/sky/clouds/cudo.py index 6f02e007049..145a5d1c26e 100644 --- a/sky/clouds/cudo.py +++ b/sky/clouds/cudo.py @@ -196,6 +196,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: del zones, cluster_name # unused diff --git a/sky/clouds/fluidstack.py b/sky/clouds/fluidstack.py index 31e2112f8f7..2668ea3e5e0 100644 --- a/sky/clouds/fluidstack.py +++ b/sky/clouds/fluidstack.py @@ -176,6 +176,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: clouds.Region, zones: Optional[List[clouds.Zone]], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 0e20fdc9789..8a28a35505e 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -417,6 +417,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: assert zones is not None, (region, zones) diff --git a/sky/clouds/ibm.py b/sky/clouds/ibm.py index 0ac3c36cc48..13f6a27e78a 100644 --- a/sky/clouds/ibm.py +++ b/sky/clouds/ibm.py @@ -170,6 +170,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to cloud-specific resource variables. diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 5e1b46d52eb..ace14bc0c51 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -10,8 +10,10 @@ from sky import skypilot_config from sky.adaptors import kubernetes from sky.clouds import service_catalog +from sky.provision import instance_setup from sky.provision.kubernetes import network_utils from sky.provision.kubernetes import utils as kubernetes_utils +from sky.skylet import constants from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import schemas @@ -128,32 +130,30 @@ def _log_skipped_contexts_once(cls, skipped_contexts: Tuple[str, 'Ignoring these contexts.') @classmethod - def _existing_allowed_contexts(cls) -> List[Optional[str]]: + def _existing_allowed_contexts(cls) -> List[str]: """Get existing allowed contexts. If None is returned in the list, it means that we are running in a pod with in-cluster auth. In this case, we specify None context, which will use the service account mounted in the pod. """ - all_contexts = kubernetes_utils.get_all_kube_config_context_names() + all_contexts = kubernetes_utils.get_all_kube_context_names() if len(all_contexts) == 0: return [] - if all_contexts == [None]: - # If only one context is found and it is None, we are running in a - # pod with in-cluster auth. In this case, we allow it to be used - # without checking against allowed_contexts. - # TODO(romilb): We may want check in-cluster auth against - # allowed_contexts in the future by adding a special context name - # for in-cluster auth. - return [None] + all_contexts = set(all_contexts) allowed_contexts = skypilot_config.get_nested( ('kubernetes', 'allowed_contexts'), None) if allowed_contexts is None: + # Try kubeconfig if present current_context = ( kubernetes_utils.get_current_kube_config_context_name()) + if (current_context is None and + kubernetes_utils.is_incluster_config_available()): + # If no kubeconfig contexts found, use in-cluster if available + current_context = kubernetes.in_cluster_context_name() allowed_contexts = [] if current_context is not None: allowed_contexts = [current_context] @@ -178,13 +178,7 @@ def regions_with_offering(cls, instance_type: Optional[str], regions = [] for context in existing_contexts: - if context is None: - # If running in-cluster, we allow the region to be set to the - # singleton region since there is no context name available. - regions.append(clouds.Region( - kubernetes_utils.IN_CLUSTER_REGION)) - else: - regions.append(clouds.Region(context)) + regions.append(clouds.Region(context)) if region is not None: regions = [r for r in regions if r.name == region] @@ -311,12 +305,34 @@ def get_image_size(cls, image_id: str, region: Optional[str]) -> int: # we don't have a notion of disk size in Kubernetes. return 0 + @staticmethod + def _calculate_provision_timeout(num_nodes: int) -> int: + """Calculate provision timeout based on number of nodes. + + The timeout scales linearly with the number of nodes to account for + scheduling overhead, but is capped to avoid excessive waiting. + + Args: + num_nodes: Number of nodes being provisioned + + Returns: + Timeout in seconds + """ + base_timeout = 10 # Base timeout for single node + per_node_timeout = 0.2 # Additional seconds per node + max_timeout = 60 # Cap at 1 minute + + return int( + min(base_timeout + (per_node_timeout * (num_nodes - 1)), + max_timeout)) + def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', cluster_name: resources_utils.ClusterName, region: Optional['clouds.Region'], zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, zones, dryrun # Unused. if region is None: @@ -385,12 +401,25 @@ def make_deploy_resources_variables( remote_identity = skypilot_config.get_nested( ('kubernetes', 'remote_identity'), schemas.get_default_remote_identity('kubernetes')) - if (remote_identity == + + if isinstance(remote_identity, dict): + # If remote_identity is a dict, use the service account for the + # current context + k8s_service_account_name = remote_identity.get(context, None) + if k8s_service_account_name is None: + err_msg = (f'Context {context!r} not found in ' + 'remote identities from config.yaml') + raise ValueError(err_msg) + else: + # If remote_identity is not a dict, use + k8s_service_account_name = remote_identity + + if (k8s_service_account_name == schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value): # SA name doesn't matter since automounting credentials is disabled k8s_service_account_name = 'default' k8s_automount_sa_token = 'false' - elif (remote_identity == + elif (k8s_service_account_name == schemas.RemoteIdentityOptions.SERVICE_ACCOUNT.value): # Use the default service account k8s_service_account_name = ( @@ -398,7 +427,6 @@ def make_deploy_resources_variables( k8s_automount_sa_token = 'true' else: # User specified a custom service account - k8s_service_account_name = remote_identity k8s_automount_sa_token = 'true' fuse_device_required = bool(resources.requires_fuse) @@ -413,12 +441,30 @@ def make_deploy_resources_variables( # Larger timeout may be required for autoscaling clusters, since # autoscaler may take some time to provision new nodes. # Note that this timeout includes time taken by the Kubernetes scheduler - # itself, which can be upto 2-3 seconds. - # For non-autoscaling clusters, we conservatively set this to 10s. + # itself, which can be upto 2-3 seconds, and up to 10-15 seconds when + # scheduling 100s of pods. + # We use a linear scaling formula to determine the timeout based on the + # number of nodes. + + timeout = self._calculate_provision_timeout(num_nodes) timeout = skypilot_config.get_nested( ('kubernetes', 'provision_timeout'), - 10, + timeout, override_configs=resources.cluster_config_overrides) + + # Set environment variables for the pod. Note that SkyPilot env vars + # are set separately when the task is run. These env vars are + # independent of the SkyPilot task to be run. + k8s_env_vars = {kubernetes.IN_CLUSTER_CONTEXT_NAME_ENV_VAR: context} + + # We specify object-store-memory to be 500MB to avoid taking up too + # much memory on the head node. 'num-cpus' should be set to limit + # the CPU usage on the head pod, otherwise the ray cluster will use the + # CPU resources on the node instead within the pod. + custom_ray_options = { + 'object-store-memory': 500000000, + 'num-cpus': str(int(cpus)), + } deploy_vars = { 'instance_type': resources.instance_type, 'custom_resources': custom_resources, @@ -444,7 +490,14 @@ def make_deploy_resources_variables( 'k8s_topology_label_key': k8s_topology_label_key, 'k8s_topology_label_value': k8s_topology_label_value, 'k8s_resource_key': k8s_resource_key, + 'k8s_env_vars': k8s_env_vars, 'image_id': image_id, + 'ray_installation_commands': constants.RAY_INSTALLATION_COMMANDS, + 'ray_head_start_command': instance_setup.ray_head_start_command( + custom_resources, custom_ray_options), + 'skypilot_ray_port': constants.SKY_REMOTE_RAY_PORT, + 'ray_worker_start_command': instance_setup.ray_worker_start_command( + custom_resources, custom_ray_options, no_restart=False), } # Add kubecontext if it is set. It may be None if SkyPilot is running @@ -536,7 +589,11 @@ def _make(instance_list): @classmethod def check_credentials(cls) -> Tuple[bool, Optional[str]]: # Test using python API - existing_allowed_contexts = cls._existing_allowed_contexts() + try: + existing_allowed_contexts = cls._existing_allowed_contexts() + except ImportError as e: + return (False, + f'{common_utils.format_exception(e, use_bracket=True)}') if not existing_allowed_contexts: if skypilot_config.loaded_config_path() is None: check_skypilot_config_msg = '' @@ -579,16 +636,13 @@ def validate_region_zone(self, region: Optional[str], zone: Optional[str]): # TODO: Remove this after 0.9.0. return region, zone - if region == kubernetes_utils.IN_CLUSTER_REGION: + if region == kubernetes.in_cluster_context_name(): # If running incluster, we set region to IN_CLUSTER_REGION # since there is no context name available. return region, zone - all_contexts = kubernetes_utils.get_all_kube_config_context_names() - if all_contexts == [None]: - # If [None] context is returned, use the singleton region since we - # are running in a pod with in-cluster auth. - all_contexts = [kubernetes_utils.IN_CLUSTER_REGION] + all_contexts = kubernetes_utils.get_all_kube_context_names() + if region not in all_contexts: raise ValueError( f'Context {region} not found in kubeconfig. Kubernetes only ' diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index 055a5338750..11ec96a78c1 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -157,6 +157,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert zones is None, 'Lambda does not support zones.' diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index 37806ff8349..95f4efe95e3 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -208,6 +208,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: Optional['clouds.Region'], zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert region is not None, resources diff --git a/sky/clouds/paperspace.py b/sky/clouds/paperspace.py index 4047a2f5926..69a0d69ca61 100644 --- a/sky/clouds/paperspace.py +++ b/sky/clouds/paperspace.py @@ -175,6 +175,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del zones, dryrun, cluster_name diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 0d693fd9f60..487793ecf97 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -160,6 +160,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del zones, dryrun, cluster_name # unused diff --git a/sky/clouds/scp.py b/sky/clouds/scp.py index d0ad611bf0c..4a6b8564a97 100644 --- a/sky/clouds/scp.py +++ b/sky/clouds/scp.py @@ -181,6 +181,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert zones is None, 'SCP does not support zones.' diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 4deab8ac204..d28b530ff06 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -324,9 +324,8 @@ def get_common_gpus() -> List[str]: 'A100', 'A100-80GB', 'H100', - 'K80', 'L4', - 'M60', + 'L40S', 'P100', 'T4', 'V100', @@ -337,13 +336,13 @@ def get_common_gpus() -> List[str]: def get_tpus() -> List[str]: """Returns a list of TPU names.""" # TODO(wei-lin): refactor below hard-coded list. - # There are many TPU configurations available, we show the three smallest - # and the largest configuration for the latest gen TPUs. + # There are many TPU configurations available, we show the some smallest + # ones for each generation, and people should find larger ones with + # sky show-gpus tpu. return [ - 'tpu-v2-512', 'tpu-v3-2048', 'tpu-v4-8', 'tpu-v4-16', 'tpu-v4-32', - 'tpu-v4-3968', 'tpu-v5litepod-1', 'tpu-v5litepod-4', 'tpu-v5litepod-8', - 'tpu-v5litepod-256', 'tpu-v5p-8', 'tpu-v5p-32', 'tpu-v5p-128', - 'tpu-v5p-12288' + 'tpu-v2-8', 'tpu-v3-8', 'tpu-v4-8', 'tpu-v4-16', 'tpu-v4-32', + 'tpu-v5litepod-1', 'tpu-v5litepod-4', 'tpu-v5litepod-8', 'tpu-v5p-8', + 'tpu-v5p-16', 'tpu-v5p-32', 'tpu-v6e-1', 'tpu-v6e-4', 'tpu-v6e-8' ] diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py index f646cac339a..4aef41f9c90 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py @@ -64,7 +64,7 @@ 'standardNVSv2Family': 'M60', 'standardNVSv3Family': 'M60', 'standardNVPromoFamily': 'M60', - 'standardNVSv4Family': 'Radeon MI25', + 'standardNVSv4Family': 'MI25', 'standardNDSFamily': 'P40', 'StandardNVADSA10v5Family': 'A10', 'StandardNCadsH100v5Family': 'H100', diff --git a/sky/clouds/service_catalog/kubernetes_catalog.py b/sky/clouds/service_catalog/kubernetes_catalog.py index 6d11d1715e2..1d0c97c0442 100644 --- a/sky/clouds/service_catalog/kubernetes_catalog.py +++ b/sky/clouds/service_catalog/kubernetes_catalog.py @@ -65,9 +65,14 @@ def list_accelerators( # TODO(romilb): We should consider putting a lru_cache() with TTL to # avoid multiple calls to kubernetes API in a short period of time (e.g., # from the optimizer). - return list_accelerators_realtime(gpus_only, name_filter, region_filter, - quantity_filter, case_sensitive, - all_regions, require_price)[0] + return _list_accelerators(gpus_only, + name_filter, + region_filter, + quantity_filter, + case_sensitive, + all_regions, + require_price, + realtime=False)[0] def list_accelerators_realtime( @@ -78,10 +83,36 @@ def list_accelerators_realtime( case_sensitive: bool = True, all_regions: bool = False, require_price: bool = True +) -> Tuple[Dict[str, List[common.InstanceTypeInfo]], Dict[str, int], Dict[str, + int]]: + return _list_accelerators(gpus_only, + name_filter, + region_filter, + quantity_filter, + case_sensitive, + all_regions, + require_price, + realtime=True) + + +def _list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + quantity_filter: Optional[int], + case_sensitive: bool = True, + all_regions: bool = False, + require_price: bool = True, + realtime: bool = False ) -> Tuple[Dict[str, List[common.InstanceTypeInfo]], Dict[str, int], Dict[str, int]]: """List accelerators in the Kubernetes cluster. + If realtime is True, the function will query the cluster to fetch real-time + GPU usage, which is returned in total_accelerators_available. Note that + this may require an expensive list_pod_for_all_namespaces call, which + requires cluster-wide pod read permissions. + If the user does not have sufficient permissions to list pods in all namespaces, the function will return free GPUs as -1. """ @@ -115,18 +146,20 @@ def list_accelerators_realtime( accelerators_qtys: Set[Tuple[str, int]] = set() keys = lf.get_label_keys() nodes = kubernetes_utils.get_kubernetes_nodes(context) - # Get the pods to get the real-time GPU usage - try: - pods = kubernetes_utils.get_all_pods_in_kubernetes_cluster(context) - except kubernetes.api_exception() as e: - if e.status == 403: - logger.warning('Failed to get pods in the Kubernetes cluster ' - '(forbidden). Please check if your account has ' - 'necessary permissions to list pods. Realtime GPU ' - 'availability information may be incorrect.') - pods = None - else: - raise + pods = None + if realtime: + # Get the pods to get the real-time GPU usage + try: + pods = kubernetes_utils.get_all_pods_in_kubernetes_cluster(context) + except kubernetes.api_exception() as e: + if e.status == 403: + logger.warning( + 'Failed to get pods in the Kubernetes cluster ' + '(forbidden). Please check if your account has ' + 'necessary permissions to list pods. Realtime GPU ' + 'availability information may be incorrect.') + else: + raise # Total number of GPUs in the cluster total_accelerators_capacity: Dict[str, int] = {} # Total number of GPUs currently available in the cluster diff --git a/sky/clouds/vsphere.py b/sky/clouds/vsphere.py index 88d5df3232a..92e62a8a240 100644 --- a/sky/clouds/vsphere.py +++ b/sky/clouds/vsphere.py @@ -173,6 +173,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: # TODO get image id here. diff --git a/sky/execution.py b/sky/execution.py index 90eb44e069f..963e0356753 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -305,7 +305,8 @@ def _execute( do_workdir = (Stage.SYNC_WORKDIR in stages and not dryrun and task.workdir is not None) do_file_mounts = (Stage.SYNC_FILE_MOUNTS in stages and not dryrun and - task.file_mounts is not None) + (task.file_mounts is not None or + task.storage_mounts is not None)) if do_workdir or do_file_mounts: logger.info(ux_utils.starting_message('Mounting files.')) diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 60159232787..700d31c597f 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -305,7 +305,8 @@ def _create_vm( network_profile=network_profile, identity=compute.VirtualMachineIdentity( type='UserAssigned', - user_assigned_identities={provider_config['msi']: {}})) + user_assigned_identities={provider_config['msi']: {}}), + priority=node_config['azure_arm_parameters'].get('priority', None)) vm_poller = compute_client.virtual_machines.begin_create_or_update( resource_group_name=provider_config['resource_group'], vm_name=vm_name, diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 8c390adaf87..86d1c59f36c 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -4,7 +4,6 @@ import hashlib import json import os -import resource import time from typing import Any, Callable, Dict, List, Optional, Tuple @@ -20,6 +19,7 @@ from sky.utils import command_runner from sky.utils import common_utils from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) @@ -115,7 +115,8 @@ def _parallel_ssh_with_cache(func, if max_workers is None: # Not using the default value of `max_workers` in ThreadPoolExecutor, # as 32 is too large for some machines. - max_workers = subprocess_utils.get_parallel_threads() + max_workers = subprocess_utils.get_parallel_threads( + cluster_info.provider_name) with futures.ThreadPoolExecutor(max_workers=max_workers) as pool: results = [] runners = provision.get_command_runners(cluster_info.provider_name, @@ -170,6 +171,7 @@ def _initialize_docker(runner: command_runner.CommandRunner, log_path: str): @common.log_function_start_end +@timeline.event def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: @@ -245,20 +247,9 @@ def _ray_gpu_options(custom_resource: str) -> str: return f' --num-gpus={acc_count}' -@common.log_function_start_end -@_auto_retry() -def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], - cluster_info: common.ClusterInfo, - ssh_credentials: Dict[str, Any]) -> None: - """Start Ray on the head node.""" - runners = provision.get_command_runners(cluster_info.provider_name, - cluster_info, **ssh_credentials) - head_runner = runners[0] - assert cluster_info.head_instance_id is not None, (cluster_name, - cluster_info) - - # Log the head node's output to the provision.log - log_path_abs = str(provision_logging.get_log_path()) +def ray_head_start_command(custom_resource: Optional[str], + custom_ray_options: Optional[Dict[str, Any]]) -> str: + """Returns the command to start Ray on the head node.""" ray_options = ( # --disable-usage-stats in `ray start` saves 10 seconds of idle wait. f'--disable-usage-stats ' @@ -270,23 +261,14 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], if custom_resource: ray_options += f' --resources=\'{custom_resource}\'' ray_options += _ray_gpu_options(custom_resource) - - if cluster_info.custom_ray_options: - if 'use_external_ip' in cluster_info.custom_ray_options: - cluster_info.custom_ray_options.pop('use_external_ip') - for key, value in cluster_info.custom_ray_options.items(): + if custom_ray_options: + if 'use_external_ip' in custom_ray_options: + custom_ray_options.pop('use_external_ip') + for key, value in custom_ray_options.items(): ray_options += f' --{key}={value}' - # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY to avoid using credentials - # from environment variables set by user. SkyPilot's ray cluster should use - # the `~/.aws/` credentials, as that is the one used to create the cluster, - # and the autoscaler module started by the `ray start` command should use - # the same credentials. Otherwise, `ray status` will fail to fetch the - # available nodes. - # Reference: https://github.com/skypilot-org/skypilot/issues/2441 cmd = ( f'{constants.SKY_RAY_CMD} stop; ' - 'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; ' 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' # worker_maximum_startup_concurrency controls the maximum number of # workers that can be started concurrently. However, it also controls @@ -305,6 +287,62 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], 'RAY_worker_maximum_startup_concurrency=$(( 3 * $(nproc --all) )) ' f'{constants.SKY_RAY_CMD} start --head {ray_options} || exit 1;' + _RAY_PRLIMIT + _DUMP_RAY_PORTS + RAY_HEAD_WAIT_INITIALIZED_COMMAND) + return cmd + + +def ray_worker_start_command(custom_resource: Optional[str], + custom_ray_options: Optional[Dict[str, Any]], + no_restart: bool) -> str: + """Returns the command to start Ray on the worker node.""" + # We need to use the ray port in the env variable, because the head node + # determines the port to be used for the worker node. + ray_options = ('--address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT} ' + '--object-manager-port=8076') + + if custom_resource: + ray_options += f' --resources=\'{custom_resource}\'' + ray_options += _ray_gpu_options(custom_resource) + + if custom_ray_options: + for key, value in custom_ray_options.items(): + ray_options += f' --{key}={value}' + + cmd = ( + 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' + f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || ' + 'exit 1;' + _RAY_PRLIMIT) + if no_restart: + # We do not use ray status to check whether ray is running, because + # on worker node, if the user started their own ray cluster, ray status + # will return 0, i.e., we don't know skypilot's ray cluster is running. + # Instead, we check whether the raylet process is running on gcs address + # that is connected to the head with the correct port. + cmd = ( + f'ps aux | grep "ray/raylet/raylet" | ' + 'grep "gcs-address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT}" ' + f'|| {{ {cmd} }}') + else: + cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd + return cmd + + +@common.log_function_start_end +@_auto_retry() +@timeline.event +def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], + cluster_info: common.ClusterInfo, + ssh_credentials: Dict[str, Any]) -> None: + """Start Ray on the head node.""" + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + head_runner = runners[0] + assert cluster_info.head_instance_id is not None, (cluster_name, + cluster_info) + + # Log the head node's output to the provision.log + log_path_abs = str(provision_logging.get_log_path()) + cmd = ray_head_start_command(custom_resource, + cluster_info.custom_ray_options) logger.info(f'Running command on head node: {cmd}') # TODO(zhwu): add the output to log files. returncode, stdout, stderr = head_runner.run( @@ -324,6 +362,7 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], @common.log_function_start_end @_auto_retry() +@timeline.event def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, custom_resource: Optional[str], ray_port: int, cluster_info: common.ClusterInfo, @@ -358,43 +397,17 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, head_ip = (head_instance.internal_ip if not use_external_ip else head_instance.external_ip) - ray_options = (f'--address={head_ip}:{constants.SKY_REMOTE_RAY_PORT} ' - f'--object-manager-port=8076') - - if custom_resource: - ray_options += f' --resources=\'{custom_resource}\'' - ray_options += _ray_gpu_options(custom_resource) - - if cluster_info.custom_ray_options: - for key, value in cluster_info.custom_ray_options.items(): - ray_options += f' --{key}={value}' + ray_cmd = ray_worker_start_command(custom_resource, + cluster_info.custom_ray_options, + no_restart) - # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY, see the comment in - # `start_ray_on_head_node`. - cmd = ( - f'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; ' - 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' - f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || ' - 'exit 1;' + _RAY_PRLIMIT) - if no_restart: - # We do not use ray status to check whether ray is running, because - # on worker node, if the user started their own ray cluster, ray status - # will return 0, i.e., we don't know skypilot's ray cluster is running. - # Instead, we check whether the raylet process is running on gcs address - # that is connected to the head with the correct port. - cmd = (f'RAY_PORT={ray_port}; ps aux | grep "ray/raylet/raylet" | ' - f'grep "gcs-address={head_ip}:${{RAY_PORT}}" || ' - f'{{ {cmd} }}') - else: - cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd + cmd = (f'export SKYPILOT_RAY_HEAD_IP="{head_ip}"; ' + f'export SKYPILOT_RAY_PORT={ray_port}; ' + ray_cmd) logger.info(f'Running command on worker nodes: {cmd}') def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, str]): - # for cmd in config_from_yaml['worker_start_ray_commands']: - # cmd = cmd.replace('$RAY_HEAD_IP', ip_list[0][0]) - # runner.run(cmd) runner, instance_id = runner_and_id log_dir = metadata_utils.get_instance_log_dir(cluster_name, instance_id) log_path_abs = str(log_dir / ('ray_cluster' + '.log')) @@ -407,8 +420,10 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, # by ray will have the correct PATH. source_bashrc=True) + num_threads = subprocess_utils.get_parallel_threads( + cluster_info.provider_name) results = subprocess_utils.run_in_parallel( - _setup_ray_worker, list(zip(worker_runners, cache_ids))) + _setup_ray_worker, list(zip(worker_runners, cache_ids)), num_threads) for returncode, stdout, stderr in results: if returncode: with ux_utils.print_exception_no_traceback(): @@ -421,6 +436,7 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, @common.log_function_start_end @_auto_retry() +@timeline.event def start_skylet_on_head_node(cluster_name: str, cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: @@ -482,28 +498,8 @@ def _internal_file_mounts(file_mounts: Dict, ) -def _max_workers_for_file_mounts(common_file_mounts: Dict[str, str]) -> int: - fd_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) - - fd_per_rsync = 5 - for src in common_file_mounts.values(): - if os.path.isdir(src): - # Assume that each file/folder under src takes 5 file descriptors - # on average. - fd_per_rsync = max(fd_per_rsync, len(os.listdir(src)) * 5) - - # Reserve some file descriptors for the system and other processes - fd_reserve = 100 - - max_workers = (fd_limit - fd_reserve) // fd_per_rsync - # At least 1 worker, and avoid too many workers overloading the system. - max_workers = min(max(max_workers, 1), - subprocess_utils.get_parallel_threads()) - logger.debug(f'Using {max_workers} workers for file mounts.') - return max_workers - - @common.log_function_start_end +@timeline.event def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, str]) -> None: @@ -524,4 +520,5 @@ def _setup_node(runner: command_runner.CommandRunner, log_path: str): digest=None, cluster_info=cluster_info, ssh_credentials=ssh_credentials, - max_workers=_max_workers_for_file_mounts(common_file_mounts)) + max_workers=subprocess_utils.get_max_workers_for_file_mounts( + common_file_mounts, cluster_info.provider_name)) diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index f5e93204934..2b13e78fdf8 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -20,12 +20,13 @@ from sky.utils import common_utils from sky.utils import kubernetes_enums from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils POLL_INTERVAL = 2 _TIMEOUT_FOR_POD_TERMINATION = 60 # 1 minutes _MAX_RETRIES = 3 -NUM_THREADS = subprocess_utils.get_parallel_threads() * 2 +_NUM_THREADS = subprocess_utils.get_parallel_threads('kubernetes') logger = sky_logging.init_logger(__name__) TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' @@ -120,6 +121,9 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes): are recorded as events. This function retrieves those events and raises descriptive errors for better debugging and user feedback. """ + timeout_err_msg = ('Timed out while waiting for nodes to start. ' + 'Cluster may be out of resources or ' + 'may be too slow to autoscale.') for new_node in new_nodes: pod = kubernetes.core_api(context).read_namespaced_pod( new_node.metadata.name, namespace) @@ -148,9 +152,6 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes): if event.reason == 'FailedScheduling': event_message = event.message break - timeout_err_msg = ('Timed out while waiting for nodes to start. ' - 'Cluster may be out of resources or ' - 'may be too slow to autoscale.') if event_message is not None: if pod_status == 'Pending': logger.info(event_message) @@ -219,6 +220,7 @@ def _raise_command_running_error(message: str, command: str, pod_name: str, f'code {rc}: {command!r}\nOutput: {stdout}.') +@timeline.event def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int): """Wait for all pods to be scheduled. @@ -229,6 +231,10 @@ def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int): If timeout is set to a negative value, this method will wait indefinitely. """ + # Create a set of pod names we're waiting for + if not new_nodes: + return + expected_pod_names = {node.metadata.name for node in new_nodes} start_time = time.time() def _evaluate_timeout() -> bool: @@ -238,19 +244,34 @@ def _evaluate_timeout() -> bool: return time.time() - start_time < timeout while _evaluate_timeout(): - all_pods_scheduled = True - for node in new_nodes: - # Iterate over each pod to check their status - pod = kubernetes.core_api(context).read_namespaced_pod( - node.metadata.name, namespace) - if pod.status.phase == 'Pending': + # Get all pods in a single API call using the cluster name label + # which all pods in new_nodes should share + cluster_name = new_nodes[0].metadata.labels[TAG_SKYPILOT_CLUSTER_NAME] + pods = kubernetes.core_api(context).list_namespaced_pod( + namespace, + label_selector=f'{TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items + + # Get the set of found pod names and check if we have all expected pods + found_pod_names = {pod.metadata.name for pod in pods} + missing_pods = expected_pod_names - found_pod_names + if missing_pods: + logger.info('Retrying waiting for pods: ' + f'Missing pods: {missing_pods}') + time.sleep(0.5) + continue + + # Check if all pods are scheduled + all_scheduled = True + for pod in pods: + if (pod.metadata.name in expected_pod_names and + pod.status.phase == 'Pending'): # If container_statuses is None, then the pod hasn't # been scheduled yet. if pod.status.container_statuses is None: - all_pods_scheduled = False + all_scheduled = False break - if all_pods_scheduled: + if all_scheduled: return time.sleep(1) @@ -266,12 +287,18 @@ def _evaluate_timeout() -> bool: f'Error: {common_utils.format_exception(e)}') from None +@timeline.event def _wait_for_pods_to_run(namespace, context, new_nodes): """Wait for pods and their containers to be ready. Pods may be pulling images or may be in the process of container creation. """ + if not new_nodes: + return + + # Create a set of pod names we're waiting for + expected_pod_names = {node.metadata.name for node in new_nodes} def _check_init_containers(pod): # Check if any of the init containers failed @@ -299,12 +326,25 @@ def _check_init_containers(pod): f'{pod.metadata.name}. Error details: {msg}.') while True: - all_pods_running = True - # Iterate over each pod to check their status - for node in new_nodes: - pod = kubernetes.core_api(context).read_namespaced_pod( - node.metadata.name, namespace) + # Get all pods in a single API call + cluster_name = new_nodes[0].metadata.labels[TAG_SKYPILOT_CLUSTER_NAME] + all_pods = kubernetes.core_api(context).list_namespaced_pod( + namespace, + label_selector=f'{TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items + + # Get the set of found pod names and check if we have all expected pods + found_pod_names = {pod.metadata.name for pod in all_pods} + missing_pods = expected_pod_names - found_pod_names + if missing_pods: + logger.info('Retrying running pods check: ' + f'Missing pods: {missing_pods}') + time.sleep(0.5) + continue + all_pods_running = True + for pod in all_pods: + if pod.metadata.name not in expected_pod_names: + continue # Continue if pod and all the containers within the # pod are successfully created and running. if pod.status.phase == 'Running' and all( @@ -367,6 +407,7 @@ def _run_function_with_retries(func: Callable, raise +@timeline.event def pre_init(namespace: str, context: Optional[str], new_nodes: List) -> None: """Pre-initialization step for SkyPilot pods. @@ -514,7 +555,7 @@ def _pre_init_thread(new_node): logger.info(f'{"-"*20}End: Pre-init in pod {pod_name!r} {"-"*20}') # Run pre_init in parallel across all new_nodes - subprocess_utils.run_in_parallel(_pre_init_thread, new_nodes, NUM_THREADS) + subprocess_utils.run_in_parallel(_pre_init_thread, new_nodes, _NUM_THREADS) def _label_pod(namespace: str, context: Optional[str], pod_name: str, @@ -528,6 +569,7 @@ def _label_pod(namespace: str, context: Optional[str], pod_name: str, _request_timeout=kubernetes.API_TIMEOUT) +@timeline.event def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict, context: Optional[str]) -> Any: """Attempts to create a Kubernetes Pod and handle any errors. @@ -606,6 +648,7 @@ def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict, raise e +@timeline.event def _create_pods(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """Create pods based on the config.""" @@ -627,7 +670,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str, terminating_pods = kubernetes_utils.filter_pods(namespace, context, tags, ['Terminating']) start_time = time.time() - while (len(terminating_pods) > 0 and + while (terminating_pods and time.time() - start_time < _TIMEOUT_FOR_POD_TERMINATION): logger.debug(f'run_instances: Found {len(terminating_pods)} ' 'terminating pods. Waiting them to finish: ' @@ -636,7 +679,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str, terminating_pods = kubernetes_utils.filter_pods(namespace, context, tags, ['Terminating']) - if len(terminating_pods) > 0: + if terminating_pods: # If there are still terminating pods, we force delete them. logger.debug(f'run_instances: Found {len(terminating_pods)} ' 'terminating pods still in terminating state after ' @@ -695,24 +738,29 @@ def _create_pods(region: str, cluster_name_on_cloud: str, created_pods = {} logger.debug(f'run_instances: calling create_namespaced_pod ' f'(count={to_start_count}).') - for _ in range(to_start_count): - if head_pod_name is None: - pod_spec['metadata']['labels'].update(constants.HEAD_NODE_TAGS) + + def _create_pod_thread(i: int): + pod_spec_copy = copy.deepcopy(pod_spec) + if head_pod_name is None and i == 0: + # First pod should be head if no head exists + pod_spec_copy['metadata']['labels'].update(constants.HEAD_NODE_TAGS) head_selector = head_service_selector(cluster_name_on_cloud) - pod_spec['metadata']['labels'].update(head_selector) - pod_spec['metadata']['name'] = f'{cluster_name_on_cloud}-head' + pod_spec_copy['metadata']['labels'].update(head_selector) + pod_spec_copy['metadata']['name'] = f'{cluster_name_on_cloud}-head' else: - pod_spec['metadata']['labels'].update(constants.WORKER_NODE_TAGS) - pod_uuid = str(uuid.uuid4())[:4] + # Worker pods + pod_spec_copy['metadata']['labels'].update( + constants.WORKER_NODE_TAGS) + pod_uuid = str(uuid.uuid4())[:6] pod_name = f'{cluster_name_on_cloud}-{pod_uuid}' - pod_spec['metadata']['name'] = f'{pod_name}-worker' + pod_spec_copy['metadata']['name'] = f'{pod_name}-worker' # For multi-node support, we put a soft-constraint to schedule # worker pods on different nodes than the head pod. # This is not set as a hard constraint because if different nodes # are not available, we still want to be able to schedule worker # pods on larger nodes which may be able to fit multiple SkyPilot # "nodes". - pod_spec['spec']['affinity'] = { + pod_spec_copy['spec']['affinity'] = { 'podAntiAffinity': { # Set as a soft constraint 'preferredDuringSchedulingIgnoredDuringExecution': [{ @@ -747,17 +795,22 @@ def _create_pods(region: str, cluster_name_on_cloud: str, 'value': 'present', 'effect': 'NoSchedule' } - pod_spec['spec']['tolerations'] = [tpu_toleration] + pod_spec_copy['spec']['tolerations'] = [tpu_toleration] - pod = _create_namespaced_pod_with_retries(namespace, pod_spec, context) + return _create_namespaced_pod_with_retries(namespace, pod_spec_copy, + context) + + # Create pods in parallel + pods = subprocess_utils.run_in_parallel(_create_pod_thread, + range(to_start_count), _NUM_THREADS) + + # Process created pods + for pod in pods: created_pods[pod.metadata.name] = pod - if head_pod_name is None: + if head_pod_name is None and pod.metadata.labels.get( + constants.TAG_RAY_NODE_KIND) == 'head': head_pod_name = pod.metadata.name - wait_pods_dict = kubernetes_utils.filter_pods(namespace, context, tags, - ['Pending']) - wait_pods = list(wait_pods_dict.values()) - networking_mode = network_utils.get_networking_mode( config.provider_config.get('networking_mode')) if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: @@ -766,52 +819,24 @@ def _create_pods(region: str, cluster_name_on_cloud: str, ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump'] jump_pod = kubernetes.core_api(context).read_namespaced_pod( ssh_jump_pod_name, namespace) - wait_pods.append(jump_pod) + pods.append(jump_pod) provision_timeout = provider_config['timeout'] wait_str = ('indefinitely' if provision_timeout < 0 else f'for {provision_timeout}s') logger.debug(f'run_instances: waiting {wait_str} for pods to schedule and ' - f'run: {list(wait_pods_dict.keys())}') + f'run: {[pod.metadata.name for pod in pods]}') # Wait until the pods are scheduled and surface cause for error # if there is one - _wait_for_pods_to_schedule(namespace, context, wait_pods, provision_timeout) + _wait_for_pods_to_schedule(namespace, context, pods, provision_timeout) # Wait until the pods and their containers are up and running, and # fail early if there is an error logger.debug(f'run_instances: waiting for pods to be running (pulling ' - f'images): {list(wait_pods_dict.keys())}') - _wait_for_pods_to_run(namespace, context, wait_pods) + f'images): {[pod.metadata.name for pod in pods]}') + _wait_for_pods_to_run(namespace, context, pods) logger.debug(f'run_instances: all pods are scheduled and running: ' - f'{list(wait_pods_dict.keys())}') - - running_pods = kubernetes_utils.filter_pods(namespace, context, tags, - ['Running']) - initialized_pods = kubernetes_utils.filter_pods(namespace, context, { - TAG_POD_INITIALIZED: 'true', - **tags - }, ['Running']) - uninitialized_pods = { - pod_name: pod - for pod_name, pod in running_pods.items() - if pod_name not in initialized_pods - } - if len(uninitialized_pods) > 0: - logger.debug(f'run_instances: Initializing {len(uninitialized_pods)} ' - f'pods: {list(uninitialized_pods.keys())}') - uninitialized_pods_list = list(uninitialized_pods.values()) - - # Run pre-init steps in the pod. - pre_init(namespace, context, uninitialized_pods_list) - - for pod in uninitialized_pods.values(): - _label_pod(namespace, - context, - pod.metadata.name, - label={ - TAG_POD_INITIALIZED: 'true', - **pod.metadata.labels - }) + f'{[pod.metadata.name for pod in pods]}') assert head_pod_name is not None, 'head_instance_id should not be None' return common.ProvisionRecord( @@ -854,11 +879,6 @@ def _terminate_node(namespace: str, context: Optional[str], pod_name: str) -> None: """Terminate a pod.""" logger.debug('terminate_instances: calling delete_namespaced_pod') - try: - kubernetes_utils.clean_zombie_ssh_jump_pod(namespace, context, pod_name) - except Exception as e: # pylint: disable=broad-except - logger.warning('terminate_instances: Error occurred when analyzing ' - f'SSH Jump pod: {e}') try: kubernetes.core_api(context).delete_namespaced_service( pod_name, namespace, _request_timeout=config_lib.DELETION_TIMEOUT) @@ -895,6 +915,18 @@ def terminate_instances( } pods = kubernetes_utils.filter_pods(namespace, context, tag_filters, None) + # Clean up the SSH jump pod if in use + networking_mode = network_utils.get_networking_mode( + provider_config.get('networking_mode')) + if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: + pod_name = list(pods.keys())[0] + try: + kubernetes_utils.clean_zombie_ssh_jump_pod(namespace, context, + pod_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('terminate_instances: Error occurred when analyzing ' + f'SSH Jump pod: {e}') + def _is_head(pod) -> bool: return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head' @@ -907,7 +939,7 @@ def _terminate_pod_thread(pod_info): # Run pod termination in parallel subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(), - NUM_THREADS) + _NUM_THREADS) def get_cluster_info( diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index ee00e449b78..7442c9be7a6 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -28,6 +28,7 @@ from sky.utils import env_options from sky.utils import kubernetes_enums from sky.utils import schemas +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -36,7 +37,6 @@ # TODO(romilb): Move constants to constants.py DEFAULT_NAMESPACE = 'default' -IN_CLUSTER_REGION = 'in-cluster' DEFAULT_SERVICE_ACCOUNT_NAME = 'skypilot-service-account' @@ -920,6 +920,9 @@ def is_kubeconfig_exec_auth( str: Error message if exec-based authentication is used, None otherwise """ k8s = kubernetes.kubernetes + if context == kubernetes.in_cluster_context_name(): + # If in-cluster config is used, exec-based auth is not used. + return False, None try: k8s.config.load_kube_config() except kubernetes.config_exception(): @@ -1002,30 +1005,34 @@ def is_incluster_config_available() -> bool: return os.path.exists('/var/run/secrets/kubernetes.io/serviceaccount/token') -def get_all_kube_config_context_names() -> List[Optional[str]]: - """Get all kubernetes context names from the kubeconfig file. +def get_all_kube_context_names() -> List[str]: + """Get all kubernetes context names available in the environment. + + Fetches context names from the kubeconfig file and in-cluster auth, if any. - If running in-cluster, returns [None] to indicate in-cluster config. + If running in-cluster and IN_CLUSTER_CONTEXT_NAME_ENV_VAR is not set, + returns the default in-cluster kubernetes context name. We should not cache the result of this function as the admin policy may update the contexts. Returns: List[Optional[str]]: The list of kubernetes context names if - available, an empty list otherwise. If running in-cluster, - returns [None] to indicate in-cluster config. + available, an empty list otherwise. """ k8s = kubernetes.kubernetes + context_names = [] try: all_contexts, _ = k8s.config.list_kube_config_contexts() # all_contexts will always have at least one context. If kubeconfig # does not have any contexts defined, it will raise ConfigException. - return [context['name'] for context in all_contexts] + context_names = [context['name'] for context in all_contexts] except k8s.config.config_exception.ConfigException: - # If running in cluster, return [None] to indicate in-cluster config - if is_incluster_config_available(): - return [None] - return [] + # If no config found, continue + pass + if is_incluster_config_available(): + context_names.append(kubernetes.in_cluster_context_name()) + return context_names @functools.lru_cache() @@ -1038,11 +1045,15 @@ def get_kube_config_context_namespace( the default namespace. """ k8s = kubernetes.kubernetes - # Get namespace if using in-cluster config ns_path = '/var/run/secrets/kubernetes.io/serviceaccount/namespace' - if os.path.exists(ns_path): - with open(ns_path, encoding='utf-8') as f: - return f.read().strip() + # If using in-cluster context, get the namespace from the service account + # namespace file. Uses the same logic as adaptors.kubernetes._load_config() + # to stay consistent with in-cluster config loading. + if (context_name == kubernetes.in_cluster_context_name() or + context_name is None): + if os.path.exists(ns_path): + with open(ns_path, encoding='utf-8') as f: + return f.read().strip() # If not in-cluster, get the namespace from kubeconfig try: contexts, current_context = k8s.config.list_kube_config_contexts() @@ -1129,7 +1140,11 @@ def name(self) -> str: name = (f'{common_utils.format_float(self.cpus)}CPU--' f'{common_utils.format_float(self.memory)}GB') if self.accelerator_count: - name += f'--{self.accelerator_count}{self.accelerator_type}' + # Replace spaces with underscores in accelerator type to make it a + # valid logical instance type name. + assert self.accelerator_type is not None, self.accelerator_count + acc_name = self.accelerator_type.replace(' ', '_') + name += f'--{self.accelerator_count}{acc_name}' return name @staticmethod @@ -1160,7 +1175,9 @@ def _parse_instance_type( accelerator_type = match.group('accelerator_type') if accelerator_count: accelerator_count = int(accelerator_count) - accelerator_type = str(accelerator_type) + # This is to revert the accelerator types with spaces back to + # the original format. + accelerator_type = str(accelerator_type).replace('_', ' ') else: accelerator_count = None accelerator_type = None @@ -2053,6 +2070,7 @@ def get_namespace_from_config(provider_config: Dict[str, Any]) -> str: get_kube_config_context_namespace(context)) +@timeline.event def filter_pods(namespace: str, context: Optional[str], tag_filters: Dict[str, str], @@ -2183,9 +2201,9 @@ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle', def get_context_from_config(provider_config: Dict[str, Any]) -> Optional[str]: context = provider_config.get('context', get_current_kube_config_context_name()) - if context == IN_CLUSTER_REGION: - # If the context (also used as the region) is set to IN_CLUSTER_REGION - # we need to use in-cluster auth. + if context == kubernetes.in_cluster_context_name(): + # If the context (also used as the region) is in-cluster, we need to + # we need to use in-cluster auth by setting the context to None. context = None return context diff --git a/sky/provision/oci/instance.py b/sky/provision/oci/instance.py index 811d27d0e21..e04089ff8d4 100644 --- a/sky/provision/oci/instance.py +++ b/sky/provision/oci/instance.py @@ -123,8 +123,8 @@ def run_instances(region: str, cluster_name_on_cloud: str, # Let's create additional new nodes (if neccessary) to_start_count = config.count - len(resume_instances) created_instances = [] + node_config = config.node_config if to_start_count > 0: - node_config = config.node_config compartment = query_helper.find_compartment(region) vcn = query_helper.find_create_vcn_subnet(region) @@ -242,10 +242,12 @@ def run_instances(region: str, cluster_name_on_cloud: str, assert head_instance_id is not None, head_instance_id + # Format: TenancyPrefix:AvailabilityDomain, e.g. bxtG:US-SANJOSE-1-AD-1 + _, ad = str(node_config['AvailabilityDomain']).split(':', maxsplit=1) return common.ProvisionRecord( provider_name='oci', region=region, - zone=None, + zone=ad, cluster_name=cluster_name_on_cloud, head_instance_id=head_instance_id, created_instance_ids=[n['inst_id'] for n in created_instances], diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index b3e965769c9..cc2ca73e1dc 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -29,6 +29,7 @@ from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils # Do not use __name__ as we do not want to propagate logs to sky.provision, @@ -343,6 +344,7 @@ def _wait_ssh_connection_indirect(ip: str, return True, '' +@timeline.event def wait_for_ssh(cluster_info: provision_common.ClusterInfo, ssh_credentials: Dict[str, str]): """Wait until SSH is ready. @@ -432,11 +434,15 @@ def _post_provision_setup( ux_utils.spinner_message( 'Launching - Waiting for SSH access', provision_logging.config.log_path)) as status: - - logger.debug( - f'\nWaiting for SSH to be available for {cluster_name!r} ...') - wait_for_ssh(cluster_info, ssh_credentials) - logger.debug(f'SSH Connection ready for {cluster_name!r}') + # If on Kubernetes, skip SSH check since the pods are guaranteed to be + # ready by the provisioner, and we use kubectl instead of SSH to run the + # commands and rsync on the pods. SSH will still be ready after a while + # for the users to SSH into the pod. + if cloud_name.lower() != 'kubernetes': + logger.debug( + f'\nWaiting for SSH to be available for {cluster_name!r} ...') + wait_for_ssh(cluster_info, ssh_credentials) + logger.debug(f'SSH Connection ready for {cluster_name!r}') vm_str = 'Instance' if cloud_name.lower() != 'kubernetes' else 'Pod' plural = '' if len(cluster_info.instances) == 1 else 's' verb = 'is' if len(cluster_info.instances) == 1 else 'are' @@ -496,31 +502,94 @@ def _post_provision_setup( **ssh_credentials) head_runner = runners[0] - status.update( - runtime_preparation_str.format(step=3, step_name='runtime')) - full_ray_setup = True - ray_port = constants.SKY_REMOTE_RAY_PORT - if not provision_record.is_instance_just_booted( - head_instance.instance_id): + def is_ray_cluster_healthy(ray_status_output: str, + expected_num_nodes: int) -> bool: + """Parse the output of `ray status` to get #active nodes. + + The output of `ray status` looks like: + Node status + --------------------------------------------------------------- + Active: + 1 node_291a8b849439ad6186387c35dc76dc43f9058108f09e8b68108cf9ec + 1 node_0945fbaaa7f0b15a19d2fd3dc48f3a1e2d7c97e4a50ca965f67acbfd + Pending: + (no pending nodes) + Recent failures: + (no failures) + """ + start = ray_status_output.find('Active:') + end = ray_status_output.find('Pending:', start) + if start == -1 or end == -1: + return False + num_active_nodes = 0 + for line in ray_status_output[start:end].split('\n'): + if line.strip() and not line.startswith('Active:'): + num_active_nodes += 1 + return num_active_nodes == expected_num_nodes + + def check_ray_port_and_cluster_healthy() -> Tuple[int, bool, bool]: + head_ray_needs_restart = True + ray_cluster_healthy = False + ray_port = constants.SKY_REMOTE_RAY_PORT + # Check if head node Ray is alive returncode, stdout, _ = head_runner.run( instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND, stream_logs=False, require_outputs=True) - if returncode: - logger.debug('Ray cluster on head is not up. Restarting...') - else: - logger.debug('Ray cluster on head is up.') + if not returncode: ray_port = common_utils.decode_payload(stdout)['ray_port'] - full_ray_setup = bool(returncode) + logger.debug(f'Ray cluster on head is up with port {ray_port}.') + + head_ray_needs_restart = bool(returncode) + # This is a best effort check to see if the ray cluster has expected + # number of nodes connected. + ray_cluster_healthy = (not head_ray_needs_restart and + is_ray_cluster_healthy( + stdout, cluster_info.num_instances)) + return ray_port, ray_cluster_healthy, head_ray_needs_restart + + status.update( + runtime_preparation_str.format(step=3, step_name='runtime')) + + ray_port = constants.SKY_REMOTE_RAY_PORT + head_ray_needs_restart = True + ray_cluster_healthy = False + if (not provision_record.is_instance_just_booted( + head_instance.instance_id)): + # Check if head node Ray is alive + (ray_port, ray_cluster_healthy, + head_ray_needs_restart) = check_ray_port_and_cluster_healthy() + elif cloud_name.lower() == 'kubernetes': + timeout = 90 # 1.5-min maximum timeout + start = time.time() + while True: + # Wait until Ray cluster is ready + (ray_port, ray_cluster_healthy, + head_ray_needs_restart) = check_ray_port_and_cluster_healthy() + if ray_cluster_healthy: + logger.debug('Ray cluster is ready. Skip head and worker ' + 'node ray cluster setup.') + break + if time.time() - start > timeout: + # In most cases, the ray cluster will be ready after a few + # seconds. Trigger ray start on head or worker nodes to be + # safe, if the ray cluster is not ready after timeout. + break + logger.debug('Ray cluster is not ready yet, waiting for the ' + 'async setup to complete...') + time.sleep(1) - if full_ray_setup: + if head_ray_needs_restart: logger.debug('Starting Ray on the entire cluster.') instance_setup.start_ray_on_head_node( cluster_name.name_on_cloud, custom_resource=custom_resource, cluster_info=cluster_info, ssh_credentials=ssh_credentials) + else: + logger.debug('Ray cluster on head is ready. Skip starting ray ' + 'cluster on head node.') # NOTE: We have to check all worker nodes to make sure they are all # healthy, otherwise we can only start Ray on newly started worker @@ -531,10 +600,13 @@ def _post_provision_setup( # if provision_record.is_instance_just_booted(inst.instance_id): # worker_ips.append(inst.public_ip) - if cluster_info.num_instances > 1: + # We don't need to restart ray on worker nodes if the ray cluster is + # already healthy, i.e. the head node has expected number of nodes + # connected to the ray cluster. + if cluster_info.num_instances > 1 and not ray_cluster_healthy: instance_setup.start_ray_on_worker_nodes( cluster_name.name_on_cloud, - no_restart=not full_ray_setup, + no_restart=not head_ray_needs_restart, custom_resource=custom_resource, # Pass the ray_port to worker nodes for backward compatibility # as in some existing clusters the ray_port is not dumped with @@ -543,6 +615,9 @@ def _post_provision_setup( ray_port=ray_port, cluster_info=cluster_info, ssh_credentials=ssh_credentials) + elif ray_cluster_healthy: + logger.debug('Ray cluster is ready. Skip starting ray cluster on ' + 'worker nodes.') instance_setup.start_skylet_on_head_node(cluster_name.name_on_cloud, cluster_info, ssh_credentials) @@ -553,6 +628,7 @@ def _post_provision_setup( return cluster_info +@timeline.event def post_provision_runtime_setup( cloud_name: str, cluster_name: resources_utils.ClusterName, cluster_yaml: str, provision_record: provision_common.ProvisionRecord, diff --git a/sky/resources.py b/sky/resources.py index deb05a6eade..729b9d62a28 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -1041,6 +1041,7 @@ def get_spot_str(self) -> str: def make_deploy_variables(self, cluster_name: resources_utils.ClusterName, region: clouds.Region, zones: Optional[List[clouds.Zone]], + num_nodes: int, dryrun: bool) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to resource variables. @@ -1062,7 +1063,7 @@ def make_deploy_variables(self, cluster_name: resources_utils.ClusterName, # Cloud specific variables cloud_specific_variables = self.cloud.make_deploy_resources_variables( - self, cluster_name, region, zones, dryrun) + self, cluster_name, region, zones, num_nodes, dryrun) # Docker run options docker_run_options = skypilot_config.get_nested( diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 77be8119758..97d745b2e26 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -159,10 +159,7 @@ _sky_version = str(version.parse(sky.__version__)) RAY_STATUS = f'RAY_ADDRESS=127.0.0.1:{SKY_REMOTE_RAY_PORT} {SKY_RAY_CMD} status' -# Install ray and skypilot on the remote cluster if they are not already -# installed. {var} will be replaced with the actual value in -# backend_utils.write_cluster_config. -RAY_SKYPILOT_INSTALLATION_COMMANDS = ( +RAY_INSTALLATION_COMMANDS = ( 'mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app;' # Disable the pip version check to avoid the warning message, which makes # the output hard to read. @@ -202,24 +199,31 @@ # Writes ray path to file if it does not exist or the file is empty. f'[ -s {SKY_RAY_PATH_FILE} ] || ' f'{{ {ACTIVATE_SKY_REMOTE_PYTHON_ENV} && ' - f'which ray > {SKY_RAY_PATH_FILE} || exit 1; }}; ' - # END ray package check and installation + f'which ray > {SKY_RAY_PATH_FILE} || exit 1; }}; ') + +SKYPILOT_WHEEL_INSTALLATION_COMMANDS = ( f'{{ {SKY_PIP_CMD} list | grep "skypilot " && ' '[ "$(cat ~/.sky/wheels/current_sky_wheel_hash)" == "{sky_wheel_hash}" ]; } || ' # pylint: disable=line-too-long f'{{ {SKY_PIP_CMD} uninstall skypilot -y; ' f'{SKY_PIP_CMD} install "$(echo ~/.sky/wheels/{{sky_wheel_hash}}/' f'skypilot-{_sky_version}*.whl)[{{cloud}}, remote]" && ' 'echo "{sky_wheel_hash}" > ~/.sky/wheels/current_sky_wheel_hash || ' - 'exit 1; }; ' - # END SkyPilot package check and installation + 'exit 1; }; ') +# Install ray and skypilot on the remote cluster if they are not already +# installed. {var} will be replaced with the actual value in +# backend_utils.write_cluster_config. +RAY_SKYPILOT_INSTALLATION_COMMANDS = ( + f'{RAY_INSTALLATION_COMMANDS} ' + f'{SKYPILOT_WHEEL_INSTALLATION_COMMANDS} ' # Only patch ray when the ray version is the same as the expected version. # The ray installation above can be skipped due to the existing ray cluster # for backward compatibility. In this case, we should not patch the ray # files. - f'{SKY_PIP_CMD} list | grep "ray " | grep {SKY_REMOTE_RAY_VERSION} 2>&1 > /dev/null ' - f'&& {{ {SKY_PYTHON_CMD} -c "from sky.skylet.ray_patches import patch; patch()" ' - '|| exit 1; };') + f'{SKY_PIP_CMD} list | grep "ray " | ' + f'grep {SKY_REMOTE_RAY_VERSION} 2>&1 > /dev/null && ' + f'{{ {SKY_PYTHON_CMD} -c ' + '"from sky.skylet.ray_patches import patch; patch()" || exit 1; }; ') # The name for the environment variable that stores SkyPilot user hash, which # is mainly used to make sure sky commands runs on a VM launched by SkyPilot diff --git a/sky/templates/aws-ray.yml.j2 b/sky/templates/aws-ray.yml.j2 index 95751ab1849..8e9898cb784 100644 --- a/sky/templates/aws-ray.yml.j2 +++ b/sky/templates/aws-ray.yml.j2 @@ -173,6 +173,7 @@ setup_commands: # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. # Line 'mkdir -p ..': disable host key check # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + # Line 'rm ~/.aws/credentials': explicitly remove the credentials file to be safe. This is to guard against the case where the credential files was uploaded once as `remote_identity` was not set in a previous launch. - mkdir -p ~/.ssh; touch ~/.ssh/config; {%- for initial_setup_command in initial_setup_commands %} {{ initial_setup_command }} @@ -185,7 +186,12 @@ setup_commands: sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; {%- endif %} mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; - [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); # This is needed for `-o allow_other` option for `goofys`; + [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); + {%- if remote_identity != 'LOCAL_CREDENTIALS' %} + rm ~/.aws/credentials || true; + {%- endif %} + + # Command to start ray clusters are now placed in `sky.provision.instance_setup`. # We do not need to list it here anymore. diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 7b9737748d3..1140704a708 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -75,9 +75,6 @@ available_node_types: {%- if use_spot %} # optionally set priority to use Spot instances priority: Spot - # set a maximum price for spot instances if desired - # billingProfile: - # maxPrice: -1 {%- endif %} cloudInitSetupCommands: |- {%- for cmd in cloud_init_setup_commands %} diff --git a/sky/templates/kubernetes-port-forward-proxy-command.sh b/sky/templates/kubernetes-port-forward-proxy-command.sh index 0407209a77c..f8205c2393c 100644 --- a/sky/templates/kubernetes-port-forward-proxy-command.sh +++ b/sky/templates/kubernetes-port-forward-proxy-command.sh @@ -58,6 +58,11 @@ KUBECTL_ARGS=() if [ -n "$KUBE_CONTEXT" ]; then KUBECTL_ARGS+=("--context=$KUBE_CONTEXT") fi +# If context is not provided, it means we are using incluster auth. In this case, +# we need to set KUBECONFIG to /dev/null to avoid using kubeconfig file. +if [ -z "$KUBE_CONTEXT" ]; then + KUBECTL_ARGS+=("--kubeconfig=/dev/null") +fi if [ -n "$KUBE_NAMESPACE" ]; then KUBECTL_ARGS+=("--namespace=$KUBE_NAMESPACE") fi diff --git a/sky/templates/kubernetes-ray.yml.j2 b/sky/templates/kubernetes-ray.yml.j2 index b981ee8bf12..e572b263924 100644 --- a/sky/templates/kubernetes-ray.yml.j2 +++ b/sky/templates/kubernetes-ray.yml.j2 @@ -222,7 +222,9 @@ provider: - protocol: TCP port: 22 targetPort: 22 - # Service that maps to the head node of the Ray cluster. + # Service that maps to the head node of the Ray cluster, so that the + # worker nodes can find the head node using + # {{cluster_name_on_cloud}}-head.{{k8s_namespace}}.svc.cluster.local - apiVersion: v1 kind: Service metadata: @@ -235,18 +237,12 @@ provider: # names. name: {{cluster_name_on_cloud}}-head spec: + # Create a headless service so that the head node can be reached by + # the worker nodes with any port number. + clusterIP: None # This selector must match the head node pod's selector below. selector: component: {{cluster_name_on_cloud}}-head - ports: - - name: client - protocol: TCP - port: 10001 - targetPort: 10001 - - name: dashboard - protocol: TCP - port: 8265 - targetPort: 8265 # Specify the pod type for the ray head node (as configured below). head_node_type: ray_head_default @@ -280,7 +276,6 @@ available_node_types: # serviceAccountName: skypilot-service-account serviceAccountName: {{k8s_service_account_name}} automountServiceAccountToken: {{k8s_automount_sa_token}} - restartPolicy: Never # Add node selector if GPU/TPUs are requested: @@ -322,18 +317,158 @@ available_node_types: - name: ray-node imagePullPolicy: IfNotPresent image: {{image_id}} + env: + - name: SKYPILOT_POD_NODE_TYPE + valueFrom: + fieldRef: + fieldPath: metadata.labels['ray-node-type'] + {% for key, value in k8s_env_vars.items() if k8s_env_vars is not none %} + - name: {{ key }} + value: {{ value }} + {% endfor %} # Do not change this command - it keeps the pod alive until it is # explicitly killed. command: ["/bin/bash", "-c", "--"] args: - | + # For backwards compatibility, we put a marker file in the pod + # to indicate that the pod is running with the changes introduced + # in project nimbus: https://github.com/skypilot-org/skypilot/pull/4393 + # TODO: Remove this marker file and it's usage in setup_commands + # after v0.10.0 release. + touch /tmp/skypilot_is_nimbus + # Helper function to conditionally use sudo + # TODO(zhwu): consolidate the two prefix_cmd and sudo replacements prefix_cmd() { if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; } + [ $(id -u) -eq 0 ] && function sudo() { "$@"; } || true; + + STEPS=("apt-ssh-setup" "runtime-setup" "env-setup") - # Run apt update in background and log to a file + # STEP 1: Run apt update, install missing packages, and set up ssh. ( + ( DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get update > /tmp/apt-update.log 2>&1 || \ echo "Warning: apt-get update failed. Continuing anyway..." >> /tmp/apt-update.log + PACKAGES="rsync curl netcat gcc patch pciutils fuse openssh-server"; + + # Separate packages into two groups: packages that are installed first + # so that curl and rsync are available sooner to unblock the following + # conda installation and rsync. + set -e + INSTALL_FIRST=""; + MISSING_PACKAGES=""; + for pkg in $PACKAGES; do + if [ "$pkg" == "netcat" ]; then + if ! dpkg -l | grep -q "^ii \(netcat\|netcat-openbsd\|netcat-traditional\) "; then + INSTALL_FIRST="$INSTALL_FIRST netcat-openbsd"; + fi + elif ! dpkg -l | grep -q "^ii $pkg "; then + if [ "$pkg" == "curl" ] || [ "$pkg" == "rsync" ]; then + INSTALL_FIRST="$INSTALL_FIRST $pkg"; + else + MISSING_PACKAGES="$MISSING_PACKAGES $pkg"; + fi + fi + done; + if [ ! -z "$INSTALL_FIRST" ]; then + echo "Installing core packages: $INSTALL_FIRST"; + DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y $INSTALL_FIRST; + fi; + # SSH and other packages are not necessary, so we disable set -e + set +e + + if [ ! -z "$MISSING_PACKAGES" ]; then + echo "Installing missing packages: $MISSING_PACKAGES"; + DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y $MISSING_PACKAGES; + fi; + $(prefix_cmd) mkdir -p /var/run/sshd; + $(prefix_cmd) sed -i "s/PermitRootLogin prohibit-password/PermitRootLogin yes/" /etc/ssh/sshd_config; + $(prefix_cmd) sed "s@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g" -i /etc/pam.d/sshd; + cd /etc/ssh/ && $(prefix_cmd) ssh-keygen -A; + $(prefix_cmd) mkdir -p ~/.ssh; + $(prefix_cmd) chown -R $(whoami) ~/.ssh; + $(prefix_cmd) chmod 700 ~/.ssh; + $(prefix_cmd) cat /etc/secret-volume/ssh-publickey* > ~/.ssh/authorized_keys; + $(prefix_cmd) chmod 644 ~/.ssh/authorized_keys; + $(prefix_cmd) service ssh restart; + $(prefix_cmd) sed -i "s/mesg n/tty -s \&\& mesg n/" ~/.profile; + ) > /tmp/${STEPS[0]}.log 2>&1 || { + echo "Error: ${STEPS[0]} failed. Continuing anyway..." > /tmp/${STEPS[0]}.failed + cat /tmp/${STEPS[0]}.log + exit 1 + } + ) & + + # STEP 2: Install conda, ray and skypilot (for dependencies); start + # ray cluster. + ( + ( + set -e + mkdir -p ~/.sky + # Wait for `curl` package to be installed before installing conda + # and ray. + until dpkg -l | grep -q "^ii curl "; do + sleep 0.1 + echo "Waiting for curl package to be installed..." + done + {{ conda_installation_commands }} + {{ ray_installation_commands }} + ~/skypilot-runtime/bin/python -m pip install skypilot[kubernetes,remote] + touch /tmp/ray_skypilot_installation_complete + echo "=== Ray and skypilot installation completed ===" + + # Disable set -e, as we have some commands that are ok to fail + # after the ray start. + # TODO(zhwu): this is a hack, we should fix the commands that are + # ok to fail. + if [ "$SKYPILOT_POD_NODE_TYPE" == "head" ]; then + set +e + {{ ray_head_start_command }} + else + # Start ray worker on the worker pod. + # Wait until the head pod is available with an IP address + export SKYPILOT_RAY_HEAD_IP="{{cluster_name_on_cloud}}-head.{{k8s_namespace}}.svc.cluster.local" + export SKYPILOT_RAY_PORT={{skypilot_ray_port}} + # Wait until the ray cluster is started on the head pod + until dpkg -l | grep -q "^ii \(netcat\|netcat-openbsd\|netcat-traditional\) "; do + sleep 0.1 + echo "Waiting for netcat package to be installed..." + done + until nc -z -w 1 ${SKYPILOT_RAY_HEAD_IP} ${SKYPILOT_RAY_PORT}; do + sleep 0.1 + done + + set +e + {{ ray_worker_start_command }} + fi + ) > /tmp/${STEPS[1]}.log 2>&1 || { + echo "Error: ${STEPS[1]} failed. Continuing anyway..." > /tmp/${STEPS[1]}.failed + cat /tmp/${STEPS[1]}.log + exit 1 + } + ) & + + + # STEP 3: Set up environment variables; this should be relatively fast. + ( + ( + set -e + if [ $(id -u) -eq 0 ]; then + echo 'alias sudo=""' >> ~/.bashrc; echo succeed; + else + if command -v sudo >/dev/null 2>&1; then + timeout 2 sudo -l >/dev/null 2>&1 && echo succeed || { echo 52; exit 52; }; + else + { echo 52; exit 52; }; + fi; + fi; + printenv | while IFS='=' read -r key value; do echo "export $key=\"$value\""; done > ~/container_env_var.sh && $(prefix_cmd) mv ~/container_env_var.sh /etc/profile.d/container_env_var.sh + ) > /tmp/${STEPS[2]}.log 2>&1 || { + echo "Error: ${STEPS[2]} failed. Continuing anyway..." > /tmp/${STEPS[2]}.failed + cat /tmp/${STEPS[2]}.log + exit 1 + } ) & function mylsof { p=$(for pid in /proc/{0..9}*; do i=$(basename "$pid"); for file in "$pid"/fd/*; do link=$(readlink -e "$file"); if [ "$link" = "$1" ]; then echo "$i"; fi; done; done); echo "$p"; }; @@ -441,42 +576,51 @@ setup_commands: # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. # Line 'mkdir -p ..': disable host key check # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + # Line 'for step in ..': check if any failure indicator exists for the setup done in pod args and print the error message. This is only a best effort, as the + # commands in pod args are asynchronous and we cannot guarantee the failure indicators are created before the setup commands finish. - | - PACKAGES="gcc patch pciutils rsync fuse curl"; - MISSING_PACKAGES=""; - for pkg in $PACKAGES; do - if ! dpkg -l | grep -q "^ii $pkg "; then - MISSING_PACKAGES="$MISSING_PACKAGES $pkg"; - fi - done; - if [ ! -z "$MISSING_PACKAGES" ]; then - echo "Installing missing packages: $MISSING_PACKAGES"; - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y $MISSING_PACKAGES; - fi; mkdir -p ~/.ssh; touch ~/.ssh/config; {%- for initial_setup_command in initial_setup_commands %} {{ initial_setup_command }} {%- endfor %} - {{ conda_installation_commands }} - {{ ray_skypilot_installation_commands }} + STEPS=("apt-ssh-setup" "runtime-setup" "env-setup") + start_epoch=$(date +%s); + echo "=== Logs for asynchronous ray and skypilot installation ==="; + if [ -f /tmp/skypilot_is_nimbus ]; then + echo "=== Logs for asynchronous ray and skypilot installation ==="; + [ -f /tmp/ray_skypilot_installation_complete ] && cat /tmp/${STEPS[1]}.log || + { tail -f -n +1 /tmp/${STEPS[1]}.log & TAIL_PID=$!; echo "Tail PID: $TAIL_PID"; until [ -f /tmp/ray_skypilot_installation_complete ]; do sleep 0.5; done; kill $TAIL_PID || true; }; + [ -f /tmp/${STEPS[1]}.failed ] && { echo "Error: ${STEPS[1]} failed. Exiting."; exit 1; } || true; + fi + end_epoch=$(date +%s); + echo "=== Ray and skypilot dependencies installation completed in $(($end_epoch - $start_epoch)) secs ==="; + start_epoch=$(date +%s); + {{ skypilot_wheel_installation_commands }} + end_epoch=$(date +%s); + echo "=== Skypilot wheel installation completed in $(($end_epoch - $start_epoch)) secs ==="; + start_epoch=$(date +%s); sudo touch ~/.sudo_as_admin_successful; sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf'; - sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; + sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); + ulimit -n 1048576; mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); # This is needed for `-o allow_other` option for `goofys`; - {% if tpu_requested %} - # The /tmp/tpu_logs directory is where TPU-related logs, such as logs from - # the TPU runtime, are written. These capture runtime information about the - # TPU execution, including any warnings, errors, or general activity of - # the TPU driver. By default, the /tmp/tpu_logs directory is created with - # 755 permissions, and the user of the provisioned pod is not necessarily - # a root. Hence, we need to update the write permission so the logs can be - # properly written. - # TODO(Doyoung): Investigate to see why TPU workload fails to run without - # execution permission, such as granting 766 to log file. Check if it's a - # must and see if there's a workaround to grant minimum permission. - - sudo chmod 777 /tmp/tpu_logs; - {% endif %} + end_epoch=$(date +%s); + echo "=== Setup system configs and fuse completed in $(($end_epoch - $start_epoch)) secs ==="; + for step in $STEPS; do [ -f "/tmp/${step}.failed" ] && { echo "Error: /tmp/${step}.failed found:"; cat /tmp/${step}.log; exit 1; } || true; done; + {% if tpu_requested %} + # The /tmp/tpu_logs directory is where TPU-related logs, such as logs from + # the TPU runtime, are written. These capture runtime information about the + # TPU execution, including any warnings, errors, or general activity of + # the TPU driver. By default, the /tmp/tpu_logs directory is created with + # 755 permissions, and the user of the provisioned pod is not necessarily + # a root. Hence, we need to update the write permission so the logs can be + # properly written. + # TODO(Doyoung): Investigate to see why TPU workload fails to run without + # execution permission, such as granting 766 to log file. Check if it's a + # must and see if there's a workaround to grant minimum permission. + sudo chmod 777 /tmp/tpu_logs; + {% endif %} # Format: `REMOTE_PATH : LOCAL_PATH` file_mounts: { diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index e705debaf8d..92d1f2749d7 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -767,6 +767,10 @@ def run( ] if self.context: kubectl_args += ['--context', self.context] + # If context is none, it means we are using incluster auth. In this + # case, need to set KUBECONFIG to /dev/null to avoid using kubeconfig. + if self.context is None: + kubectl_args += ['--kubeconfig', '/dev/null'] kubectl_args += [self.pod_name] if ssh_mode == SshMode.LOGIN: assert isinstance(cmd, list), 'cmd must be a list for login mode.' diff --git a/sky/utils/kubernetes/generate_kubeconfig.sh b/sky/utils/kubernetes/generate_kubeconfig.sh index 8d363370597..4ed27b62e1e 100755 --- a/sky/utils/kubernetes/generate_kubeconfig.sh +++ b/sky/utils/kubernetes/generate_kubeconfig.sh @@ -12,6 +12,7 @@ # * Specify SKYPILOT_NAMESPACE env var to override the default namespace where the service account is created. # * Specify SKYPILOT_SA_NAME env var to override the default service account name. # * Specify SKIP_SA_CREATION=1 to skip creating the service account and use an existing one +# * Specify SUPER_USER=1 to create a service account with cluster-admin permissions # # Usage: # # Create "sky-sa" service account with minimal permissions in "default" namespace and generate kubeconfig @@ -22,6 +23,9 @@ # # # Use an existing service account "my-sa" in "my-namespace" namespace and generate kubeconfig # $ SKIP_SA_CREATION=1 SKYPILOT_SA_NAME=my-sa SKYPILOT_NAMESPACE=my-namespace ./generate_kubeconfig.sh +# +# # Create "sky-sa" service account with cluster-admin permissions in "default" namespace +# $ SUPER_USER=1 ./generate_kubeconfig.sh set -eu -o pipefail @@ -29,9 +33,11 @@ set -eu -o pipefail # use default. SKYPILOT_SA=${SKYPILOT_SA_NAME:-sky-sa} NAMESPACE=${SKYPILOT_NAMESPACE:-default} +SUPER_USER=${SUPER_USER:-0} echo "Service account: ${SKYPILOT_SA}" echo "Namespace: ${NAMESPACE}" +echo "Super user permissions: ${SUPER_USER}" # Set OS specific values. if [[ "$OSTYPE" == "linux-gnu" ]]; then @@ -47,8 +53,43 @@ fi # If the user has set SKIP_SA_CREATION=1, skip creating the service account. if [ -z ${SKIP_SA_CREATION+x} ]; then - echo "Creating the Kubernetes Service Account with minimal RBAC permissions." - kubectl apply -f - <&2 context_lower=$(echo "$context" | tr '[:upper:]' '[:lower:]') shift if [ -z "$context" ] || [ "$context_lower" = "none" ]; then - kubectl exec -i $pod -n $namespace -- "$@" + # If context is none, it means we are using incluster auth. In this case, + # use need to set KUBECONFIG to /dev/null to avoid using kubeconfig file. + kubectl exec -i $pod -n $namespace --kubeconfig=/dev/null -- "$@" else kubectl exec -i $pod -n $namespace --context=$context -- "$@" fi diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 0b057e512bf..851e77a57fc 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -684,7 +684,14 @@ def get_default_remote_identity(cloud: str) -> str: _REMOTE_IDENTITY_SCHEMA_KUBERNETES = { 'remote_identity': { - 'type': 'string' + 'anyOf': [{ + 'type': 'string' + }, { + 'type': 'object', + 'additionalProperties': { + 'type': 'string' + } + }] }, } diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 28bd2c2ee07..992c6bbe3ff 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -2,9 +2,10 @@ from multiprocessing import pool import os import random +import resource import subprocess import time -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import colorama import psutil @@ -18,6 +19,8 @@ logger = sky_logging.init_logger(__name__) +_fd_limit_warning_shown = False + @timeline.event def run(cmd, **kwargs): @@ -43,12 +46,54 @@ def run_no_outputs(cmd, **kwargs): **kwargs) -def get_parallel_threads() -> int: - """Returns the number of idle CPUs.""" +def _get_thread_multiplier(cloud_str: Optional[str] = None) -> int: + # If using Kubernetes, we use 4x the number of cores. + if cloud_str and cloud_str.lower() == 'kubernetes': + return 4 + return 1 + + +def get_max_workers_for_file_mounts(common_file_mounts: Dict[str, str], + cloud_str: Optional[str] = None) -> int: + global _fd_limit_warning_shown + fd_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) + + # Raise warning for low fd_limit (only once) + if fd_limit < 1024 and not _fd_limit_warning_shown: + logger.warning( + f'Open file descriptor limit ({fd_limit}) is low. File sync to ' + 'remote clusters may be slow. Consider increasing the limit using ' + '`ulimit -n ` or modifying system limits.') + _fd_limit_warning_shown = True + + fd_per_rsync = 5 + for src in common_file_mounts.values(): + if os.path.isdir(src): + # Assume that each file/folder under src takes 5 file descriptors + # on average. + fd_per_rsync = max(fd_per_rsync, len(os.listdir(src)) * 5) + + # Reserve some file descriptors for the system and other processes + fd_reserve = 100 + + max_workers = (fd_limit - fd_reserve) // fd_per_rsync + # At least 1 worker, and avoid too many workers overloading the system. + num_threads = get_parallel_threads(cloud_str) + max_workers = min(max(max_workers, 1), num_threads) + logger.debug(f'Using {max_workers} workers for file mounts.') + return max_workers + + +def get_parallel_threads(cloud_str: Optional[str] = None) -> int: + """Returns the number of threads to use for parallel execution. + + Args: + cloud_str: The cloud + """ cpu_count = os.cpu_count() if cpu_count is None: cpu_count = 1 - return max(4, cpu_count - 1) + return max(4, cpu_count - 1) * _get_thread_multiplier(cloud_str) def run_in_parallel(func: Callable, diff --git a/tests/test_smoke.py b/tests/test_smoke.py index ce93c3bfa30..574dae21ea0 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -25,6 +25,7 @@ # Change cloud for generic tests to aws # > pytest tests/test_smoke.py --generic-cloud aws +import enum import inspect import json import os @@ -58,8 +59,11 @@ from sky.data import data_utils from sky.data import storage as storage_lib from sky.data.data_utils import Rclone +from sky.jobs.state import ManagedJobStatus from sky.skylet import constants from sky.skylet import events +from sky.skylet.job_lib import JobStatus +from sky.status_lib import ClusterStatus from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import subprocess_utils @@ -95,6 +99,165 @@ 'sleep 10; s=$(sky jobs queue);' 'echo "Waiting for job to stop RUNNING"; echo "$s"; done') +# Cluster functions +_ALL_JOB_STATUSES = "|".join([status.value for status in JobStatus]) +_ALL_CLUSTER_STATUSES = "|".join([status.value for status in ClusterStatus]) +_ALL_MANAGED_JOB_STATUSES = "|".join( + [status.value for status in ManagedJobStatus]) + + +def _statuses_to_str(statuses: List[enum.Enum]): + """Convert a list of enums to a string with all the values separated by |.""" + assert len(statuses) > 0, 'statuses must not be empty' + if len(statuses) > 1: + return '(' + '|'.join([status.value for status in statuses]) + ')' + else: + return statuses[0].value + + +_WAIT_UNTIL_CLUSTER_STATUS_CONTAINS = ( + # A while loop to wait until the cluster status + # becomes certain status, with timeout. + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for cluster status \'{cluster_status}\'"; exit 1; ' + 'fi; ' + 'current_status=$(sky status {cluster_name} --refresh | ' + 'awk "/^{cluster_name}/ ' + '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_CLUSTER_STATUSES + + ')$/) print \$i}}"); ' + 'if [[ "$current_status" =~ {cluster_status} ]]; ' + 'then echo "Target cluster status {cluster_status} reached."; break; fi; ' + 'echo "Waiting for cluster status to become {cluster_status}, current status: $current_status"; ' + 'sleep 10; ' + 'done') + + +def _get_cmd_wait_until_cluster_status_contains( + cluster_name: str, cluster_status: List[ClusterStatus], timeout: int): + return _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.format( + cluster_name=cluster_name, + cluster_status=_statuses_to_str(cluster_status), + timeout=timeout) + + +def _get_cmd_wait_until_cluster_status_contains_wildcard( + cluster_name_wildcard: str, cluster_status: List[ClusterStatus], + timeout: int): + wait_cmd = _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.replace( + 'sky status {cluster_name}', + 'sky status "{cluster_name}"').replace('awk "/^{cluster_name}/', + 'awk "/^{cluster_name_awk}/') + return wait_cmd.format(cluster_name=cluster_name_wildcard, + cluster_name_awk=cluster_name_wildcard.replace( + '*', '.*'), + cluster_status=_statuses_to_str(cluster_status), + timeout=timeout) + + +_WAIT_UNTIL_CLUSTER_IS_NOT_FOUND = ( + # A while loop to wait until the cluster is not found or timeout + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for cluster to be removed"; exit 1; ' + 'fi; ' + 'if sky status -r {cluster_name}; sky status {cluster_name} | grep "{cluster_name} not found"; then ' + ' echo "Cluster {cluster_name} successfully removed."; break; ' + 'fi; ' + 'echo "Waiting for cluster {cluster_name} to be removed..."; ' + 'sleep 10; ' + 'done') + + +def _get_cmd_wait_until_cluster_is_not_found(cluster_name: str, timeout: int): + return _WAIT_UNTIL_CLUSTER_IS_NOT_FOUND.format(cluster_name=cluster_name, + timeout=timeout) + + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID = ( + # A while loop to wait until the job status + # contains certain status, with timeout. + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for job status \'{job_status}\'"; exit 1; ' + 'fi; ' + 'current_status=$(sky queue {cluster_name} | ' + 'awk "\\$1 == \\"{job_id}\\" ' + '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_JOB_STATUSES + + ')$/) print \$i}}"); ' + 'found=0; ' # Initialize found variable outside the loop + 'while read -r line; do ' # Read line by line + ' if [[ "$line" =~ {job_status} ]]; then ' # Check each line + ' echo "Target job status {job_status} reached."; ' + ' found=1; ' + ' break; ' # Break inner loop + ' fi; ' + 'done <<< "$current_status"; ' + 'if [ "$found" -eq 1 ]; then break; fi; ' # Break outer loop if match found + 'echo "Waiting for job status to contains {job_status}, current status: $current_status"; ' + 'sleep 10; ' + 'done') + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace( + 'awk "\\$1 == \\"{job_id}\\"', 'awk "') + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace( + 'awk "\\$1 == \\"{job_id}\\"', 'awk "\\$2 == \\"{job_name}\\"') + + +def _get_cmd_wait_until_job_status_contains_matching_job_id( + cluster_name: str, job_id: str, job_status: List[JobStatus], + timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.format( + cluster_name=cluster_name, + job_id=job_id, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +def _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name: str, job_status: List[JobStatus], timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB.format( + cluster_name=cluster_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +def _get_cmd_wait_until_job_status_contains_matching_job_name( + cluster_name: str, job_name: str, job_status: List[JobStatus], + timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format( + cluster_name=cluster_name, + job_name=job_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +# Managed job functions + +_WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.replace( + 'sky queue {cluster_name}', 'sky jobs queue').replace( + 'awk "\\$2 == \\"{job_name}\\"', + 'awk "\\$2 == \\"{job_name}\\" || \\$3 == \\"{job_name}\\"').replace( + _ALL_JOB_STATUSES, _ALL_MANAGED_JOB_STATUSES) + + +def _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name: str, job_status: List[JobStatus], timeout: int): + return _WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format( + job_name=job_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +# After the timeout, the cluster will stop if autostop is set, and our check +# should be more than the timeout. To address this, we extend the timeout by +# _BUMP_UP_SECONDS before exiting. +_BUMP_UP_SECONDS = 35 + DEFAULT_CMD_TIMEOUT = 15 * 60 @@ -399,7 +562,6 @@ def test_launch_fast_with_autostop(generic_cloud: str): # Azure takes ~ 7m15s (435s) to autostop a VM, so here we use 600 to ensure # the VM is stopped. autostop_timeout = 600 if generic_cloud == 'azure' else 250 - test = Test( 'test_launch_fast_with_autostop', [ @@ -407,11 +569,15 @@ def test_launch_fast_with_autostop(generic_cloud: str): f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', f'sky logs {name} 1 --status', f'sky status -r {name} | grep UP', - f'sleep {autostop_timeout}', # Ensure cluster is stopped - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', - + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=autostop_timeout), + # Even the cluster is stopped, cloud platform may take a while to + # delete the VM. + f'sleep {_BUMP_UP_SECONDS}', # Launch again. Do full output validation - we expect the cluster to re-launch f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', f'sky logs {name} 2 --status', @@ -448,6 +614,7 @@ def test_aws_region(): @pytest.mark.aws def test_aws_with_ssh_proxy_command(): name = _get_cluster_name() + with tempfile.NamedTemporaryFile(mode='w') as f: f.write( textwrap.dedent(f"""\ @@ -469,10 +636,18 @@ def test_aws_with_ssh_proxy_command(): f'sky jobs launch -n {name}-0 --cloud aws --cpus 2 --use-spot -y echo hi', # Wait other tests to create the job controller first, so that # the job controller is not launched with proxy command. - 'timeout 300s bash -c "until sky status sky-jobs-controller* | grep UP; do sleep 1; done"', + _get_cmd_wait_until_cluster_status_contains_wildcard( + cluster_name_wildcard='sky-jobs-controller-*', + cluster_status=[ClusterStatus.UP], + timeout=300), f'export SKYPILOT_CONFIG={f.name}; sky jobs launch -n {name} --cpus 2 --cloud aws --region us-east-1 -yd echo hi', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name} | grep "STARTING\|RUNNING\|SUCCEEDED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + ManagedJobStatus.SUCCEEDED, ManagedJobStatus.RUNNING, + ManagedJobStatus.STARTING + ], + timeout=300), ], f'sky down -y {name} jump-{name}; sky jobs cancel -y -n {name}', ) @@ -842,6 +1017,12 @@ def test_clone_disk_aws(): f'sky launch -y -c {name} --cloud aws --region us-east-2 --retry-until-up "echo hello > ~/user_file.txt"', f'sky launch --clone-disk-from {name} -y -c {name}-clone && exit 1 || true', f'sky stop {name} -y', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=60), + # Wait for EC2 instance to be in stopped state. + # TODO: event based wait. 'sleep 60', f'sky launch --clone-disk-from {name} -y -c {name}-clone --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"', f'sky launch --clone-disk-from {name} -y -c {name}-clone-2 --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"', @@ -888,8 +1069,8 @@ def test_gcp_mig(): # Check MIG exists. f'gcloud compute instance-groups managed list --format="value(name)" | grep "^sky-mig-{name}"', f'sky autostop -i 0 --down -y {name}', - 'sleep 120', - f'sky status -r {name}; sky status {name} | grep "{name} not found"', + _get_cmd_wait_until_cluster_is_not_found(cluster_name=name, + timeout=120), f'gcloud compute instance-templates list | grep "sky-it-{name}"', # Launch again with the same region. The original instance template # should be removed. @@ -956,8 +1137,10 @@ def test_custom_default_conda_env(generic_cloud: str): f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml', f'sky logs {name} 2 --status', f'sky autostop -y -i 0 {name}', - 'sleep 60', - f'sky status -r {name} | grep "STOPPED"', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=80), f'sky start -y {name}', f'sky logs {name} 2 --no-follow | grep -E "myenv\\s+\\*"', f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml', @@ -978,7 +1161,10 @@ def test_stale_job(generic_cloud: str): f'sky launch -y -c {name} --cloud {generic_cloud} "echo hi"', f'sky exec {name} -d "echo start; sleep 10000"', f'sky stop {name} -y', - 'sleep 100', # Ensure this is large enough, else GCP leaks. + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=100), f'sky start {name} -y', f'sky logs {name} 1 --status', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', @@ -1006,13 +1192,18 @@ def test_aws_stale_job_manual_restart(): '--output text`; ' f'aws ec2 stop-instances --region {region} ' '--instance-ids $id', - 'sleep 40', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=40), f'sky launch -c {name} -y "echo hi"', f'sky logs {name} 1 --status', f'sky logs {name} 3 --status', # Ensure the skylet updated the stale job status. - f'sleep {events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS}', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', + _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name=name, + job_status=[JobStatus.FAILED_DRIVER], + timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS), ], f'sky down -y {name}', ) @@ -1042,8 +1233,10 @@ def test_gcp_stale_job_manual_restart(): f'sky logs {name} 1 --status', f'sky logs {name} 3 --status', # Ensure the skylet updated the stale job status. - f'sleep {events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS}', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', + _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name=name, + job_status=[JobStatus.FAILED_DRIVER], + timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS) ], f'sky down -y {name}', ) @@ -1720,6 +1913,7 @@ def test_large_job_queue(generic_cloud: str): f'for i in `seq 1 75`; do sky exec {name} -n {name}-$i -d "echo $i; sleep 100000000"; done', f'sky cancel -y {name} 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16', 'sleep 90', + # Each job takes 0.5 CPU and the default VM has 8 CPUs, so there should be 8 / 0.5 = 16 jobs running. # The first 16 jobs are canceled, so there should be 75 - 32 = 43 jobs PENDING. f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep -v grep | grep PENDING | wc -l | grep 43', @@ -1861,7 +2055,13 @@ def test_multi_echo(generic_cloud: str): f'until sky logs {name} 32 --status; do echo "Waiting for job 32 to finish..."; sleep 1; done', ] + # Ensure jobs succeeded. - [f'sky logs {name} {i + 1} --status' for i in range(32)] + + [ + _get_cmd_wait_until_job_status_contains_matching_job_id( + cluster_name=name, + job_id=i + 1, + job_status=[JobStatus.SUCCEEDED], + timeout=120) for i in range(32) + ] + # Ensure monitor/autoscaler didn't crash on the 'assert not # unfulfilled' error. If process not found, grep->ssh returns 1. [f'ssh {name} \'ps aux | grep "[/]"monitor.py\''], @@ -2433,12 +2633,17 @@ def test_gcp_start_stop(): f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"', # Ensure the raylet process has the correct file descriptor limit. f'sky logs {name} 3 --status', # Ensure the job succeeded. f'sky stop -y {name}', - f'sleep 20', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=40), f'sky start -y {name} -i 1', f'sky exec {name} examples/gcp_start_stop.yaml', f'sky logs {name} 4 --status', # Ensure the job succeeded. - 'sleep 180', - f'sky status -r {name} | grep "INIT\|STOPPED"', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED, ClusterStatus.INIT], + timeout=200), ], f'sky down -y {name}', ) @@ -2461,9 +2666,11 @@ def test_azure_start_stop(): f'sky start -y {name} -i 1', f'sky exec {name} examples/azure_start_stop.yaml', f'sky logs {name} 3 --status', # Ensure the job succeeded. - 'sleep 260', - f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' - f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED, ClusterStatus.INIT], + timeout=280) + + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}', ], f'sky down -y {name}', timeout=30 * 60, # 30 mins @@ -2499,8 +2706,10 @@ def test_autostop(generic_cloud: str): f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is STOPPED. - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=autostop_timeout), # Ensure the cluster is UP and the autostop setting is reset ('-'). f'sky start -y {name}', @@ -2516,8 +2725,10 @@ def test_autostop(generic_cloud: str): f'sky autostop -y {name} -i 1', # Should restart the timer. 'sleep 40', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=autostop_timeout), # Test restarting the idleness timer via exec: f'sky start -y {name}', @@ -2526,9 +2737,10 @@ def test_autostop(generic_cloud: str): 'sleep 45', # Almost reached the threshold. f'sky exec {name} echo hi', # Should restart the timer. 'sleep 45', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=autostop_timeout + _BUMP_UP_SECONDS), ], f'sky down -y {name}', timeout=total_timeout_minutes * 60, @@ -2745,15 +2957,19 @@ def test_stop_gcp_spot(): f'sky exec {name} -- ls myfile', f'sky logs {name} 2 --status', f'sky autostop {name} -i0 -y', - 'sleep 90', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=90), f'sky start {name} -y', f'sky exec {name} -- ls myfile', f'sky logs {name} 3 --status', # -i option at launch should go through: f'sky launch -c {name} -i0 -y', - 'sleep 120', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=120), ], f'sky down -y {name}', ) @@ -2773,14 +2989,25 @@ def test_managed_jobs(generic_cloud: str): [ f'sky jobs launch -n {name}-1 --cloud {generic_cloud} examples/managed_job.yaml -y -d', f'sky jobs launch -n {name}-2 --cloud {generic_cloud} examples/managed_job.yaml -y -d', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep "PENDING\|SUBMITTED\|STARTING\|RUNNING"', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "PENDING\|SUBMITTED\|STARTING\|RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-1', + job_status=[ + ManagedJobStatus.PENDING, ManagedJobStatus.SUBMITTED, + ManagedJobStatus.STARTING, ManagedJobStatus.RUNNING + ], + timeout=60), + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[ + ManagedJobStatus.PENDING, ManagedJobStatus.SUBMITTED, + ManagedJobStatus.STARTING, ManagedJobStatus.RUNNING + ], + timeout=60), f'sky jobs cancel -y -n {name}-1', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep CANCELLED', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-1', + job_status=[ManagedJobStatus.CANCELLED], + timeout=230), # Test the functionality for logging. f's=$(sky jobs logs -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "start counting"', f's=$(sky jobs logs --controller -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "Cluster launched:"', @@ -2850,9 +3077,11 @@ def test_managed_jobs_failed_setup(generic_cloud: str): 'managed_jobs_failed_setup', [ f'sky jobs launch -n {name} --cloud {generic_cloud} -y -d tests/test_yamls/failed_setup.yaml', - 'sleep 330', # Make sure the job failed quickly. - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.FAILED_SETUP], + timeout=330 + _BUMP_UP_SECONDS), ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -2875,7 +3104,10 @@ def test_managed_jobs_pipeline_failed_setup(generic_cloud: str): 'managed_jobs_pipeline_failed_setup', [ f'sky jobs launch -n {name} -y -d tests/test_yamls/failed_setup_pipeline.yaml', - 'sleep 600', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.FAILED_SETUP], + timeout=600), # Make sure the job failed quickly. f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"', # Task 0 should be SUCCEEDED. @@ -2909,8 +3141,10 @@ def test_managed_jobs_recovery_aws(aws_config_region): 'managed_jobs_recovery_aws', [ f'sky jobs launch --cloud aws --region {region} --use-spot -n {name} "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=600), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' @@ -2920,8 +3154,10 @@ def test_managed_jobs_recovery_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -2949,15 +3185,19 @@ def test_managed_jobs_recovery_gcp(): 'managed_jobs_recovery_gcp', [ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --cpus 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=300), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -2980,8 +3220,10 @@ def test_managed_jobs_pipeline_recovery_aws(aws_config_region): 'managed_jobs_pipeline_recovery_aws', [ f'sky jobs launch -n {name} tests/test_yamls/pipeline_aws.yaml -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids', # Terminate the cluster manually. @@ -3000,8 +3242,10 @@ def test_managed_jobs_pipeline_recovery_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new', f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new', @@ -3031,8 +3275,10 @@ def test_managed_jobs_pipeline_recovery_gcp(): 'managed_jobs_pipeline_recovery_gcp', [ f'sky jobs launch -n {name} tests/test_yamls/pipeline_gcp.yaml -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids', # Terminate the cluster manually. @@ -3043,8 +3289,10 @@ def test_managed_jobs_pipeline_recovery_gcp(): f'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`; {terminate_cmd}'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new', f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new', @@ -3070,8 +3318,12 @@ def test_managed_jobs_recovery_default_resources(generic_cloud: str): 'managed-spot-recovery-default-resources', [ f'sky jobs launch -n {name} --cloud {generic_cloud} --use-spot "sleep 30 && sudo shutdown now && sleep 1000" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING\|RECOVERING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + ManagedJobStatus.RUNNING, ManagedJobStatus.RECOVERING + ], + timeout=360), ], f'sky jobs cancel -y -n {name}', timeout=25 * 60, @@ -3091,8 +3343,10 @@ def test_managed_jobs_recovery_multi_node_aws(aws_config_region): 'managed_jobs_recovery_multi_node_aws', [ f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 450', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=450), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' @@ -3103,8 +3357,10 @@ def test_managed_jobs_recovery_multi_node_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 560', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=560), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -3132,15 +3388,19 @@ def test_managed_jobs_recovery_multi_node_gcp(): 'managed_jobs_recovery_multi_node_gcp', [ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 420', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.RUNNING], + timeout=560), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -3165,13 +3425,17 @@ def test_managed_jobs_cancellation_aws(aws_config_region): [ # Test cancellation during spot cluster being launched. f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot "sleep 1000" -y -d', - 'sleep 60', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING\|RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + ManagedJobStatus.STARTING, ManagedJobStatus.RUNNING + ], + timeout=60 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), (f's=$(aws ec2 describe-instances --region {region} ' f'--filters Name=tag:ray-cluster-name,Values={name_on_cloud}-* ' f'--query Reservations[].Instances[].State[].Name ' @@ -3179,12 +3443,16 @@ def test_managed_jobs_cancellation_aws(aws_config_region): ), # Test cancelling the spot cluster during spot job being setup. f'sky jobs launch --cloud aws --region {region} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d', - 'sleep 300', + # The job is set up in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}-2', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), (f's=$(aws ec2 describe-instances --region {region} ' f'--filters Name=tag:ray-cluster-name,Values={name_2_on_cloud}-* ' f'--query Reservations[].Instances[].State[].Name ' @@ -3192,8 +3460,11 @@ def test_managed_jobs_cancellation_aws(aws_config_region): ), # Test cancellation during spot job is recovering. f'sky jobs launch --cloud aws --region {region} -n {name}-3 --use-spot "sleep 1000" -y -d', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RUNNING"', + # The job is running in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), # Terminate the cluster manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' f'aws ec2 describe-instances --region {region} ' @@ -3203,10 +3474,10 @@ def test_managed_jobs_cancellation_aws(aws_config_region): _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'), f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"', f'sky jobs cancel -y -n {name}-3', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # The cluster should be terminated (shutting-down) after cancellation. We don't use the `=` operator here because # there can be multiple VM with the same name due to the recovery. (f's=$(aws ec2 describe-instances --region {region} ' @@ -3241,34 +3512,42 @@ def test_managed_jobs_cancellation_gcp(): [ # Test cancellation during spot cluster being launched. f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot "sleep 1000" -y -d', - 'sleep 60', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.STARTING], + timeout=60 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # Test cancelling the spot cluster during spot job being setup. f'sky jobs launch --cloud gcp --zone {zone} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d', - 'sleep 300', + # The job is set up in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}-2', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # Test cancellation during spot job is recovering. f'sky jobs launch --cloud gcp --zone {zone} -n {name}-3 --use-spot "sleep 1000" -y -d', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), # Terminate the cluster manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'), f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"', f'sky jobs cancel -y -n {name}-3', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # The cluster should be terminated (STOPPING) after cancellation. We don't use the `=` operator here because # there can be multiple VM with the same name due to the recovery. (f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "PROVISIONING|STAGING|RUNNING|REPAIRING|TERMINATED|SUSPENDING|SUSPENDED|SUSPENDED"' @@ -3358,8 +3637,10 @@ def test_managed_jobs_storage(generic_cloud: str): *STORAGE_SETUP_COMMANDS, f'sky jobs launch -n {name}{use_spot} --cloud {generic_cloud}{region_flag} {file_path} -y', region_validation_cmd, # Check if the bucket is created in the correct region - 'sleep 60', # Wait the spot queue to be updated - f'{_GET_JOB_QUEUE} | grep {name} | grep SUCCEEDED', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.SUCCEEDED], + timeout=60 + _BUMP_UP_SECONDS), f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{storage_name}\')].Name" --output text | wc -l) -eq 0 ]', # Check if file was written to the mounted output bucket output_check_cmd @@ -3383,10 +3664,17 @@ def test_managed_jobs_tpu(): 'test-spot-tpu', [ f'sky jobs launch -n {name} --use-spot examples/tpu/tpuvm_mnist.yaml -y -d', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep STARTING', - 'sleep 900', # TPU takes a while to launch - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING\|SUCCEEDED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.STARTING], + timeout=60 + _BUMP_UP_SECONDS), + # TPU takes a while to launch + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + ManagedJobStatus.RUNNING, ManagedJobStatus.SUCCEEDED + ], + timeout=900 + _BUMP_UP_SECONDS), ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -3404,8 +3692,10 @@ def test_managed_jobs_inline_env(generic_cloud: str): 'test-managed-jobs-inline-env', [ f'sky jobs launch -n {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"', - 'sleep 20', - f'{_GET_JOB_QUEUE} | grep {name} | grep SUCCEEDED', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ManagedJobStatus.SUCCEEDED], + timeout=20 + _BUMP_UP_SECONDS), ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -3512,8 +3802,10 @@ def test_azure_start_stop_two_nodes(): f'sky start -y {name} -i 1', f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml', f'sky logs {name} 2 --status', # Ensure the job succeeded. - 'sleep 200', - f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.INIT, ClusterStatus.STOPPED], + timeout=200 + _BUMP_UP_SECONDS) + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' ], f'sky down -y {name}', @@ -4524,7 +4816,10 @@ def test_core_api_sky_launch_fast(generic_cloud: str): idle_minutes_to_autostop=1, fast=True) # Sleep to let the cluster autostop - time.sleep(120) + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ClusterStatus.STOPPED], + timeout=120) # Run it again - should work with fast=True sky.launch(task, cluster_name=name, diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index 48e47a6007c..be40cc55723 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -172,7 +172,7 @@ def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: idle_minutes_to_autostop=None) -@mock.patch('sky.provision.kubernetes.utils.get_all_kube_config_context_names', +@mock.patch('sky.provision.kubernetes.utils.get_all_kube_context_names', return_value=['kind-skypilot', 'kind-skypilot2', 'kind-skypilot3']) def test_dynamic_kubernetes_contexts_policy(add_example_policy_paths, task): _, config = _load_task_and_apply_policy( diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index 5006fc454aa..65c90544f49 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -140,6 +140,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) expected_config_base = { @@ -180,6 +181,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) assert config == expected_config, ('unexpected resource ' 'variables generated') @@ -195,6 +197,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) assert config == expected_config, ('unexpected resource ' 'variables generated')