diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 065924070a7..97a69b36b36 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -22,29 +22,35 @@ jobs: python-version: ["3.8"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install yapf==0.32.0 - pip install toml==0.10.2 - pip install black==22.10.0 - pip install isort==5.12.0 + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install yapf==0.32.0 + uv pip install toml==0.10.2 + uv pip install black==22.10.0 + uv pip install isort==5.12.0 - name: Running yapf run: | + source ~/test-env/bin/activate yapf --diff --recursive ./ --exclude 'sky/skylet/ray_patches/**' \ --exclude 'sky/skylet/providers/ibm/**' - name: Running black run: | + source ~/test-env/bin/activate black --diff --check sky/skylet/providers/ibm/ - name: Running isort for black formatted files run: | + source ~/test-env/bin/activate isort --diff --check --profile black -l 88 -m 3 \ sky/skylet/providers/ibm/ - name: Running isort for yapf formatted files run: | + source ~/test-env/bin/activate isort --diff --check ./ --sg 'sky/skylet/ray_patches/**' \ --sg 'sky/skylet/providers/ibm/**' diff --git a/.github/workflows/mypy-generic.yml b/.github/workflows/mypy-generic.yml deleted file mode 100644 index 3d7b3ce5e57..00000000000 --- a/.github/workflows/mypy-generic.yml +++ /dev/null @@ -1,23 +0,0 @@ -# This is needed for GitHub Actions for the "Waiting for status to be reported" problem, -# according to https://docs.github.com/en/repositories/configuring-branches-and-merges-in-your-repository/defining-the-mergeability-of-pull-requests/troubleshooting-required-status-checks -name: mypy - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - master - - 'releases/**' - pull_request: - branches: - - master - - 'releases/**' - - restapi - merge_group: - -jobs: - mypy: - runs-on: ubuntu-latest - steps: - - run: 'echo "No mypy to run"' diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index dc5c952d84c..6d400ac0ec0 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -12,6 +12,8 @@ on: - master - 'releases/**' - restapi + merge_group: + jobs: mypy: runs-on: ubuntu-latest @@ -20,15 +22,18 @@ jobs: python-version: ["3.8"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install mypy==$(grep mypy requirements-dev.txt | cut -d'=' -f3) - pip install $(grep types- requirements-dev.txt | tr '\n' ' ') + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install mypy==$(grep mypy requirements-dev.txt | cut -d'=' -f3) + uv pip install $(grep types- requirements-dev.txt | tr '\n' ' ') - name: Running mypy run: | + source ~/test-env/bin/activate mypy $(cat tests/mypy_files.txt) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 19dd96a1469..480cdf157d3 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -22,16 +22,20 @@ jobs: python-version: ["3.8"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install ".[all]" - pip install pylint==2.14.5 - pip install pylint-quotes==0.2.3 + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install --prerelease=allow "azure-cli>=2.65.0" + uv pip install ".[all]" + uv pip install pylint==2.14.5 + uv pip install pylint-quotes==0.2.3 - name: Analysing the code with pylint run: | + source ~/test-env/bin/activate pylint --load-plugins pylint_quotes sky diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 757bfec36d2..bface9232cf 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -35,26 +35,21 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v3 - - - name: Install Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - - - name: Cache dependencies - uses: actions/cache@v3 - if: startsWith(runner.os, 'Linux') - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-pytest-${{ matrix.python-version }} - restore-keys: | - ${{ runner.os }}-pip-pytest-${{ matrix.python-version }} - - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e ".[all]" - pip install pytest pytest-xdist pytest-env>=0.6 memory-profiler==0.61.0 - + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install --prerelease=allow "azure-cli>=2.65.0" + # Use -e to include examples and tests folder in the path for unit + # tests to access them. + uv pip install -e ".[all]" + uv pip install pytest pytest-xdist pytest-env>=0.6 memory-profiler==0.61.0 - name: Run tests with pytest - run: SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 pytest -n 0 --dist no ${{ matrix.test-path }} + run: | + source ~/test-env/bin/activate + SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 pytest -n 0 --dist no ${{ matrix.test-path }} diff --git a/.github/workflows/test-doc-build.yml b/.github/workflows/test-doc-build.yml index 706aa071706..954a1b2c017 100644 --- a/.github/workflows/test-doc-build.yml +++ b/.github/workflows/test-doc-build.yml @@ -11,27 +11,32 @@ on: branches: - master - 'releases/**' + - restapi merge_group: jobs: - format: + doc-build: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.10"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install . + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install --prerelease=allow "azure-cli>=2.65.0" + uv pip install ".[all]" cd docs - pip install -r ./requirements-docs.txt + uv pip install -r ./requirements-docs.txt - name: Build documentation run: | + source ~/test-env/bin/activate cd ./docs ./build.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..db40b03b5fa --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,74 @@ +# Ensure this configuration aligns with format.sh and requirements.txt +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/psf/black + rev: 22.10.0 # Match the version from requirements + hooks: + - id: black + name: black (IBM specific) + files: "^sky/skylet/providers/ibm/.*" # Match only files in the IBM directory + +- repo: https://github.com/pycqa/isort + rev: 5.12.0 # Match the version from requirements + hooks: + # First isort command + - id: isort + name: isort (general) + args: + - "--sg=build/**" # Matches "${ISORT_YAPF_EXCLUDES[@]}" + - "--sg=sky/skylet/providers/ibm/**" + files: "^(sky|tests|examples|llm|docs)/.*" # Only match these directories + # Second isort command + - id: isort + name: isort (IBM specific) + args: + - "--profile=black" + - "-l=88" + - "-m=3" + files: "^sky/skylet/providers/ibm/.*" # Only match IBM-specific directory + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.991 # Match the version from requirements + hooks: + - id: mypy + args: + # From tests/mypy_files.txt + - "sky" + - "--exclude" + - "sky/benchmark|sky/callbacks|sky/skylet/providers/azure|sky/resources.py|sky/backends/monkey_patches" + pass_filenames: false + additional_dependencies: + - types-PyYAML + - types-requests<2.31 # Match the condition in requirements.txt + - types-setuptools + - types-cachetools + - types-pyvmomi + +- repo: https://github.com/google/yapf + rev: v0.32.0 # Match the version from requirements + hooks: + - id: yapf + name: yapf + exclude: (build/.*|sky/skylet/providers/ibm/.*) # Matches exclusions from the script + args: ['--recursive', '--parallel'] # Only necessary flags + additional_dependencies: [toml==0.10.2] + +- repo: https://github.com/pylint-dev/pylint + rev: v2.14.5 # Match the version from requirements + hooks: + - id: pylint + additional_dependencies: + - pylint-quotes==0.2.3 # Match the version from requirements + name: pylint + args: + - --rcfile=.pylintrc # Use your custom pylint configuration + - --load-plugins=pylint_quotes # Load the pylint-quotes plugin + files: ^sky/ # Only include files from the 'sky/' directory + exclude: ^sky/skylet/providers/ibm/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 73716f0994e..d204c27969e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -78,6 +78,7 @@ It has some convenience features which you might find helpful (see [Dockerfile]( - If relevant, add tests for your changes. For changes that touch the core system, run the [smoke tests](#testing) and ensure they pass. - Follow the [Google style guide](https://google.github.io/styleguide/pyguide.html). - Ensure code is properly formatted by running [`format.sh`](https://github.com/skypilot-org/skypilot/blob/master/format.sh). + - [Optional] You can also install pre-commit hooks by running `pre-commit install` to automatically format your code on commit. - Push your changes to your fork and open a pull request in the SkyPilot repository. - In the PR description, write a `Tested:` section to describe relevant tests performed. diff --git a/Dockerfile_k8s b/Dockerfile_k8s index 45625871078..f031dff3668 100644 --- a/Dockerfile_k8s +++ b/Dockerfile_k8s @@ -7,7 +7,7 @@ ARG DEBIAN_FRONTEND=noninteractive # Initialize conda for root user, install ssh and other local dependencies RUN apt update -y && \ - apt install git gcc rsync sudo patch openssh-server pciutils nano fuse socat netcat curl -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils nano fuse socat netcat-openbsd curl -y && \ rm -rf /var/lib/apt/lists/* && \ apt remove -y python3 && \ conda init diff --git a/Dockerfile_k8s_gpu b/Dockerfile_k8s_gpu index 09570d102df..6277e7f8d12 100644 --- a/Dockerfile_k8s_gpu +++ b/Dockerfile_k8s_gpu @@ -7,7 +7,7 @@ ARG DEBIAN_FRONTEND=noninteractive # We remove cuda lists to avoid conflicts with the cuda version installed by ray RUN rm -rf /etc/apt/sources.list.d/cuda* && \ apt update -y && \ - apt install git gcc rsync sudo patch openssh-server pciutils nano fuse unzip socat netcat curl -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils nano fuse unzip socat netcat-openbsd curl -y && \ rm -rf /var/lib/apt/lists/* # Setup SSH and generate hostkeys @@ -36,6 +36,7 @@ SHELL ["/bin/bash", "-c"] # Install conda and other dependencies # Keep the conda and Ray versions below in sync with the ones in skylet.constants +# Keep this section in sync with the custom image optimization recommendations in our docs (kubernetes-getting-started.rst) RUN curl https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -o Miniconda3-Linux-x86_64.sh && \ bash Miniconda3-Linux-x86_64.sh -b && \ eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true && conda activate base && \ diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst index 69303a582e2..deb2307b67b 100644 --- a/docs/source/getting-started/installation.rst +++ b/docs/source/getting-started/installation.rst @@ -267,6 +267,14 @@ The :code:`~/.oci/config` file should contain the following fields: # Note that we should avoid using full home path for the key_file configuration, e.g. use ~/.oci instead of /home/username/.oci key_file=~/.oci/oci_api_key.pem +By default, the provisioned nodes will be in the root `compartment `__. To specify the `compartment `_ other than root, create/edit the file :code:`~/.sky/config.yaml`, put the compartment's OCID there, as the following: + +.. code-block:: text + + oci: + default: + compartment_ocid: ocid1.compartment.oc1..aaaaaaaa...... + Lambda Cloud ~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/reference/comparison.rst b/docs/source/reference/comparison.rst index e9bffabba68..23985e5081b 100644 --- a/docs/source/reference/comparison.rst +++ b/docs/source/reference/comparison.rst @@ -46,7 +46,7 @@ SkyPilot provides faster iteration for interactive development. For example, a c * :strong:`With SkyPilot, a single command (`:literal:`sky launch`:strong:`) takes care of everything.` Behind the scenes, SkyPilot provisions pods, installs all required dependencies, executes the job, returns logs, and provides SSH and VSCode access to debug. -.. figure:: https://blog.skypilot.co/ai-on-kubernetes/images/k8s_vs_skypilot_iterative_v2.png +.. figure:: https://i.imgur.com/xfCfz4N.png :align: center :width: 95% :alt: Iterative Development with Kubernetes vs SkyPilot diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 6edc2b1bc68..10a6b1bc90f 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -17,6 +17,7 @@ Spec: ``~/.sky/config.yaml`` Available fields and semantics: .. code-block:: yaml + # Endpoint of the SkyPilot API server (optional). # # This is used to connect to the SkyPilot API server. @@ -252,6 +253,10 @@ Available fields and semantics: # instances. SkyPilot will auto-create and reuse a service account (IAM # role) for AWS instances. # + # NO_UPLOAD: No credentials will be uploaded to the pods. Useful for + # avoiding overriding any existing credentials that may be automounted on + # the cluster. + # # Customized service account (IAM role): or # - : apply the service account with the specified name to all instances. # Example: @@ -271,7 +276,8 @@ Available fields and semantics: # # - This only affects AWS instances. Local AWS credentials will still be # uploaded to non-AWS instances (since those instances may need to access - # AWS resources). + # AWS resources). To fully disable credential upload, set + # `remote_identity: NO_UPLOAD`. # - If the SkyPilot jobs/serve controller is on AWS, this setting will make # non-AWS managed jobs / non-AWS service replicas fail to access any # resources on AWS (since the controllers don't have AWS credential @@ -414,11 +420,16 @@ Available fields and semantics: # instances. SkyPilot will auto-create and reuse a service account for GCP # instances. # + # NO_UPLOAD: No credentials will be uploaded to the pods. Useful for + # avoiding overriding any existing credentials that may be automounted on + # the cluster. + # # Two caveats of SERVICE_ACCOUNT for multicloud users: # # - This only affects GCP instances. Local GCP credentials will still be # uploaded to non-GCP instances (since those instances may need to access - # GCP resources). + # GCP resources). To fully disable credential uploads, set + # `remote_identity: NO_UPLOAD`. # - If the SkyPilot jobs/serve controller is on GCP, this setting will make # non-GCP managed jobs / non-GCP service replicas fail to access any # resources on GCP (since the controllers don't have GCP credential @@ -505,6 +516,10 @@ Available fields and semantics: # SkyPilot will auto-create and reuse a service account with necessary roles # in the user's namespace. # + # NO_UPLOAD: No credentials will be uploaded to the pods. Useful for + # avoiding overriding any existing credentials that may be automounted on + # the cluster. + # # : The name of a service account to use for all Kubernetes pods. # This service account must exist in the user's namespace and have all # necessary permissions. Refer to https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/kubernetes.html @@ -513,7 +528,8 @@ Available fields and semantics: # Using SERVICE_ACCOUNT or a custom service account only affects Kubernetes # instances. Local ~/.kube/config will still be uploaded to non-Kubernetes # instances (e.g., a serve controller on GCP or AWS may need to provision - # Kubernetes resources). + # Kubernetes resources). To fully disable credential uploads, set + # `remote_identity: NO_UPLOAD`. # # Default: 'SERVICE_ACCOUNT'. remote_identity: my-k8s-service-account diff --git a/docs/source/reference/kubernetes/kubernetes-getting-started.rst b/docs/source/reference/kubernetes/kubernetes-getting-started.rst index 0e19eb6e266..e4bbb2c8915 100644 --- a/docs/source/reference/kubernetes/kubernetes-getting-started.rst +++ b/docs/source/reference/kubernetes/kubernetes-getting-started.rst @@ -324,3 +324,32 @@ FAQs type: Directory For more details refer to :ref:`config-yaml`. + +* **I am using a custom image. How can I speed up the pod startup time?** + + You can pre-install SkyPilot dependencies in your custom image to speed up the pod startup time. Simply add these lines at the end of your Dockerfile: + + .. code-block:: dockerfile + + FROM + + # Install system dependencies + RUN apt update -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils fuse unzip socat netcat-openbsd curl -y && \ + rm -rf /var/lib/apt/lists/* + + # Install conda and other python dependencies + RUN curl https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -o Miniconda3-Linux-x86_64.sh && \ + bash Miniconda3-Linux-x86_64.sh -b && \ + eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true && conda activate base && \ + grep "# >>> conda initialize >>>" ~/.bashrc || { conda init && source ~/.bashrc; } && \ + rm Miniconda3-Linux-x86_64.sh && \ + export PIP_DISABLE_PIP_VERSION_CHECK=1 && \ + python3 -m venv ~/skypilot-runtime && \ + PYTHON_EXEC=$(echo ~/skypilot-runtime)/bin/python && \ + $PYTHON_EXEC -m pip install 'skypilot-nightly[remote,kubernetes]' 'ray[default]==2.9.3' 'pycryptodome==3.12.0' && \ + $PYTHON_EXEC -m pip uninstall skypilot-nightly -y && \ + curl -LO "https://dl.k8s.io/release/v1.28.11/bin/linux/amd64/kubectl" && \ + sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl && \ + echo 'export PATH="$PATH:$HOME/.local/bin"' >> ~/.bashrc + diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 7addcffbe3c..ff95162ac63 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -139,7 +139,7 @@ def update_current_kubernetes_clusters_from_registry(): def get_allowed_contexts(): """Mock implementation of getting allowed kubernetes contexts.""" from sky.provision.kubernetes import utils - contexts = utils.get_all_kube_config_context_names() + contexts = utils.get_all_kube_context_names() return contexts[:2] diff --git a/examples/oci/serve-qwen-7b.yaml b/examples/oci/serve-qwen-7b.yaml index 799e5a7d891..004e912b088 100644 --- a/examples/oci/serve-qwen-7b.yaml +++ b/examples/oci/serve-qwen-7b.yaml @@ -13,8 +13,8 @@ resources: setup: | conda create -n vllm python=3.12 -y conda activate vllm - pip install vllm - pip install vllm-flash-attn + pip install vllm==0.6.3.post1 + pip install vllm-flash-attn==2.6.2 run: | conda activate vllm diff --git a/sky/__init__.py b/sky/__init__.py index ac8637651f4..d2977359e59 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -86,11 +86,11 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.admin_policy import AdminPolicy from sky.admin_policy import MutatedUserRequest from sky.admin_policy import UserRequest -# from sky.api.sdk import download_logs from sky.api.sdk import autostop from sky.api.sdk import cancel from sky.api.sdk import cost_report from sky.api.sdk import down +from sky.api.sdk import download_logs from sky.api.sdk import exec # pylint: disable=redefined-builtin from sky.api.sdk import get from sky.api.sdk import job_status @@ -109,6 +109,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.data import Storage from sky.data import StorageMode from sky.data import StoreType +from sky.jobs import ManagedJobStatus # TODO (zhwu): These imports are for backward compatibility, and spot APIs # should be called with `sky.spot.xxx` instead. Remove in release 0.8.0 from sky.jobs.api.sdk import spot_cancel @@ -166,6 +167,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): 'StoreType', 'ClusterStatus', 'JobStatus', + 'ManagedJobStatus', # APIs 'Dag', 'Task', @@ -185,7 +187,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): 'cancel', 'tail_logs', 'spot_tail_logs', - # 'download_logs', + 'download_logs', 'job_status', # core APIs Spot Job Management 'spot_queue', diff --git a/sky/adaptors/kubernetes.py b/sky/adaptors/kubernetes.py index ea8fb194efa..001d397ac9e 100644 --- a/sky/adaptors/kubernetes.py +++ b/sky/adaptors/kubernetes.py @@ -19,6 +19,13 @@ # Timeout to use for API calls API_TIMEOUT = 5 +DEFAULT_IN_CLUSTER_REGION = 'in-cluster' +# The name for the environment variable that stores the in-cluster context name +# for Kubernetes clusters. This is used to associate a name with the current +# context when running with in-cluster auth. If not set, the context name is +# set to DEFAULT_IN_CLUSTER_REGION. +IN_CLUSTER_CONTEXT_NAME_ENV_VAR = 'SKYPILOT_IN_CLUSTER_CONTEXT_NAME' + def _decorate_methods(obj: Any, decorator: Callable, decoration_type: str): for attr_name in dir(obj): @@ -57,16 +64,8 @@ def wrapped(*args, **kwargs): def _load_config(context: Optional[str] = None): urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - try: - # Load in-cluster config if running in a pod - # Kubernetes set environment variables for service discovery do not - # show up in SkyPilot tasks. For now, we work around by using - # DNS name instead of environment variables. - # See issue: https://github.com/skypilot-org/skypilot/issues/2287 - os.environ['KUBERNETES_SERVICE_HOST'] = 'kubernetes.default.svc' - os.environ['KUBERNETES_SERVICE_PORT'] = '443' - kubernetes.config.load_incluster_config() - except kubernetes.config.config_exception.ConfigException: + + def _load_config_from_kubeconfig(context: Optional[str] = None): try: kubernetes.config.load_kube_config(context=context) except kubernetes.config.config_exception.ConfigException as e: @@ -90,6 +89,21 @@ def _load_config(context: Optional[str] = None): with ux_utils.print_exception_no_traceback(): raise ValueError(err_str) from None + if context == in_cluster_context_name() or context is None: + try: + # Load in-cluster config if running in a pod and context is None. + # Kubernetes set environment variables for service discovery do not + # show up in SkyPilot tasks. For now, we work around by using + # DNS name instead of environment variables. + # See issue: https://github.com/skypilot-org/skypilot/issues/2287 + os.environ['KUBERNETES_SERVICE_HOST'] = 'kubernetes.default.svc' + os.environ['KUBERNETES_SERVICE_PORT'] = '443' + kubernetes.config.load_incluster_config() + except kubernetes.config.config_exception.ConfigException: + _load_config_from_kubeconfig() + else: + _load_config_from_kubeconfig(context) + @_api_logging_decorator('urllib3', logging.ERROR) @functools.lru_cache() @@ -154,3 +168,13 @@ def max_retry_error(): def stream(): return kubernetes.stream.stream + + +def in_cluster_context_name() -> Optional[str]: + """Returns the name of the in-cluster context from the environment. + + If the environment variable is not set, returns the default in-cluster + context name. + """ + return (os.environ.get(IN_CLUSTER_CONTEXT_NAME_ENV_VAR) or + DEFAULT_IN_CLUSTER_REGION) diff --git a/sky/api/requests/payloads.py b/sky/api/requests/payloads.py index 908ee7a724c..d5018eb32df 100644 --- a/sky/api/requests/payloads.py +++ b/sky/api/requests/payloads.py @@ -259,6 +259,7 @@ class JobsLogsBody(RequestBody): job_id: Optional[int] = None follow: bool = True controller: bool = False + refresh: bool = False class RequestIdBody(pydantic.BaseModel): diff --git a/sky/api/sdk.py b/sky/api/sdk.py index 2328f35651a..e66174e8685 100644 --- a/sky/api/sdk.py +++ b/sky/api/sdk.py @@ -567,7 +567,7 @@ def status( Args: cluster_names: names of clusters to get status for. If None, get status for all clusters. The cluster names specified can be in glob pattern - (e.g., 'my-cluster-*'). + (e.g., ``my-cluster-*``). refresh: whether to refresh the status of the clusters. """ # TODO(zhwu): this does not stream the logs output by logger back to the diff --git a/sky/authentication.py b/sky/authentication.py index f7221f22edf..499d276edb0 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -379,9 +379,10 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: # Add the user's public key to the SkyPilot cluster. secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name - context = kubernetes_utils.get_context_from_config(config['provider']) - if context == kubernetes_utils.IN_CLUSTER_REGION: - # If the context is set to IN_CLUSTER_REGION, we are running in a pod + context = config['provider'].get( + 'context', kubernetes_utils.get_current_kube_config_context_name()) + if context == kubernetes.in_cluster_context_name(): + # If the context is an in-cluster context name, we are running in a pod # with in-cluster configuration. We need to set the context to None # to use the mounted service account. context = None diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 40e4f713c29..0af3f31ca73 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1,5 +1,4 @@ """Util constants/functions for the backends.""" -import contextlib from datetime import datetime import enum import fnmatch @@ -99,6 +98,10 @@ CLUSTER_STATUS_LOCK_PATH = os.path.expanduser('~/.sky/.{}.lock') CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS = 20 +# Time that must elapse since the last status check before we should re-check if +# the cluster has been terminated or autostopped. +_CLUSTER_STATUS_CACHE_DURATION_SECONDS = 2 + # Filelocks for updating cluster's file_mounts. CLUSTER_FILE_MOUNTS_LOCK_PATH = os.path.expanduser( '~/.sky/.{}_file_mounts.lock') @@ -503,33 +506,68 @@ def write_cluster_config( resources_utils.ClusterName( cluster_name, cluster_name_on_cloud, - ), region, zones, dryrun) + ), region, zones, num_nodes, dryrun) config_dict = {} specific_reservations = set( skypilot_config.get_nested( (str(to_provision.cloud).lower(), 'specific_reservations'), set())) + # Remote identity handling can have 4 cases: + # 1. LOCAL_CREDENTIALS (default for most clouds): Upload local credentials + # 2. SERVICE_ACCOUNT: SkyPilot creates and manages a service account + # 3. Custom service account: Use specified service account + # 4. NO_UPLOAD: Do not upload any credentials + # + # We need to upload credentials only if LOCAL_CREDENTIALS is specified. In + # other cases, we exclude the cloud from credential file uploads after + # running required checks. assert cluster_name is not None - excluded_clouds = [] + excluded_clouds = set() remote_identity_config = skypilot_config.get_nested( (str(cloud).lower(), 'remote_identity'), None) remote_identity = schemas.get_default_remote_identity(str(cloud).lower()) if isinstance(remote_identity_config, str): remote_identity = remote_identity_config if isinstance(remote_identity_config, list): + # Some clouds (e.g., AWS) support specifying multiple service accounts + # chosen based on the cluster name. Do the matching here to pick the + # correct one. for profile in remote_identity_config: if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]): remote_identity = list(profile.values())[0] break if remote_identity != schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value: - if not cloud.supports_service_account_on_remote(): + # If LOCAL_CREDENTIALS is not specified, we add the cloud to the + # excluded_clouds set, but we must also check if the cloud supports + # service accounts. + if remote_identity == schemas.RemoteIdentityOptions.NO_UPLOAD.value: + # If NO_UPLOAD is specified, fall back to default remote identity + # for downstream logic but add it to excluded_clouds to skip + # credential file uploads. + remote_identity = schemas.get_default_remote_identity( + str(cloud).lower()) + elif not cloud.supports_service_account_on_remote(): raise exceptions.InvalidCloudConfigs( 'remote_identity: SERVICE_ACCOUNT is specified in ' f'{skypilot_config.loaded_config_path!r} for {cloud}, but it ' 'is not supported by this cloud. Remove the config or set: ' '`remote_identity: LOCAL_CREDENTIALS`.') - excluded_clouds = [cloud] + if isinstance(cloud, clouds.Kubernetes): + if skypilot_config.get_nested( + ('kubernetes', 'allowed_contexts'), None) is None: + excluded_clouds.add(cloud) + else: + excluded_clouds.add(cloud) + + for cloud_str, cloud_obj in registry.CLOUD_REGISTRY.items(): + remote_identity_config = skypilot_config.get_nested( + (cloud_str.lower(), 'remote_identity'), None) + if remote_identity_config: + if (remote_identity_config == + schemas.RemoteIdentityOptions.NO_UPLOAD.value): + excluded_clouds.add(cloud_obj) + credentials = sky_check.get_cloud_credential_file_mounts(excluded_clouds) private_key_path, _ = auth.get_or_generate_keys() @@ -635,7 +673,11 @@ def write_cluster_config( '{sky_wheel_hash}', wheel_hash).replace('{cloud}', str(cloud).lower())), - + 'skypilot_wheel_installation_commands': + constants.SKYPILOT_WHEEL_INSTALLATION_COMMANDS.replace( + '{sky_wheel_hash}', + wheel_hash).replace('{cloud}', + str(cloud).lower()), # Port of Ray (GCS server). # Ray's default port 6379 is conflicted with Redis. 'ray_port': constants.SKY_REMOTE_RAY_PORT, @@ -984,18 +1026,18 @@ def ssh_credential_from_yaml( def parallel_data_transfer_to_nodes( - runners: List[command_runner.CommandRunner], - source: Optional[str], - target: str, - cmd: Optional[str], - run_rsync: bool, - *, - action_message: str, - # Advanced options. - log_path: str = os.devnull, - stream_logs: bool = False, - source_bashrc: bool = False, -): + runners: List[command_runner.CommandRunner], + source: Optional[str], + target: str, + cmd: Optional[str], + run_rsync: bool, + *, + action_message: str, + # Advanced options. + log_path: str = os.devnull, + stream_logs: bool = False, + source_bashrc: bool = False, + num_threads: Optional[int] = None): """Runs a command on all nodes and optionally runs rsync from src->dst. Args: @@ -1007,6 +1049,7 @@ def parallel_data_transfer_to_nodes( log_path: str; Path to the log file stream_logs: bool; Whether to stream logs to stdout source_bashrc: bool; Source bashrc before running the command. + num_threads: Optional[int]; Number of threads to use. """ style = colorama.Style @@ -1047,7 +1090,7 @@ def _sync_node(runner: 'command_runner.CommandRunner') -> None: message = (f' {style.DIM}{action_message} (to {num_nodes} node{plural})' f': {origin_source} -> {target}{style.RESET_ALL}') logger.info(message) - subprocess_utils.run_in_parallel(_sync_node, runners) + subprocess_utils.run_in_parallel(_sync_node, runners, num_threads) def check_local_gpus() -> bool: @@ -1395,14 +1438,14 @@ def check_can_clone_disk_and_override_task( The task to use and the resource handle of the source cluster. Raises: - ValueError: If the source cluster does not exist. + exceptions.ClusterDoesNotExist: If the source cluster does not exist. exceptions.NotSupportedError: If the source cluster is not valid or the task is not compatible to clone disk from the source cluster. """ source_cluster_status, handle = refresh_cluster_status_handle(cluster_name) if source_cluster_status is None: with ux_utils.print_exception_no_traceback(): - raise ValueError( + raise exceptions.ClusterDoesNotExist( f'Cannot find cluster {cluster_name!r} to clone disk from.') if not isinstance(handle, backends.CloudVmRayResourceHandle): @@ -1494,21 +1537,30 @@ def check_can_clone_disk_and_override_task( return task, handle -def _maybe_acquire_lock(lock_path: str, timeout: int, acquire_lock: bool): - if acquire_lock: - # TODO(zhwu): handle timeout - return filelock.FileLock(lock_path, timeout=timeout) - else: - return contextlib.nullcontext() +def _update_cluster_status_no_lock( + cluster_name: str) -> Optional[Dict[str, Any]]: + """Update the cluster status. + The cluster status is updated by checking ray cluster and real status from + cloud. -def _update_cluster_status_or_abort( - cluster_name: str, - acquire_per_cluster_status_lock: bool, - cluster_status_lock_timeout: int = CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS -) -> Optional[Dict[str, Any]]: - """Update cluster status or abort if the cluster's record has been updated.""" - transaction_id = global_user_state.get_transaction_id(cluster_name) + The function will update the cached cluster status in the global state. For + the design of the cluster status and transition, please refer to the + sky/design_docs/cluster_status.md + + Returns: + If the cluster is terminated or does not exist, return None. Otherwise + returns the input record with status and handle potentially updated. + + Raises: + exceptions.ClusterOwnerIdentityMismatchError: if the current user is + not the same as the user who created the cluster. + exceptions.CloudUserIdentityError: if we fail to get the current user + identity. + exceptions.ClusterStatusFetchingError: the cluster status cannot be + fetched from the cloud provider or there are leaked nodes causing + the node number larger than expected. + """ record = global_user_state.get_cluster_from_name(cluster_name) if record is None: return None @@ -1592,16 +1644,11 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: # run_ray_status_to_check_all_nodes_up() is slow due to calling `ray get # head-ip/worker-ips`. record['status'] = status_lib.ClusterStatus.UP - with _maybe_acquire_lock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name), - timeout=cluster_status_lock_timeout, - acquire_lock=acquire_per_cluster_status_lock): - global_user_state.add_or_update_cluster( - cluster_name, - handle, - requested_resources=None, - ready=True, - is_launch=False, - expected_transaction_id=transaction_id) + global_user_state.add_or_update_cluster(cluster_name, + handle, + requested_resources=None, + ready=True, + is_launch=False) return global_user_state.get_cluster_from_name(cluster_name) # All cases below are transitioning the cluster to non-UP states. @@ -1664,61 +1711,48 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: backends.CloudVmRayBackend) and record['autostop'] >= 0: if not backend.is_definitely_autostopping(handle, stream_logs=False): - # Autostop cancellation should not be applied when there is - # another launch in progress, and it should be aborted when the - # cluster's record has been updated, i.e., the transaction_id - # has been changed. - with _maybe_acquire_lock( - CLUSTER_STATUS_LOCK_PATH.format(cluster_name), - timeout=cluster_status_lock_timeout, - acquire_lock=acquire_per_cluster_status_lock): - if (transaction_id == global_user_state.get_transaction_id( - cluster_name)): - # Friendly hint. - autostop = record['autostop'] - maybe_down_str = ' --down' if record['to_down'] else '' - noun = 'autodown' if record['to_down'] else 'autostop' - - # Reset the autostopping as the cluster is abnormal, and may - # not correctly autostop. Resetting the autostop will let - # the user know that the autostop may not happen to avoid - # leakages from the assumption that the cluster will autostop. - success = True - reset_local_autostop = True - try: - backend.set_autostop(handle, -1, stream_logs=False) - except exceptions.CommandError as e: - success = False - if e.returncode == 255: - logger.debug(f'The cluster is likely {noun}ed.') - reset_local_autostop = False - except (Exception, SystemExit) as e: # pylint: disable=broad-except - success = False - logger.debug(f'Failed to reset autostop. Due to ' - f'{common_utils.format_exception(e)}') - if reset_local_autostop: - global_user_state.set_cluster_autostop_value( - handle.cluster_name, - -1, - to_down=False, - expected_transaction_id=transaction_id) - - if success: - operation_str = (f'Canceled {noun} on the cluster ' - f'{cluster_name!r}') - else: - operation_str = ( - f'Attempted to cancel {noun} on the ' - f'cluster {cluster_name!r} with best effort') - yellow = colorama.Fore.YELLOW - bright = colorama.Style.BRIGHT - reset = colorama.Style.RESET_ALL - ux_utils.console_newline() - logger.warning( - f'{yellow}{operation_str}, since it is found to be in an ' - f'abnormal state. To fix, try running: {reset}{bright}sky ' - f'start -f -i {autostop}{maybe_down_str} {cluster_name}' - f'{reset}') + # Friendly hint. + autostop = record['autostop'] + maybe_down_str = ' --down' if record['to_down'] else '' + noun = 'autodown' if record['to_down'] else 'autostop' + + # Reset the autostopping as the cluster is abnormal, and may + # not correctly autostop. Resetting the autostop will let + # the user know that the autostop may not happen to avoid + # leakages from the assumption that the cluster will autostop. + success = True + reset_local_autostop = True + try: + backend.set_autostop(handle, -1, stream_logs=False) + except exceptions.CommandError as e: + success = False + if e.returncode == 255: + logger.debug(f'The cluster is likely {noun}ed.') + reset_local_autostop = False + except (Exception, SystemExit) as e: # pylint: disable=broad-except + success = False + logger.debug(f'Failed to reset autostop. Due to ' + f'{common_utils.format_exception(e)}') + if reset_local_autostop: + global_user_state.set_cluster_autostop_value( + handle.cluster_name, -1, to_down=False) + + if success: + operation_str = (f'Canceled {noun} on the cluster ' + f'{cluster_name!r}') + else: + operation_str = ( + f'Attempted to cancel {noun} on the ' + f'cluster {cluster_name!r} with best effort') + yellow = colorama.Fore.YELLOW + bright = colorama.Style.BRIGHT + reset = colorama.Style.RESET_ALL + ux_utils.console_newline() + logger.warning( + f'{yellow}{operation_str}, since it is found to be in an ' + f'abnormal state. To fix, try running: {reset}{bright}sky ' + f'start -f -i {autostop}{maybe_down_str} {cluster_name}' + f'{reset}') else: ux_utils.console_newline() operation_str = 'autodowning' if record[ @@ -1732,73 +1766,35 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: # represent that the cluster is partially preempted. # TODO(zhwu): the definition of INIT should be audited/changed. # Adding a new status UNHEALTHY for abnormal status can be a choice. - with _maybe_acquire_lock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name), - timeout=cluster_status_lock_timeout, - acquire_lock=acquire_per_cluster_status_lock): - global_user_state.add_or_update_cluster( - cluster_name, - handle, - requested_resources=None, - ready=False, - is_launch=False, - expected_transaction_id=transaction_id) + global_user_state.add_or_update_cluster(cluster_name, + handle, + requested_resources=None, + ready=False, + is_launch=False) return global_user_state.get_cluster_from_name(cluster_name) # Now is_abnormal is False: either node_statuses is empty or all nodes are # STOPPED. - with _maybe_acquire_lock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name), - timeout=cluster_status_lock_timeout, - acquire_lock=acquire_per_cluster_status_lock): - if transaction_id == global_user_state.get_transaction_id(cluster_name): - backend = backends.CloudVmRayBackend() - backend.post_teardown_cleanup(handle, - terminate=to_terminate, - purge=False) + backend = backends.CloudVmRayBackend() + backend.post_teardown_cleanup(handle, terminate=to_terminate, purge=False) return global_user_state.get_cluster_from_name(cluster_name) -def _update_cluster_status( - cluster_name: str, - acquire_per_cluster_status_lock: bool, - cluster_status_lock_timeout: int = CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS -) -> Optional[Dict[str, Any]]: - """Update the cluster status. - - The cluster status is updated by checking ray cluster and real status from - cloud. - - The function will update the cached cluster status in the global state. For - the design of the cluster status and transition, please refer to the - sky/design_docs/cluster_status.md +def _must_refresh_cluster_status( + record: Dict[str, Any], + force_refresh_statuses: Optional[Set[status_lib.ClusterStatus]] +) -> bool: + force_refresh_for_cluster = (force_refresh_statuses is not None and + record['status'] in force_refresh_statuses) - Args: - cluster_name: The name of the cluster. - acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock - before updating the status. - cluster_status_lock_timeout: The timeout to acquire the per-cluster - lock. - - Returns: - If the cluster is terminated or does not exist, return None. Otherwise - returns the input record with status and handle potentially updated. + use_spot = record['handle'].launched_resources.use_spot + has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and + record['autostop'] >= 0) + recently_refreshed = (record['status_updated_at'] is not None and + time.time() - record['status_updated_at'] < + _CLUSTER_STATUS_CACHE_DURATION_SECONDS) + is_stale = (use_spot or has_autostop) and not recently_refreshed - Raises: - exceptions.ClusterOwnerIdentityMismatchError: if the current user is - not the same as the user who created the cluster. - exceptions.CloudUserIdentityError: if we fail to get the current user - identity. - exceptions.ClusterStatusFetchingError: the cluster status cannot be - fetched from the cloud provider or there are leaked nodes causing - the node number larger than expected. - """ - try: - return _update_cluster_status_or_abort(cluster_name, - acquire_per_cluster_status_lock, - cluster_status_lock_timeout) - except filelock.Timeout: - logger.debug('Refreshing status: Failed get the lock for cluster ' - f'{cluster_name!r}. Using the cached status.') - record = global_user_state.get_cluster_from_name(cluster_name) - return record + return force_refresh_for_cluster or is_stale def refresh_cluster_record( @@ -1810,22 +1806,28 @@ def refresh_cluster_record( ) -> Optional[Dict[str, Any]]: """Refresh the cluster, and return the possibly updated record. - This function will also check the owner identity of the cluster, and raise - exceptions if the current user is not the same as the user who created the - cluster. + The function will update the cached cluster status in the global state. For + the design of the cluster status and transition, please refer to the + sky/design_docs/cluster_status.md Args: cluster_name: The name of the cluster. - force_refresh_statuses: if specified, refresh the cluster if it has one of - the specified statuses. Additionally, clusters satisfying the - following conditions will always be refreshed no matter the - argument is specified or not: - 1. is a spot cluster, or - 2. is a non-spot cluster, is not STOPPED, and autostop is set. + force_refresh_statuses: if specified, refresh the cluster if it has one + of the specified statuses. Additionally, clusters satisfying the + following conditions will be refreshed no matter the argument is + specified or not: + - the most latest available status update is more than + _CLUSTER_STATUS_CACHE_DURATION_SECONDS old, and one of: + 1. the cluster is a spot cluster, or + 2. cluster autostop is set and the cluster is not STOPPED. acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock - before updating the status. + before updating the status. Even if this is True, the lock may not be + acquired if the status does not need to be refreshed. cluster_status_lock_timeout: The timeout to acquire the per-cluster - lock. If timeout, the function will use the cached status. + lock. If timeout, the function will use the cached status. If the + value is <0, do not timeout (wait for the lock indefinitely). By + default, this is set to CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS. Warning: + if correctness is required, you must set this to -1. Returns: If the cluster is terminated or does not exist, return None. @@ -1846,19 +1848,58 @@ def refresh_cluster_record( return None check_owner_identity(cluster_name) - handle = record['handle'] - if isinstance(handle, backends.CloudVmRayResourceHandle): - use_spot = handle.launched_resources.use_spot - has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and - record['autostop'] >= 0) - force_refresh_for_cluster = (force_refresh_statuses is not None and - record['status'] in force_refresh_statuses) - if force_refresh_for_cluster or has_autostop or use_spot: - record = _update_cluster_status( - cluster_name, - acquire_per_cluster_status_lock=acquire_per_cluster_status_lock, - cluster_status_lock_timeout=cluster_status_lock_timeout) - return record + if not isinstance(record['handle'], backends.CloudVmRayResourceHandle): + return record + + # The loop logic allows us to notice if the status was updated in the + # global_user_state by another process and stop trying to get the lock. + # The core loop logic is adapted from FileLock's implementation. + lock = filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) + start_time = time.perf_counter() + + # Loop until we have an up-to-date status or until we acquire the lock. + while True: + # Check to see if we can return the cached status. + if not _must_refresh_cluster_status(record, force_refresh_statuses): + return record + + if not acquire_per_cluster_status_lock: + return _update_cluster_status_no_lock(cluster_name) + + # Try to acquire the lock so we can fetch the status. + try: + with lock.acquire(blocking=False): + # Lock acquired. + + # Check the cluster status again, since it could have been + # updated between our last check and acquiring the lock. + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None or not _must_refresh_cluster_status( + record, force_refresh_statuses): + return record + + # Update and return the cluster status. + return _update_cluster_status_no_lock(cluster_name) + except filelock.Timeout: + # lock.acquire() will throw a Timeout exception if the lock is not + # available and we have blocking=False. + pass + + # Logic adapted from FileLock.acquire(). + # If cluster_status_lock_time is <0, we will never hit this. No timeout. + # Otherwise, if we have timed out, return the cached status. This has + # the potential to cause correctness issues, but if so it is the + # caller's responsibility to set the timeout to -1. + if 0 <= cluster_status_lock_timeout < time.perf_counter() - start_time: + logger.debug('Refreshing status: Failed get the lock for cluster ' + f'{cluster_name!r}. Using the cached status.') + return record + time.sleep(0.05) + + # Refresh for next loop iteration. + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None: + return None @timeline.event @@ -1921,7 +1962,7 @@ def check_cluster_available( """Check if the cluster is available. Raises: - ValueError: if the cluster does not exist. + exceptions.ClusterDoesNotExist: if the cluster does not exist. exceptions.ClusterNotUpError: if the cluster is not UP. exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. @@ -1986,7 +2027,8 @@ def check_cluster_available( error_msg += message with ux_utils.print_exception_no_traceback(): - raise ValueError(f'{colorama.Fore.YELLOW}{error_msg}{reset}') + raise exceptions.ClusterDoesNotExist( + f'{colorama.Fore.YELLOW}{error_msg}{reset}') assert cluster_status is not None, 'handle is not None but status is None' backend = get_backend_from_handle(handle) if check_cloud_vm_ray_backend and not isinstance( diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 8f2aebe1c8d..6bc0a0df620 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -272,6 +272,13 @@ def add_prologue(self, job_id: int) -> None: import time from typing import Dict, List, Optional, Tuple, Union + # Set the environment variables to avoid deduplicating logs and + # scheduler events. This should be set in driver code, since we are + # not using `ray job submit` anymore, and the environment variables + # from the ray cluster is not inherited. + os.environ['RAY_DEDUP_LOGS'] = '0' + os.environ['RAY_SCHEDULER_EVENTS'] = '0' + import ray import ray.util as ray_util @@ -297,6 +304,8 @@ def add_prologue(self, job_id: int) -> None: ) def get_or_fail(futures, pg) -> List[int]: \"\"\"Wait for tasks, if any fails, cancel all unready.\"\"\" + if not futures: + return [] returncodes = [1] * len(futures) # Wait for 1 task to be ready. ready = [] @@ -1531,7 +1540,7 @@ def _retry_zones( to_provision, resources_utils.ClusterName( cluster_name, handle.cluster_name_on_cloud), - region, zones)) + region, zones, num_nodes)) config_dict['provision_record'] = provision_record config_dict['resources_vars'] = resources_vars config_dict['handle'] = handle @@ -3092,9 +3101,12 @@ def _sync_workdir_node(runner: command_runner.CommandRunner) -> None: f'{workdir} -> {SKY_REMOTE_WORKDIR}{style.RESET_ALL}') os.makedirs(os.path.expanduser(self.log_dir), exist_ok=True) os.system(f'touch {log_path}') + num_threads = subprocess_utils.get_parallel_threads( + str(handle.launched_resources.cloud)) with rich_utils.safe_status( ux_utils.spinner_message('Syncing workdir', log_path)): - subprocess_utils.run_in_parallel(_sync_workdir_node, runners) + subprocess_utils.run_in_parallel(_sync_workdir_node, runners, + num_threads) logger.info(ux_utils.finishing_message('Workdir synced.', log_path)) def _sync_file_mounts( @@ -3417,15 +3429,33 @@ def _execute( Returns: Job id if the task is submitted to the cluster, None otherwise. """ - if task.run is None: + if task.run is None and self._setup_cmd is None: + # This message is fine without mentioning setup, as there are three + # cases when run section is empty: + # 1. setup specified, no --detach-setup: setup is executed and this + # message is fine for saying no run command specified. + # 2. setup specified, with --detach-setup: setup is executed in + # detached mode and this message will not be shown. + # 3. no setup specified: this message is fine as a user is likely + # creating a cluster only, and ok with the empty run command. logger.info('Run commands not specified or empty.') return None - # Check the task resources vs the cluster resources. Since `sky exec` - # will not run the provision and _check_existing_cluster - # We need to check ports here since sky.exec shouldn't change resources - valid_resource = self.check_resources_fit_cluster(handle, - task, - check_ports=True) + if task.run is None: + # If the task has no run command, we still need to execute the + # generated ray driver program to run the setup command in detached + # mode. + # In this case, we reset the resources for the task, so that the + # detached setup does not need to wait for the task resources to be + # ready (which is not used for setup anyway). + valid_resource = sky.Resources() + else: + # Check the task resources vs the cluster resources. Since + # `sky exec` will not run the provision and _check_existing_cluster + # We need to check ports here since sky.exec shouldn't change + # resources. + valid_resource = self.check_resources_fit_cluster(handle, + task, + check_ports=True) task_copy = copy.copy(task) # Handle multiple resources exec case. task_copy.set_resources(valid_resource) @@ -3534,11 +3564,13 @@ def _teardown(self, if terminate: common_utils.remove_file_if_exists(lock_path) break - except filelock.Timeout: + except filelock.Timeout as e: logger.debug(f'Failed to acquire lock for {cluster_name}, ' f'retrying...') if n_attempts <= 0: - raise + raise RuntimeError( + f'Cluster {cluster_name!r} is locked by {lock_path}. ' + 'Check to see if it is still being launched') from e # --- CloudVMRayBackend Specific APIs --- @@ -4379,6 +4411,8 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, start = time.time() runners = handle.get_command_runners() log_path = os.path.join(self.log_dir, 'file_mounts.log') + num_threads = subprocess_utils.get_max_workers_for_file_mounts( + file_mounts, str(handle.launched_resources.cloud)) # Check the files and warn for dst, src in file_mounts.items(): @@ -4441,6 +4475,7 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, action_message='Syncing', log_path=log_path, stream_logs=False, + num_threads=num_threads, ) continue @@ -4477,6 +4512,7 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, # Need to source bashrc, as the cloud specific CLI or SDK may # require PATH in bashrc. source_bashrc=True, + num_threads=num_threads, ) # (2) Run the commands to create symlinks on all the nodes. symlink_command = ' && '.join(symlink_commands) @@ -4495,7 +4531,8 @@ def _symlink_node(runner: command_runner.CommandRunner): 'Failed to create symlinks. The target destination ' f'may already exist. Log: {log_path}') - subprocess_utils.run_in_parallel(_symlink_node, runners) + subprocess_utils.run_in_parallel(_symlink_node, runners, + num_threads) end = time.time() logger.debug(f'File mount sync took {end - start} seconds.') logger.info(ux_utils.finishing_message('Files synced.', log_path)) @@ -4524,6 +4561,8 @@ def _execute_storage_mounts( return start = time.time() runners = handle.get_command_runners() + num_threads = subprocess_utils.get_parallel_threads( + str(handle.launched_resources.cloud)) log_path = os.path.join(self.log_dir, 'storage_mounts.log') plural = 's' if len(storage_mounts) > 1 else '' @@ -4564,6 +4603,7 @@ def _execute_storage_mounts( # Need to source bashrc, as the cloud specific CLI or SDK # may require PATH in bashrc. source_bashrc=True, + num_threads=num_threads, ) except exceptions.CommandError as e: if e.returncode == exceptions.MOUNT_PATH_NON_EMPTY_CODE: diff --git a/sky/cli.py b/sky/cli.py index 8e5233839a4..241138daa2d 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -572,7 +572,7 @@ def _parse_override_params( image_id: Optional[str] = None, disk_size: Optional[int] = None, disk_tier: Optional[str] = None, - ports: Optional[Tuple[str]] = None) -> Dict[str, Any]: + ports: Optional[Tuple[str, ...]] = None) -> Dict[str, Any]: """Parses the override parameters into a dictionary.""" override_params: Dict[str, Any] = {} if cloud is not None: @@ -625,7 +625,14 @@ def _parse_override_params( else: override_params['disk_tier'] = disk_tier if ports: - override_params['ports'] = ports + if any(p.lower() == 'none' for p in ports): + if len(ports) > 1: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Cannot specify both "none" and other ' + 'ports.') + override_params['ports'] = None + else: + override_params['ports'] = ports return override_params @@ -724,7 +731,7 @@ def _make_task_or_dag_from_entrypoint_with_overrides( image_id: Optional[str] = None, disk_size: Optional[int] = None, disk_tier: Optional[str] = None, - ports: Optional[Tuple[str]] = None, + ports: Optional[Tuple[str, ...]] = None, env: Optional[List[Tuple[str, str]]] = None, field_to_ignore: Optional[List[str]] = None, # job launch specific @@ -1082,7 +1089,7 @@ def launch( env: List[Tuple[str, str]], disk_size: Optional[int], disk_tier: Optional[str], - ports: Tuple[str], + ports: Tuple[str, ...], idle_minutes_to_autostop: Optional[int], down: bool, # pylint: disable=redefined-outer-name retry_until_up: bool, @@ -4034,17 +4041,26 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool): default=False, help=('Show the controller logs of this job; useful for debugging ' 'launching/recoveries, etc.')) +@click.option( + '--refresh', + '-r', + default=False, + is_flag=True, + required=False, + help='Query the latest job logs, restarting the jobs controller if stopped.' +) @click.argument('job_id', required=False, type=int) @usage_lib.entrypoint def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool, - controller: bool): + controller: bool, refresh: bool): """Tail the log of a managed job.""" log_request_id = None try: log_request_id = managed_jobs.tail_logs(name=name, job_id=job_id, follow=follow, - controller=controller) + controller=controller, + refresh=refresh) sdk.stream_and_get(log_request_id) except exceptions.ClusterNotUpError: with ux_utils.print_exception_no_traceback(): diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index baaf991f5e0..6943227f009 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -402,6 +402,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Any]: del dryrun # unused assert zones is not None, (region, zones) @@ -664,6 +665,7 @@ def _is_access_key_of_type(type_str: str) -> bool: return AWSIdentityType.SHARED_CREDENTIALS_FILE @classmethod + @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. def get_user_identities(cls) -> Optional[List[List[str]]]: """Returns a [UserId, Account] list that uniquely identifies the user. diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 5b3275f547b..f6424138386 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -303,6 +303,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Any]: assert zones is None, ('Azure does not support zones', zones) diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 49759acd8f4..697451371f5 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -283,6 +283,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'Region', zones: Optional[List['Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to cloud-specific resource variables. diff --git a/sky/clouds/cudo.py b/sky/clouds/cudo.py index 4b9363cb89e..801768f7db0 100644 --- a/sky/clouds/cudo.py +++ b/sky/clouds/cudo.py @@ -197,6 +197,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: del zones, cluster_name # unused diff --git a/sky/clouds/fluidstack.py b/sky/clouds/fluidstack.py index a6fea30009a..ec52cf85d31 100644 --- a/sky/clouds/fluidstack.py +++ b/sky/clouds/fluidstack.py @@ -177,6 +177,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: clouds.Region, zones: Optional[List[clouds.Zone]], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 955ca5e9db8..f988ec1d29e 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -418,6 +418,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: assert zones is not None, (region, zones) diff --git a/sky/clouds/ibm.py b/sky/clouds/ibm.py index ddbdcd9c01a..c17923bf94b 100644 --- a/sky/clouds/ibm.py +++ b/sky/clouds/ibm.py @@ -171,6 +171,7 @@ def make_deploy_resources_variables( cluster_name: 'resources_utils.ClusterName', region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to cloud-specific resource variables. diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index bcfcbec95f9..9a25a7ea0b6 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -10,8 +10,10 @@ from sky import skypilot_config from sky.adaptors import kubernetes from sky.clouds import service_catalog +from sky.provision import instance_setup from sky.provision.kubernetes import network_utils from sky.provision.kubernetes import utils as kubernetes_utils +from sky.skylet import constants from sky.utils import common_utils from sky.utils import registry from sky.utils import resources_utils @@ -40,6 +42,8 @@ class Kubernetes(clouds.Cloud): SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys' SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod' + LEGACY_SINGLETON_REGION = 'kubernetes' + # Limit the length of the cluster name to avoid exceeding the limit of 63 # characters for Kubernetes resources. We limit to 42 characters (63-21) to # allow additional characters for creating ingress services to expose ports. @@ -53,7 +57,6 @@ class Kubernetes(clouds.Cloud): _DEFAULT_MEMORY_CPU_RATIO = 1 _DEFAULT_MEMORY_CPU_RATIO_WITH_GPU = 4 # Allocate more memory for GPU tasks _REPR = 'Kubernetes' - _LEGACY_SINGLETON_REGION = 'kubernetes' _CLOUD_UNSUPPORTED_FEATURES = { # TODO(romilb): Stopping might be possible to implement with # container checkpointing introduced in Kubernetes v1.25. See: @@ -129,32 +132,30 @@ def _log_skipped_contexts_once(cls, skipped_contexts: Tuple[str, 'Ignoring these contexts.') @classmethod - def _existing_allowed_contexts(cls) -> List[Optional[str]]: + def _existing_allowed_contexts(cls) -> List[str]: """Get existing allowed contexts. If None is returned in the list, it means that we are running in a pod with in-cluster auth. In this case, we specify None context, which will use the service account mounted in the pod. """ - all_contexts = kubernetes_utils.get_all_kube_config_context_names() + all_contexts = kubernetes_utils.get_all_kube_context_names() if len(all_contexts) == 0: return [] - if all_contexts == [None]: - # If only one context is found and it is None, we are running in a - # pod with in-cluster auth. In this case, we allow it to be used - # without checking against allowed_contexts. - # TODO(romilb): We may want check in-cluster auth against - # allowed_contexts in the future by adding a special context name - # for in-cluster auth. - return [None] + all_contexts = set(all_contexts) allowed_contexts = skypilot_config.get_nested( ('kubernetes', 'allowed_contexts'), None) if allowed_contexts is None: + # Try kubeconfig if present current_context = ( kubernetes_utils.get_current_kube_config_context_name()) + if (current_context is None and + kubernetes_utils.is_incluster_config_available()): + # If no kubeconfig contexts found, use in-cluster if available + current_context = kubernetes.in_cluster_context_name() allowed_contexts = [] if current_context is not None: allowed_contexts = [current_context] @@ -179,13 +180,7 @@ def regions_with_offering(cls, instance_type: Optional[str], regions = [] for context in existing_contexts: - if context is None: - # If running in-cluster, we allow the region to be set to the - # singleton region since there is no context name available. - regions.append(clouds.Region( - kubernetes_utils.IN_CLUSTER_REGION)) - else: - regions.append(clouds.Region(context)) + regions.append(clouds.Region(context)) if region is not None: regions = [r for r in regions if r.name == region] @@ -312,12 +307,34 @@ def get_image_size(cls, image_id: str, region: Optional[str]) -> int: # we don't have a notion of disk size in Kubernetes. return 0 + @staticmethod + def _calculate_provision_timeout(num_nodes: int) -> int: + """Calculate provision timeout based on number of nodes. + + The timeout scales linearly with the number of nodes to account for + scheduling overhead, but is capped to avoid excessive waiting. + + Args: + num_nodes: Number of nodes being provisioned + + Returns: + Timeout in seconds + """ + base_timeout = 10 # Base timeout for single node + per_node_timeout = 0.2 # Additional seconds per node + max_timeout = 60 # Cap at 1 minute + + return int( + min(base_timeout + (per_node_timeout * (num_nodes - 1)), + max_timeout)) + def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', cluster_name: 'resources_utils.ClusterName', region: Optional['clouds.Region'], zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, zones, dryrun # Unused. if region is None: @@ -386,12 +403,25 @@ def make_deploy_resources_variables( remote_identity = skypilot_config.get_nested( ('kubernetes', 'remote_identity'), schemas.get_default_remote_identity('kubernetes')) - if (remote_identity == + + if isinstance(remote_identity, dict): + # If remote_identity is a dict, use the service account for the + # current context + k8s_service_account_name = remote_identity.get(context, None) + if k8s_service_account_name is None: + err_msg = (f'Context {context!r} not found in ' + 'remote identities from config.yaml') + raise ValueError(err_msg) + else: + # If remote_identity is not a dict, use + k8s_service_account_name = remote_identity + + if (k8s_service_account_name == schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value): # SA name doesn't matter since automounting credentials is disabled k8s_service_account_name = 'default' k8s_automount_sa_token = 'false' - elif (remote_identity == + elif (k8s_service_account_name == schemas.RemoteIdentityOptions.SERVICE_ACCOUNT.value): # Use the default service account k8s_service_account_name = ( @@ -399,7 +429,6 @@ def make_deploy_resources_variables( k8s_automount_sa_token = 'true' else: # User specified a custom service account - k8s_service_account_name = remote_identity k8s_automount_sa_token = 'true' fuse_device_required = bool(resources.requires_fuse) @@ -414,12 +443,30 @@ def make_deploy_resources_variables( # Larger timeout may be required for autoscaling clusters, since # autoscaler may take some time to provision new nodes. # Note that this timeout includes time taken by the Kubernetes scheduler - # itself, which can be upto 2-3 seconds. - # For non-autoscaling clusters, we conservatively set this to 10s. + # itself, which can be upto 2-3 seconds, and up to 10-15 seconds when + # scheduling 100s of pods. + # We use a linear scaling formula to determine the timeout based on the + # number of nodes. + + timeout = self._calculate_provision_timeout(num_nodes) timeout = skypilot_config.get_nested( ('kubernetes', 'provision_timeout'), - 10, + timeout, override_configs=resources.cluster_config_overrides) + + # Set environment variables for the pod. Note that SkyPilot env vars + # are set separately when the task is run. These env vars are + # independent of the SkyPilot task to be run. + k8s_env_vars = {kubernetes.IN_CLUSTER_CONTEXT_NAME_ENV_VAR: context} + + # We specify object-store-memory to be 500MB to avoid taking up too + # much memory on the head node. 'num-cpus' should be set to limit + # the CPU usage on the head pod, otherwise the ray cluster will use the + # CPU resources on the node instead within the pod. + custom_ray_options = { + 'object-store-memory': 500000000, + 'num-cpus': str(int(cpus)), + } deploy_vars = { 'instance_type': resources.instance_type, 'custom_resources': custom_resources, @@ -445,7 +492,14 @@ def make_deploy_resources_variables( 'k8s_topology_label_key': k8s_topology_label_key, 'k8s_topology_label_value': k8s_topology_label_value, 'k8s_resource_key': k8s_resource_key, + 'k8s_env_vars': k8s_env_vars, 'image_id': image_id, + 'ray_installation_commands': constants.RAY_INSTALLATION_COMMANDS, + 'ray_head_start_command': instance_setup.ray_head_start_command( + custom_resources, custom_ray_options), + 'skypilot_ray_port': constants.SKY_REMOTE_RAY_PORT, + 'ray_worker_start_command': instance_setup.ray_worker_start_command( + custom_resources, custom_ray_options, no_restart=False), } # Add kubecontext if it is set. It may be None if SkyPilot is running @@ -537,7 +591,11 @@ def _make(instance_list): @classmethod def check_credentials(cls) -> Tuple[bool, Optional[str]]: # Test using python API - existing_allowed_contexts = cls._existing_allowed_contexts() + try: + existing_allowed_contexts = cls._existing_allowed_contexts() + except ImportError as e: + return (False, + f'{common_utils.format_exception(e, use_bracket=True)}') if not existing_allowed_contexts: if skypilot_config.loaded_config_path() is None: check_skypilot_config_msg = '' @@ -574,22 +632,19 @@ def instance_type_exists(self, instance_type: str) -> bool: instance_type) def validate_region_zone(self, region: Optional[str], zone: Optional[str]): - if region == self._LEGACY_SINGLETON_REGION: + if region == self.LEGACY_SINGLETON_REGION: # For backward compatibility, we allow the region to be set to the # legacy singleton region. # TODO: Remove this after 0.9.0. return region, zone - if region == kubernetes_utils.IN_CLUSTER_REGION: + if region == kubernetes.in_cluster_context_name(): # If running incluster, we set region to IN_CLUSTER_REGION # since there is no context name available. return region, zone - all_contexts = kubernetes_utils.get_all_kube_config_context_names() - if all_contexts == [None]: - # If [None] context is returned, use the singleton region since we - # are running in a pod with in-cluster auth. - all_contexts = [kubernetes_utils.IN_CLUSTER_REGION] + all_contexts = kubernetes_utils.get_all_kube_context_names() + if region not in all_contexts: raise ValueError( f'Context {region} not found in kubeconfig. Kubernetes only ' diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index 2f790843e70..a24774e2146 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -158,6 +158,7 @@ def make_deploy_resources_variables( cluster_name: 'resources_utils.ClusterName', region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert zones is None, 'Lambda does not support zones.' diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index a9512029e00..4ac423a455e 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -209,6 +209,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: Optional['clouds.Region'], zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert region is not None, resources diff --git a/sky/clouds/paperspace.py b/sky/clouds/paperspace.py index 3b26f81d22d..85f1ed45bdb 100644 --- a/sky/clouds/paperspace.py +++ b/sky/clouds/paperspace.py @@ -176,6 +176,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del zones, dryrun, cluster_name diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 66a48350064..7dff076d94e 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -161,6 +161,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del zones, dryrun, cluster_name # unused diff --git a/sky/clouds/scp.py b/sky/clouds/scp.py index 7ae765f8b2c..5ccd20960ff 100644 --- a/sky/clouds/scp.py +++ b/sky/clouds/scp.py @@ -182,6 +182,7 @@ def make_deploy_resources_variables( cluster_name: 'resources_utils.ClusterName', region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert zones is None, 'SCP does not support zones.' diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 4deab8ac204..d28b530ff06 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -324,9 +324,8 @@ def get_common_gpus() -> List[str]: 'A100', 'A100-80GB', 'H100', - 'K80', 'L4', - 'M60', + 'L40S', 'P100', 'T4', 'V100', @@ -337,13 +336,13 @@ def get_common_gpus() -> List[str]: def get_tpus() -> List[str]: """Returns a list of TPU names.""" # TODO(wei-lin): refactor below hard-coded list. - # There are many TPU configurations available, we show the three smallest - # and the largest configuration for the latest gen TPUs. + # There are many TPU configurations available, we show the some smallest + # ones for each generation, and people should find larger ones with + # sky show-gpus tpu. return [ - 'tpu-v2-512', 'tpu-v3-2048', 'tpu-v4-8', 'tpu-v4-16', 'tpu-v4-32', - 'tpu-v4-3968', 'tpu-v5litepod-1', 'tpu-v5litepod-4', 'tpu-v5litepod-8', - 'tpu-v5litepod-256', 'tpu-v5p-8', 'tpu-v5p-32', 'tpu-v5p-128', - 'tpu-v5p-12288' + 'tpu-v2-8', 'tpu-v3-8', 'tpu-v4-8', 'tpu-v4-16', 'tpu-v4-32', + 'tpu-v5litepod-1', 'tpu-v5litepod-4', 'tpu-v5litepod-8', 'tpu-v5p-8', + 'tpu-v5p-16', 'tpu-v5p-32', 'tpu-v6e-1', 'tpu-v6e-4', 'tpu-v6e-8' ] diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 918a4070414..bbd48863755 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -20,6 +20,7 @@ from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import rich_utils +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -100,6 +101,7 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']: return az_mappings +@timeline.event def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame': """Maps zone IDs (use1-az1) to zone names (us-east-1x). diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py index f646cac339a..4aef41f9c90 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py @@ -64,7 +64,7 @@ 'standardNVSv2Family': 'M60', 'standardNVSv3Family': 'M60', 'standardNVPromoFamily': 'M60', - 'standardNVSv4Family': 'Radeon MI25', + 'standardNVSv4Family': 'MI25', 'standardNDSFamily': 'P40', 'StandardNVADSA10v5Family': 'A10', 'StandardNCadsH100v5Family': 'H100', diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py b/sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py index cf943541e08..7a8b7e42e79 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py @@ -15,6 +15,26 @@ DEFAULT_FLUIDSTACK_API_KEY_PATH = os.path.expanduser('~/.fluidstack/api_key') plan_vcpus_memory = [{ + 'gpu_type': 'H100_SXM5_80GB', + 'gpu_count': 1, + 'min_cpu_count': 52, + 'min_memory': 450 +}, { + 'gpu_type': 'H100_SXM5_80GB', + 'gpu_count': 2, + 'min_cpu_count': 52, + 'min_memory': 450 +}, { + 'gpu_type': 'H100_SXM5_80GB', + 'gpu_count': 4, + 'min_cpu_count': 104, + 'min_memory': 900 +}, { + 'gpu_type': 'H100_SXM5_80GB', + 'gpu_count': 8, + 'min_cpu_count': 192, + 'min_memory': 1800 +}, { 'gpu_type': 'RTX_A6000_48GB', 'gpu_count': 2, 'min_cpu_count': 12, @@ -150,7 +170,8 @@ 'H100_PCIE_80GB': 'H100', 'H100_NVLINK_80GB': 'H100', 'A100_NVLINK_80GB': 'A100-80GB', - 'A100_SXM4_80GB': 'A100-80GB', + 'A100_SXM4_80GB': 'A100-80GB-SXM', + 'H100_SXM5_80GB': 'H100-SXM', 'A100_PCIE_80GB': 'A100-80GB', 'A100_SXM4_40GB': 'A100', 'A100_PCIE_40GB': 'A100', diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py b/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py index e4bb6e8547a..008bfe6abeb 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py @@ -46,6 +46,7 @@ 'RTX6000': 24576, 'V100': 16384, 'H100': 81920, + 'GH200': 98304, 'GENERAL': None } diff --git a/sky/clouds/service_catalog/kubernetes_catalog.py b/sky/clouds/service_catalog/kubernetes_catalog.py index 6d11d1715e2..2c7eafc20e5 100644 --- a/sky/clouds/service_catalog/kubernetes_catalog.py +++ b/sky/clouds/service_catalog/kubernetes_catalog.py @@ -65,9 +65,14 @@ def list_accelerators( # TODO(romilb): We should consider putting a lru_cache() with TTL to # avoid multiple calls to kubernetes API in a short period of time (e.g., # from the optimizer). - return list_accelerators_realtime(gpus_only, name_filter, region_filter, - quantity_filter, case_sensitive, - all_regions, require_price)[0] + return _list_accelerators(gpus_only, + name_filter, + region_filter, + quantity_filter, + case_sensitive, + all_regions, + require_price, + realtime=False)[0] def list_accelerators_realtime( @@ -78,10 +83,36 @@ def list_accelerators_realtime( case_sensitive: bool = True, all_regions: bool = False, require_price: bool = True +) -> Tuple[Dict[str, List[common.InstanceTypeInfo]], Dict[str, int], Dict[str, + int]]: + return _list_accelerators(gpus_only, + name_filter, + region_filter, + quantity_filter, + case_sensitive, + all_regions, + require_price, + realtime=True) + + +def _list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + quantity_filter: Optional[int], + case_sensitive: bool = True, + all_regions: bool = False, + require_price: bool = True, + realtime: bool = False ) -> Tuple[Dict[str, List[common.InstanceTypeInfo]], Dict[str, int], Dict[str, int]]: """List accelerators in the Kubernetes cluster. + If realtime is True, the function will query the cluster to fetch real-time + GPU usage, which is returned in total_accelerators_available. Note that + this may require an expensive list_pod_for_all_namespaces call, which + requires cluster-wide pod read permissions. + If the user does not have sufficient permissions to list pods in all namespaces, the function will return free GPUs as -1. """ @@ -115,18 +146,20 @@ def list_accelerators_realtime( accelerators_qtys: Set[Tuple[str, int]] = set() keys = lf.get_label_keys() nodes = kubernetes_utils.get_kubernetes_nodes(context) - # Get the pods to get the real-time GPU usage - try: - pods = kubernetes_utils.get_all_pods_in_kubernetes_cluster(context) - except kubernetes.api_exception() as e: - if e.status == 403: - logger.warning('Failed to get pods in the Kubernetes cluster ' - '(forbidden). Please check if your account has ' - 'necessary permissions to list pods. Realtime GPU ' - 'availability information may be incorrect.') - pods = None - else: - raise + pods = None + if realtime: + # Get the pods to get the real-time GPU usage + try: + pods = kubernetes_utils.get_all_pods_in_kubernetes_cluster(context) + except kubernetes.api_exception() as e: + if e.status == 403: + logger.warning( + 'Failed to get pods in the Kubernetes cluster ' + '(forbidden). Please check if your account has ' + 'necessary permissions to list pods. Realtime GPU ' + 'availability information may be incorrect.') + else: + raise # Total number of GPUs in the cluster total_accelerators_capacity: Dict[str, int] = {} # Total number of GPUs currently available in the cluster @@ -206,13 +239,12 @@ def list_accelerators_realtime( accelerators_available = accelerator_count - allocated_qty - if accelerator_name not in total_accelerators_available: - total_accelerators_available[accelerator_name] = 0 if accelerators_available >= min_quantity_filter: quantized_availability = min_quantity_filter * ( accelerators_available // min_quantity_filter) - total_accelerators_available[ - accelerator_name] += quantized_availability + total_accelerators_available[accelerator_name] = ( + total_accelerators_available.get(accelerator_name, 0) + + quantized_availability) result = [] diff --git a/sky/clouds/vsphere.py b/sky/clouds/vsphere.py index f8259aeafbc..6cb6c0d93a8 100644 --- a/sky/clouds/vsphere.py +++ b/sky/clouds/vsphere.py @@ -174,6 +174,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: # TODO get image id here. diff --git a/sky/core.py b/sky/core.py index d457e76a80f..eb1af24ef90 100644 --- a/sky/core.py +++ b/sky/core.py @@ -294,7 +294,8 @@ def _start( cluster_status, handle = backend_utils.refresh_cluster_status_handle( cluster_name) if handle is None: - raise ValueError(f'Cluster {cluster_name!r} does not exist.') + raise exceptions.ClusterDoesNotExist( + f'Cluster {cluster_name!r} does not exist.') if not force and cluster_status == status_lib.ClusterStatus.UP: sky_logging.print(f'Cluster {cluster_name!r} is already up.') return handle @@ -385,12 +386,13 @@ def start( Useful for upgrading SkyPilot runtime. Raises: - ValueError: argument values are invalid: (1) the specified cluster does - not exist; (2) if ``down`` is set to True but - ``idle_minutes_to_autostop`` is None; (3) if the specified cluster is - the managed jobs controller, and either ``idle_minutes_to_autostop`` - is not None or ``down`` is True (omit them to use the default - autostop settings). + ValueError: argument values are invalid: (1) if ``down`` is set to True + but ``idle_minutes_to_autostop`` is None; (2) if the specified + cluster is the managed jobs controller, and either + ``idle_minutes_to_autostop`` is not None or ``down`` is True (omit + them to use the default autostop settings). + sky.exceptions.ClusterDoesNotExist: the specified cluster does not + exist. sky.exceptions.NotSupportedError: if the cluster to restart was launched using a non-default backend that does not support this operation. @@ -438,7 +440,8 @@ def stop(cluster_name: str, purge: bool = False) -> None: related resources. Raises: - ValueError: the specified cluster does not exist. + sky.exceptions.ClusterDoesNotExist: the specified cluster does not + exist. RuntimeError: failed to stop the cluster. sky.exceptions.NotSupportedError: if the specified cluster is a spot cluster, or a TPU VM Pod cluster, or the managed jobs controller. @@ -449,7 +452,8 @@ def stop(cluster_name: str, purge: bool = False) -> None: f'is not supported.') handle = global_user_state.get_handle_from_cluster_name(cluster_name) if handle is None: - raise ValueError(f'Cluster {cluster_name!r} does not exist.') + raise exceptions.ClusterDoesNotExist( + f'Cluster {cluster_name!r} does not exist.') backend = backend_utils.get_backend_from_handle(handle) @@ -493,14 +497,17 @@ def down(cluster_name: str, purge: bool = False) -> None: resources. Raises: - ValueError: the specified cluster does not exist. + sky.exceptions.ClusterDoesNotExist: the specified cluster does not + exist. RuntimeError: failed to tear down the cluster. sky.exceptions.NotSupportedError: the specified cluster is the managed jobs controller. """ handle = global_user_state.get_handle_from_cluster_name(cluster_name) if handle is None: - raise ValueError(f'Cluster {cluster_name!r} does not exist.') + raise exceptions.ClusterDoesNotExist( + f'Cluster {cluster_name!r} does not exist.') + usage_lib.record_cluster_name_for_current_operation(cluster_name) backend = backend_utils.get_backend_from_handle(handle) backend.teardown(handle, terminate=True, purge=purge) @@ -546,7 +553,7 @@ def autostop( rather than autostop (restartable). Raises: - ValueError: if the cluster does not exist. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend or the cluster is TPU VM Pod. @@ -641,7 +648,7 @@ def queue(cluster_name: str, } ] raises: - ValueError: if the cluster does not exist. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. @@ -702,7 +709,8 @@ def cancel( worker node is preempted in the spot cluster. Raises: - ValueError: if arguments are invalid, or the cluster does not exist. + ValueError: if arguments are invalid. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the specified cluster is a controller that does not support this operation. @@ -782,8 +790,8 @@ def tail_logs(cluster_name: str, Please refer to the sky.cli.tail_logs for the document. Raises: - ValueError: arguments are invalid or the cluster is not supported or - the cluster does not exist. + ValueError: if arguments are invalid or the cluster is not supported. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. @@ -817,7 +825,7 @@ def download_logs( Returns: Dict[str, str]: a mapping of job_id to local log path. Raises: - ValueError: if the cluster does not exist. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. @@ -862,7 +870,7 @@ def job_status(cluster_name: str, If job_ids is None and there is no job on the cluster, it will return {None: None}. Raises: - ValueError: if the cluster does not exist. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. diff --git a/sky/exceptions.py b/sky/exceptions.py index f2fbe48ee51..213e80f4a50 100644 --- a/sky/exceptions.py +++ b/sky/exceptions.py @@ -179,6 +179,13 @@ class ClusterSetUpError(Exception): pass +class ClusterDoesNotExist(ValueError): + """Raise when trying to operate on a cluster that does not exist.""" + # This extends ValueError for compatibility reasons - we used to throw + # ValueError instead of this. + pass + + class NotSupportedError(Exception): """Raised when a feature is not supported.""" pass diff --git a/sky/execution.py b/sky/execution.py index 05dce961d58..49cc0c9fd2e 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -276,6 +276,12 @@ def _execute( # no-credential machine should not enter optimize(), which # would directly error out ('No cloud is enabled...'). Fix # by moving `sky check` checks out of optimize()? + controller = controller_utils.Controllers.from_name( + cluster_name) + if controller is not None: + logger.info( + f'Choosing resources for {controller.value.name}...' + ) dag = optimizer.Optimizer.optimize(dag, minimize=optimize_target, quiet=_quiet_optimizer) @@ -309,7 +315,8 @@ def _execute( do_workdir = (Stage.SYNC_WORKDIR in stages and not dryrun and task.workdir is not None) do_file_mounts = (Stage.SYNC_FILE_MOUNTS in stages and not dryrun and - task.file_mounts is not None) + (task.file_mounts is not None or + task.storage_mounts is not None)) if do_workdir or do_file_mounts: logger.info(ux_utils.starting_message('Mounting files.')) @@ -576,8 +583,9 @@ def exec( # pylint: disable=redefined-builtin (CloudVMRayBackend). Raises: - ValueError: if the specified cluster does not exist or is not in UP - status. + ValueError: if the specified cluster is not in UP status. + sky.exceptions.ClusterDoesNotExist: if the specified cluster does not + exist. sky.exceptions.NotSupportedError: if the specified cluster is a controller that does not support this operation. diff --git a/sky/global_user_state.py b/sky/global_user_state.py index 04b5c22bf00..9f75f87d659 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -66,7 +66,7 @@ def create_table(cursor, conn): cluster_hash TEXT DEFAULT null, storage_mounts_metadata BLOB DEFAULT null, cluster_ever_up INTEGER DEFAULT 0, - transaction_id INTEGER DEFAULT 0, + status_updated_at INTEGER DEFAULT null, user_hash TEXT DEFAULT null)""") # Table for Cluster History @@ -143,12 +143,8 @@ def create_table(cursor, conn): # clusters were never really UP, setting it to 1 means they won't be # auto-deleted during any failover. value_to_replace_existing_entries=1) - db_utils.add_column_to_table(cursor, - conn, - 'clusters', - 'transaction_id', - 'INTEGER DEFAULT 0', - value_to_replace_existing_entries=0) + db_utils.add_column_to_table(cursor, conn, 'clusters', 'status_updated_at', + 'INTEGER DEFAULT null') db_utils.add_column_to_table( cursor, conn, @@ -164,15 +160,6 @@ def create_table(cursor, conn): _DB = db_utils.SQLiteConn(_DB_PATH, create_table) -def get_transaction_id(cluster_name: str) -> int: - rows = _DB.cursor.execute( - 'SELECT transaction_id FROM clusters WHERE name=(?)', - (cluster_name,)).fetchone() - if rows is None: - return -1 - return rows[0] - - def add_user(user: models.User): if user.name is None: return @@ -193,8 +180,7 @@ def add_or_update_cluster(cluster_name: str, cluster_handle: 'backends.ResourceHandle', requested_resources: Optional[Set[Any]], ready: bool, - is_launch: bool = True, - expected_transaction_id: Optional[int] = None): + is_launch: bool = True): """Adds or updates cluster_name -> cluster_handle mapping. Args: @@ -206,12 +192,6 @@ def add_or_update_cluster(cluster_name: str, is_launch: if the cluster is firstly launched. If True, the launched_at and last_use will be updated. Otherwise, use the old value. """ - transaction_id = get_transaction_id(cluster_name) - if (expected_transaction_id is not None and - expected_transaction_id != transaction_id): - logger.debug(f'Cluster {cluster_name} has been updated by another ' - 'transaction. Skipping update.') - return # FIXME: launched_at will be changed when `sky launch -c` is called. handle = pickle.dumps(cluster_handle) cluster_launched_at = int(time.time()) if is_launch else None @@ -219,6 +199,7 @@ def add_or_update_cluster(cluster_name: str, status = status_lib.ClusterStatus.INIT if ready: status = status_lib.ClusterStatus.UP + status_updated_at = int(time.time()) # TODO (sumanth): Cluster history table will have multiple entries # when the cluster failover through multiple regions (one entry per region). @@ -253,7 +234,8 @@ def add_or_update_cluster(cluster_name: str, # specified. '(name, launched_at, handle, last_use, status, ' 'autostop, to_down, metadata, owner, cluster_hash, ' - 'storage_mounts_metadata, cluster_ever_up, transaction_id, user_hash) ' + 'storage_mounts_metadata, cluster_ever_up, status_updated_at, ' + 'user_hash) ' 'VALUES (' # name '?, ' @@ -291,7 +273,7 @@ def add_or_update_cluster(cluster_name: str, '(SELECT storage_mounts_metadata FROM clusters WHERE name=?), null), ' # cluster_ever_up '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?), ' - # transaction_id + # status_updated_at '?,' # user_hash: keep original user_hash if it exists 'COALESCE(' @@ -327,8 +309,8 @@ def add_or_update_cluster(cluster_name: str, # cluster_ever_up cluster_name, int(ready), - # transaction_id - transaction_id + 1, + # status_updated_at + status_updated_at, # user_hash cluster_name, user_hash, @@ -407,11 +389,13 @@ def remove_cluster(cluster_name: str, terminate: bool) -> None: # stopped VM, which leads to timeout. if hasattr(handle, 'stable_internal_external_ips'): handle.stable_internal_external_ips = None + current_time = int(time.time()) _DB.cursor.execute( - 'UPDATE clusters SET handle=(?), status=(?) ' - 'WHERE name=(?)', ( + 'UPDATE clusters SET handle=(?), status=(?), ' + 'status_updated_at=(?) WHERE name=(?)', ( pickle.dumps(handle), status_lib.ClusterStatus.STOPPED.value, + current_time, cluster_name, )) _DB.conn.commit() @@ -436,10 +420,10 @@ def get_glob_cluster_names(cluster_name: str) -> List[str]: def set_cluster_status(cluster_name: str, status: status_lib.ClusterStatus) -> None: - _DB.cursor.execute('UPDATE clusters SET status=(?) WHERE name=(?)', ( - status.value, - cluster_name, - )) + current_time = int(time.time()) + _DB.cursor.execute( + 'UPDATE clusters SET status=(?), status_updated_at=(?) WHERE name=(?)', + (status.value, current_time, cluster_name)) count = _DB.cursor.rowcount _DB.conn.commit() assert count <= 1, count @@ -447,18 +431,8 @@ def set_cluster_status(cluster_name: str, raise ValueError(f'Cluster {cluster_name} not found.') -def set_cluster_autostop_value( - cluster_name: str, - idle_minutes: int, - to_down: bool, - expected_transaction_id: Optional[int] = None) -> None: - transaction_id = get_transaction_id(cluster_name) - if (expected_transaction_id is not None and - expected_transaction_id != transaction_id): - logger.debug(f'Cluster {cluster_name} has been updated by another ' - 'transaction. Skipping update.') - return - +def set_cluster_autostop_value(cluster_name: str, idle_minutes: int, + to_down: bool) -> None: _DB.cursor.execute( 'UPDATE clusters SET autostop=(?), to_down=(?) WHERE name=(?)', ( idle_minutes, @@ -657,15 +631,18 @@ def _load_storage_mounts_metadata( def get_cluster_from_name( cluster_name: Optional[str]) -> Optional[Dict[str, Any]]: - rows = _DB.cursor.execute('SELECT * FROM clusters WHERE name=(?)', - (cluster_name,)).fetchall() + rows = _DB.cursor.execute( + 'SELECT name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, status_updated_at, user_hash ' + 'FROM clusters WHERE name=(?)', (cluster_name,)).fetchall() for row in rows: # Explicitly specify the number of fields to unpack, so that # we can add new fields to the database in the future without # breaking the previous code. (name, launched_at, handle, last_use, status, autostop, metadata, to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, - _, user_hash) = row[:14] + status_updated_at, user_hash) = row[:14] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -681,6 +658,7 @@ def get_cluster_from_name( 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'status_updated_at': status_updated_at, 'user_hash': user_hash, 'user_name': get_user(user_hash).name, } @@ -690,12 +668,15 @@ def get_cluster_from_name( def get_clusters() -> List[Dict[str, Any]]: rows = _DB.cursor.execute( - 'select * from clusters order by launched_at desc').fetchall() + 'select name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, status_updated_at, user_hash ' + 'from clusters order by launched_at desc').fetchall() records = [] for row in rows: (name, launched_at, handle, last_use, status, autostop, metadata, to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, - _, user_hash) = row[:14] + status_updated_at, user_hash) = row[:14] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -711,6 +692,7 @@ def get_clusters() -> List[Dict[str, Any]]: 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'status_updated_at': status_updated_at, 'user_hash': user_hash, 'user_name': get_user(user_hash).name, } diff --git a/sky/jobs/api/core.py b/sky/jobs/api/core.py index ab73c49dcd2..b8a2362fc45 100644 --- a/sky/jobs/api/core.py +++ b/sky/jobs/api/core.py @@ -33,8 +33,7 @@ if typing.TYPE_CHECKING: import sky - -# TODO(zhwu): Fix jobs API to work with API server + from sky.backends import cloud_vm_ray_backend @timeline.event @@ -236,6 +235,40 @@ def queue_from_kubernetes_pod( return jobs +def _maybe_restart_controller( + refresh: bool, stopped_message: str, spinner_message: str +) -> 'cloud_vm_ray_backend.CloudVmRayResourceHandle': + """Restart controller if refresh is True and it is stopped.""" + jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER + if refresh: + stopped_message = '' + try: + handle = backend_utils.is_controller_accessible( + controller=jobs_controller_type, stopped_message=stopped_message) + except exceptions.ClusterNotUpError as e: + if not refresh: + raise + handle = None + controller_status = e.cluster_status + + if handle is not None: + return handle + + sky_logging.print(f'{colorama.Fore.YELLOW}' + f'Restarting {jobs_controller_type.value.name}...' + f'{colorama.Style.RESET_ALL}') + + rich_utils.force_update_status( + ux_utils.spinner_message(f'{spinner_message} - restarting ' + 'controller')) + handle = core.start(jobs_controller_type.value.cluster_name) + controller_status = status_lib.ClusterStatus.UP + rich_utils.force_update_status(ux_utils.spinner_message(spinner_message)) + + assert handle is not None, (controller_status, refresh) + return handle + + @usage_lib.entrypoint def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. @@ -263,34 +296,11 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]: does not exist. RuntimeError: if failed to get the managed jobs with ssh. """ - jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER - stopped_message = '' - if not refresh: - stopped_message = 'No in-progress managed jobs.' - try: - handle = backend_utils.is_controller_accessible( - controller=jobs_controller_type, stopped_message=stopped_message) - except exceptions.ClusterNotUpError as e: - if not refresh: - raise - handle = None - controller_status = e.cluster_status - - if refresh and handle is None: - sky_logging.print(f'{colorama.Fore.YELLOW}' - 'Restarting controller for latest status...' - f'{colorama.Style.RESET_ALL}') - - rich_utils.force_update_status( - ux_utils.spinner_message('Checking managed jobs - restarting ' - 'controller')) - handle = core.start(jobs_controller_type.value.cluster_name) - controller_status = status_lib.ClusterStatus.UP - rich_utils.force_update_status( - ux_utils.spinner_message('Checking managed jobs')) - - assert handle is not None, (controller_status, refresh) - + handle = _maybe_restart_controller(refresh, + stopped_message='No in-progress ' + 'managed jobs.', + spinner_message='Checking ' + 'managed jobs') backend = backend_utils.get_backend_from_handle(handle) assert isinstance(backend, backends.CloudVmRayBackend) @@ -382,7 +392,7 @@ def cancel(name: Optional[str] = None, @usage_lib.entrypoint def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool, - controller: bool) -> None: + controller: bool, refresh: bool) -> None: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Tail logs of managed jobs. @@ -393,15 +403,26 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool, sky.exceptions.ClusterNotUpError: the jobs controller is not up. """ # TODO(zhwu): Automatically restart the jobs controller + if name is not None and job_id is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Cannot specify both name and job_id.') + jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER - handle = backend_utils.is_controller_accessible( - controller=jobs_controller_type, + job_name_or_id_str = '' + if job_id is not None: + job_name_or_id_str = str(job_id) + elif name is not None: + job_name_or_id_str = f'-n {name}' + else: + job_name_or_id_str = '' + handle = _maybe_restart_controller( + refresh, stopped_message=( - 'Please restart the jobs controller with ' - f'`sky start {jobs_controller_type.value.cluster_name}`.')) + f'{jobs_controller_type.value.name.capitalize()} is stopped. To ' + f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs ' + f'-r {job_name_or_id_str}{colorama.Style.RESET_ALL}'), + spinner_message='Retrieving job logs') - if name is not None and job_id is not None: - raise ValueError('Cannot specify both name and job_id.') backend = backend_utils.get_backend_from_handle(handle) assert isinstance(backend, backends.CloudVmRayBackend), backend diff --git a/sky/jobs/api/sdk.py b/sky/jobs/api/sdk.py index 2ba0e3bb253..7b371c59b34 100644 --- a/sky/jobs/api/sdk.py +++ b/sky/jobs/api/sdk.py @@ -94,14 +94,18 @@ def cancel( @usage_lib.entrypoint @api_common.check_health -def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool, - controller: bool) -> str: +def tail_logs(name: Optional[str], + job_id: Optional[int], + follow: bool, + controller: bool, + refresh: bool = False) -> str: """Tail logs of managed jobs.""" body = payloads.JobsLogsBody( name=name, job_id=job_id, follow=follow, controller=controller, + refresh=refresh, ) response = requests.post( f'{api_common.get_server_url()}/jobs/logs', diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 7cfc50fd4b4..23123eae761 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -6,7 +6,7 @@ import time import traceback import typing -from typing import Tuple +from typing import Optional, Tuple import filelock @@ -20,6 +20,7 @@ from sky.skylet import constants from sky.skylet import job_lib from sky.usage import usage_lib +from sky.utils import common from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import dag_utils @@ -87,18 +88,28 @@ def __init__(self, job_id: int, dag_yaml: str, task.update_envs(task_envs) def _download_log_and_stream( - self, - handle: cloud_vm_ray_backend.CloudVmRayResourceHandle) -> None: - """Downloads and streams the logs of the latest job. + self, task_id: Optional[int], + handle: Optional[cloud_vm_ray_backend.CloudVmRayResourceHandle] + ) -> None: + """Downloads and streams the logs of the current job with given task ID. We do not stream the logs from the cluster directly, as the donwload and stream should be faster, and more robust against preemptions or ssh disconnection during the streaming. """ + if handle is None: + logger.info(f'Cluster for job {self._job_id} is not found. ' + 'Skipping downloading and streaming the logs.') + return managed_job_logs_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, 'managed_jobs') - controller_utils.download_and_stream_latest_job_log( + log_file = controller_utils.download_and_stream_latest_job_log( self._backend, handle, managed_job_logs_dir) + if log_file is not None: + # Set the path of the log file for the current task, so it can be + # accessed even after the job is finished + managed_job_state.set_local_log_file(self._job_id, task_id, + log_file) logger.info(f'\n== End of logs (ID: {self._job_id}) ==') def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: @@ -213,7 +224,8 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: if job_status == job_lib.JobStatus.SUCCEEDED: end_time = managed_job_utils.get_job_timestamp( self._backend, cluster_name, get_end_time=True) - # The job is done. + # The job is done. Set the job to SUCCEEDED first before start + # downloading and streaming the logs to make it more responsive. managed_job_state.set_succeeded(self._job_id, task_id, end_time=end_time, @@ -221,12 +233,21 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: logger.info( f'Managed job {self._job_id} (task: {task_id}) SUCCEEDED. ' f'Cleaning up the cluster {cluster_name}.') + clusters = backend_utils.get_clusters( + cluster_names=[cluster_name], + refresh=common.StatusRefreshMode.NONE, + all_users=True) + if clusters: + assert len(clusters) == 1, (clusters, cluster_name) + handle = clusters[0].get('handle') + # Best effort to download and stream the logs. + self._download_log_and_stream(task_id, handle) # Only clean up the cluster, not the storages, because tasks may # share storages. recovery_strategy.terminate_cluster(cluster_name=cluster_name) return True - # For single-node jobs, nonterminated job_status indicates a + # For single-node jobs, non-terminated job_status indicates a # healthy cluster. We can safely continue monitoring. # For multi-node jobs, since the job may not be set to FAILED # immediately (depending on user program) when only some of the @@ -278,7 +299,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: 'The user job failed. Please check the logs below.\n' f'== Logs of the user job (ID: {self._job_id}) ==\n') - self._download_log_and_stream(handle) + self._download_log_and_stream(task_id, handle) managed_job_status = ( managed_job_state.ManagedJobStatus.FAILED) if job_status == job_lib.JobStatus.FAILED_SETUP: diff --git a/sky/jobs/recovery_strategy.py b/sky/jobs/recovery_strategy.py index d4e18970309..0c90f0c279d 100644 --- a/sky/jobs/recovery_strategy.py +++ b/sky/jobs/recovery_strategy.py @@ -50,8 +50,9 @@ def terminate_cluster(cluster_name: str, max_retry: int = 3) -> None: usage_lib.messages.usage.set_internal() core.down(cluster_name) return - except ValueError: + except exceptions.ClusterDoesNotExist: # The cluster is already down. + logger.debug(f'The cluster {cluster_name} is already down.') return except Exception as e: # pylint: disable=broad-except retry_cnt += 1 diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 6a0e3caeda3..9a5ab4b3cad 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -66,7 +66,8 @@ def create_table(cursor, conn): spot_job_id INTEGER, task_id INTEGER DEFAULT 0, task_name TEXT, - specs TEXT)""") + specs TEXT, + local_log_file TEXT DEFAULT NULL)""") conn.commit() db_utils.add_column_to_table(cursor, conn, 'spot', 'failure_reason', 'TEXT') @@ -103,6 +104,8 @@ def create_table(cursor, conn): value_to_replace_existing_entries=json.dumps({ 'max_restarts_on_errors': 0, })) + db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file', + 'TEXT DEFAULT NULL') # `job_info` contains the mapping from job_id to the job_name. # In the future, it may contain more information about each job. @@ -157,6 +160,7 @@ def _get_db_path() -> str: 'task_id', 'task_name', 'specs', + 'local_log_file', # columns from the job_info table '_job_info_job_id', # This should be the same as job_id 'job_name', @@ -512,6 +516,20 @@ def set_cancelled(job_id: int, callback_func: CallbackType): callback_func('CANCELLED') +def set_local_log_file(job_id: int, task_id: Optional[int], + local_log_file: str): + """Set the local log file for a job.""" + filter_str = 'spot_job_id=(?)' + filter_args = [local_log_file, job_id] + if task_id is not None: + filter_str += ' AND task_id=(?)' + filter_args.append(task_id) + with db_utils.safe_cursor(_DB_PATH) as cursor: + cursor.execute( + 'UPDATE spot SET local_log_file=(?) ' + f'WHERE {filter_str}', filter_args) + + # ======== utility functions ======== def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]: """Get non-terminal job ids by name.""" @@ -662,3 +680,17 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]: WHERE spot_job_id=(?) AND task_id=(?)""", (job_id, task_id)).fetchone() return json.loads(task_specs[0]) + + +def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]: + """Get the local log directory for a job.""" + filter_str = 'spot_job_id=(?)' + filter_args = [job_id] + if task_id is not None: + filter_str += ' AND task_id=(?)' + filter_args.append(task_id) + with db_utils.safe_cursor(_DB_PATH) as cursor: + local_log_file = cursor.execute( + f'SELECT local_log_file FROM spot ' + f'WHERE {filter_str}', filter_args).fetchone() + return local_log_file[-1] if local_log_file else None diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 8845f488dab..8e30dcebc04 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -322,10 +322,24 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: if managed_job_status.is_failed(): job_msg = ('\nFailure reason: ' f'{managed_job_state.get_failure_reason(job_id)}') + log_file = managed_job_state.get_local_log_file(job_id, None) + if log_file is not None: + with open(log_file, 'r', encoding='utf-8') as f: + # Stream the logs to the console without reading the whole + # file into memory. + start_streaming = False + for line in f: + if log_lib.LOG_FILE_START_STREAMING_AT in line: + start_streaming = True + if start_streaming: + print(line, end='', flush=True) + return '' return (f'{colorama.Fore.YELLOW}' f'Job {job_id} is already in terminal state ' - f'{managed_job_status.value}. Logs will not be shown.' - f'{colorama.Style.RESET_ALL}{job_msg}') + f'{managed_job_status.value}. For more details, run: ' + f'sky jobs logs --controller {job_id}' + f'{colorama.Style.RESET_ALL}' + f'{job_msg}') backend = backends.CloudVmRayBackend() task_id, managed_job_status = ( managed_job_state.get_latest_task_id_status(job_id)) diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index d9cbb1dd643..ec1b7d2ae15 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -305,7 +305,8 @@ def _create_vm( network_profile=network_profile, identity=compute.VirtualMachineIdentity( type='UserAssigned', - user_assigned_identities={provider_config['msi']: {}})) + user_assigned_identities={provider_config['msi']: {}}), + priority=node_config['azure_arm_parameters'].get('priority', None)) vm_poller = compute_client.virtual_machines.begin_create_or_update( resource_group_name=provider_config['resource_group'], vm_name=vm_name, diff --git a/sky/provision/fluidstack/instance.py b/sky/provision/fluidstack/instance.py index 0e77de23ee6..1bb123ef649 100644 --- a/sky/provision/fluidstack/instance.py +++ b/sky/provision/fluidstack/instance.py @@ -81,9 +81,7 @@ def run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """Runs instances for the given cluster.""" - pending_status = [ - 'pending', - ] + pending_status = ['pending', 'provisioning'] while True: instances = _filter_instances(cluster_name_on_cloud, pending_status) if len(instances) > config.count: diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index a211cbd07f7..5287197243b 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -4,7 +4,6 @@ import hashlib import json import os -import resource import time from typing import Any, Callable, Dict, List, Optional, Tuple @@ -20,6 +19,7 @@ from sky.utils import command_runner from sky.utils import common_utils from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) @@ -115,7 +115,8 @@ def _parallel_ssh_with_cache(func, if max_workers is None: # Not using the default value of `max_workers` in ThreadPoolExecutor, # as 32 is too large for some machines. - max_workers = subprocess_utils.get_parallel_threads() + max_workers = subprocess_utils.get_parallel_threads( + cluster_info.provider_name) with futures.ThreadPoolExecutor(max_workers=max_workers) as pool: results = [] runners = provision.get_command_runners(cluster_info.provider_name, @@ -170,6 +171,7 @@ def _initialize_docker(runner: command_runner.CommandRunner, log_path: str): @common.log_function_start_end +@timeline.event def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: @@ -245,20 +247,9 @@ def _ray_gpu_options(custom_resource: str) -> str: return f' --num-gpus={acc_count}' -@common.log_function_start_end -@_auto_retry() -def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], - cluster_info: common.ClusterInfo, - ssh_credentials: Dict[str, Any]) -> None: - """Start Ray on the head node.""" - runners = provision.get_command_runners(cluster_info.provider_name, - cluster_info, **ssh_credentials) - head_runner = runners[0] - assert cluster_info.head_instance_id is not None, (cluster_name, - cluster_info) - - # Log the head node's output to the provision.log - log_path_abs = str(provision_logging.get_log_path()) +def ray_head_start_command(custom_resource: Optional[str], + custom_ray_options: Optional[Dict[str, Any]]) -> str: + """Returns the command to start Ray on the head node.""" ray_options = ( # --disable-usage-stats in `ray start` saves 10 seconds of idle wait. f'--disable-usage-stats ' @@ -270,23 +261,14 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], if custom_resource: ray_options += f' --resources=\'{custom_resource}\'' ray_options += _ray_gpu_options(custom_resource) - - if cluster_info.custom_ray_options: - if 'use_external_ip' in cluster_info.custom_ray_options: - cluster_info.custom_ray_options.pop('use_external_ip') - for key, value in cluster_info.custom_ray_options.items(): + if custom_ray_options: + if 'use_external_ip' in custom_ray_options: + custom_ray_options.pop('use_external_ip') + for key, value in custom_ray_options.items(): ray_options += f' --{key}={value}' - # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY to avoid using credentials - # from environment variables set by user. SkyPilot's ray cluster should use - # the `~/.aws/` credentials, as that is the one used to create the cluster, - # and the autoscaler module started by the `ray start` command should use - # the same credentials. Otherwise, `ray status` will fail to fetch the - # available nodes. - # Reference: https://github.com/skypilot-org/skypilot/issues/2441 cmd = ( f'{constants.SKY_RAY_CMD} stop; ' - 'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; ' 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' # worker_maximum_startup_concurrency controls the maximum number of # workers that can be started concurrently. However, it also controls @@ -305,6 +287,62 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], 'RAY_worker_maximum_startup_concurrency=$(( 3 * $(nproc --all) )) ' f'{constants.SKY_RAY_CMD} start --head {ray_options} || exit 1;' + _RAY_PRLIMIT + _DUMP_RAY_PORTS + RAY_HEAD_WAIT_INITIALIZED_COMMAND) + return cmd + + +def ray_worker_start_command(custom_resource: Optional[str], + custom_ray_options: Optional[Dict[str, Any]], + no_restart: bool) -> str: + """Returns the command to start Ray on the worker node.""" + # We need to use the ray port in the env variable, because the head node + # determines the port to be used for the worker node. + ray_options = ('--address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT} ' + '--object-manager-port=8076') + + if custom_resource: + ray_options += f' --resources=\'{custom_resource}\'' + ray_options += _ray_gpu_options(custom_resource) + + if custom_ray_options: + for key, value in custom_ray_options.items(): + ray_options += f' --{key}={value}' + + cmd = ( + 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' + f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || ' + 'exit 1;' + _RAY_PRLIMIT) + if no_restart: + # We do not use ray status to check whether ray is running, because + # on worker node, if the user started their own ray cluster, ray status + # will return 0, i.e., we don't know skypilot's ray cluster is running. + # Instead, we check whether the raylet process is running on gcs address + # that is connected to the head with the correct port. + cmd = ( + f'ps aux | grep "ray/raylet/raylet" | ' + 'grep "gcs-address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT}" ' + f'|| {{ {cmd} }}') + else: + cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd + return cmd + + +@common.log_function_start_end +@_auto_retry() +@timeline.event +def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], + cluster_info: common.ClusterInfo, + ssh_credentials: Dict[str, Any]) -> None: + """Start Ray on the head node.""" + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + head_runner = runners[0] + assert cluster_info.head_instance_id is not None, (cluster_name, + cluster_info) + + # Log the head node's output to the provision.log + log_path_abs = str(provision_logging.get_log_path()) + cmd = ray_head_start_command(custom_resource, + cluster_info.custom_ray_options) logger.info(f'Running command on head node: {cmd}') # TODO(zhwu): add the output to log files. returncode, stdout, stderr = head_runner.run( @@ -324,6 +362,7 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], @common.log_function_start_end @_auto_retry() +@timeline.event def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, custom_resource: Optional[str], ray_port: int, cluster_info: common.ClusterInfo, @@ -358,43 +397,17 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, head_ip = (head_instance.internal_ip if not use_external_ip else head_instance.external_ip) - ray_options = (f'--address={head_ip}:{constants.SKY_REMOTE_RAY_PORT} ' - f'--object-manager-port=8076') - - if custom_resource: - ray_options += f' --resources=\'{custom_resource}\'' - ray_options += _ray_gpu_options(custom_resource) - - if cluster_info.custom_ray_options: - for key, value in cluster_info.custom_ray_options.items(): - ray_options += f' --{key}={value}' + ray_cmd = ray_worker_start_command(custom_resource, + cluster_info.custom_ray_options, + no_restart) - # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY, see the comment in - # `start_ray_on_head_node`. - cmd = ( - f'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; ' - 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' - f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || ' - 'exit 1;' + _RAY_PRLIMIT) - if no_restart: - # We do not use ray status to check whether ray is running, because - # on worker node, if the user started their own ray cluster, ray status - # will return 0, i.e., we don't know skypilot's ray cluster is running. - # Instead, we check whether the raylet process is running on gcs address - # that is connected to the head with the correct port. - cmd = (f'RAY_PORT={ray_port}; ps aux | grep "ray/raylet/raylet" | ' - f'grep "gcs-address={head_ip}:${{RAY_PORT}}" || ' - f'{{ {cmd} }}') - else: - cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd + cmd = (f'export SKYPILOT_RAY_HEAD_IP="{head_ip}"; ' + f'export SKYPILOT_RAY_PORT={ray_port}; ' + ray_cmd) logger.info(f'Running command on worker nodes: {cmd}') def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, str]): - # for cmd in config_from_yaml['worker_start_ray_commands']: - # cmd = cmd.replace('$RAY_HEAD_IP', ip_list[0][0]) - # runner.run(cmd) runner, instance_id = runner_and_id log_dir = metadata_utils.get_instance_log_dir(cluster_name, instance_id) log_path_abs = str(log_dir / ('ray_cluster' + '.log')) @@ -407,8 +420,10 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, # by ray will have the correct PATH. source_bashrc=True) + num_threads = subprocess_utils.get_parallel_threads( + cluster_info.provider_name) results = subprocess_utils.run_in_parallel( - _setup_ray_worker, list(zip(worker_runners, cache_ids))) + _setup_ray_worker, list(zip(worker_runners, cache_ids)), num_threads) for returncode, stdout, stderr in results: if returncode: with ux_utils.print_exception_no_traceback(): @@ -421,6 +436,7 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, @common.log_function_start_end @_auto_retry() +@timeline.event def start_skylet_on_head_node(cluster_name: str, cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: @@ -482,28 +498,8 @@ def _internal_file_mounts(file_mounts: Dict, ) -def _max_workers_for_file_mounts(common_file_mounts: Dict[str, str]) -> int: - fd_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) - - fd_per_rsync = 5 - for src in common_file_mounts.values(): - if os.path.isdir(src): - # Assume that each file/folder under src takes 5 file descriptors - # on average. - fd_per_rsync = max(fd_per_rsync, len(os.listdir(src)) * 5) - - # Reserve some file descriptors for the system and other processes - fd_reserve = 100 - - max_workers = (fd_limit - fd_reserve) // fd_per_rsync - # At least 1 worker, and avoid too many workers overloading the system. - max_workers = min(max(max_workers, 1), - subprocess_utils.get_parallel_threads()) - logger.debug(f'Using {max_workers} workers for file mounts.') - return max_workers - - @common.log_function_start_end +@timeline.event def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, str]) -> None: @@ -524,4 +520,5 @@ def _setup_node(runner: command_runner.CommandRunner, log_path: str): digest=None, cluster_info=cluster_info, ssh_credentials=ssh_credentials, - max_workers=_max_workers_for_file_mounts(common_file_mounts)) + max_workers=subprocess_utils.get_max_workers_for_file_mounts( + common_file_mounts, cluster_info.provider_name)) diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index cd6f90cc3b3..731e5afb275 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -20,12 +20,13 @@ from sky.utils import kubernetes_enums from sky.utils import status_lib from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils POLL_INTERVAL = 2 _TIMEOUT_FOR_POD_TERMINATION = 60 # 1 minutes _MAX_RETRIES = 3 -NUM_THREADS = subprocess_utils.get_parallel_threads() * 2 +_NUM_THREADS = subprocess_utils.get_parallel_threads('kubernetes') logger = sky_logging.init_logger(__name__) TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' @@ -120,6 +121,9 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes): are recorded as events. This function retrieves those events and raises descriptive errors for better debugging and user feedback. """ + timeout_err_msg = ('Timed out while waiting for nodes to start. ' + 'Cluster may be out of resources or ' + 'may be too slow to autoscale.') for new_node in new_nodes: pod = kubernetes.core_api(context).read_namespaced_pod( new_node.metadata.name, namespace) @@ -148,9 +152,6 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes): if event.reason == 'FailedScheduling': event_message = event.message break - timeout_err_msg = ('Timed out while waiting for nodes to start. ' - 'Cluster may be out of resources or ' - 'may be too slow to autoscale.') if event_message is not None: if pod_status == 'Pending': logger.info(event_message) @@ -219,6 +220,7 @@ def _raise_command_running_error(message: str, command: str, pod_name: str, f'code {rc}: {command!r}\nOutput: {stdout}.') +@timeline.event def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int): """Wait for all pods to be scheduled. @@ -229,6 +231,10 @@ def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int): If timeout is set to a negative value, this method will wait indefinitely. """ + # Create a set of pod names we're waiting for + if not new_nodes: + return + expected_pod_names = {node.metadata.name for node in new_nodes} start_time = time.time() def _evaluate_timeout() -> bool: @@ -238,19 +244,34 @@ def _evaluate_timeout() -> bool: return time.time() - start_time < timeout while _evaluate_timeout(): - all_pods_scheduled = True - for node in new_nodes: - # Iterate over each pod to check their status - pod = kubernetes.core_api(context).read_namespaced_pod( - node.metadata.name, namespace) - if pod.status.phase == 'Pending': + # Get all pods in a single API call using the cluster name label + # which all pods in new_nodes should share + cluster_name = new_nodes[0].metadata.labels[TAG_SKYPILOT_CLUSTER_NAME] + pods = kubernetes.core_api(context).list_namespaced_pod( + namespace, + label_selector=f'{TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items + + # Get the set of found pod names and check if we have all expected pods + found_pod_names = {pod.metadata.name for pod in pods} + missing_pods = expected_pod_names - found_pod_names + if missing_pods: + logger.info('Retrying waiting for pods: ' + f'Missing pods: {missing_pods}') + time.sleep(0.5) + continue + + # Check if all pods are scheduled + all_scheduled = True + for pod in pods: + if (pod.metadata.name in expected_pod_names and + pod.status.phase == 'Pending'): # If container_statuses is None, then the pod hasn't # been scheduled yet. if pod.status.container_statuses is None: - all_pods_scheduled = False + all_scheduled = False break - if all_pods_scheduled: + if all_scheduled: return time.sleep(1) @@ -266,12 +287,18 @@ def _evaluate_timeout() -> bool: f'Error: {common_utils.format_exception(e)}') from None +@timeline.event def _wait_for_pods_to_run(namespace, context, new_nodes): """Wait for pods and their containers to be ready. Pods may be pulling images or may be in the process of container creation. """ + if not new_nodes: + return + + # Create a set of pod names we're waiting for + expected_pod_names = {node.metadata.name for node in new_nodes} def _check_init_containers(pod): # Check if any of the init containers failed @@ -299,12 +326,25 @@ def _check_init_containers(pod): f'{pod.metadata.name}. Error details: {msg}.') while True: - all_pods_running = True - # Iterate over each pod to check their status - for node in new_nodes: - pod = kubernetes.core_api(context).read_namespaced_pod( - node.metadata.name, namespace) + # Get all pods in a single API call + cluster_name = new_nodes[0].metadata.labels[TAG_SKYPILOT_CLUSTER_NAME] + all_pods = kubernetes.core_api(context).list_namespaced_pod( + namespace, + label_selector=f'{TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items + + # Get the set of found pod names and check if we have all expected pods + found_pod_names = {pod.metadata.name for pod in all_pods} + missing_pods = expected_pod_names - found_pod_names + if missing_pods: + logger.info('Retrying running pods check: ' + f'Missing pods: {missing_pods}') + time.sleep(0.5) + continue + all_pods_running = True + for pod in all_pods: + if pod.metadata.name not in expected_pod_names: + continue # Continue if pod and all the containers within the # pod are successfully created and running. if pod.status.phase == 'Running' and all( @@ -367,6 +407,7 @@ def _run_function_with_retries(func: Callable, raise +@timeline.event def pre_init(namespace: str, context: Optional[str], new_nodes: List) -> None: """Pre-initialization step for SkyPilot pods. @@ -430,7 +471,7 @@ def pre_init(namespace: str, context: Optional[str], new_nodes: List) -> None: 'start_time=$(date +%s); ' 'while ! grep -q "Fetched" /tmp/apt-update.log 2>/dev/null; do ' ' echo "apt update still running. Logs:"; ' - ' cat /tmp/apt-update.log; ' + ' cat /tmp/apt-update.log || true; ' ' current_time=$(date +%s); ' ' elapsed=$((current_time - start_time)); ' ' if [ $elapsed -ge $timeout_secs ]; then ' @@ -514,7 +555,7 @@ def _pre_init_thread(new_node): logger.info(f'{"-"*20}End: Pre-init in pod {pod_name!r} {"-"*20}') # Run pre_init in parallel across all new_nodes - subprocess_utils.run_in_parallel(_pre_init_thread, new_nodes, NUM_THREADS) + subprocess_utils.run_in_parallel(_pre_init_thread, new_nodes, _NUM_THREADS) def _label_pod(namespace: str, context: Optional[str], pod_name: str, @@ -528,6 +569,7 @@ def _label_pod(namespace: str, context: Optional[str], pod_name: str, _request_timeout=kubernetes.API_TIMEOUT) +@timeline.event def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict, context: Optional[str]) -> Any: """Attempts to create a Kubernetes Pod and handle any errors. @@ -606,6 +648,7 @@ def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict, raise e +@timeline.event def _create_pods(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """Create pods based on the config.""" @@ -627,7 +670,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str, terminating_pods = kubernetes_utils.filter_pods(namespace, context, tags, ['Terminating']) start_time = time.time() - while (len(terminating_pods) > 0 and + while (terminating_pods and time.time() - start_time < _TIMEOUT_FOR_POD_TERMINATION): logger.debug(f'run_instances: Found {len(terminating_pods)} ' 'terminating pods. Waiting them to finish: ' @@ -636,7 +679,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str, terminating_pods = kubernetes_utils.filter_pods(namespace, context, tags, ['Terminating']) - if len(terminating_pods) > 0: + if terminating_pods: # If there are still terminating pods, we force delete them. logger.debug(f'run_instances: Found {len(terminating_pods)} ' 'terminating pods still in terminating state after ' @@ -695,24 +738,29 @@ def _create_pods(region: str, cluster_name_on_cloud: str, created_pods = {} logger.debug(f'run_instances: calling create_namespaced_pod ' f'(count={to_start_count}).') - for _ in range(to_start_count): - if head_pod_name is None: - pod_spec['metadata']['labels'].update(constants.HEAD_NODE_TAGS) + + def _create_pod_thread(i: int): + pod_spec_copy = copy.deepcopy(pod_spec) + if head_pod_name is None and i == 0: + # First pod should be head if no head exists + pod_spec_copy['metadata']['labels'].update(constants.HEAD_NODE_TAGS) head_selector = head_service_selector(cluster_name_on_cloud) - pod_spec['metadata']['labels'].update(head_selector) - pod_spec['metadata']['name'] = f'{cluster_name_on_cloud}-head' + pod_spec_copy['metadata']['labels'].update(head_selector) + pod_spec_copy['metadata']['name'] = f'{cluster_name_on_cloud}-head' else: - pod_spec['metadata']['labels'].update(constants.WORKER_NODE_TAGS) - pod_uuid = str(uuid.uuid4())[:4] + # Worker pods + pod_spec_copy['metadata']['labels'].update( + constants.WORKER_NODE_TAGS) + pod_uuid = str(uuid.uuid4())[:6] pod_name = f'{cluster_name_on_cloud}-{pod_uuid}' - pod_spec['metadata']['name'] = f'{pod_name}-worker' + pod_spec_copy['metadata']['name'] = f'{pod_name}-worker' # For multi-node support, we put a soft-constraint to schedule # worker pods on different nodes than the head pod. # This is not set as a hard constraint because if different nodes # are not available, we still want to be able to schedule worker # pods on larger nodes which may be able to fit multiple SkyPilot # "nodes". - pod_spec['spec']['affinity'] = { + pod_spec_copy['spec']['affinity'] = { 'podAntiAffinity': { # Set as a soft constraint 'preferredDuringSchedulingIgnoredDuringExecution': [{ @@ -747,17 +795,22 @@ def _create_pods(region: str, cluster_name_on_cloud: str, 'value': 'present', 'effect': 'NoSchedule' } - pod_spec['spec']['tolerations'] = [tpu_toleration] + pod_spec_copy['spec']['tolerations'] = [tpu_toleration] - pod = _create_namespaced_pod_with_retries(namespace, pod_spec, context) + return _create_namespaced_pod_with_retries(namespace, pod_spec_copy, + context) + + # Create pods in parallel + pods = subprocess_utils.run_in_parallel(_create_pod_thread, + range(to_start_count), _NUM_THREADS) + + # Process created pods + for pod in pods: created_pods[pod.metadata.name] = pod - if head_pod_name is None: + if head_pod_name is None and pod.metadata.labels.get( + constants.TAG_RAY_NODE_KIND) == 'head': head_pod_name = pod.metadata.name - wait_pods_dict = kubernetes_utils.filter_pods(namespace, context, tags, - ['Pending']) - wait_pods = list(wait_pods_dict.values()) - networking_mode = network_utils.get_networking_mode( config.provider_config.get('networking_mode')) if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: @@ -766,52 +819,24 @@ def _create_pods(region: str, cluster_name_on_cloud: str, ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump'] jump_pod = kubernetes.core_api(context).read_namespaced_pod( ssh_jump_pod_name, namespace) - wait_pods.append(jump_pod) + pods.append(jump_pod) provision_timeout = provider_config['timeout'] wait_str = ('indefinitely' if provision_timeout < 0 else f'for {provision_timeout}s') logger.debug(f'run_instances: waiting {wait_str} for pods to schedule and ' - f'run: {list(wait_pods_dict.keys())}') + f'run: {[pod.metadata.name for pod in pods]}') # Wait until the pods are scheduled and surface cause for error # if there is one - _wait_for_pods_to_schedule(namespace, context, wait_pods, provision_timeout) + _wait_for_pods_to_schedule(namespace, context, pods, provision_timeout) # Wait until the pods and their containers are up and running, and # fail early if there is an error logger.debug(f'run_instances: waiting for pods to be running (pulling ' - f'images): {list(wait_pods_dict.keys())}') - _wait_for_pods_to_run(namespace, context, wait_pods) + f'images): {[pod.metadata.name for pod in pods]}') + _wait_for_pods_to_run(namespace, context, pods) logger.debug(f'run_instances: all pods are scheduled and running: ' - f'{list(wait_pods_dict.keys())}') - - running_pods = kubernetes_utils.filter_pods(namespace, context, tags, - ['Running']) - initialized_pods = kubernetes_utils.filter_pods(namespace, context, { - TAG_POD_INITIALIZED: 'true', - **tags - }, ['Running']) - uninitialized_pods = { - pod_name: pod - for pod_name, pod in running_pods.items() - if pod_name not in initialized_pods - } - if len(uninitialized_pods) > 0: - logger.debug(f'run_instances: Initializing {len(uninitialized_pods)} ' - f'pods: {list(uninitialized_pods.keys())}') - uninitialized_pods_list = list(uninitialized_pods.values()) - - # Run pre-init steps in the pod. - pre_init(namespace, context, uninitialized_pods_list) - - for pod in uninitialized_pods.values(): - _label_pod(namespace, - context, - pod.metadata.name, - label={ - TAG_POD_INITIALIZED: 'true', - **pod.metadata.labels - }) + f'{[pod.metadata.name for pod in pods]}') assert head_pod_name is not None, 'head_instance_id should not be None' return common.ProvisionRecord( @@ -854,11 +879,6 @@ def _terminate_node(namespace: str, context: Optional[str], pod_name: str) -> None: """Terminate a pod.""" logger.debug('terminate_instances: calling delete_namespaced_pod') - try: - kubernetes_utils.clean_zombie_ssh_jump_pod(namespace, context, pod_name) - except Exception as e: # pylint: disable=broad-except - logger.warning('terminate_instances: Error occurred when analyzing ' - f'SSH Jump pod: {e}') try: kubernetes.core_api(context).delete_namespaced_service( pod_name, namespace, _request_timeout=config_lib.DELETION_TIMEOUT) @@ -895,6 +915,18 @@ def terminate_instances( } pods = kubernetes_utils.filter_pods(namespace, context, tag_filters, None) + # Clean up the SSH jump pod if in use + networking_mode = network_utils.get_networking_mode( + provider_config.get('networking_mode')) + if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: + pod_name = list(pods.keys())[0] + try: + kubernetes_utils.clean_zombie_ssh_jump_pod(namespace, context, + pod_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('terminate_instances: Error occurred when analyzing ' + f'SSH Jump pod: {e}') + def _is_head(pod) -> bool: return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head' @@ -907,7 +939,7 @@ def _terminate_pod_thread(pod_info): # Run pod termination in parallel subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(), - NUM_THREADS) + _NUM_THREADS) def get_cluster_info( diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index bff173e841f..d8fac3bf638 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -29,6 +29,7 @@ from sky.utils import kubernetes_enums from sky.utils import schemas from sky.utils import status_lib +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -37,7 +38,6 @@ # TODO(romilb): Move constants to constants.py DEFAULT_NAMESPACE = 'default' -IN_CLUSTER_REGION = 'in-cluster' DEFAULT_SERVICE_ACCOUNT_NAME = 'skypilot-service-account' @@ -921,6 +921,9 @@ def is_kubeconfig_exec_auth( str: Error message if exec-based authentication is used, None otherwise """ k8s = kubernetes.kubernetes + if context == kubernetes.in_cluster_context_name(): + # If in-cluster config is used, exec-based auth is not used. + return False, None try: k8s.config.load_kube_config() except kubernetes.config_exception(): @@ -1003,30 +1006,34 @@ def is_incluster_config_available() -> bool: return os.path.exists('/var/run/secrets/kubernetes.io/serviceaccount/token') -def get_all_kube_config_context_names() -> List[Optional[str]]: - """Get all kubernetes context names from the kubeconfig file. +def get_all_kube_context_names() -> List[str]: + """Get all kubernetes context names available in the environment. + + Fetches context names from the kubeconfig file and in-cluster auth, if any. - If running in-cluster, returns [None] to indicate in-cluster config. + If running in-cluster and IN_CLUSTER_CONTEXT_NAME_ENV_VAR is not set, + returns the default in-cluster kubernetes context name. We should not cache the result of this function as the admin policy may update the contexts. Returns: List[Optional[str]]: The list of kubernetes context names if - available, an empty list otherwise. If running in-cluster, - returns [None] to indicate in-cluster config. + available, an empty list otherwise. """ k8s = kubernetes.kubernetes + context_names = [] try: all_contexts, _ = k8s.config.list_kube_config_contexts() # all_contexts will always have at least one context. If kubeconfig # does not have any contexts defined, it will raise ConfigException. - return [context['name'] for context in all_contexts] + context_names = [context['name'] for context in all_contexts] except k8s.config.config_exception.ConfigException: - # If running in cluster, return [None] to indicate in-cluster config - if is_incluster_config_available(): - return [None] - return [] + # If no config found, continue + pass + if is_incluster_config_available(): + context_names.append(kubernetes.in_cluster_context_name()) + return context_names @functools.lru_cache() @@ -1039,11 +1046,15 @@ def get_kube_config_context_namespace( the default namespace. """ k8s = kubernetes.kubernetes - # Get namespace if using in-cluster config ns_path = '/var/run/secrets/kubernetes.io/serviceaccount/namespace' - if os.path.exists(ns_path): - with open(ns_path, encoding='utf-8') as f: - return f.read().strip() + # If using in-cluster context, get the namespace from the service account + # namespace file. Uses the same logic as adaptors.kubernetes._load_config() + # to stay consistent with in-cluster config loading. + if (context_name == kubernetes.in_cluster_context_name() or + context_name is None): + if os.path.exists(ns_path): + with open(ns_path, encoding='utf-8') as f: + return f.read().strip() # If not in-cluster, get the namespace from kubeconfig try: contexts, current_context = k8s.config.list_kube_config_contexts() @@ -1130,7 +1141,11 @@ def name(self) -> str: name = (f'{common_utils.format_float(self.cpus)}CPU--' f'{common_utils.format_float(self.memory)}GB') if self.accelerator_count: - name += f'--{self.accelerator_count}{self.accelerator_type}' + # Replace spaces with underscores in accelerator type to make it a + # valid logical instance type name. + assert self.accelerator_type is not None, self.accelerator_count + acc_name = self.accelerator_type.replace(' ', '_') + name += f'--{self.accelerator_count}{acc_name}' return name @staticmethod @@ -1161,7 +1176,9 @@ def _parse_instance_type( accelerator_type = match.group('accelerator_type') if accelerator_count: accelerator_count = int(accelerator_count) - accelerator_type = str(accelerator_type) + # This is to revert the accelerator types with spaces back to + # the original format. + accelerator_type = str(accelerator_type).replace('_', ' ') else: accelerator_count = None accelerator_type = None @@ -1697,6 +1714,8 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): else: destination[key].extend(value) else: + if destination is None: + destination = {} destination[key] = value @@ -2045,6 +2064,7 @@ def get_namespace_from_config(provider_config: Dict[str, Any]) -> str: get_kube_config_context_namespace(context)) +@timeline.event def filter_pods(namespace: str, context: Optional[str], tag_filters: Dict[str, str], @@ -2175,9 +2195,9 @@ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle', def get_context_from_config(provider_config: Dict[str, Any]) -> Optional[str]: context = provider_config.get('context', get_current_kube_config_context_name()) - if context == IN_CLUSTER_REGION: - # If the context (also used as the region) is set to IN_CLUSTER_REGION - # we need to use in-cluster auth. + if context == kubernetes.in_cluster_context_name(): + # If the context (also used as the region) is in-cluster, we need to + # we need to use in-cluster auth by setting the context to None. context = None return context diff --git a/sky/provision/oci/instance.py b/sky/provision/oci/instance.py index 26a21d9792b..bb00cc7a32f 100644 --- a/sky/provision/oci/instance.py +++ b/sky/provision/oci/instance.py @@ -126,8 +126,8 @@ def run_instances(region: str, cluster_name_on_cloud: str, # Let's create additional new nodes (if neccessary) to_start_count = config.count - len(resume_instances) created_instances = [] + node_config = config.node_config if to_start_count > 0: - node_config = config.node_config compartment = query_helper.find_compartment(region) vcn = query_helper.find_create_vcn_subnet(region) @@ -245,10 +245,12 @@ def run_instances(region: str, cluster_name_on_cloud: str, assert head_instance_id is not None, head_instance_id + # Format: TenancyPrefix:AvailabilityDomain, e.g. bxtG:US-SANJOSE-1-AD-1 + _, ad = str(node_config['AvailabilityDomain']).split(':', maxsplit=1) return common.ProvisionRecord( provider_name='oci', region=region, - zone=None, + zone=ad, cluster_name=cluster_name_on_cloud, head_instance_id=head_instance_id, created_instance_ids=[n['inst_id'] for n in created_instances], diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 94a73d5bc8d..fb9ea380c0d 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -30,6 +30,7 @@ from sky.utils import rich_utils from sky.utils import status_lib from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils # Do not use __name__ as we do not want to propagate logs to sky.provision, @@ -344,6 +345,7 @@ def _wait_ssh_connection_indirect(ip: str, return True, '' +@timeline.event def wait_for_ssh(cluster_info: provision_common.ClusterInfo, ssh_credentials: Dict[str, str]): """Wait until SSH is ready. @@ -433,11 +435,15 @@ def _post_provision_setup( ux_utils.spinner_message( 'Launching - Waiting for SSH access', provision_logging.config.log_path)) as status: - - logger.debug( - f'\nWaiting for SSH to be available for {cluster_name!r} ...') - wait_for_ssh(cluster_info, ssh_credentials) - logger.debug(f'SSH Connection ready for {cluster_name!r}') + # If on Kubernetes, skip SSH check since the pods are guaranteed to be + # ready by the provisioner, and we use kubectl instead of SSH to run the + # commands and rsync on the pods. SSH will still be ready after a while + # for the users to SSH into the pod. + if cloud_name.lower() != 'kubernetes': + logger.debug( + f'\nWaiting for SSH to be available for {cluster_name!r} ...') + wait_for_ssh(cluster_info, ssh_credentials) + logger.debug(f'SSH Connection ready for {cluster_name!r}') vm_str = 'Instance' if cloud_name.lower() != 'kubernetes' else 'Pod' plural = '' if len(cluster_info.instances) == 1 else 's' verb = 'is' if len(cluster_info.instances) == 1 else 'are' @@ -497,31 +503,94 @@ def _post_provision_setup( **ssh_credentials) head_runner = runners[0] - status.update( - runtime_preparation_str.format(step=3, step_name='runtime')) - full_ray_setup = True - ray_port = constants.SKY_REMOTE_RAY_PORT - if not provision_record.is_instance_just_booted( - head_instance.instance_id): + def is_ray_cluster_healthy(ray_status_output: str, + expected_num_nodes: int) -> bool: + """Parse the output of `ray status` to get #active nodes. + + The output of `ray status` looks like: + Node status + --------------------------------------------------------------- + Active: + 1 node_291a8b849439ad6186387c35dc76dc43f9058108f09e8b68108cf9ec + 1 node_0945fbaaa7f0b15a19d2fd3dc48f3a1e2d7c97e4a50ca965f67acbfd + Pending: + (no pending nodes) + Recent failures: + (no failures) + """ + start = ray_status_output.find('Active:') + end = ray_status_output.find('Pending:', start) + if start == -1 or end == -1: + return False + num_active_nodes = 0 + for line in ray_status_output[start:end].split('\n'): + if line.strip() and not line.startswith('Active:'): + num_active_nodes += 1 + return num_active_nodes == expected_num_nodes + + def check_ray_port_and_cluster_healthy() -> Tuple[int, bool, bool]: + head_ray_needs_restart = True + ray_cluster_healthy = False + ray_port = constants.SKY_REMOTE_RAY_PORT + # Check if head node Ray is alive returncode, stdout, _ = head_runner.run( instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND, stream_logs=False, require_outputs=True) - if returncode: - logger.debug('Ray cluster on head is not up. Restarting...') - else: - logger.debug('Ray cluster on head is up.') + if not returncode: ray_port = message_utils.decode_payload(stdout)['ray_port'] - full_ray_setup = bool(returncode) + logger.debug(f'Ray cluster on head is up with port {ray_port}.') + + head_ray_needs_restart = bool(returncode) + # This is a best effort check to see if the ray cluster has expected + # number of nodes connected. + ray_cluster_healthy = (not head_ray_needs_restart and + is_ray_cluster_healthy( + stdout, cluster_info.num_instances)) + return ray_port, ray_cluster_healthy, head_ray_needs_restart + + status.update( + runtime_preparation_str.format(step=3, step_name='runtime')) + + ray_port = constants.SKY_REMOTE_RAY_PORT + head_ray_needs_restart = True + ray_cluster_healthy = False + if (not provision_record.is_instance_just_booted( + head_instance.instance_id)): + # Check if head node Ray is alive + (ray_port, ray_cluster_healthy, + head_ray_needs_restart) = check_ray_port_and_cluster_healthy() + elif cloud_name.lower() == 'kubernetes': + timeout = 90 # 1.5-min maximum timeout + start = time.time() + while True: + # Wait until Ray cluster is ready + (ray_port, ray_cluster_healthy, + head_ray_needs_restart) = check_ray_port_and_cluster_healthy() + if ray_cluster_healthy: + logger.debug('Ray cluster is ready. Skip head and worker ' + 'node ray cluster setup.') + break + if time.time() - start > timeout: + # In most cases, the ray cluster will be ready after a few + # seconds. Trigger ray start on head or worker nodes to be + # safe, if the ray cluster is not ready after timeout. + break + logger.debug('Ray cluster is not ready yet, waiting for the ' + 'async setup to complete...') + time.sleep(1) - if full_ray_setup: + if head_ray_needs_restart: logger.debug('Starting Ray on the entire cluster.') instance_setup.start_ray_on_head_node( cluster_name.name_on_cloud, custom_resource=custom_resource, cluster_info=cluster_info, ssh_credentials=ssh_credentials) + else: + logger.debug('Ray cluster on head is ready. Skip starting ray ' + 'cluster on head node.') # NOTE: We have to check all worker nodes to make sure they are all # healthy, otherwise we can only start Ray on newly started worker @@ -532,10 +601,13 @@ def _post_provision_setup( # if provision_record.is_instance_just_booted(inst.instance_id): # worker_ips.append(inst.public_ip) - if cluster_info.num_instances > 1: + # We don't need to restart ray on worker nodes if the ray cluster is + # already healthy, i.e. the head node has expected number of nodes + # connected to the ray cluster. + if cluster_info.num_instances > 1 and not ray_cluster_healthy: instance_setup.start_ray_on_worker_nodes( cluster_name.name_on_cloud, - no_restart=not full_ray_setup, + no_restart=not head_ray_needs_restart, custom_resource=custom_resource, # Pass the ray_port to worker nodes for backward compatibility # as in some existing clusters the ray_port is not dumped with @@ -544,6 +616,9 @@ def _post_provision_setup( ray_port=ray_port, cluster_info=cluster_info, ssh_credentials=ssh_credentials) + elif ray_cluster_healthy: + logger.debug('Ray cluster is ready. Skip starting ray cluster on ' + 'worker nodes.') instance_setup.start_skylet_on_head_node(cluster_name.name_on_cloud, cluster_info, ssh_credentials) @@ -554,6 +629,7 @@ def _post_provision_setup( return cluster_info +@timeline.event def post_provision_runtime_setup( cloud_name: str, cluster_name: resources_utils.ClusterName, cluster_yaml: str, provision_record: provision_common.ProvisionRecord, diff --git a/sky/resources.py b/sky/resources.py index c2ed2205caf..5f70c6cb603 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -45,7 +45,7 @@ class Resources: """ # If any fields changed, increment the version. For backward compatibility, # modify the __setstate__ method to handle the old version. - _VERSION = 20 + _VERSION = 21 def __init__( self, @@ -1040,6 +1040,7 @@ def get_spot_str(self) -> str: def make_deploy_variables(self, cluster_name: resources_utils.ClusterName, region: clouds.Region, zones: Optional[List[clouds.Zone]], + num_nodes: int, dryrun: bool) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to resource variables. @@ -1061,7 +1062,7 @@ def make_deploy_variables(self, cluster_name: resources_utils.ClusterName, # Cloud specific variables cloud_specific_variables = self.cloud.make_deploy_resources_variables( - self, cluster_name, region, zones, dryrun) + self, cluster_name, region, zones, num_nodes, dryrun) # Docker run options docker_run_options = skypilot_config.get_nested( @@ -1606,6 +1607,27 @@ def __setstate__(self, state): '_cluster_config_overrides', None) if version < 20: + # Pre-0.7.0, we used 'kubernetes' as the default region for + # Kubernetes clusters. With the introduction of support for + # multiple contexts, we now set the region to the context name. + # Since we do not have information on which context the cluster + # was run in, we default it to the current active context. + legacy_region = clouds.Kubernetes().LEGACY_SINGLETON_REGION + original_cloud = state.get('_cloud', None) + original_region = state.get('_region', None) + if (isinstance(original_cloud, clouds.Kubernetes) and + original_region == legacy_region): + current_context = ( + kubernetes_utils.get_current_kube_config_context_name()) + state['_region'] = current_context + # Also update the image_id dict if it contains the old region + if isinstance(state['_image_id'], dict): + if legacy_region in state['_image_id']: + state['_image_id'][current_context] = ( + state['_image_id'][legacy_region]) + del state['_image_id'][legacy_region] + + if version < 21: self._cached_repr = None self.__dict__.update(state) diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index ce3c240626f..99820e63388 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -159,10 +159,7 @@ _sky_version = str(version.parse(sky.__version__)) RAY_STATUS = f'RAY_ADDRESS=127.0.0.1:{SKY_REMOTE_RAY_PORT} {SKY_RAY_CMD} status' -# Install ray and skypilot on the remote cluster if they are not already -# installed. {var} will be replaced with the actual value in -# backend_utils.write_cluster_config. -RAY_SKYPILOT_INSTALLATION_COMMANDS = ( +RAY_INSTALLATION_COMMANDS = ( 'mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app;' # Disable the pip version check to avoid the warning message, which makes # the output hard to read. @@ -202,24 +199,31 @@ # Writes ray path to file if it does not exist or the file is empty. f'[ -s {SKY_RAY_PATH_FILE} ] || ' f'{{ {ACTIVATE_SKY_REMOTE_PYTHON_ENV} && ' - f'which ray > {SKY_RAY_PATH_FILE} || exit 1; }}; ' - # END ray package check and installation + f'which ray > {SKY_RAY_PATH_FILE} || exit 1; }}; ') + +SKYPILOT_WHEEL_INSTALLATION_COMMANDS = ( f'{{ {SKY_PIP_CMD} list | grep "skypilot " && ' '[ "$(cat ~/.sky/wheels/current_sky_wheel_hash)" == "{sky_wheel_hash}" ]; } || ' # pylint: disable=line-too-long f'{{ {SKY_PIP_CMD} uninstall skypilot -y; ' f'{SKY_PIP_CMD} install "$(echo ~/.sky/wheels/{{sky_wheel_hash}}/' f'skypilot-{_sky_version}*.whl)[{{cloud}}, remote]" && ' 'echo "{sky_wheel_hash}" > ~/.sky/wheels/current_sky_wheel_hash || ' - 'exit 1; }; ' - # END SkyPilot package check and installation + 'exit 1; }; ') +# Install ray and skypilot on the remote cluster if they are not already +# installed. {var} will be replaced with the actual value in +# backend_utils.write_cluster_config. +RAY_SKYPILOT_INSTALLATION_COMMANDS = ( + f'{RAY_INSTALLATION_COMMANDS} ' + f'{SKYPILOT_WHEEL_INSTALLATION_COMMANDS} ' # Only patch ray when the ray version is the same as the expected version. # The ray installation above can be skipped due to the existing ray cluster # for backward compatibility. In this case, we should not patch the ray # files. - f'{SKY_PIP_CMD} list | grep "ray " | grep {SKY_REMOTE_RAY_VERSION} 2>&1 > /dev/null ' - f'&& {{ {SKY_PYTHON_CMD} -c "from sky.skylet.ray_patches import patch; patch()" ' - '|| exit 1; };') + f'{SKY_PIP_CMD} list | grep "ray " | ' + f'grep {SKY_REMOTE_RAY_VERSION} 2>&1 > /dev/null && ' + f'{{ {SKY_PYTHON_CMD} -c ' + '"from sky.skylet.ray_patches import patch; patch()" || exit 1; }; ') # The name for the environment variable that stores SkyPilot user hash, which # is mainly used to make sure sky commands runs on a VM launched by SkyPilot diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index b79a66a2327..35e093e8afb 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -34,6 +34,8 @@ logger = sky_logging.init_logger(__name__) +LOG_FILE_START_STREAMING_AT = 'Waiting for task resources on ' + class _ProcessingArgs: """Arguments for processing logs.""" @@ -435,7 +437,7 @@ def tail_logs(job_id: Optional[int], time.sleep(_SKY_LOG_WAITING_GAP_SECONDS) status = job_lib.update_job_status([job_id], silent=True)[0] - start_stream_at = 'Waiting for task resources on ' + start_stream_at = LOG_FILE_START_STREAMING_AT # Explicitly declare the type to avoid mypy warning. lines: Iterable[str] = [] if follow and status in [ diff --git a/sky/skylet/log_lib.pyi b/sky/skylet/log_lib.pyi index 01b08b6444f..89d1628ec11 100644 --- a/sky/skylet/log_lib.pyi +++ b/sky/skylet/log_lib.pyi @@ -13,6 +13,9 @@ from sky.skylet import constants as constants from sky.skylet import job_lib as job_lib from sky.utils import log_utils as log_utils +LOG_FILE_START_STREAMING_AT: str = ... + + class _ProcessingArgs: log_path: str stream_logs: bool diff --git a/sky/templates/aws-ray.yml.j2 b/sky/templates/aws-ray.yml.j2 index 0e6f58cfd28..1cadd550f36 100644 --- a/sky/templates/aws-ray.yml.j2 +++ b/sky/templates/aws-ray.yml.j2 @@ -174,6 +174,8 @@ setup_commands: # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. # Line 'mkdir -p ..': disable host key check # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + # Line 'mkdir -p ~/.ssh ...': adding the key in the ssh config to allow interconnection for nodes in the cluster + # Line 'rm ~/.aws/credentials': explicitly remove the credentials file to be safe. This is to guard against the case where the credential files was uploaded once as `remote_identity` was not set in a previous launch. - mkdir -p ~/.ssh; touch ~/.ssh/config; {%- for initial_setup_command in initial_setup_commands %} {{ initial_setup_command }} @@ -186,7 +188,10 @@ setup_commands: sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; {%- endif %} mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no\n IdentityFile ~/.ssh/sky-cluster-key\n IdentityFile ~/.ssh/id_rsa" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n IdentityFile ~/.ssh/sky-cluster-key\n IdentityFile ~/.ssh/id_rsa\n" >> ~/.ssh/config; - [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); # This is needed for `-o allow_other` option for `goofys`; + [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); + {%- if remote_identity != 'LOCAL_CREDENTIALS' %} + rm ~/.aws/credentials || true; + {%- endif %} # Command to start ray clusters are now placed in `sky.provision.instance_setup`. # We do not need to list it here anymore. diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 30918709d60..5ca0020cbfe 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -75,9 +75,6 @@ available_node_types: {%- if use_spot %} # optionally set priority to use Spot instances priority: Spot - # set a maximum price for spot instances if desired - # billingProfile: - # maxPrice: -1 {%- endif %} cloudInitSetupCommands: |- {%- for cmd in cloud_init_setup_commands %} diff --git a/sky/templates/kubernetes-port-forward-proxy-command.sh b/sky/templates/kubernetes-port-forward-proxy-command.sh index 0407209a77c..f8205c2393c 100644 --- a/sky/templates/kubernetes-port-forward-proxy-command.sh +++ b/sky/templates/kubernetes-port-forward-proxy-command.sh @@ -58,6 +58,11 @@ KUBECTL_ARGS=() if [ -n "$KUBE_CONTEXT" ]; then KUBECTL_ARGS+=("--context=$KUBE_CONTEXT") fi +# If context is not provided, it means we are using incluster auth. In this case, +# we need to set KUBECONFIG to /dev/null to avoid using kubeconfig file. +if [ -z "$KUBE_CONTEXT" ]; then + KUBECTL_ARGS+=("--kubeconfig=/dev/null") +fi if [ -n "$KUBE_NAMESPACE" ]; then KUBECTL_ARGS+=("--namespace=$KUBE_NAMESPACE") fi diff --git a/sky/templates/kubernetes-ray.yml.j2 b/sky/templates/kubernetes-ray.yml.j2 index 99a4b6f601a..71d2d3491f6 100644 --- a/sky/templates/kubernetes-ray.yml.j2 +++ b/sky/templates/kubernetes-ray.yml.j2 @@ -222,7 +222,9 @@ provider: - protocol: TCP port: 22 targetPort: 22 - # Service that maps to the head node of the Ray cluster. + # Service that maps to the head node of the Ray cluster, so that the + # worker nodes can find the head node using + # {{cluster_name_on_cloud}}-head.{{k8s_namespace}}.svc.cluster.local - apiVersion: v1 kind: Service metadata: @@ -235,18 +237,12 @@ provider: # names. name: {{cluster_name_on_cloud}}-head spec: + # Create a headless service so that the head node can be reached by + # the worker nodes with any port number. + clusterIP: None # This selector must match the head node pod's selector below. selector: component: {{cluster_name_on_cloud}}-head - ports: - - name: client - protocol: TCP - port: 10001 - targetPort: 10001 - - name: dashboard - protocol: TCP - port: 8265 - targetPort: 8265 # Specify the pod type for the ray head node (as configured below). head_node_type: ray_head_default @@ -280,7 +276,6 @@ available_node_types: # serviceAccountName: skypilot-service-account serviceAccountName: {{k8s_service_account_name}} automountServiceAccountToken: {{k8s_automount_sa_token}} - restartPolicy: Never # Add node selector if GPU/TPUs are requested: @@ -322,18 +317,158 @@ available_node_types: - name: ray-node imagePullPolicy: IfNotPresent image: {{image_id}} + env: + - name: SKYPILOT_POD_NODE_TYPE + valueFrom: + fieldRef: + fieldPath: metadata.labels['ray-node-type'] + {% for key, value in k8s_env_vars.items() if k8s_env_vars is not none %} + - name: {{ key }} + value: {{ value }} + {% endfor %} # Do not change this command - it keeps the pod alive until it is # explicitly killed. command: ["/bin/bash", "-c", "--"] args: - | + # For backwards compatibility, we put a marker file in the pod + # to indicate that the pod is running with the changes introduced + # in project nimbus: https://github.com/skypilot-org/skypilot/pull/4393 + # TODO: Remove this marker file and it's usage in setup_commands + # after v0.10.0 release. + touch /tmp/skypilot_is_nimbus + # Helper function to conditionally use sudo + # TODO(zhwu): consolidate the two prefix_cmd and sudo replacements prefix_cmd() { if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; } + [ $(id -u) -eq 0 ] && function sudo() { "$@"; } || true; + + STEPS=("apt-ssh-setup" "runtime-setup" "env-setup") - # Run apt update in background and log to a file + # STEP 1: Run apt update, install missing packages, and set up ssh. ( + ( DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get update > /tmp/apt-update.log 2>&1 || \ echo "Warning: apt-get update failed. Continuing anyway..." >> /tmp/apt-update.log + PACKAGES="rsync curl netcat gcc patch pciutils fuse openssh-server"; + + # Separate packages into two groups: packages that are installed first + # so that curl and rsync are available sooner to unblock the following + # conda installation and rsync. + set -e + INSTALL_FIRST=""; + MISSING_PACKAGES=""; + for pkg in $PACKAGES; do + if [ "$pkg" == "netcat" ]; then + if ! dpkg -l | grep -q "^ii \(netcat\|netcat-openbsd\|netcat-traditional\) "; then + INSTALL_FIRST="$INSTALL_FIRST netcat-openbsd"; + fi + elif ! dpkg -l | grep -q "^ii $pkg "; then + if [ "$pkg" == "curl" ] || [ "$pkg" == "rsync" ]; then + INSTALL_FIRST="$INSTALL_FIRST $pkg"; + else + MISSING_PACKAGES="$MISSING_PACKAGES $pkg"; + fi + fi + done; + if [ ! -z "$INSTALL_FIRST" ]; then + echo "Installing core packages: $INSTALL_FIRST"; + DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y $INSTALL_FIRST; + fi; + # SSH and other packages are not necessary, so we disable set -e + set +e + + if [ ! -z "$MISSING_PACKAGES" ]; then + echo "Installing missing packages: $MISSING_PACKAGES"; + DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y $MISSING_PACKAGES; + fi; + $(prefix_cmd) mkdir -p /var/run/sshd; + $(prefix_cmd) sed -i "s/PermitRootLogin prohibit-password/PermitRootLogin yes/" /etc/ssh/sshd_config; + $(prefix_cmd) sed "s@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g" -i /etc/pam.d/sshd; + cd /etc/ssh/ && $(prefix_cmd) ssh-keygen -A; + $(prefix_cmd) mkdir -p ~/.ssh; + $(prefix_cmd) chown -R $(whoami) ~/.ssh; + $(prefix_cmd) chmod 700 ~/.ssh; + $(prefix_cmd) cat /etc/secret-volume/ssh-publickey* > ~/.ssh/authorized_keys; + $(prefix_cmd) chmod 644 ~/.ssh/authorized_keys; + $(prefix_cmd) service ssh restart; + $(prefix_cmd) sed -i "s/mesg n/tty -s \&\& mesg n/" ~/.profile; + ) > /tmp/${STEPS[0]}.log 2>&1 || { + echo "Error: ${STEPS[0]} failed. Continuing anyway..." > /tmp/${STEPS[0]}.failed + cat /tmp/${STEPS[0]}.log + exit 1 + } + ) & + + # STEP 2: Install conda, ray and skypilot (for dependencies); start + # ray cluster. + ( + ( + set -e + mkdir -p ~/.sky + # Wait for `curl` package to be installed before installing conda + # and ray. + until dpkg -l | grep -q "^ii curl "; do + sleep 0.1 + echo "Waiting for curl package to be installed..." + done + {{ conda_installation_commands }} + {{ ray_installation_commands }} + ~/skypilot-runtime/bin/python -m pip install skypilot[kubernetes,remote] + touch /tmp/ray_skypilot_installation_complete + echo "=== Ray and skypilot installation completed ===" + + # Disable set -e, as we have some commands that are ok to fail + # after the ray start. + # TODO(zhwu): this is a hack, we should fix the commands that are + # ok to fail. + if [ "$SKYPILOT_POD_NODE_TYPE" == "head" ]; then + set +e + {{ ray_head_start_command }} + else + # Start ray worker on the worker pod. + # Wait until the head pod is available with an IP address + export SKYPILOT_RAY_HEAD_IP="{{cluster_name_on_cloud}}-head.{{k8s_namespace}}.svc.cluster.local" + export SKYPILOT_RAY_PORT={{skypilot_ray_port}} + # Wait until the ray cluster is started on the head pod + until dpkg -l | grep -q "^ii \(netcat\|netcat-openbsd\|netcat-traditional\) "; do + sleep 0.1 + echo "Waiting for netcat package to be installed..." + done + until nc -z -w 1 ${SKYPILOT_RAY_HEAD_IP} ${SKYPILOT_RAY_PORT}; do + sleep 0.1 + done + + set +e + {{ ray_worker_start_command }} + fi + ) > /tmp/${STEPS[1]}.log 2>&1 || { + echo "Error: ${STEPS[1]} failed. Continuing anyway..." > /tmp/${STEPS[1]}.failed + cat /tmp/${STEPS[1]}.log + exit 1 + } + ) & + + + # STEP 3: Set up environment variables; this should be relatively fast. + ( + ( + set -e + if [ $(id -u) -eq 0 ]; then + echo 'alias sudo=""' >> ~/.bashrc; echo succeed; + else + if command -v sudo >/dev/null 2>&1; then + timeout 2 sudo -l >/dev/null 2>&1 && echo succeed || { echo 52; exit 52; }; + else + { echo 52; exit 52; }; + fi; + fi; + printenv | while IFS='=' read -r key value; do echo "export $key=\"$value\""; done > ~/container_env_var.sh && $(prefix_cmd) mv ~/container_env_var.sh /etc/profile.d/container_env_var.sh + ) > /tmp/${STEPS[2]}.log 2>&1 || { + echo "Error: ${STEPS[2]} failed. Continuing anyway..." > /tmp/${STEPS[2]}.failed + cat /tmp/${STEPS[2]}.log + exit 1 + } ) & function mylsof { p=$(for pid in /proc/{0..9}*; do i=$(basename "$pid"); for file in "$pid"/fd/*; do link=$(readlink -e "$file"); if [ "$link" = "$1" ]; then echo "$i"; fi; done; done); echo "$p"; }; @@ -441,42 +576,50 @@ setup_commands: # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. # Line 'mkdir -p ..': disable host key check # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + # Line 'for step in ..': check if any failure indicator exists for the setup done in pod args and print the error message. This is only a best effort, as the + # commands in pod args are asynchronous and we cannot guarantee the failure indicators are created before the setup commands finish. - | - PACKAGES="gcc patch pciutils rsync fuse curl"; - MISSING_PACKAGES=""; - for pkg in $PACKAGES; do - if ! dpkg -l | grep -q "^ii $pkg "; then - MISSING_PACKAGES="$MISSING_PACKAGES $pkg"; - fi - done; - if [ ! -z "$MISSING_PACKAGES" ]; then - echo "Installing missing packages: $MISSING_PACKAGES"; - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y $MISSING_PACKAGES; - fi; mkdir -p ~/.ssh; touch ~/.ssh/config; {%- for initial_setup_command in initial_setup_commands %} {{ initial_setup_command }} {%- endfor %} - {{ conda_installation_commands }} - {{ ray_skypilot_installation_commands }} + STEPS=("apt-ssh-setup" "runtime-setup" "env-setup") + start_epoch=$(date +%s); + echo "=== Logs for asynchronous ray and skypilot installation ==="; + if [ -f /tmp/skypilot_is_nimbus ]; then + echo "=== Logs for asynchronous ray and skypilot installation ==="; + [ -f /tmp/ray_skypilot_installation_complete ] && cat /tmp/${STEPS[1]}.log || + { tail -f -n +1 /tmp/${STEPS[1]}.log & TAIL_PID=$!; echo "Tail PID: $TAIL_PID"; until [ -f /tmp/ray_skypilot_installation_complete ]; do sleep 0.5; done; kill $TAIL_PID || true; }; + [ -f /tmp/${STEPS[1]}.failed ] && { echo "Error: ${STEPS[1]} failed. Exiting."; exit 1; } || true; + fi + end_epoch=$(date +%s); + echo "=== Ray and skypilot dependencies installation completed in $(($end_epoch - $start_epoch)) secs ==="; + start_epoch=$(date +%s); + {{ skypilot_wheel_installation_commands }} + end_epoch=$(date +%s); + echo "=== Skypilot wheel installation completed in $(($end_epoch - $start_epoch)) secs ==="; + start_epoch=$(date +%s); sudo touch ~/.sudo_as_admin_successful; sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf'; sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no\n IdentityFile ~/.ssh/sky-cluster-key\n IdentityFile ~/.ssh/id_rsa" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n IdentityFile ~/.ssh/sky-cluster-key\n IdentityFile ~/.ssh/id_rsa\n" >> ~/.ssh/config; [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); # This is needed for `-o allow_other` option for `goofys`; - {% if tpu_requested %} - # The /tmp/tpu_logs directory is where TPU-related logs, such as logs from - # the TPU runtime, are written. These capture runtime information about the - # TPU execution, including any warnings, errors, or general activity of - # the TPU driver. By default, the /tmp/tpu_logs directory is created with - # 755 permissions, and the user of the provisioned pod is not necessarily - # a root. Hence, we need to update the write permission so the logs can be - # properly written. - # TODO(Doyoung): Investigate to see why TPU workload fails to run without - # execution permission, such as granting 766 to log file. Check if it's a - # must and see if there's a workaround to grant minimum permission. - - sudo chmod 777 /tmp/tpu_logs; - {% endif %} + end_epoch=$(date +%s); + echo "=== Setup system configs and fuse completed in $(($end_epoch - $start_epoch)) secs ==="; + for step in $STEPS; do [ -f "/tmp/${step}.failed" ] && { echo "Error: /tmp/${step}.failed found:"; cat /tmp/${step}.log; exit 1; } || true; done; + {% if tpu_requested %} + # The /tmp/tpu_logs directory is where TPU-related logs, such as logs from + # the TPU runtime, are written. These capture runtime information about the + # TPU execution, including any warnings, errors, or general activity of + # the TPU driver. By default, the /tmp/tpu_logs directory is created with + # 755 permissions, and the user of the provisioned pod is not necessarily + # a root. Hence, we need to update the write permission so the logs can be + # properly written. + # TODO(Doyoung): Investigate to see why TPU workload fails to run without + # execution permission, such as granting 766 to log file. Check if it's a + # must and see if there's a workaround to grant minimum permission. + sudo chmod 777 /tmp/tpu_logs; + {% endif %} # Format: `REMOTE_PATH : LOCAL_PATH` file_mounts: { diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index 7842eaf2d27..e166c5f2306 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -764,6 +764,10 @@ def run( ] if self.context: kubectl_args += ['--context', self.context] + # If context is none, it means we are using incluster auth. In this + # case, need to set KUBECONFIG to /dev/null to avoid using kubeconfig. + if self.context is None: + kubectl_args += ['--kubeconfig', '/dev/null'] kubectl_args += [self.pod_name] if ssh_mode == SshMode.LOGIN: assert isinstance(cmd, list), 'cmd must be a list for login mode.' diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 472fd65d72f..df6852dec1e 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -24,6 +24,7 @@ from sky.jobs import constants as managed_job_constants from sky.serve import constants as serve_constants from sky.skylet import constants +from sky.skylet import log_lib from sky.utils import common from sky.utils import common_utils from sky.utils import env_options @@ -379,11 +380,19 @@ def download_and_stream_latest_job_log( else: log_dir = list(log_dirs.values())[0] log_file = os.path.join(log_dir, 'run.log') - # Print the logs to the console. + # TODO(zhwu): refactor this into log_utils, along with the + # refactoring for the log_lib.tail_logs. try: with open(log_file, 'r', encoding='utf-8') as f: - print(f.read()) + # Stream the logs to the console without reading the whole + # file into memory. + start_streaming = False + for line in f: + if log_lib.LOG_FILE_START_STREAMING_AT in line: + start_streaming = True + if start_streaming: + print(line, end='', flush=True) except FileNotFoundError: logger.error('Failed to find the logs for the user ' f'program at {log_file}.') diff --git a/sky/utils/kubernetes/generate_kubeconfig.sh b/sky/utils/kubernetes/generate_kubeconfig.sh index 8d363370597..4ed27b62e1e 100755 --- a/sky/utils/kubernetes/generate_kubeconfig.sh +++ b/sky/utils/kubernetes/generate_kubeconfig.sh @@ -12,6 +12,7 @@ # * Specify SKYPILOT_NAMESPACE env var to override the default namespace where the service account is created. # * Specify SKYPILOT_SA_NAME env var to override the default service account name. # * Specify SKIP_SA_CREATION=1 to skip creating the service account and use an existing one +# * Specify SUPER_USER=1 to create a service account with cluster-admin permissions # # Usage: # # Create "sky-sa" service account with minimal permissions in "default" namespace and generate kubeconfig @@ -22,6 +23,9 @@ # # # Use an existing service account "my-sa" in "my-namespace" namespace and generate kubeconfig # $ SKIP_SA_CREATION=1 SKYPILOT_SA_NAME=my-sa SKYPILOT_NAMESPACE=my-namespace ./generate_kubeconfig.sh +# +# # Create "sky-sa" service account with cluster-admin permissions in "default" namespace +# $ SUPER_USER=1 ./generate_kubeconfig.sh set -eu -o pipefail @@ -29,9 +33,11 @@ set -eu -o pipefail # use default. SKYPILOT_SA=${SKYPILOT_SA_NAME:-sky-sa} NAMESPACE=${SKYPILOT_NAMESPACE:-default} +SUPER_USER=${SUPER_USER:-0} echo "Service account: ${SKYPILOT_SA}" echo "Namespace: ${NAMESPACE}" +echo "Super user permissions: ${SUPER_USER}" # Set OS specific values. if [[ "$OSTYPE" == "linux-gnu" ]]; then @@ -47,8 +53,43 @@ fi # If the user has set SKIP_SA_CREATION=1, skip creating the service account. if [ -z ${SKIP_SA_CREATION+x} ]; then - echo "Creating the Kubernetes Service Account with minimal RBAC permissions." - kubectl apply -f - <&2 context_lower=$(echo "$context" | tr '[:upper:]' '[:lower:]') shift if [ -z "$context" ] || [ "$context_lower" = "none" ]; then - kubectl exec -i $pod -n $namespace -- "$@" + # If context is none, it means we are using incluster auth. In this case, + # use need to set KUBECONFIG to /dev/null to avoid using kubeconfig file. + kubectl exec -i $pod -n $namespace --kubeconfig=/dev/null -- "$@" else kubectl exec -i $pod -n $namespace --context=$context -- "$@" fi diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 1bba5ddfdfe..7d716134272 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -666,6 +666,7 @@ class RemoteIdentityOptions(enum.Enum): """ LOCAL_CREDENTIALS = 'LOCAL_CREDENTIALS' SERVICE_ACCOUNT = 'SERVICE_ACCOUNT' + NO_UPLOAD = 'NO_UPLOAD' def get_default_remote_identity(cloud: str) -> str: @@ -686,7 +687,14 @@ def get_default_remote_identity(cloud: str) -> str: _REMOTE_IDENTITY_SCHEMA_KUBERNETES = { 'remote_identity': { - 'type': 'string' + 'anyOf': [{ + 'type': 'string' + }, { + 'type': 'object', + 'additionalProperties': { + 'type': 'string' + } + }] }, } diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 7ef8a6a8534..19903467edf 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -2,9 +2,10 @@ from multiprocessing import pool import os import random +import resource import subprocess import time -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import colorama import psutil @@ -18,6 +19,8 @@ logger = sky_logging.init_logger(__name__) +_fd_limit_warning_shown = False + @timeline.event def run(cmd, **kwargs): @@ -43,12 +46,54 @@ def run_no_outputs(cmd, **kwargs): **kwargs) -def get_parallel_threads() -> int: - """Returns the number of idle CPUs.""" +def _get_thread_multiplier(cloud_str: Optional[str] = None) -> int: + # If using Kubernetes, we use 4x the number of cores. + if cloud_str and cloud_str.lower() == 'kubernetes': + return 4 + return 1 + + +def get_max_workers_for_file_mounts(common_file_mounts: Dict[str, str], + cloud_str: Optional[str] = None) -> int: + global _fd_limit_warning_shown + fd_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) + + # Raise warning for low fd_limit (only once) + if fd_limit < 1024 and not _fd_limit_warning_shown: + logger.warning( + f'Open file descriptor limit ({fd_limit}) is low. File sync to ' + 'remote clusters may be slow. Consider increasing the limit using ' + '`ulimit -n ` or modifying system limits.') + _fd_limit_warning_shown = True + + fd_per_rsync = 5 + for src in common_file_mounts.values(): + if os.path.isdir(src): + # Assume that each file/folder under src takes 5 file descriptors + # on average. + fd_per_rsync = max(fd_per_rsync, len(os.listdir(src)) * 5) + + # Reserve some file descriptors for the system and other processes + fd_reserve = 100 + + max_workers = (fd_limit - fd_reserve) // fd_per_rsync + # At least 1 worker, and avoid too many workers overloading the system. + num_threads = get_parallel_threads(cloud_str) + max_workers = min(max(max_workers, 1), num_threads) + logger.debug(f'Using {max_workers} workers for file mounts.') + return max_workers + + +def get_parallel_threads(cloud_str: Optional[str] = None) -> int: + """Returns the number of threads to use for parallel execution. + + Args: + cloud_str: The cloud + """ cpu_count = os.cpu_count() if cpu_count is None: cpu_count = 1 - return max(4, cpu_count - 1) + return max(4, cpu_count - 1) * _get_thread_multiplier(cloud_str) def run_in_parallel(func: Callable, diff --git a/sky/utils/timeline.py b/sky/utils/timeline.py index f7244bd9ab2..4db9bd149b2 100644 --- a/sky/utils/timeline.py +++ b/sky/utils/timeline.py @@ -79,11 +79,9 @@ def event(name_or_fn: Union[str, Callable], message: Optional[str] = None): class FileLockEvent: """Serve both as a file lock and event for the lock.""" - def __init__(self, lockfile: Union[str, os.PathLike]): + def __init__(self, lockfile: Union[str, os.PathLike], timeout: float = -1): self._lockfile = lockfile - # TODO(mraheja): remove pylint disabling when filelock version updated - # pylint: disable=abstract-class-instantiated - self._lock = filelock.FileLock(self._lockfile) + self._lock = filelock.FileLock(self._lockfile, timeout) self._hold_lock_event = Event(f'[FileLock.hold]:{self._lockfile}') def acquire(self): diff --git a/tests/kubernetes/README.md b/tests/kubernetes/README.md index 7c5ed7586ff..e15f593e006 100644 --- a/tests/kubernetes/README.md +++ b/tests/kubernetes/README.md @@ -1,10 +1,10 @@ # SkyPilot Kubernetes Development Scripts -This directory contains useful scripts and notes for developing SkyPilot on Kubernetes. +This directory contains useful scripts and notes for developing SkyPilot on Kubernetes. ## Building and pushing SkyPilot image -We maintain a container image that has all basic SkyPilot dependencies installed. +We maintain a container image that has all basic SkyPilot dependencies installed. This image is hosted at `us-central1-docker.pkg.dev/skypilot-375900/skypilotk8s/skypilot:latest`. To build this image locally and optionally push to the SkyPilot registry, run: @@ -18,10 +18,10 @@ To build this image locally and optionally push to the SkyPilot registry, run: ``` ## Running a local development cluster -We use (kind)[https://kind.sigs.k8s.io/] to run a local Kubernetes cluster +We use (kind)[https://kind.sigs.k8s.io/] to run a local Kubernetes cluster for development. To create a local development cluster, run: -```bash +```bash sky local up ``` @@ -50,7 +50,13 @@ curl --header "Content-Type: application/json-patch+json" \ ```bash PROJECT_ID=$(gcloud config get-value project) CLUSTER_NAME=testclusterromil - gcloud beta container --project "${PROJECT_ID}" clusters create "${CLUSTER_NAME}" --zone "us-central1-c" --no-enable-basic-auth --cluster-version "1.29.1-gke.1589020" --release-channel "regular" --machine-type "n1-standard-8" --accelerator "type=nvidia-tesla-t4,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --logging=SYSTEM,WORKLOAD --monitoring=SYSTEM --enable-ip-alias --network "projects/${PROJECT_ID}/global/networks/default" --subnetwork "projects/${PROJECT_ID}/regions/us-central1/subnetworks/default" --no-enable-intra-node-visibility --default-max-pods-per-node "110" --security-posture=standard --workload-vulnerability-scanning=disabled --no-enable-master-authorized-networks --addons HorizontalPodAutoscaling,HttpLoadBalancing,GcePersistentDiskCsiDriver --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --enable-managed-prometheus --enable-shielded-nodes --node-locations "us-central1-c" && gcloud beta container --project "${PROJECT_ID}" node-pools create "v100" --cluster "${CLUSTER_NAME}" --zone "us-central1-c" --machine-type "n1-standard-8" --accelerator "type=nvidia-tesla-v100,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "us-central1-c" && gcloud beta container --project "${PROJECT_ID}" node-pools create "largecpu" --cluster "${CLUSTER_NAME}" --zone "us-central1-c" --machine-type "n1-standard-16" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "us-central1-c" && gcloud beta container --project "${PROJECT_ID}" node-pools create "l4" --cluster "${CLUSTER_NAME}" --zone "us-central1-c" --machine-type "g2-standard-4" --accelerator "type=nvidia-l4,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "us-central1-c" + REGION=us-central1-c + GKE_VERSION=$(gcloud container get-server-config \ + --region=${REGION} \ + --flatten=channels \ + --filter="channels.channel=REGULAR" \ + --format="value(channels.defaultVersion)") + gcloud beta container --project "${PROJECT_ID}" clusters create "${CLUSTER_NAME}" --zone "${REGION}" --no-enable-basic-auth --cluster-version "${GKE_VERSION}" --release-channel "regular" --machine-type "n1-standard-8" --accelerator "type=nvidia-tesla-t4,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --logging=SYSTEM,WORKLOAD --monitoring=SYSTEM --enable-ip-alias --network "projects/${PROJECT_ID}/global/networks/default" --subnetwork "projects/${PROJECT_ID}/regions/${REGION%-*}/subnetworks/default" --no-enable-intra-node-visibility --default-max-pods-per-node "110" --security-posture=standard --workload-vulnerability-scanning=disabled --no-enable-master-authorized-networks --addons HorizontalPodAutoscaling,HttpLoadBalancing,GcePersistentDiskCsiDriver --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --enable-managed-prometheus --enable-shielded-nodes --node-locations "${REGION}" && gcloud beta container --project "${PROJECT_ID}" node-pools create "v100" --cluster "${CLUSTER_NAME}" --zone "${REGION}" --machine-type "n1-standard-8" --accelerator "type=nvidia-tesla-v100,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "${REGION}" && gcloud beta container --project "${PROJECT_ID}" node-pools create "largecpu" --cluster "${CLUSTER_NAME}" --zone "${REGION}" --machine-type "n1-standard-16" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "${REGION}" && gcloud beta container --project "${PROJECT_ID}" node-pools create "l4" --cluster "${CLUSTER_NAME}" --zone "${REGION}" --machine-type "g2-standard-4" --accelerator "type=nvidia-l4,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "${REGION}" ``` 2. Get the kubeconfig for your cluster and place it in `~/.kube/config`: ```bash @@ -65,7 +71,7 @@ curl --header "Content-Type: application/json-patch+json" \ kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/cos/daemonset-preloaded.yaml kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/cos/daemonset-preloaded-latest.yaml - + # If using Ubuntu based nodes: kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/ubuntu/daemonset-preloaded.yaml @@ -123,6 +129,6 @@ NOTE - If are using nodeport networking, make sure port 32100 is open in your no NOTE - If are using nodeport networking, make sure port 32100 is open in your EKS cluster's default security group. ## Other useful scripts -`scripts` directory contains other useful scripts for development, including -Kubernetes dashboard, ray yaml for testing the SkyPilot Kubernetes node provider +`scripts` directory contains other useful scripts for development, including +Kubernetes dashboard, ray yaml for testing the SkyPilot Kubernetes node provider and more. diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 727fb704a18..73d4963d1f2 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -25,6 +25,7 @@ # Change cloud for generic tests to aws # > pytest tests/test_smoke.py --generic-cloud aws +import enum import inspect import json import os @@ -49,7 +50,6 @@ from sky import jobs from sky import serve from sky import skypilot_config -from sky.adaptors import azure from sky.adaptors import cloudflare from sky.adaptors import ibm from sky.clouds import AWS @@ -95,6 +95,166 @@ 'sleep 10; s=$(sky jobs queue);' 'echo "Waiting for job to stop RUNNING"; echo "$s"; done') +# Cluster functions +_ALL_JOB_STATUSES = "|".join([status.value for status in sky.JobStatus]) +_ALL_CLUSTER_STATUSES = "|".join([status.value for status in sky.ClusterStatus]) +_ALL_MANAGED_JOB_STATUSES = "|".join( + [status.value for status in sky.ManagedJobStatus]) + + +def _statuses_to_str(statuses: List[enum.Enum]): + """Convert a list of enums to a string with all the values separated by |.""" + assert len(statuses) > 0, 'statuses must not be empty' + if len(statuses) > 1: + return '(' + '|'.join([status.value for status in statuses]) + ')' + else: + return statuses[0].value + + +_WAIT_UNTIL_CLUSTER_STATUS_CONTAINS = ( + # A while loop to wait until the cluster status + # becomes certain status, with timeout. + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for cluster status \'{cluster_status}\'"; exit 1; ' + 'fi; ' + 'current_status=$(sky status {cluster_name} --refresh | ' + 'awk "/^{cluster_name}/ ' + '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_CLUSTER_STATUSES + + ')$/) print \$i}}"); ' + 'if [[ "$current_status" =~ {cluster_status} ]]; ' + 'then echo "Target cluster status {cluster_status} reached."; break; fi; ' + 'echo "Waiting for cluster status to become {cluster_status}, current status: $current_status"; ' + 'sleep 10; ' + 'done') + + +def _get_cmd_wait_until_cluster_status_contains( + cluster_name: str, cluster_status: List[sky.ClusterStatus], + timeout: int): + return _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.format( + cluster_name=cluster_name, + cluster_status=_statuses_to_str(cluster_status), + timeout=timeout) + + +def _get_cmd_wait_until_cluster_status_contains_wildcard( + cluster_name_wildcard: str, cluster_status: List[sky.ClusterStatus], + timeout: int): + wait_cmd = _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.replace( + 'sky status {cluster_name}', + 'sky status "{cluster_name}"').replace('awk "/^{cluster_name}/', + 'awk "/^{cluster_name_awk}/') + return wait_cmd.format(cluster_name=cluster_name_wildcard, + cluster_name_awk=cluster_name_wildcard.replace( + '*', '.*'), + cluster_status=_statuses_to_str(cluster_status), + timeout=timeout) + + +_WAIT_UNTIL_CLUSTER_IS_NOT_FOUND = ( + # A while loop to wait until the cluster is not found or timeout + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for cluster to be removed"; exit 1; ' + 'fi; ' + 'if sky status -r {cluster_name}; sky status {cluster_name} | grep "{cluster_name} not found"; then ' + ' echo "Cluster {cluster_name} successfully removed."; break; ' + 'fi; ' + 'echo "Waiting for cluster {cluster_name} to be removed..."; ' + 'sleep 10; ' + 'done') + + +def _get_cmd_wait_until_cluster_is_not_found(cluster_name: str, timeout: int): + return _WAIT_UNTIL_CLUSTER_IS_NOT_FOUND.format(cluster_name=cluster_name, + timeout=timeout) + + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID = ( + # A while loop to wait until the job status + # contains certain status, with timeout. + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for job status \'{job_status}\'"; exit 1; ' + 'fi; ' + 'current_status=$(sky queue {cluster_name} | ' + 'awk "\\$1 == \\"{job_id}\\" ' + '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_JOB_STATUSES + + ')$/) print \$i}}"); ' + 'found=0; ' # Initialize found variable outside the loop + 'while read -r line; do ' # Read line by line + ' if [[ "$line" =~ {job_status} ]]; then ' # Check each line + ' echo "Target job status {job_status} reached."; ' + ' found=1; ' + ' break; ' # Break inner loop + ' fi; ' + 'done <<< "$current_status"; ' + 'if [ "$found" -eq 1 ]; then break; fi; ' # Break outer loop if match found + 'echo "Waiting for job status to contains {job_status}, current status: $current_status"; ' + 'sleep 10; ' + 'done') + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace( + 'awk "\\$1 == \\"{job_id}\\"', 'awk "') + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace( + 'awk "\\$1 == \\"{job_id}\\"', 'awk "\\$2 == \\"{job_name}\\"') + + +def _get_cmd_wait_until_job_status_contains_matching_job_id( + cluster_name: str, job_id: str, job_status: List[sky.JobStatus], + timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.format( + cluster_name=cluster_name, + job_id=job_id, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +def _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name: str, job_status: List[sky.JobStatus], timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB.format( + cluster_name=cluster_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +def _get_cmd_wait_until_job_status_contains_matching_job_name( + cluster_name: str, job_name: str, job_status: List[sky.JobStatus], + timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format( + cluster_name=cluster_name, + job_name=job_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +# Managed job functions + +_WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.replace( + 'sky queue {cluster_name}', 'sky jobs queue').replace( + 'awk "\\$2 == \\"{job_name}\\"', + 'awk "\\$2 == \\"{job_name}\\" || \\$3 == \\"{job_name}\\"').replace( + _ALL_JOB_STATUSES, _ALL_MANAGED_JOB_STATUSES) + + +def _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name: str, job_status: List[sky.JobStatus], timeout: int): + return _WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format( + job_name=job_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +# After the timeout, the cluster will stop if autostop is set, and our check +# should be more than the timeout. To address this, we extend the timeout by +# _BUMP_UP_SECONDS before exiting. +_BUMP_UP_SECONDS = 35 + DEFAULT_CMD_TIMEOUT = 15 * 60 @@ -401,7 +561,6 @@ def test_launch_fast_with_autostop(generic_cloud: str): # Azure takes ~ 7m15s (435s) to autostop a VM, so here we use 600 to ensure # the VM is stopped. autostop_timeout = 600 if generic_cloud == 'azure' else 250 - test = Test( 'test_launch_fast_with_autostop', [ @@ -409,11 +568,15 @@ def test_launch_fast_with_autostop(generic_cloud: str): f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', f'sky logs {name} 1 --status', f'sky status -r {name} | grep UP', - f'sleep {autostop_timeout}', # Ensure cluster is stopped - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', - + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=autostop_timeout), + # Even the cluster is stopped, cloud platform may take a while to + # delete the VM. + f'sleep {_BUMP_UP_SECONDS}', # Launch again. Do full output validation - we expect the cluster to re-launch f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', f'sky logs {name} 2 --status', @@ -479,10 +642,19 @@ def test_aws_with_ssh_proxy_command(): f'sky jobs launch -n {name}-0 --cloud aws --cpus 2 --use-spot -y echo hi', # Wait other tests to create the job controller first, so that # the job controller is not launched with proxy command. - 'timeout 300s bash -c "until sky status sky-jobs-controller* | grep UP; do sleep 1; done"', + _get_cmd_wait_until_cluster_status_contains_wildcard( + cluster_name_wildcard='sky-jobs-controller-*', + cluster_status=[sky.ClusterStatus.UP], + timeout=300), f'export SKYPILOT_CONFIG={f.name}; sky jobs launch -n {name} --cpus 2 --cloud aws --region us-east-1 -yd echo hi', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name} | grep "STARTING\|RUNNING\|SUCCEEDED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + sky.ManagedJobStatus.SUCCEEDED, + sky.ManagedJobStatus.RUNNING, + sky.ManagedJobStatus.STARTING + ], + timeout=300), ], f'sky down -y {name} jump-{name}; sky jobs cancel -y -n {name}', # Make sure this test runs on local API server. @@ -854,6 +1026,12 @@ def test_clone_disk_aws(): f'sky launch -y -c {name} --cloud aws --region us-east-2 --retry-until-up "echo hello > ~/user_file.txt"', f'sky launch --clone-disk-from {name} -y -c {name}-clone && exit 1 || true', f'sky stop {name} -y', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=60), + # Wait for EC2 instance to be in stopped state. + # TODO: event based wait. 'sleep 60', f'sky launch --clone-disk-from {name} -y -c {name}-clone --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"', f'sky launch --clone-disk-from {name} -y -c {name}-clone-2 --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"', @@ -900,8 +1078,8 @@ def test_gcp_mig(): # Check MIG exists. f'gcloud compute instance-groups managed list --format="value(name)" | grep "^sky-mig-{name}"', f'sky autostop -i 0 --down -y {name}', - 'sleep 120', - f'sky status -r {name}; sky status {name} | grep "No existing clusters."', + _get_cmd_wait_until_cluster_is_not_found(cluster_name=name, + timeout=120), f'gcloud compute instance-templates list | grep "sky-it-{name}"', # Launch again with the same region. The original instance template # should be removed. @@ -968,8 +1146,10 @@ def test_custom_default_conda_env(generic_cloud: str): f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml', f'sky logs {name} 2 --status', f'sky autostop -y -i 0 {name}', - 'sleep 60', - f'sky status -r {name} | grep "STOPPED"', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=80), f'sky start -y {name}', f'sky logs {name} 2 --no-follow | grep -E "myenv\\s+\\*"', f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml', @@ -990,7 +1170,10 @@ def test_stale_job(generic_cloud: str): f'sky launch -y -c {name} --cloud {generic_cloud} "echo hi"', f'sky exec {name} -d "echo start; sleep 10000"', f'sky stop {name} -y', - 'sleep 100', # Ensure this is large enough, else GCP leaks. + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=100), f'sky start {name} -y', f'sky logs {name} 1 --status', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', @@ -1018,13 +1201,18 @@ def test_aws_stale_job_manual_restart(): '--output text`; ' f'aws ec2 stop-instances --region {region} ' '--instance-ids $id', - 'sleep 40', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=40), f'sky launch -c {name} -y "echo hi"', f'sky logs {name} 1 --status', f'sky logs {name} 3 --status', # Ensure the skylet updated the stale job status. - f'sleep {events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS}', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', + _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name=name, + job_status=[sky.JobStatus.FAILED_DRIVER], + timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS), ], f'sky down -y {name}', ) @@ -1054,8 +1242,10 @@ def test_gcp_stale_job_manual_restart(): f'sky logs {name} 1 --status', f'sky logs {name} 3 --status', # Ensure the skylet updated the stale job status. - f'sleep {events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS}', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', + _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name=name, + job_status=[sky.JobStatus.FAILED_DRIVER], + timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS) ], f'sky down -y {name}', ) @@ -1073,6 +1263,10 @@ def test_env_check(generic_cloud: str): [ f'sky launch -y -c {name} --cloud {generic_cloud} --detach-setup examples/env_check.yaml', f'sky logs {name} 1 --status', # Ensure the job succeeded. + # Test --detach-setup with only setup. + f'sky launch -y -c {name} --detach-setup tests/test_yamls/test_only_setup.yaml', + f'sky logs {name} 2 --status', + f'sky logs {name} 2 | grep "hello world"', ], f'sky down -y {name}', timeout=total_timeout_minutes * 60, @@ -1732,6 +1926,7 @@ def test_large_job_queue(generic_cloud: str): f'for i in `seq 1 75`; do sky exec {name} -n {name}-$i -d "echo $i; sleep 100000000"; done', f'sky cancel -y {name} 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16', 'sleep 90', + # Each job takes 0.5 CPU and the default VM has 8 CPUs, so there should be 8 / 0.5 = 16 jobs running. # The first 16 jobs are canceled, so there should be 75 - 32 = 43 jobs PENDING. f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep -v grep | grep PENDING | wc -l | grep 43', @@ -1873,7 +2068,13 @@ def test_multi_echo(generic_cloud: str): f'until sky logs {name} 32 --status; do echo "Waiting for job 32 to finish..."; sleep 1; done', ] + # Ensure jobs succeeded. - [f'sky logs {name} {i + 1} --status' for i in range(32)] + + [ + _get_cmd_wait_until_job_status_contains_matching_job_id( + cluster_name=name, + job_id=i + 1, + job_status=[sky.JobStatus.SUCCEEDED], + timeout=120) for i in range(32) + ] + # Ensure monitor/autoscaler didn't crash on the 'assert not # unfulfilled' error. If process not found, grep->ssh returns 1. [f'ssh {name} \'ps aux | grep "[/]"monitor.py\''], @@ -2445,12 +2646,19 @@ def test_gcp_start_stop(): f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"', # Ensure the raylet process has the correct file descriptor limit. f'sky logs {name} 3 --status', # Ensure the job succeeded. f'sky stop -y {name}', - f'sleep 20', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=40), f'sky start -y {name} -i 1', f'sky exec {name} examples/gcp_start_stop.yaml', f'sky logs {name} 4 --status', # Ensure the job succeeded. - 'sleep 180', - f'sky status -r {name} | grep "INIT\|STOPPED"', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ + sky.ClusterStatus.STOPPED, sky.ClusterStatus.INIT + ], + timeout=200), ], f'sky down -y {name}', ) @@ -2473,9 +2681,13 @@ def test_azure_start_stop(): f'sky start -y {name} -i 1', f'sky exec {name} examples/azure_start_stop.yaml', f'sky logs {name} 3 --status', # Ensure the job succeeded. - 'sleep 260', - f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' - f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ + sky.ClusterStatus.STOPPED, sky.ClusterStatus.INIT + ], + timeout=280) + + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}', ], f'sky down -y {name}', timeout=30 * 60, # 30 mins @@ -2511,8 +2723,10 @@ def test_autostop(generic_cloud: str): f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is STOPPED. - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=autostop_timeout), # Ensure the cluster is UP and the autostop setting is reset ('-'). f'sky start -y {name}', @@ -2528,8 +2742,10 @@ def test_autostop(generic_cloud: str): f'sky autostop -y {name} -i 1', # Should restart the timer. 'sleep 40', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=autostop_timeout), # Test restarting the idleness timer via exec: f'sky start -y {name}', @@ -2538,9 +2754,10 @@ def test_autostop(generic_cloud: str): 'sleep 45', # Almost reached the threshold. f'sky exec {name} echo hi', # Should restart the timer. 'sleep 45', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=autostop_timeout + _BUMP_UP_SECONDS), ], f'sky down -y {name}', timeout=total_timeout_minutes * 60, @@ -2759,15 +2976,19 @@ def test_stop_gcp_spot(): f'sky exec {name} -- ls myfile', f'sky logs {name} 2 --status', f'sky autostop {name} -i0 -y', - 'sleep 90', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=90), f'sky start {name} -y', f'sky exec {name} -- ls myfile', f'sky logs {name} 3 --status', # -i option at launch should go through: f'sky launch -c {name} -i0 -y', - 'sleep 120', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=120), ], f'sky down -y {name}', ) @@ -2787,14 +3008,27 @@ def test_managed_jobs(generic_cloud: str): [ f'sky jobs launch -n {name}-1 --cloud {generic_cloud} examples/managed_job.yaml -y -d', f'sky jobs launch -n {name}-2 --cloud {generic_cloud} examples/managed_job.yaml -y -d', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep "PENDING\|SUBMITTED\|STARTING\|RUNNING"', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "PENDING\|SUBMITTED\|STARTING\|RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-1', + job_status=[ + sky.ManagedJobStatus.PENDING, + sky.ManagedJobStatus.SUBMITTED, + sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING + ], + timeout=60), + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[ + sky.ManagedJobStatus.PENDING, + sky.ManagedJobStatus.SUBMITTED, + sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING + ], + timeout=60), f'sky jobs cancel -y -n {name}-1', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep CANCELLED', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-1', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=230), # Test the functionality for logging. f's=$(sky jobs logs -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "start counting"', f's=$(sky jobs logs --controller -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "Cluster launched:"', @@ -2865,9 +3099,11 @@ def test_managed_jobs_failed_setup(generic_cloud: str): 'managed_jobs_failed_setup', [ f'sky jobs launch -n {name} --cloud {generic_cloud} -y -d tests/test_yamls/failed_setup.yaml', - 'sleep 330', # Make sure the job failed quickly. - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.FAILED_SETUP], + timeout=330 + _BUMP_UP_SECONDS), ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -2890,7 +3126,10 @@ def test_managed_jobs_pipeline_failed_setup(generic_cloud: str): 'managed_jobs_pipeline_failed_setup', [ f'sky jobs launch -n {name} -y -d tests/test_yamls/failed_setup_pipeline.yaml', - 'sleep 600', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.FAILED_SETUP], + timeout=600), # Make sure the job failed quickly. f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"', # Task 0 should be SUCCEEDED. @@ -2924,8 +3163,10 @@ def test_managed_jobs_recovery_aws(aws_config_region): 'managed_jobs_recovery_aws', [ f'sky jobs launch --cloud aws --region {region} --use-spot -n {name} "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=600), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' @@ -2935,8 +3176,10 @@ def test_managed_jobs_recovery_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -2964,15 +3207,19 @@ def test_managed_jobs_recovery_gcp(): 'managed_jobs_recovery_gcp', [ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --cpus 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -2995,8 +3242,10 @@ def test_managed_jobs_pipeline_recovery_aws(aws_config_region): 'managed_jobs_pipeline_recovery_aws', [ f'sky jobs launch -n {name} tests/test_yamls/pipeline_aws.yaml -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids', # Terminate the cluster manually. @@ -3015,8 +3264,10 @@ def test_managed_jobs_pipeline_recovery_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new', f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new', @@ -3046,8 +3297,10 @@ def test_managed_jobs_pipeline_recovery_gcp(): 'managed_jobs_pipeline_recovery_gcp', [ f'sky jobs launch -n {name} tests/test_yamls/pipeline_gcp.yaml -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids', # Terminate the cluster manually. @@ -3058,8 +3311,10 @@ def test_managed_jobs_pipeline_recovery_gcp(): f'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`; {terminate_cmd}'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new', f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new', @@ -3085,8 +3340,13 @@ def test_managed_jobs_recovery_default_resources(generic_cloud: str): 'managed-spot-recovery-default-resources', [ f'sky jobs launch -n {name} --cloud {generic_cloud} --use-spot "sleep 30 && sudo shutdown now && sleep 1000" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING\|RECOVERING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + sky.ManagedJobStatus.RUNNING, + sky.ManagedJobStatus.RECOVERING + ], + timeout=360), ], f'sky jobs cancel -y -n {name}', timeout=25 * 60, @@ -3106,8 +3366,10 @@ def test_managed_jobs_recovery_multi_node_aws(aws_config_region): 'managed_jobs_recovery_multi_node_aws', [ f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 450', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=450), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' @@ -3118,8 +3380,10 @@ def test_managed_jobs_recovery_multi_node_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 560', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=560), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -3147,15 +3411,19 @@ def test_managed_jobs_recovery_multi_node_gcp(): 'managed_jobs_recovery_multi_node_gcp', [ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 420', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=560), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -3180,13 +3448,17 @@ def test_managed_jobs_cancellation_aws(aws_config_region): [ # Test cancellation during spot cluster being launched. f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot "sleep 1000" -y -d', - 'sleep 60', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING\|RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING + ], + timeout=60 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), (f's=$(aws ec2 describe-instances --region {region} ' f'--filters "Name=tag:ray-cluster-name,Values={name_on_cloud}-*" ' '--query "Reservations[].Instances[].State[].Name" ' @@ -3194,12 +3466,16 @@ def test_managed_jobs_cancellation_aws(aws_config_region): ), # Test cancelling the spot cluster during spot job being setup. f'sky jobs launch --cloud aws --region {region} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d', - 'sleep 300', + # The job is set up in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}-2', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), (f's=$(aws ec2 describe-instances --region {region} ' f'--filters "Name=tag:ray-cluster-name,Values={name_2_on_cloud}-*" ' '--query "Reservations[].Instances[].State[].Name" ' @@ -3207,8 +3483,11 @@ def test_managed_jobs_cancellation_aws(aws_config_region): ), # Test cancellation during spot job is recovering. f'sky jobs launch --cloud aws --region {region} -n {name}-3 --use-spot "sleep 1000" -y -d', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RUNNING"', + # The job is running in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), # Terminate the cluster manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' f'aws ec2 describe-instances --region {region} ' @@ -3218,10 +3497,10 @@ def test_managed_jobs_cancellation_aws(aws_config_region): _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'), f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"', f'sky jobs cancel -y -n {name}-3', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # The cluster should be terminated (shutting-down) after cancellation. We don't use the `=` operator here because # there can be multiple VM with the same name due to the recovery. (f's=$(aws ec2 describe-instances --region {region} ' @@ -3256,34 +3535,42 @@ def test_managed_jobs_cancellation_gcp(): [ # Test cancellation during spot cluster being launched. f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot "sleep 1000" -y -d', - 'sleep 60', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.STARTING], + timeout=60 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # Test cancelling the spot cluster during spot job being setup. f'sky jobs launch --cloud gcp --zone {zone} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d', - 'sleep 300', + # The job is set up in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}-2', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # Test cancellation during spot job is recovering. f'sky jobs launch --cloud gcp --zone {zone} -n {name}-3 --use-spot "sleep 1000" -y -d', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), # Terminate the cluster manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'), f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"', f'sky jobs cancel -y -n {name}-3', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # The cluster should be terminated (STOPPING) after cancellation. We don't use the `=` operator here because # there can be multiple VM with the same name due to the recovery. (f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "PROVISIONING|STAGING|RUNNING|REPAIRING|TERMINATED|SUSPENDING|SUSPENDED|SUSPENDED"' @@ -3373,8 +3660,12 @@ def test_managed_jobs_storage(generic_cloud: str): *STORAGE_SETUP_COMMANDS, f'sky jobs launch -n {name}{use_spot} --cloud {generic_cloud}{region_flag} {file_path} -y', region_validation_cmd, # Check if the bucket is created in the correct region - 'sleep 60', # Wait the spot queue to be updated - f'{_GET_JOB_QUEUE} | grep {name} | grep SUCCEEDED', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.SUCCEEDED], + timeout=60 + _BUMP_UP_SECONDS), + # Wait for the job to be cleaned up. + 'sleep 20', f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{storage_name}\')].Name" --output text | wc -l) -eq 0 ]', # Check if file was written to the mounted output bucket output_check_cmd @@ -3398,10 +3689,17 @@ def test_managed_jobs_tpu(): 'test-spot-tpu', [ f'sky jobs launch -n {name} --use-spot examples/tpu/tpuvm_mnist.yaml -y -d', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep STARTING', - 'sleep 900', # TPU takes a while to launch - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING\|SUCCEEDED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.STARTING], + timeout=60 + _BUMP_UP_SECONDS), + # TPU takes a while to launch + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + sky.ManagedJobStatus.RUNNING, sky.ManagedJobStatus.SUCCEEDED + ], + timeout=900 + _BUMP_UP_SECONDS), ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -3418,9 +3716,19 @@ def test_managed_jobs_inline_env(generic_cloud: str): test = Test( 'test-managed-jobs-inline-env', [ - f'sky jobs launch -n {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"', - 'sleep 20', - f'{_GET_JOB_QUEUE} | grep {name} | grep SUCCEEDED', + f'sky jobs launch -n {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "echo "\\$TEST_ENV"; ([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.SUCCEEDED], + timeout=20 + _BUMP_UP_SECONDS), + f'JOB_ROW=$(sky jobs queue | grep {name} | head -n1) && ' + f'echo "$JOB_ROW" && echo "$JOB_ROW" | grep "SUCCEEDED" && ' + f'JOB_ID=$(echo "$JOB_ROW" | awk \'{{print $1}}\') && ' + f'echo "JOB_ID=$JOB_ID" && ' + # Test that logs are still available after the job finishes. + 'unset SKYPILOT_DEBUG; s=$(sky jobs logs $JOB_ID --refresh) && echo "$s" && echo "$s" | grep "hello world" && ' + # Make sure we skip the unnecessary logs. + 'echo "$s" | head -n1 | grep "Waiting for"', ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -3527,8 +3835,12 @@ def test_azure_start_stop_two_nodes(): f'sky start -y {name} -i 1', f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml', f'sky logs {name} 2 --status', # Ensure the job succeeded. - 'sleep 200', - f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ + sky.ClusterStatus.INIT, sky.ClusterStatus.STOPPED + ], + timeout=200 + _BUMP_UP_SECONDS) + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' ], f'sky down -y {name}', @@ -4540,7 +4852,10 @@ def test_core_api_sky_launch_fast(generic_cloud: str): idle_minutes_to_autostop=1, fast=True) # Sleep to let the cluster autostop - time.sleep(120) + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=120) # Run it again - should work with fast=True sky.launch(task, cluster_name=name, diff --git a/tests/test_yamls/test_only_setup.yaml b/tests/test_yamls/test_only_setup.yaml new file mode 100644 index 00000000000..245d2b1de69 --- /dev/null +++ b/tests/test_yamls/test_only_setup.yaml @@ -0,0 +1,2 @@ +setup: | + echo "hello world" diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index 48e47a6007c..c9e7ad35af2 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -16,6 +16,10 @@ POLICY_PATH = os.path.join(os.path.dirname(os.path.dirname(sky.__file__)), 'examples', 'admin_policy') +if not os.path.exists(POLICY_PATH): + # This is used for GitHub Actions, as we copy the examples to the package. + POLICY_PATH = os.path.join(os.path.dirname(__file__), 'examples', + 'admin_policy') @pytest.fixture @@ -172,7 +176,7 @@ def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: idle_minutes_to_autostop=None) -@mock.patch('sky.provision.kubernetes.utils.get_all_kube_config_context_names', +@mock.patch('sky.provision.kubernetes.utils.get_all_kube_context_names', return_value=['kind-skypilot', 'kind-skypilot2', 'kind-skypilot3']) def test_dynamic_kubernetes_contexts_policy(add_example_policy_paths, task): _, config = _load_task_and_apply_policy( diff --git a/tests/unit_tests/test_recovery_strategy.py b/tests/unit_tests/test_recovery_strategy.py new file mode 100644 index 00000000000..da8e8142da0 --- /dev/null +++ b/tests/unit_tests/test_recovery_strategy.py @@ -0,0 +1,48 @@ +from unittest import mock + +from sky.exceptions import ClusterDoesNotExist +from sky.jobs import recovery_strategy + + +@mock.patch('sky.down') +@mock.patch('sky.usage.usage_lib.messages.usage.set_internal') +def test_terminate_cluster_retry_on_value_error(mock_set_internal, + mock_sky_down) -> None: + # Set up mock to fail twice with ValueError, then succeed + mock_sky_down.side_effect = [ + ValueError('Mock error 1'), + ValueError('Mock error 2'), + None, + ] + + # Call should succeed after retries + recovery_strategy.terminate_cluster('test-cluster') + + # Verify sky.down was called 3 times + assert mock_sky_down.call_count == 3 + mock_sky_down.assert_has_calls([ + mock.call('test-cluster'), + mock.call('test-cluster'), + mock.call('test-cluster'), + ]) + + # Verify usage.set_internal was called before each sky.down + assert mock_set_internal.call_count == 3 + + +@mock.patch('sky.down') +@mock.patch('sky.usage.usage_lib.messages.usage.set_internal') +def test_terminate_cluster_handles_nonexistent_cluster(mock_set_internal, + mock_sky_down) -> None: + # Set up mock to raise ClusterDoesNotExist + mock_sky_down.side_effect = ClusterDoesNotExist('test-cluster') + + # Call should succeed silently + recovery_strategy.terminate_cluster('test-cluster') + + # Verify sky.down was called once + assert mock_sky_down.call_count == 1 + mock_sky_down.assert_called_once_with('test-cluster') + + # Verify usage.set_internal was called once + assert mock_set_internal.call_count == 1 diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index 5006fc454aa..65c90544f49 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -140,6 +140,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) expected_config_base = { @@ -180,6 +181,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) assert config == expected_config, ('unexpected resource ' 'variables generated') @@ -195,6 +197,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) assert config == expected_config, ('unexpected resource ' 'variables generated')