diff --git a/.buildkite/generate_pipeline.py b/.buildkite/generate_pipeline.py index 99f29ee258a..3889105d02d 100644 --- a/.buildkite/generate_pipeline.py +++ b/.buildkite/generate_pipeline.py @@ -21,11 +21,13 @@ clouds are not supported yet, smoke tests for those clouds are not generated. """ -import ast import os import random +import re +import subprocess from typing import Any, Dict, List, Optional +import click from conftest import cloud_to_pytest_keyword from conftest import default_clouds_to_run import yaml @@ -36,7 +38,7 @@ QUEUE_GENERIC_CLOUD = 'generic_cloud' QUEUE_GENERIC_CLOUD_SERVE = 'generic_cloud_serve' QUEUE_KUBERNETES = 'kubernetes' -QUEUE_KUBERNETES_SERVE = 'kubernetes_serve' +QUEUE_GKE = 'gke' # Only aws, gcp, azure, and kubernetes are supported for now. # Other clouds do not have credentials. CLOUD_QUEUE_MAP = { @@ -52,7 +54,9 @@ 'aws': QUEUE_GENERIC_CLOUD_SERVE, 'gcp': QUEUE_GENERIC_CLOUD_SERVE, 'azure': QUEUE_GENERIC_CLOUD_SERVE, - 'kubernetes': QUEUE_KUBERNETES_SERVE + # Now we run kubernetes on local cluster, so it should be find if we run + # serve tests on same queue as kubernetes. + 'kubernetes': QUEUE_KUBERNETES } GENERATED_FILE_HEAD = ('# This is an auto-generated Buildkite pipeline by ' @@ -60,18 +64,8 @@ 'edit directly.\n') -def _get_full_decorator_path(decorator: ast.AST) -> str: - """Recursively get the full path of a decorator.""" - if isinstance(decorator, ast.Attribute): - return f'{_get_full_decorator_path(decorator.value)}.{decorator.attr}' - elif isinstance(decorator, ast.Name): - return decorator.id - elif isinstance(decorator, ast.Call): - return _get_full_decorator_path(decorator.func) - raise ValueError(f'Unknown decorator type: {type(decorator)}') - - -def _extract_marked_tests(file_path: str) -> Dict[str, List[str]]: +def _extract_marked_tests(file_path: str, + filter_marks: List[str]) -> Dict[str, List[str]]: """Extract test functions and filter clouds using pytest.mark from a Python test file. @@ -85,80 +79,72 @@ def _extract_marked_tests(file_path: str) -> Dict[str, List[str]]: rerun failures. Additionally, the parallelism would be controlled by pytest instead of the buildkite job queue. """ - with open(file_path, 'r', encoding='utf-8') as file: - tree = ast.parse(file.read(), filename=file_path) - - for node in ast.walk(tree): - for child in ast.iter_child_nodes(node): - setattr(child, 'parent', node) - + cmd = f'pytest {file_path} --collect-only' + output = subprocess.run(cmd, shell=True, capture_output=True, text=True) + matches = re.findall('Collected .+?\.py::(.+?) with marks: \[(.*?)\]', + output.stdout) + function_name_marks_map = {} + for function_name, marks in matches: + function_name = re.sub(r'\[.*?\]', '', function_name) + marks = marks.replace('\'', '').split(',') + marks = [i.strip() for i in marks] + if function_name not in function_name_marks_map: + function_name_marks_map[function_name] = set(marks) + else: + function_name_marks_map[function_name].update(marks) function_cloud_map = {} - for node in ast.walk(tree): - if isinstance(node, ast.FunctionDef) and node.name.startswith('test_'): - class_name = None - if hasattr(node, 'parent') and isinstance(node.parent, - ast.ClassDef): - class_name = node.parent.name - - clouds_to_include = [] - clouds_to_exclude = [] - is_serve_test = False - for decorator in node.decorator_list: - if isinstance(decorator, ast.Call): - # We only need to consider the decorator with no arguments - # to extract clouds. + filter_marks = set(filter_marks) + for function_name, marks in function_name_marks_map.items(): + if filter_marks and not filter_marks & marks: + continue + clouds_to_include = [] + clouds_to_exclude = [] + is_serve_test = 'serve' in marks + run_on_gke = 'requires_gke' in marks + for mark in marks: + if mark.startswith('no_'): + clouds_to_exclude.append(mark[3:]) + else: + if mark not in PYTEST_TO_CLOUD_KEYWORD: + # This mark does not specify a cloud, so we skip it. continue - full_path = _get_full_decorator_path(decorator) - if full_path.startswith('pytest.mark.'): - assert isinstance(decorator, ast.Attribute) - suffix = decorator.attr - if suffix.startswith('no_'): - clouds_to_exclude.append(suffix[3:]) - else: - if suffix == 'serve': - is_serve_test = True - continue - if suffix not in PYTEST_TO_CLOUD_KEYWORD: - # This mark does not specify a cloud, so we skip it. - continue - clouds_to_include.append( - PYTEST_TO_CLOUD_KEYWORD[suffix]) - clouds_to_include = (clouds_to_include if clouds_to_include else - DEFAULT_CLOUDS_TO_RUN) - clouds_to_include = [ - cloud for cloud in clouds_to_include - if cloud not in clouds_to_exclude - ] - cloud_queue_map = SERVE_CLOUD_QUEUE_MAP if is_serve_test else CLOUD_QUEUE_MAP - final_clouds_to_include = [ - cloud for cloud in clouds_to_include if cloud in cloud_queue_map - ] - if clouds_to_include and not final_clouds_to_include: - print(f'Warning: {file_path}:{node.name} ' - f'is marked to run on {clouds_to_include}, ' - f'but we do not have credentials for those clouds. ' - f'Skipped.') - continue - if clouds_to_include != final_clouds_to_include: - excluded_clouds = set(clouds_to_include) - set( - final_clouds_to_include) - print( - f'Warning: {file_path}:{node.name} ' - f'is marked to run on {clouds_to_include}, ' - f'but we only have credentials for {final_clouds_to_include}. ' - f'clouds {excluded_clouds} are skipped.') - function_name = (f'{class_name}::{node.name}' - if class_name else node.name) - function_cloud_map[function_name] = (final_clouds_to_include, [ - cloud_queue_map[cloud] for cloud in final_clouds_to_include - ]) + clouds_to_include.append(PYTEST_TO_CLOUD_KEYWORD[mark]) + + clouds_to_include = (clouds_to_include + if clouds_to_include else DEFAULT_CLOUDS_TO_RUN) + clouds_to_include = [ + cloud for cloud in clouds_to_include + if cloud not in clouds_to_exclude + ] + cloud_queue_map = SERVE_CLOUD_QUEUE_MAP if is_serve_test else CLOUD_QUEUE_MAP + final_clouds_to_include = [ + cloud for cloud in clouds_to_include if cloud in cloud_queue_map + ] + if clouds_to_include and not final_clouds_to_include: + print( + f'Warning: {function_name} is marked to run on {clouds_to_include}, ' + f'but we do not have credentials for those clouds. Skipped.') + continue + if clouds_to_include != final_clouds_to_include: + excluded_clouds = set(clouds_to_include) - set( + final_clouds_to_include) + print( + f'Warning: {function_name} is marked to run on {clouds_to_include}, ' + f'but we only have credentials for {final_clouds_to_include}. ' + f'clouds {excluded_clouds} are skipped.') + function_cloud_map[function_name] = (final_clouds_to_include, [ + QUEUE_GKE if run_on_gke else cloud_queue_map[cloud] + for cloud in final_clouds_to_include + ]) return function_cloud_map -def _generate_pipeline(test_file: str) -> Dict[str, Any]: +def _generate_pipeline(test_file: str, + filter_marks: List[str], + auto_retry: bool = False) -> Dict[str, Any]: """Generate a Buildkite pipeline from test files.""" steps = [] - function_cloud_map = _extract_marked_tests(test_file) + function_cloud_map = _extract_marked_tests(test_file, filter_marks) for test_function, clouds_and_queues in function_cloud_map.items(): for cloud, queue in zip(*clouds_and_queues): step = { @@ -172,6 +158,11 @@ def _generate_pipeline(test_file: str) -> Dict[str, Any]: }, 'if': f'build.env("{cloud}") == "1"' } + if auto_retry: + step['retry'] = { + # Automatically retry 2 times on any failure by default. + 'automatic': True + } steps.append(step) return {'steps': steps} @@ -194,12 +185,12 @@ def _dump_pipeline_to_file(yaml_file_path: str, yaml.dump(final_pipeline, file, default_flow_style=False) -def _convert_release(test_files: List[str]): +def _convert_release(test_files: List[str], filter_marks: List[str]): yaml_file_path = '.buildkite/pipeline_smoke_tests_release.yaml' output_file_pipelines = [] for test_file in test_files: print(f'Converting {test_file} to {yaml_file_path}') - pipeline = _generate_pipeline(test_file) + pipeline = _generate_pipeline(test_file, filter_marks, auto_retry=True) output_file_pipelines.append(pipeline) print(f'Converted {test_file} to {yaml_file_path}\n\n') # Enable all clouds by default for release pipeline. @@ -208,7 +199,7 @@ def _convert_release(test_files: List[str]): extra_env={cloud: '1' for cloud in CLOUD_QUEUE_MAP}) -def _convert_quick_tests_core(test_files: List[str]): +def _convert_quick_tests_core(test_files: List[str], filter_marks: List[str]): yaml_file_path = '.buildkite/pipeline_smoke_tests_quick_tests_core.yaml' output_file_pipelines = [] for test_file in test_files: @@ -216,7 +207,7 @@ def _convert_quick_tests_core(test_files: List[str]): # We want enable all clouds by default for each test function # for pre-merge. And let the author controls which clouds # to run by parameter. - pipeline = _generate_pipeline(test_file) + pipeline = _generate_pipeline(test_file, filter_marks) pipeline['steps'].append({ 'label': 'Backward compatibility test', 'command': 'bash tests/backward_compatibility_tests.sh', @@ -231,7 +222,12 @@ def _convert_quick_tests_core(test_files: List[str]): extra_env={'SKYPILOT_SUPPRESS_SENSITIVE_LOG': '1'}) -def main(): +@click.command() +@click.option( + '--filter-marks', + type=str, + help='Filter to include only a subset of pytest marks, e.g., managed_jobs') +def main(filter_marks): test_files = os.listdir('tests/smoke_tests') release_files = [] quick_tests_core_files = [] @@ -244,8 +240,15 @@ def main(): else: release_files.append(test_file_path) - _convert_release(release_files) - _convert_quick_tests_core(quick_tests_core_files) + filter_marks = filter_marks or os.getenv('FILTER_MARKS') + if filter_marks: + filter_marks = filter_marks.split(',') + print(f'Filter marks: {filter_marks}') + else: + filter_marks = [] + + _convert_release(release_files, filter_marks) + _convert_quick_tests_core(quick_tests_core_files, filter_marks) if __name__ == '__main__': diff --git a/.github/workflows/test-poetry-build.yml b/.github/workflows/test-poetry-build.yml deleted file mode 100644 index 4cce22809ef..00000000000 --- a/.github/workflows/test-poetry-build.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: Poetry Test -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - master - - 'releases/**' - pull_request: - branches: - - master - - 'releases/**' - merge_group: - -jobs: - poetry-build-test: - runs-on: ubuntu-latest - steps: - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python - - echo "$HOME/.poetry/bin" >> $GITHUB_PATH - - name: Create foo package - run: | - mkdir foo - MASTER_REPO_URL=${{ github.server_url }}/${{ github.repository }} - REPO_URL=${{ github.event.pull_request.head.repo.html_url }} - if [ -z "$REPO_URL" ]; then - # This is a push, not a PR, so use the repo URL - REPO_URL=$MASTER_REPO_URL - fi - echo Master repo URL: $MASTER_REPO_URL - echo Using repo URL: $REPO_URL - cat < foo/pyproject.toml - [tool.poetry] - name = "foo" - version = "1.0.0" - authors = ["skypilot-bot"] - description = "" - - [tool.poetry.dependencies] - python = "3.10.x" - - [tool.poetry.group.dev.dependencies] - skypilot = {git = "${REPO_URL}.git", branch = "${{ github.head_ref }}"} - - [build-system] - requires = ["poetry-core"] - build-backend = "poetry.core.masonry.api" - - EOF - - - name: Check poetry lock time - run: | - cd foo - poetry lock --no-update - timeout-minutes: 2 - - diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 85ca90b2c4a..25c6421c347 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -50,9 +50,10 @@ pytest tests/test_smoke.py --generic-cloud azure For profiling code, use: ``` -pip install tuna # Tuna is used for visualization of profiling data. -python3 -m cProfile -o sky.prof -m sky.cli status # Or some other command -tuna sky.prof +pip install py-spy # py-spy is a sampling profiler for Python programs +py-spy record -t -o sky.svg -- python -m sky.cli status # Or some other command +py-spy top -- python -m sky.cli status # Get a live top view +py-spy -h # For more options ``` #### Testing in a container diff --git a/README.md b/README.md index 1ed99325df5..5c1699c0ee1 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,8 @@ Read the research: - [Sky Computing vision paper](https://sigops.org/s/conferences/hotos/2021/papers/hotos21-s02-stoica.pdf) (HotOS 2021) - [Policy for Managed Spot Jobs](https://www.usenix.org/conference/nsdi24/presentation/wu-zhanghao) (NSDI 2024) +SkyPilot was initially started at the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley and has since gained many industry contributors. Read more about the project's origin [here](https://docs.skypilot.co/en/latest/sky-computing.html). + ## Support and Questions We are excited to hear your feedback! * For issues and feature requests, please [open a GitHub issue](https://github.com/skypilot-org/skypilot/issues/new). diff --git a/docs/source/_static/SkyPilot_wide_dark.svg b/docs/source/_static/SkyPilot_wide_dark.svg index 6be00d9e591..cb2f742ab98 100644 --- a/docs/source/_static/SkyPilot_wide_dark.svg +++ b/docs/source/_static/SkyPilot_wide_dark.svg @@ -1,64 +1,54 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/SkyPilot_wide_light.svg b/docs/source/_static/SkyPilot_wide_light.svg index 0b2eaae8538..71945c0f927 100644 --- a/docs/source/_static/SkyPilot_wide_light.svg +++ b/docs/source/_static/SkyPilot_wide_light.svg @@ -1,64 +1,55 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css index d5bbdd6cb51..aae9defea90 100644 --- a/docs/source/_static/custom.css +++ b/docs/source/_static/custom.css @@ -27,6 +27,7 @@ html[data-theme="light"] { --pst-color-primary: #176de8; --pst-color-secondary: var(--pst-color-primary); --pst-color-text-base: #4c4c4d; + --logo-text-color: #0E2E65; } html[data-theme="dark"] { @@ -34,6 +35,7 @@ html[data-theme="dark"] { --pst-color-primary: #176de8; --pst-color-secondary: var(--pst-color-primary); --pst-color-text-base: #d8d8d8; + --logo-text-color: #D8D8D8; .bd-sidebar::-webkit-scrollbar { width: 6px; diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js index 5ae47b7b7be..6daa0883885 100644 --- a/docs/source/_static/custom.js +++ b/docs/source/_static/custom.js @@ -1,16 +1,17 @@ -document.addEventListener('DOMContentLoaded', function () { - var script = document.createElement('script'); - script.src = 'https://widget.kapa.ai/kapa-widget.bundle.js'; - script.setAttribute('data-website-id', '4223d017-a3d2-4b92-b191-ea4d425a23c3'); - script.setAttribute('data-project-name', 'SkyPilot'); - script.setAttribute('data-project-color', '#4C4C4D'); - script.setAttribute('data-project-logo', 'https://avatars.githubusercontent.com/u/109387420?s=100&v=4'); - script.setAttribute('data-modal-disclaimer', 'Results are automatically generated and may be inaccurate or contain inappropriate information. Do not include any sensitive information in your query.\n**To get further assistance, you can chat directly with the development team** by joining the [SkyPilot Slack](https://slack.skypilot.co/).'); - script.setAttribute('data-modal-title', 'SkyPilot Docs AI - Ask a Question.'); - script.setAttribute('data-button-position-bottom', '100px'); - script.async = true; - document.head.appendChild(script); -}); +// As of 2025-01-01, Kapa seems to be having issues loading on some ISPs, including comcast. Uncomment once resolved. +// document.addEventListener('DOMContentLoaded', function () { +// var script = document.createElement('script'); +// script.src = 'https://widget.kapa.ai/kapa-widget.bundle.js'; +// script.setAttribute('data-website-id', '4223d017-a3d2-4b92-b191-ea4d425a23c3'); +// script.setAttribute('data-project-name', 'SkyPilot'); +// script.setAttribute('data-project-color', '#4C4C4D'); +// script.setAttribute('data-project-logo', 'https://avatars.githubusercontent.com/u/109387420?s=100&v=4'); +// script.setAttribute('data-modal-disclaimer', 'Results are automatically generated and may be inaccurate or contain inappropriate information. Do not include any sensitive information in your query.\n**To get further assistance, you can chat directly with the development team** by joining the [SkyPilot Slack](https://slack.skypilot.co/).'); +// script.setAttribute('data-modal-title', 'SkyPilot Docs AI - Ask a Question.'); +// script.setAttribute('data-button-position-bottom', '100px'); +// script.async = true; +// document.head.appendChild(script); +// }); (function(h,o,t,j,a,r){ h.hj=h.hj||function(){(h.hj.q=h.hj.q||[]).push(arguments)}; @@ -25,14 +26,9 @@ document.addEventListener('DOMContentLoaded', function () { document.addEventListener('DOMContentLoaded', () => { // New items: const newItems = [ - { selector: '.toctree-l1 > a', text: 'Managed Jobs' }, - { selector: '.toctree-l1 > a', text: 'Pixtral (Mistral AI)' }, { selector: '.toctree-l1 > a', text: 'Many Parallel Jobs' }, - { selector: '.toctree-l1 > a', text: 'Reserved, Capacity Blocks, DWS' }, - { selector: '.toctree-l1 > a', text: 'Llama 3.2 (Meta)' }, { selector: '.toctree-l1 > a', text: 'Admin Policy Enforcement' }, { selector: '.toctree-l1 > a', text: 'Using Existing Machines' }, - { selector: '.toctree-l1 > a', text: 'Concept: Sky Computing' }, ]; newItems.forEach(({ selector, text }) => { document.querySelectorAll(selector).forEach((el) => { diff --git a/docs/source/_templates/navbar-skypilot-logo.html b/docs/source/_templates/navbar-skypilot-logo.html index 0323953acde..1692f1f2a5d 100644 --- a/docs/source/_templates/navbar-skypilot-logo.html +++ b/docs/source/_templates/navbar-skypilot-logo.html @@ -9,5 +9,59 @@ {#- Logo HTML and image #} diff --git a/docs/source/cloud-setup/cloud-permissions/aws.rst b/docs/source/cloud-setup/cloud-permissions/aws.rst index 89510331988..57fc7ac9732 100644 --- a/docs/source/cloud-setup/cloud-permissions/aws.rst +++ b/docs/source/cloud-setup/cloud-permissions/aws.rst @@ -223,7 +223,7 @@ IAM Role Creation Using a specific VPC ----------------------- -By default, SkyPilot uses the "default" VPC in each region. +By default, SkyPilot uses the "default" VPC in each region. If a region does not have a `default VPC `_, SkyPilot will not be able to use the region. To instruct SkyPilot to use a specific VPC, you can use SkyPilot's global config file ``~/.sky/config.yaml`` to specify the VPC name in the ``aws.vpc_name`` diff --git a/docs/source/examples/managed-jobs.rst b/docs/source/examples/managed-jobs.rst index 99fa461249d..2cd99b6c24b 100644 --- a/docs/source/examples/managed-jobs.rst +++ b/docs/source/examples/managed-jobs.rst @@ -152,6 +152,7 @@ The :code:`MOUNT` mode in :ref:`SkyPilot bucket mounting ` ensures Note that the application code should save program checkpoints periodically and reload those states when the job is restarted. This is typically achieved by reloading the latest checkpoint at the beginning of your program. + .. _spot-jobs-end-to-end: An End-to-End Example @@ -455,6 +456,46 @@ especially useful when there are many in-progress jobs to monitor, which the terminal-based CLI may need more than one page to display. +.. _intermediate-bucket: + +Intermediate storage for files +------------------------------ + +For managed jobs, SkyPilot requires an intermediate bucket to store files used in the task, such as local file mounts, temporary files, and the workdir. +If you do not configure a bucket, SkyPilot will automatically create a temporary bucket named :code:`skypilot-filemounts-{username}-{run_id}` for each job launch. SkyPilot automatically deletes the bucket after the job completes. + +Alternatively, you can pre-provision a bucket and use it as an intermediate for storing file by setting :code:`jobs.bucket` in :code:`~/.sky/config.yaml`: + +.. code-block:: yaml + + # ~/.sky/config.yaml + jobs: + bucket: s3://my-bucket # Supports s3://, gs://, https://.blob.core.windows.net/, r2://, cos:/// + + +If you choose to specify a bucket, ensure that the bucket already exists and that you have the necessary permissions. + +When using a pre-provisioned intermediate bucket with :code:`jobs.bucket`, SkyPilot creates job-specific directories under the bucket root to store files. They are organized in the following structure: + +.. code-block:: text + + # cloud bucket, s3://my-bucket/ for example + my-bucket/ + ├── job-15891b25/ # Job-specific directory + │ ├── local-file-mounts/ # Files from local file mounts + │ ├── tmp-files/ # Temporary files + │ └── workdir/ # Files from workdir + └── job-cae228be/ # Another job's directory + ├── local-file-mounts/ + ├── tmp-files/ + └── workdir/ + +When using a custom bucket (:code:`jobs.bucket`), the job-specific directories (e.g., :code:`job-15891b25/`) created by SkyPilot are removed when the job completes. + +.. tip:: + Multiple users can share the same intermediate bucket. Each user's jobs will have their own unique job-specific directories, ensuring that files are kept separate and organized. + + Concept: Jobs Controller ------------------------ @@ -505,4 +546,3 @@ The :code:`resources` field has the same spec as a normal SkyPilot job; see `her These settings will not take effect if you have an existing controller (either stopped or live). For them to take effect, tear down the existing controller first, which requires all in-progress jobs to finish or be canceled. - diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst index deb2307b67b..93c730ef651 100644 --- a/docs/source/getting-started/installation.rst +++ b/docs/source/getting-started/installation.rst @@ -59,6 +59,7 @@ Install SkyPilot using pip: pip install "skypilot-nightly[runpod]" pip install "skypilot-nightly[fluidstack]" pip install "skypilot-nightly[paperspace]" + pip install "skypilot-nightly[do]" pip install "skypilot-nightly[cudo]" pip install "skypilot-nightly[ibm]" pip install "skypilot-nightly[scp]" @@ -303,7 +304,7 @@ RunPod .. code-block:: shell - pip install "runpod>=1.5.1" + pip install "runpod>=1.6.1" runpod config diff --git a/docs/source/getting-started/tutorial.rst b/docs/source/getting-started/tutorial.rst index 175f1391a6d..9b067be2876 100644 --- a/docs/source/getting-started/tutorial.rst +++ b/docs/source/getting-started/tutorial.rst @@ -2,19 +2,20 @@ Tutorial: AI Training ====================== -This example uses SkyPilot to train a Transformer-based language model from HuggingFace. +This example uses SkyPilot to train a GPT-like model (inspired by Karpathy's `minGPT `_) with Distributed Data Parallel (DDP) in PyTorch. -First, define a :ref:`task YAML ` with the resource requirements, the setup commands, +We define a :ref:`task YAML ` with the resource requirements, the setup commands, and the commands to run: .. code-block:: yaml - # dnn.yaml + # train.yaml - name: huggingface + name: minGPT-ddp resources: - accelerators: V100:4 + cpus: 4+ + accelerators: L4:4 # Or A100:8, H100:8 # Optional: upload a working directory to remote ~/sky_workdir. # Commands in "setup" and "run" will be executed under it. @@ -30,26 +31,21 @@ and the commands to run: # ~/.netrc: ~/.netrc setup: | - set -e # Exit if any command failed. - git clone https://github.com/huggingface/transformers/ || true - cd transformers - pip install . - cd examples/pytorch/text-classification - pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + git clone --depth 1 https://github.com/pytorch/examples || true + cd examples + git filter-branch --prune-empty --subdirectory-filter distributed/minGPT-ddp + # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5). + uv pip install -r requirements.txt "numpy<2" "torch==1.12.1+cu113" --extra-index-url https://download.pytorch.org/whl/cu113 run: | - set -e # Exit if any command failed. - cd transformers/examples/pytorch/text-classification - python run_glue.py \ - --model_name_or_path bert-base-cased \ - --dataset_name imdb \ - --do_train \ - --max_seq_length 128 \ - --per_device_train_batch_size 32 \ - --learning_rate 2e-5 \ - --max_steps 50 \ - --output_dir /tmp/imdb/ --overwrite_output_dir \ - --fp16 + cd examples/mingpt + export LOGLEVEL=INFO + + echo "Starting minGPT-ddp training" + + torchrun \ + --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \ + main.py .. tip:: @@ -57,11 +53,15 @@ and the commands to run: learn about how to use them to mount local dirs/files or object store buckets (S3, GCS, R2) into your cluster, see :ref:`sync-code-artifacts`. +.. tip:: + + The ``SKYPILOT_NUM_GPUS_PER_NODE`` environment variable is automatically set by SkyPilot to the number of GPUs per node. See :ref:`env-vars` for more. + Then, launch training: .. code-block:: console - $ sky launch -c lm-cluster dnn.yaml + $ sky launch -c mingpt train.yaml This will provision the cheapest cluster with the required resources, execute the setup commands, then execute the run commands. diff --git a/docs/source/images/skypilot-wide-dark-1k.png b/docs/source/images/skypilot-wide-dark-1k.png index 057b6a0ae97..b6ed7caec6f 100644 Binary files a/docs/source/images/skypilot-wide-dark-1k.png and b/docs/source/images/skypilot-wide-dark-1k.png differ diff --git a/docs/source/images/skypilot-wide-light-1k.png b/docs/source/images/skypilot-wide-light-1k.png index 7af87ad2864..178c6553dd3 100644 Binary files a/docs/source/images/skypilot-wide-light-1k.png and b/docs/source/images/skypilot-wide-light-1k.png differ diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index d5ee4d2134a..a76dc473206 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -24,6 +24,10 @@ Available fields and semantics: # # Ref: https://docs.skypilot.co/en/latest/examples/managed-jobs.html#customizing-job-controller-resources jobs: + # Bucket to store managed jobs mount files and tmp files. Bucket must already exist. + # Optional. If not set, SkyPilot will create a new bucket for each managed job launch. + # Supports s3://, gs://, https://.blob.core.windows.net/, r2://, cos:/// + bucket: s3://my-bucket/ controller: resources: # same spec as 'resources' in a task YAML cloud: gcp @@ -624,20 +628,30 @@ Available fields and semantics: # Advanced OCI configurations (optional). oci: # A dict mapping region names to region-specific configurations, or - # `default` for the default configuration. + # `default` for the default/global configuration. default: - # The OCID of the profile to use for launching instances (optional). - oci_config_profile: DEFAULT - # The OCID of the compartment to use for launching instances (optional). + # The profile name in ~/.oci/config to use for launching instances. If not + # set, the one named DEFAULT will be used (optional). + oci_config_profile: SKY_PROVISION_PROFILE + # The OCID of the compartment to use for launching instances. If not set, + # the root compartment will be used (optional). compartment_ocid: ocid1.compartment.oc1..aaaaaaaahr7aicqtodxmcfor6pbqn3hvsngpftozyxzqw36gj4kh3w3kkj4q - # The image tag to use for launching general instances (optional). - image_tag_general: skypilot:cpu-ubuntu-2004 - # The image tag to use for launching GPU instances (optional). - image_tag_gpu: skypilot:gpu-ubuntu-2004 - + # The default image tag to use for launching general instances (CPU) if the + # image_id parameter is not specified. If not set, the default is + # skypilot:cpu-ubuntu-2204 (optional). + image_tag_general: skypilot:cpu-oraclelinux8 + # The default image tag to use for launching GPU instances if the image_id + # parameter is not specified. If not set, the default is + # skypilot:gpu-ubuntu-2204 (optional). + image_tag_gpu: skypilot:gpu-oraclelinux8 + + # Region-specific configurations ap-seoul-1: + # The OCID of the VCN to use for instances (optional). + vcn_ocid: ocid1.vcn.oc1.ap-seoul-1.amaaaaaaak7gbriarkfs2ssus5mh347ktmi3xa72tadajep6asio3ubqgarq # The OCID of the subnet to use for instances (optional). vcn_subnet: ocid1.subnet.oc1.ap-seoul-1.aaaaaaaa5c6wndifsij6yfyfehmi3tazn6mvhhiewqmajzcrlryurnl7nuja us-ashburn-1: + vcn_ocid: ocid1.vcn.oc1.ap-seoul-1.amaaaaaaak7gbriarkfs2ssus5mh347ktmi3xa72tadajep6asio3ubqgarq vcn_subnet: ocid1.subnet.oc1.iad.aaaaaaaafbj7i3aqc4ofjaapa5edakde6g4ea2yaslcsay32cthp7qo55pxa diff --git a/docs/source/reference/kubernetes/kubernetes-getting-started.rst b/docs/source/reference/kubernetes/kubernetes-getting-started.rst index e4bbb2c8915..3323559bb36 100644 --- a/docs/source/reference/kubernetes/kubernetes-getting-started.rst +++ b/docs/source/reference/kubernetes/kubernetes-getting-started.rst @@ -258,6 +258,67 @@ After launching the cluster with :code:`sky launch -c myclus task.yaml`, you can To learn more about opening ports in SkyPilot tasks, see :ref:`Opening Ports `. +Customizing SkyPilot pods +------------------------- + +You can override the pod configuration used by SkyPilot by setting the :code:`pod_config` key in :code:`~/.sky/config.yaml`. +The value of :code:`pod_config` should be a dictionary that follows the `Kubernetes Pod API `_. This will apply to all pods created by SkyPilot. + +For example, to set custom environment variables and use GPUDirect RDMA, you can add the following to your :code:`~/.sky/config.yaml` file: + +.. code-block:: yaml + + # ~/.sky/config.yaml + kubernetes: + pod_config: + spec: + containers: + - env: # Custom environment variables to set in pod + - name: MY_ENV_VAR + value: MY_ENV_VALUE + resources: # Custom resources for GPUDirect RDMA + requests: + rdma/rdma_shared_device_a: 1 + limits: + rdma/rdma_shared_device_a: 1 + + +Similarly, you can attach `Kubernetes volumes `_ (e.g., an `NFS volume `_) directly to your SkyPilot pods: + +.. code-block:: yaml + + # ~/.sky/config.yaml + kubernetes: + pod_config: + spec: + containers: + - volumeMounts: # Custom volume mounts for the pod + - mountPath: /data + name: nfs-volume + volumes: + - name: nfs-volume + nfs: # Alternatively, use hostPath if your NFS is directly attached to the nodes + server: nfs.example.com + path: /nfs + + +.. tip:: + + As an alternative to setting ``pod_config`` globally, you can also set it on a per-task basis directly in your task YAML with the ``config_overrides`` :ref:`field `. + + .. code-block:: yaml + + # task.yaml + run: | + python myscript.py + + # Set pod_config for this task + experimental: + config_overrides: + pod_config: + ... + + FAQs ---- @@ -293,38 +354,6 @@ FAQs You can use your existing observability tools to filter resources with the label :code:`parent=skypilot` (:code:`kubectl get pods -l 'parent=skypilot'`). As an example, follow the instructions :ref:`here ` to deploy the Kubernetes Dashboard on your cluster. -* **How can I specify custom configuration for the pods created by SkyPilot?** - - You can override the pod configuration used by SkyPilot by setting the :code:`pod_config` key in :code:`~/.sky/config.yaml`. - The value of :code:`pod_config` should be a dictionary that follows the `Kubernetes Pod API `_. - - For example, to set custom environment variables and attach a volume on your pods, you can add the following to your :code:`~/.sky/config.yaml` file: - - .. code-block:: yaml - - kubernetes: - pod_config: - spec: - containers: - - env: - - name: MY_ENV_VAR - value: MY_ENV_VALUE - volumeMounts: # Custom volume mounts for the pod - - mountPath: /foo - name: example-volume - resources: # Custom resource requests and limits - requests: - rdma/rdma_shared_device_a: 1 - limits: - rdma/rdma_shared_device_a: 1 - volumes: - - name: example-volume - hostPath: - path: /tmp - type: Directory - - For more details refer to :ref:`config-yaml`. - * **I am using a custom image. How can I speed up the pod startup time?** You can pre-install SkyPilot dependencies in your custom image to speed up the pod startup time. Simply add these lines at the end of your Dockerfile: diff --git a/docs/source/reference/storage.rst b/docs/source/reference/storage.rst index 3c54680e79b..16f87c1ce2f 100644 --- a/docs/source/reference/storage.rst +++ b/docs/source/reference/storage.rst @@ -3,7 +3,7 @@ Cloud Object Storage ==================== -SkyPilot tasks can access data from buckets in cloud object storages such as AWS S3, Google Cloud Storage (GCS), Cloudflare R2 or IBM COS. +SkyPilot tasks can access data from buckets in cloud object storages such as AWS S3, Google Cloud Storage (GCS), Cloudflare R2, OCI Object Storage or IBM COS. Buckets are made available to each task at a local path on the remote VM, so the task can access bucket objects as if they were local files. @@ -28,7 +28,7 @@ Object storages are specified using the :code:`file_mounts` field in a SkyPilot # Mount an existing S3 bucket file_mounts: /my_data: - source: s3://my-bucket/ # or gs://, https://.blob.core.windows.net/, r2://, cos:/// + source: s3://my-bucket/ # or gs://, https://.blob.core.windows.net/, r2://, cos:///, oci:// mode: MOUNT # Optional: either MOUNT or COPY. Defaults to MOUNT. This will `mount `__ the contents of the bucket at ``s3://my-bucket/`` to the remote VM at ``/my_data``. @@ -45,7 +45,7 @@ Object storages are specified using the :code:`file_mounts` field in a SkyPilot file_mounts: /my_data: name: my-sky-bucket - store: gcs # Optional: either of s3, gcs, azure, r2, ibm + store: gcs # Optional: either of s3, gcs, azure, r2, ibm, oci SkyPilot will create an empty GCS bucket called ``my-sky-bucket`` and mount it at ``/my_data``. This bucket can be used to write checkpoints, logs or other outputs directly to the cloud. @@ -68,7 +68,7 @@ Object storages are specified using the :code:`file_mounts` field in a SkyPilot /my_data: name: my-sky-bucket source: ~/dataset # Optional: path to local data to upload to the bucket - store: s3 # Optional: either of s3, gcs, azure, r2, ibm + store: s3 # Optional: either of s3, gcs, azure, r2, ibm, oci mode: MOUNT # Optional: either MOUNT or COPY. Defaults to MOUNT. SkyPilot will create a S3 bucket called ``my-sky-bucket`` and upload the @@ -290,12 +290,13 @@ Storage YAML reference - https://.blob.core.windows.net/ - r2:// - cos:/// + - oci:// If the source is local, data is uploaded to the cloud to an appropriate - bucket (s3, gcs, azure, r2, or ibm). If source is bucket URI, + bucket (s3, gcs, azure, r2, oci, or ibm). If source is bucket URI, the data is copied or mounted directly (see mode flag below). - store: str; either of 's3', 'gcs', 'azure', 'r2', 'ibm' + store: str; either of 's3', 'gcs', 'azure', 'r2', 'ibm', 'oci' If you wish to force sky.Storage to be backed by a specific cloud object storage, you can specify it here. If not specified, SkyPilot chooses the appropriate object storage based on the source path and task's cloud provider. diff --git a/docs/source/reference/yaml-spec.rst b/docs/source/reference/yaml-spec.rst index 0be708305c8..d2f0506993a 100644 --- a/docs/source/reference/yaml-spec.rst +++ b/docs/source/reference/yaml-spec.rst @@ -176,9 +176,9 @@ Available fields: # tpu_vm: True # True to use TPU VM (the default); False to use TPU node. # Custom image id (optional, advanced). The image id used to boot the - # instances. Only supported for AWS and GCP (for non-docker image). If not - # specified, SkyPilot will use the default debian-based image suitable for - # machine learning tasks. + # instances. Only supported for AWS, GCP, OCI and IBM (for non-docker image). + # If not specified, SkyPilot will use the default debian-based image + # suitable for machine learning tasks. # # Docker support # You can specify docker image to use by setting the image_id to @@ -204,7 +204,7 @@ Available fields: # image_id: # us-east-1: ami-0729d913a335efca7 # us-west-2: ami-050814f384259894c - image_id: ami-0868a20f5a3bf9702 + # # GCP # To find GCP images: https://cloud.google.com/compute/docs/images # image_id: projects/deeplearning-platform-release/global/images/common-cpu-v20230615-debian-11-py310 @@ -215,6 +215,24 @@ Available fields: # To find Azure images: https://docs.microsoft.com/en-us/azure/virtual-machines/linux/cli-ps-findimage # image_id: microsoft-dsvm:ubuntu-2004:2004:21.11.04 # + # OCI + # To find OCI images: https://docs.oracle.com/en-us/iaas/images + # You can choose the image with OS version from the following image tags + # provided by SkyPilot: + # image_id: skypilot:gpu-ubuntu-2204 + # image_id: skypilot:gpu-ubuntu-2004 + # image_id: skypilot:gpu-oraclelinux9 + # image_id: skypilot:gpu-oraclelinux8 + # image_id: skypilot:cpu-ubuntu-2204 + # image_id: skypilot:cpu-ubuntu-2004 + # image_id: skypilot:cpu-oraclelinux9 + # image_id: skypilot:cpu-oraclelinux8 + # + # It is also possible to specify your custom image's OCID with OS type, + # for example: + # image_id: ocid1.image.oc1.us-sanjose-1.aaaaaaaaywwfvy67wwe7f24juvjwhyjn3u7g7s3wzkhduxcbewzaeki2nt5q:oraclelinux + # image_id: ocid1.image.oc1.us-sanjose-1.aaaaaaaa5tnuiqevhoyfnaa5pqeiwjv6w5vf6w4q2hpj3atyvu3yd6rhlhyq:ubuntu + # # IBM # Create a private VPC image and paste its ID in the following format: # image_id: @@ -224,6 +242,7 @@ Available fields: # https://www.ibm.com/cloud/blog/use-ibm-packer-plugin-to-create-custom-images-on-ibm-cloud-vpc-infrastructure # To use a more limited but easier to manage tool: # https://github.com/IBM/vpc-img-inst + image_id: ami-0868a20f5a3bf9702 # Labels to apply to the instances (optional). # @@ -307,7 +326,7 @@ Available fields: /datasets-storage: name: sky-dataset # Name of storage, optional when source is bucket URI source: /local/path/datasets # Source path, can be local or bucket URI. Optional, do not specify to create an empty bucket. - store: s3 # Could be either 's3', 'gcs', 'azure', 'r2', or 'ibm'; default: None. Optional. + store: s3 # Could be either 's3', 'gcs', 'azure', 'r2', 'oci', or 'ibm'; default: None. Optional. persistent: True # Defaults to True; can be set to false to delete bucket after cluster is downed. Optional. mode: MOUNT # Either MOUNT or COPY. Defaults to MOUNT. Optional. diff --git a/docs/source/running-jobs/distributed-jobs.rst b/docs/source/running-jobs/distributed-jobs.rst index f6c8cba9c9d..7c3421aa276 100644 --- a/docs/source/running-jobs/distributed-jobs.rst +++ b/docs/source/running-jobs/distributed-jobs.rst @@ -6,39 +6,40 @@ Distributed Multi-Node Jobs SkyPilot supports multi-node cluster provisioning and distributed execution on many nodes. -For example, here is a simple PyTorch Distributed training example: +For example, here is a simple example to train a GPT-like model (inspired by Karpathy's `minGPT `_) across 2 nodes with Distributed Data Parallel (DDP) in PyTorch. .. code-block:: yaml - :emphasize-lines: 6-6,21-21,23-26 + :emphasize-lines: 6,19,23-24,26 - name: resnet-distributed-app + name: minGPT-ddp - resources: - accelerators: A100:8 + resources: + accelerators: A100:8 - num_nodes: 2 + num_nodes: 2 - setup: | - pip3 install --upgrade pip - git clone https://github.com/michaelzhiluo/pytorch-distributed-resnet - cd pytorch-distributed-resnet - # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5). - pip3 install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 - mkdir -p data && mkdir -p saved_models && cd data && \ - wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz - tar -xvzf cifar-10-python.tar.gz + setup: | + git clone --depth 1 https://github.com/pytorch/examples || true + cd examples + git filter-branch --prune-empty --subdirectory-filter distributed/minGPT-ddp + # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5). + uv pip install -r requirements.txt "numpy<2" "torch==1.12.1+cu113" --extra-index-url https://download.pytorch.org/whl/cu113 - run: | - cd pytorch-distributed-resnet + run: | + cd examples/mingpt + export LOGLEVEL=INFO + + MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1) + echo "Starting distributed training, head node: $MASTER_ADDR" - MASTER_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1` - torchrun \ + torchrun \ --nnodes=$SKYPILOT_NUM_NODES \ - --master_addr=$MASTER_ADDR \ --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \ - --node_rank=$SKYPILOT_NODE_RANK \ - --master_port=12375 \ - resnet_ddp.py --num_epochs 20 + --master_addr=$MASTER_ADDR \ + --node_rank=${SKYPILOT_NODE_RANK} \ + --master_port=8008 \ + main.py + In the above, @@ -55,6 +56,7 @@ In the above, ulimit -n 65535 +You can find more `distributed training examples `_ (including `using rdvz backend for pytorch `_) in our `GitHub repository `_. Environment variables ----------------------------------------- diff --git a/examples/local_docker/docker_in_docker.yaml b/examples/local_docker/docker_in_docker.yaml deleted file mode 100644 index bdb6ed70ecf..00000000000 --- a/examples/local_docker/docker_in_docker.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# Runs a docker container as a SkyPilot task. -# -# This demo can be run using the --docker flag, demonstrating the -# docker-in-docker (dind) capabilities of SkyPilot docker mode. -# -# Usage: -# sky launch --docker -c dind docker_in_docker.yaml -# sky down dind - -name: dind - -resources: - cloud: aws - -setup: | - echo "No setup required!" - -run: | - docker run --rm hello-world diff --git a/examples/local_docker/ping.py b/examples/local_docker/ping.py deleted file mode 100644 index c3a90c62243..00000000000 --- a/examples/local_docker/ping.py +++ /dev/null @@ -1,22 +0,0 @@ -"""An example app which pings localhost. - -This script is designed to demonstrate the use of different backends with -SkyPilot. It is useful to support a LocalDockerBackend that users can use to -debug their programs even before they run them on the Sky. -""" - -import sky - -# Set backend here. It can be either LocalDockerBackend or CloudVmRayBackend. -backend = sky.backends.LocalDockerBackend( -) # or sky.backends.CloudVmRayBackend() - -with sky.Dag() as dag: - resources = sky.Resources(accelerators={'K80': 1}) - setup_commands = 'apt-get update && apt-get install -y iputils-ping' - task = sky.Task(run='ping 127.0.0.1 -c 100', - docker_image='ubuntu', - setup=setup_commands, - name='ping').set_resources(resources) - -sky.launch(dag, backend=backend) diff --git a/examples/local_docker/ping.yaml b/examples/local_docker/ping.yaml deleted file mode 100644 index 0d0efd12419..00000000000 --- a/examples/local_docker/ping.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# A minimal ping example. -# -# Runs a task that pings localhost 100 times. -# -# Usage: -# sky launch --docker -c ping ping.yaml -# sky down ping - -name: ping - -resources: - cloud: aws - -setup: | - sudo apt-get update --allow-insecure-repositories - sudo apt-get install -y iputils-ping - -run: | - ping 127.0.0.1 -c 100 diff --git a/examples/oci/dataset-mount.yaml b/examples/oci/dataset-mount.yaml new file mode 100644 index 00000000000..91bec9cda65 --- /dev/null +++ b/examples/oci/dataset-mount.yaml @@ -0,0 +1,35 @@ +name: cpu-task1 + +resources: + cloud: oci + region: us-sanjose-1 + cpus: 2 + disk_size: 256 + disk_tier: medium + use_spot: False + +file_mounts: + # Mount an existing oci bucket + /datasets-storage: + source: oci://skybucket + mode: MOUNT # Either MOUNT or COPY. Optional. + +# Working directory (optional) containing the project codebase. +# Its contents are synced to ~/sky_workdir/ on the cluster. +workdir: . + +num_nodes: 1 + +# Typical use: pip install -r requirements.txt +# Invoked under the workdir (i.e., can use its files). +setup: | + echo "*** Running setup for the task. ***" + +# Typical use: make use of resources, such as running training. +# Invoked under the workdir (i.e., can use its files). +run: | + echo "*** Running the task on OCI ***" + timestamp=$(date +%s) + ls -lthr /datasets-storage + echo "hi" >> /datasets-storage/foo.txt + ls -lthr /datasets-storage diff --git a/examples/oci/dataset-upload-and-mount.yaml b/examples/oci/dataset-upload-and-mount.yaml new file mode 100644 index 00000000000..13ddc4d2b35 --- /dev/null +++ b/examples/oci/dataset-upload-and-mount.yaml @@ -0,0 +1,47 @@ +name: cpu-task1 + +resources: + cloud: oci + region: us-sanjose-1 + cpus: 2 + disk_size: 256 + disk_tier: medium + use_spot: False + +file_mounts: + /datasets-storage: + name: skybucket # Name of storage, optional when source is bucket URI + source: ['./examples/oci'] # Source path, can be local or bucket URL. Optional, do not specify to create an empty bucket. + store: oci # E.g 'oci', 's3', 'gcs'...; default: None. Optional. + persistent: True # Defaults to True; can be set to false. Optional. + mode: MOUNT # Either MOUNT or COPY. Optional. + + /datasets-storage2: + name: skybucket2 # Name of storage, optional when source is bucket URI + source: './examples/oci' # Source path, can be local or bucket URL. Optional, do not specify to create an empty bucket. + store: oci # E.g 'oci', 's3', 'gcs'...; default: None. Optional. + persistent: True # Defaults to True; can be set to false. Optional. + mode: MOUNT # Either MOUNT or COPY. Optional. + +# Working directory (optional) containing the project codebase. +# Its contents are synced to ~/sky_workdir/ on the cluster. +workdir: . + +num_nodes: 1 + +# Typical use: pip install -r requirements.txt +# Invoked under the workdir (i.e., can use its files). +setup: | + echo "*** Running setup for the task. ***" + +# Typical use: make use of resources, such as running training. +# Invoked under the workdir (i.e., can use its files). +run: | + echo "*** Running the task on OCI ***" + ls -lthr /datasets-storage + echo "hi" >> /datasets-storage/foo.txt + ls -lthr /datasets-storage + + ls -lthr /datasets-storage2 + echo "hi" >> /datasets-storage2/foo2.txt + ls -lthr /datasets-storage2 diff --git a/examples/oci/gpu-oraclelinux9.yaml b/examples/oci/gpu-oraclelinux9.yaml new file mode 100644 index 00000000000..cc7b05ea0fc --- /dev/null +++ b/examples/oci/gpu-oraclelinux9.yaml @@ -0,0 +1,33 @@ +name: gpu-task + +resources: + # Optional; if left out, automatically pick the cheapest cloud. + cloud: oci + + accelerators: A10:1 + + disk_size: 1024 + + disk_tier: high + + image_id: skypilot:gpu-oraclelinux9 + + +# Working directory (optional) containing the project codebase. +# Its contents are synced to ~/sky_workdir/ on the cluster. +workdir: . + +num_nodes: 1 + +# Typical use: pip install -r requirements.txt +# Invoked under the workdir (i.e., can use its files). +setup: | + echo "*** Running setup. ***" + +# Typical use: make use of resources, such as running training. +# Invoked under the workdir (i.e., can use its files). +run: | + echo "*** Running the task on OCI ***" + echo "hello, world" + nvidia-smi + echo "The task is completed." diff --git a/examples/oci/gpu-ubuntu-2204.yaml b/examples/oci/gpu-ubuntu-2204.yaml new file mode 100644 index 00000000000..e0012a31a1a --- /dev/null +++ b/examples/oci/gpu-ubuntu-2204.yaml @@ -0,0 +1,33 @@ +name: gpu-task + +resources: + # Optional; if left out, automatically pick the cheapest cloud. + cloud: oci + + accelerators: A10:1 + + disk_size: 1024 + + disk_tier: high + + image_id: skypilot:gpu-ubuntu-2204 + + +# Working directory (optional) containing the project codebase. +# Its contents are synced to ~/sky_workdir/ on the cluster. +workdir: . + +num_nodes: 1 + +# Typical use: pip install -r requirements.txt +# Invoked under the workdir (i.e., can use its files). +setup: | + echo "*** Running setup. ***" + +# Typical use: make use of resources, such as running training. +# Invoked under the workdir (i.e., can use its files). +run: | + echo "*** Running the task on OCI ***" + echo "hello, world" + nvidia-smi + echo "The task is completed." diff --git a/examples/oci/oci-mounts.yaml b/examples/oci/oci-mounts.yaml new file mode 100644 index 00000000000..6fd2aaf16eb --- /dev/null +++ b/examples/oci/oci-mounts.yaml @@ -0,0 +1,26 @@ +resources: + cloud: oci + +file_mounts: + ~/tmpfile: ~/tmpfile + ~/a/b/c/tmpfile: ~/tmpfile + /tmp/workdir: ~/tmp-workdir + + /mydir: + name: skybucket + source: ['~/tmp-workdir'] + store: oci + mode: MOUNT + +setup: | + echo "*** Setup ***" + +run: | + echo "*** Run ***" + + ls -lthr ~/tmpfile + ls -lthr ~/a/b/c + echo hi >> /tmp/workdir/new_file + ls -lthr /tmp/workdir + + ls -lthr /mydir diff --git a/examples/spot/lightning_cifar10/train.py b/examples/spot/lightning_cifar10/train.py index 0df6f18484b..14901e635ef 100644 --- a/examples/spot/lightning_cifar10/train.py +++ b/examples/spot/lightning_cifar10/train.py @@ -163,7 +163,7 @@ def main(): ) model_ckpts = glob.glob(argv.root_dir + "/*.ckpt") - if argv.resume and len(model_ckpts) > 0: + if argv.resume and model_ckpts: latest_ckpt = max(model_ckpts, key=os.path.getctime) trainer.fit(model, cifar10_dm, ckpt_path=latest_ckpt) else: diff --git a/llm/ollama/ollama.yaml b/llm/ollama/ollama.yaml index 851dfe45dee..ed37c0ceb1b 100644 --- a/llm/ollama/ollama.yaml +++ b/llm/ollama/ollama.yaml @@ -47,13 +47,9 @@ service: setup: | # Install Ollama - if [ "$(uname -m)" == "aarch64" ]; then - # For apple silicon support - sudo curl -L https://ollama.com/download/ollama-linux-arm64 -o /usr/bin/ollama - else - sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama - fi - sudo chmod +x /usr/bin/ollama + # official installation reference: https://ollama.com/download/linux + curl -fsSL https://ollama.com/install.sh | sh + sudo chmod +x /usr/local/bin/ollama # Start `ollama serve` and capture PID to kill it after pull is done ollama serve & diff --git a/sky/adaptors/common.py b/sky/adaptors/common.py index 0cfb91cb587..d039813cd32 100644 --- a/sky/adaptors/common.py +++ b/sky/adaptors/common.py @@ -1,6 +1,7 @@ """Lazy import for modules to avoid import error when not used.""" import functools import importlib +import threading from typing import Any, Callable, Optional, Tuple @@ -24,17 +25,22 @@ def __init__(self, self._module = None self._import_error_message = import_error_message self._set_loggers = set_loggers + self._lock = threading.RLock() def load_module(self): - if self._module is None: - try: - self._module = importlib.import_module(self._module_name) - if self._set_loggers is not None: - self._set_loggers() - except ImportError as e: - if self._import_error_message is not None: - raise ImportError(self._import_error_message) from e - raise + # Avoid extra imports when multiple threads try to import the same + # module. The overhead is minor since import can only run in serial + # due to GIL even in multi-threaded environments. + with self._lock: + if self._module is None: + try: + self._module = importlib.import_module(self._module_name) + if self._set_loggers is not None: + self._set_loggers() + except ImportError as e: + if self._import_error_message is not None: + raise ImportError(self._import_error_message) from e + raise return self._module def __getattr__(self, name: str) -> Any: diff --git a/sky/adaptors/do.py b/sky/adaptors/do.py new file mode 100644 index 00000000000..d619efebc1c --- /dev/null +++ b/sky/adaptors/do.py @@ -0,0 +1,20 @@ +"""Digital Ocean cloud adaptors""" + +# pylint: disable=import-outside-toplevel + +from sky.adaptors import common + +_IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for DO. ' + 'Try pip install "skypilot[do]"') +pydo = common.LazyImport('pydo', import_error_message=_IMPORT_ERROR_MESSAGE) +azure = common.LazyImport('azure', import_error_message=_IMPORT_ERROR_MESSAGE) +_LAZY_MODULES = (pydo, azure) + + +# `pydo`` inherits Azure exceptions. See: +# https://github.com/digitalocean/pydo/blob/7b01498d99eb0d3a772366b642e5fab3d6fc6aa2/examples/poc_droplets_volumes_sshkeys.py#L6 +@common.load_lazy_modules(modules=_LAZY_MODULES) +def exceptions(): + """Azure exceptions.""" + from azure.core import exceptions as azure_exceptions + return azure_exceptions diff --git a/sky/adaptors/oci.py b/sky/adaptors/oci.py index 8fe09479a38..31712de414f 100644 --- a/sky/adaptors/oci.py +++ b/sky/adaptors/oci.py @@ -1,9 +1,11 @@ """Oracle OCI cloud adaptor""" +import functools import logging import os from sky.adaptors import common +from sky.clouds.utils import oci_utils # Suppress OCI circuit breaker logging before lazy import, because # oci modules prints additional message during imports, i.e., the @@ -30,10 +32,16 @@ def get_config_file() -> str: def get_oci_config(region=None, profile='DEFAULT'): conf_file_path = get_config_file() + if not profile or profile == 'DEFAULT': + config_profile = oci_utils.oci_config.get_profile() + else: + config_profile = profile + oci_config = oci.config.from_file(file_location=conf_file_path, - profile_name=profile) + profile_name=config_profile) if region is not None: oci_config['region'] = region + return oci_config @@ -54,6 +62,29 @@ def get_identity_client(region=None, profile='DEFAULT'): return oci.identity.IdentityClient(get_oci_config(region, profile)) +def get_object_storage_client(region=None, profile='DEFAULT'): + return oci.object_storage.ObjectStorageClient( + get_oci_config(region, profile)) + + def service_exception(): """OCI service exception.""" return oci.exceptions.ServiceError + + +def with_oci_env(f): + + @functools.wraps(f) + def wrapper(*args, **kwargs): + # pylint: disable=line-too-long + enter_env_cmds = [ + 'conda info --envs | grep "sky-oci-cli-env" || conda create -n sky-oci-cli-env python=3.10 -y', + '. $(conda info --base 2> /dev/null)/etc/profile.d/conda.sh > /dev/null 2>&1 || true', + 'conda activate sky-oci-cli-env', 'pip install oci-cli', + 'export OCI_CLI_SUPPRESS_FILE_PERMISSIONS_WARNING=True' + ] + operation_cmd = [f(*args, **kwargs)] + leave_env_cmds = ['conda deactivate'] + return ' && '.join(enter_env_cmds + operation_cmd + leave_env_cmds) + + return wrapper diff --git a/sky/authentication.py b/sky/authentication.py index 2eb65bd9f6f..6108073494f 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -408,14 +408,26 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: secret = k8s.client.V1Secret( metadata=k8s.client.V1ObjectMeta(**secret_metadata), string_data={secret_field_name: public_key}) - if kubernetes_utils.check_secret_exists(secret_name, namespace, context): - logger.debug(f'Key {secret_name} exists in the cluster, patching it...') - kubernetes.core_api(context).patch_namespaced_secret( - secret_name, namespace, secret) - else: - logger.debug( - f'Key {secret_name} does not exist in the cluster, creating it...') - kubernetes.core_api(context).create_namespaced_secret(namespace, secret) + try: + if kubernetes_utils.check_secret_exists(secret_name, namespace, + context): + logger.debug(f'Key {secret_name} exists in the cluster, ' + 'patching it...') + kubernetes.core_api(context).patch_namespaced_secret( + secret_name, namespace, secret) + else: + logger.debug(f'Key {secret_name} does not exist in the cluster, ' + 'creating it...') + kubernetes.core_api(context).create_namespaced_secret( + namespace, secret) + except kubernetes.api_exception() as e: + if e.status == 409 and e.reason == 'AlreadyExists': + logger.debug(f'Key {secret_name} was created concurrently, ' + 'patching it...') + kubernetes.core_api(context).patch_namespaced_secret( + secret_name, namespace, secret) + else: + raise e private_key_path, _ = get_or_generate_keys() if network_mode == nodeport_mode: diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 0f55b8a7f17..bf92f442d2f 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -650,6 +650,42 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]): return common_utils.dump_yaml_str(new_config) +def get_expirable_clouds( + enabled_clouds: Sequence[clouds.Cloud]) -> List[clouds.Cloud]: + """Returns a list of clouds that use local credentials and whose credentials can expire. + + This function checks each cloud in the provided sequence to determine if it uses local credentials + and if its credentials can expire. If both conditions are met, the cloud is added to the list of + expirable clouds. + + Args: + enabled_clouds (Sequence[clouds.Cloud]): A sequence of cloud objects to check. + + Returns: + list[clouds.Cloud]: A list of cloud objects that use local credentials and whose credentials can expire. + """ + expirable_clouds = [] + local_credentials_value = schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value + for cloud in enabled_clouds: + remote_identities = skypilot_config.get_nested( + (str(cloud).lower(), 'remote_identity'), None) + if remote_identities is None: + remote_identities = schemas.get_default_remote_identity( + str(cloud).lower()) + + local_credential_expiring = cloud.can_credential_expire() + if isinstance(remote_identities, str): + if remote_identities == local_credentials_value and local_credential_expiring: + expirable_clouds.append(cloud) + elif isinstance(remote_identities, list): + for profile in remote_identities: + if list(profile.values( + ))[0] == local_credentials_value and local_credential_expiring: + expirable_clouds.append(cloud) + break + return expirable_clouds + + # TODO: too many things happening here - leaky abstraction. Refactor. @timeline.event def write_cluster_config( @@ -926,6 +962,13 @@ def write_cluster_config( tmp_yaml_path, cluster_config_overrides=to_provision.cluster_config_overrides) kubernetes_utils.combine_metadata_fields(tmp_yaml_path) + yaml_obj = common_utils.read_yaml(tmp_yaml_path) + pod_config = yaml_obj['available_node_types']['ray_head_default'][ + 'node_config'] + valid, message = kubernetes_utils.check_pod_config(pod_config) + if not valid: + raise exceptions.InvalidCloudConfigs( + f'Invalid pod_config. Details: {message}') if dryrun: # If dryrun, return the unfinished tmp yaml path. @@ -1000,6 +1043,7 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): clouds.Cudo, clouds.Paperspace, clouds.Azure, + clouds.DO, )): config = auth.configure_ssh_info(config) elif isinstance(cloud, clouds.GCP): @@ -1019,10 +1063,6 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): common_utils.dump_yaml(cluster_config_file, config) -def get_run_timestamp() -> str: - return 'sky-' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f') - - def get_timestamp_from_run_timestamp(run_timestamp: str) -> float: return datetime.strptime( run_timestamp.partition('-')[2], '%Y-%m-%d-%H-%M-%S-%f').timestamp() diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 8974a0129bd..e1ca81b9e0a 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -26,6 +26,7 @@ import sky from sky import backends +from sky import check as sky_check from sky import cloud_stores from sky import clouds from sky import exceptions @@ -178,6 +179,7 @@ def _get_cluster_config_template(cloud): clouds.SCP: 'scp-ray.yml.j2', clouds.OCI: 'oci-ray.yml.j2', clouds.Paperspace: 'paperspace-ray.yml.j2', + clouds.DO: 'do-ray.yml.j2', clouds.RunPod: 'runpod-ray.yml.j2', clouds.Kubernetes: 'kubernetes-ray.yml.j2', clouds.Vsphere: 'vsphere-ray.yml.j2', @@ -1995,6 +1997,22 @@ def provision_with_retries( skip_unnecessary_provisioning else None) failover_history: List[Exception] = list() + # If the user is using local credentials which may expire, the + # controller may leak resources if the credentials expire while a job + # is running. Here we check the enabled clouds and expiring credentials + # and raise a warning to the user. + if task.is_controller_task(): + enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh() + expirable_clouds = backend_utils.get_expirable_clouds( + enabled_clouds) + + if len(expirable_clouds) > 0: + warnings = (f'\033[93mWarning: Credentials used for ' + f'{expirable_clouds} may expire. Clusters may be ' + f'leaked if the credentials expire while jobs ' + f'are running. It is recommended to use credentials' + f' that never expire or a service account.\033[0m') + logger.warning(warnings) # Retrying launchable resources. while True: @@ -2599,7 +2617,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']): ResourceHandle = CloudVmRayResourceHandle # pylint: disable=invalid-name def __init__(self): - self.run_timestamp = backend_utils.get_run_timestamp() + self.run_timestamp = sky_logging.get_run_timestamp() # NOTE: do not expanduser() here, as this '~/...' path is used for # remote as well to be expanded on the remote side. self.log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, @@ -2626,7 +2644,7 @@ def register_info(self, **kwargs) -> None: self._optimize_target) or optimizer.OptimizeTarget.COST self._requested_features = kwargs.pop('requested_features', self._requested_features) - assert len(kwargs) == 0, f'Unexpected kwargs: {kwargs}' + assert not kwargs, f'Unexpected kwargs: {kwargs}' def check_resources_fit_cluster( self, @@ -3309,7 +3327,7 @@ def error_message() -> str: # even if some of them raise exceptions. We should replace it with # multi-process. rich_utils.stop_safe_status() - subprocess_utils.run_in_parallel(_setup_node, range(num_nodes)) + subprocess_utils.run_in_parallel(_setup_node, list(range(num_nodes))) if detach_setup: # Only set this when setup needs to be run outside the self._setup() @@ -3873,6 +3891,152 @@ def tail_managed_job_logs(self, stdin=subprocess.DEVNULL, ) + def sync_down_managed_job_logs( + self, + handle: CloudVmRayResourceHandle, + job_id: Optional[int] = None, + job_name: Optional[str] = None, + controller: bool = False, + local_dir: str = constants.SKY_LOGS_DIRECTORY) -> Dict[str, str]: + """Sync down logs for a managed job. + + Args: + handle: The handle to the cluster. + job_id: The job ID to sync down logs for. + job_name: The job name to sync down logs for. + controller: Whether to sync down logs for the controller. + local_dir: The local directory to sync down logs to. + + Returns: + A dictionary mapping job_id to log path. + """ + # if job_name is not None, job_id should be None + assert job_name is None or job_id is None, (job_name, job_id) + if job_id is None and job_name is not None: + # generate code to get the job_id + code = managed_jobs.ManagedJobCodeGen.get_all_job_ids_by_name( + job_name=job_name) + returncode, run_timestamps, stderr = self.run_on_head( + handle, + code, + stream_logs=False, + require_outputs=True, + separate_stderr=True) + subprocess_utils.handle_returncode(returncode, code, + 'Failed to sync down logs.', + stderr) + job_ids = common_utils.decode_payload(run_timestamps) + if not job_ids: + logger.info(f'{colorama.Fore.YELLOW}' + 'No matching job found' + f'{colorama.Style.RESET_ALL}') + return {} + elif len(job_ids) > 1: + logger.info( + f'{colorama.Fore.YELLOW}' + f'Multiple jobs IDs found under the name {job_name}. ' + 'Downloading the latest job logs.' + f'{colorama.Style.RESET_ALL}') + job_ids = [job_ids[0]] # descending order + else: + job_ids = [job_id] + + # get the run_timestamp + # the function takes in [job_id] + code = job_lib.JobLibCodeGen.get_run_timestamp_with_globbing(job_ids) + returncode, run_timestamps, stderr = self.run_on_head( + handle, + code, + stream_logs=False, + require_outputs=True, + separate_stderr=True) + subprocess_utils.handle_returncode(returncode, code, + 'Failed to sync logs.', stderr) + # returns with a dict of {job_id: run_timestamp} + run_timestamps = common_utils.decode_payload(run_timestamps) + if not run_timestamps: + logger.info(f'{colorama.Fore.YELLOW}' + 'No matching log directories found' + f'{colorama.Style.RESET_ALL}') + return {} + + run_timestamp = list(run_timestamps.values())[0] + job_id = list(run_timestamps.keys())[0] + local_log_dir = '' + if controller: # download controller logs + remote_log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, + run_timestamp) + local_log_dir = os.path.expanduser( + os.path.join(local_dir, run_timestamp)) + + logger.info(f'{colorama.Fore.CYAN}' + f'Job {job_ids} local logs: {local_log_dir}' + f'{colorama.Style.RESET_ALL}') + + runners = handle.get_command_runners() + + def _rsync_down(args) -> None: + """Rsync down logs from remote nodes. + + Args: + args: A tuple of (runner, local_log_dir, remote_log_dir) + """ + (runner, local_log_dir, remote_log_dir) = args + try: + os.makedirs(local_log_dir, exist_ok=True) + runner.rsync( + source=f'{remote_log_dir}/', + target=local_log_dir, + up=False, + stream_logs=False, + ) + except exceptions.CommandError as e: + if e.returncode == exceptions.RSYNC_FILE_NOT_FOUND_CODE: + # Raised by rsync_down. Remote log dir may not exist + # since the job can be run on some part of the nodes. + logger.debug( + f'{runner.node_id} does not have the tasks/*.') + else: + raise + + parallel_args = [[runner, *item] + for item in zip([local_log_dir], [remote_log_dir]) + for runner in runners] + subprocess_utils.run_in_parallel(_rsync_down, parallel_args) + else: # download job logs + local_log_dir = os.path.expanduser( + os.path.join(local_dir, 'managed_jobs', run_timestamp)) + os.makedirs(os.path.dirname(local_log_dir), exist_ok=True) + log_file = os.path.join(local_log_dir, 'run.log') + + code = managed_jobs.ManagedJobCodeGen.stream_logs(job_name=None, + job_id=job_id, + follow=False, + controller=False) + + # With the stdin=subprocess.DEVNULL, the ctrl-c will not + # kill the process, so we need to handle it manually here. + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGINT, backend_utils.interrupt_handler) + signal.signal(signal.SIGTSTP, backend_utils.stop_handler) + + # We redirect the output to the log file + # and disable the STDOUT and STDERR + self.run_on_head( + handle, + code, + log_path=log_file, + stream_logs=False, + process_stream=False, + ssh_mode=command_runner.SshMode.INTERACTIVE, + stdin=subprocess.DEVNULL, + ) + + logger.info(f'{colorama.Fore.CYAN}' + f'Job {job_id} logs: {local_log_dir}' + f'{colorama.Style.RESET_ALL}') + return {str(job_id): local_log_dir} + def tail_serve_logs(self, handle: CloudVmRayResourceHandle, service_name: str, target: serve_lib.ServiceComponent, replica_id: Optional[int], follow: bool) -> None: @@ -4198,11 +4362,20 @@ def post_teardown_cleanup(self, attempts = 0 while True: logger.debug(f'instance statuses attempt {attempts + 1}') - node_status_dict = provision_lib.query_instances( - repr(cloud), - cluster_name_on_cloud, - config['provider'], - non_terminated_only=False) + try: + node_status_dict = provision_lib.query_instances( + repr(cloud), + cluster_name_on_cloud, + config['provider'], + non_terminated_only=False) + except Exception as e: # pylint: disable=broad-except + if purge: + logger.warning( + f'Failed to query instances. Skipping since purge is ' + f'set. Details: ' + f'{common_utils.format_exception(e, use_bracket=True)}') + break + raise unexpected_node_state: Optional[Tuple[str, str]] = None for node_id, node_status in node_status_dict.items(): @@ -4221,8 +4394,13 @@ def post_teardown_cleanup(self, time.sleep(_TEARDOWN_WAIT_BETWEEN_ATTEMPS_SECONDS) else: (node_id, node_status) = unexpected_node_state - raise RuntimeError(f'Instance {node_id} in unexpected state ' - f'{node_status}.') + if purge: + logger.warning(f'Instance {node_id} in unexpected ' + f'state {node_status}. Skipping since purge ' + 'is set.') + break + raise RuntimeError(f'Instance {node_id} in unexpected ' + f'state {node_status}.') global_user_state.remove_cluster(handle.cluster_name, terminate=terminate) diff --git a/sky/backends/wheel_utils.py b/sky/backends/wheel_utils.py index ed580569e0b..805117ee2a3 100644 --- a/sky/backends/wheel_utils.py +++ b/sky/backends/wheel_utils.py @@ -153,7 +153,10 @@ def _get_latest_modification_time(path: pathlib.Path) -> float: if not path.exists(): return -1. try: - return max(os.path.getmtime(root) for root, _, _ in os.walk(path)) + return max( + os.path.getmtime(os.path.join(root, f)) + for root, dirs, files in os.walk(path) + for f in (*dirs, *files)) except ValueError: return -1. diff --git a/sky/benchmark/benchmark_utils.py b/sky/benchmark/benchmark_utils.py index c9c17f00944..766b1fa9138 100644 --- a/sky/benchmark/benchmark_utils.py +++ b/sky/benchmark/benchmark_utils.py @@ -535,7 +535,7 @@ def launch_benchmark_clusters(benchmark: str, clusters: List[str], for yaml_fd, cluster in zip(yaml_fds, clusters)] # Save stdout/stderr from cluster launches. - run_timestamp = backend_utils.get_run_timestamp() + run_timestamp = sky_logging.get_run_timestamp() log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp) log_dir = os.path.expanduser(log_dir) logger.info( diff --git a/sky/check.py b/sky/check.py index ee5ea77234b..1ab92cb1af6 100644 --- a/sky/check.py +++ b/sky/check.py @@ -127,7 +127,7 @@ def get_all_clouds(): '\nNote: The following clouds were disabled because they were not ' 'included in allowed_clouds in ~/.sky/config.yaml: ' f'{", ".join([c for c in disallowed_cloud_names])}') - if len(all_enabled_clouds) == 0: + if not all_enabled_clouds: echo( click.style( 'No cloud is enabled. SkyPilot will not be able to run any ' diff --git a/sky/cli.py b/sky/cli.py index 12f77e9f6c9..27948f9ec85 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -114,7 +114,7 @@ def _get_glob_clusters(clusters: List[str], silent: bool = False) -> List[str]: glob_clusters = [] for cluster in clusters: glob_cluster = global_user_state.get_glob_cluster_names(cluster) - if len(glob_cluster) == 0 and not silent: + if not glob_cluster and not silent: click.echo(f'Cluster {cluster} not found.') glob_clusters.extend(glob_cluster) return list(set(glob_clusters)) @@ -125,7 +125,7 @@ def _get_glob_storages(storages: List[str]) -> List[str]: glob_storages = [] for storage_object in storages: glob_storage = global_user_state.get_glob_storage_name(storage_object) - if len(glob_storage) == 0: + if not glob_storage: click.echo(f'Storage {storage_object} not found.') glob_storages.extend(glob_storage) return list(set(glob_storages)) @@ -830,7 +830,7 @@ class _NaturalOrderGroup(click.Group): Reference: https://github.com/pallets/click/issues/513 """ - def list_commands(self, ctx): + def list_commands(self, ctx): # pylint: disable=unused-argument return self.commands.keys() @usage_lib.entrypoint('sky.cli', fallback=True) @@ -998,8 +998,10 @@ def cli(): @click.option('--docker', 'backend_name', flag_value=backends.LocalDockerBackend.NAME, - default=False, - help='If used, runs locally inside a docker container.') + hidden=True, + help=('(Deprecated) Local docker support is deprecated. ' + 'To run locally, create a local Kubernetes cluster with ' + '``sky local up``.')) @_add_click_options(_TASK_OPTIONS_WITH_NAME + _EXTRA_RESOURCES_OPTIONS) @click.option( '--idle-minutes-to-autostop', @@ -1142,6 +1144,11 @@ def launch( backend: backends.Backend if backend_name == backends.LocalDockerBackend.NAME: backend = backends.LocalDockerBackend() + click.secho( + 'WARNING: LocalDockerBackend is deprecated and will be ' + 'removed in a future release. To run locally, create a local ' + 'Kubernetes cluster with `sky local up`.', + fg='yellow') elif backend_name == backends.CloudVmRayBackend.NAME: backend = backends.CloudVmRayBackend() else: @@ -1473,7 +1480,7 @@ def _get_services(service_names: Optional[List[str]], if len(service_records) != 1: plural = 's' if len(service_records) > 1 else '' service_num = (str(len(service_records)) - if len(service_records) > 0 else 'No') + if service_records else 'No') raise click.UsageError( f'{service_num} service{plural} found. Please specify ' 'an existing service to show its endpoint. Usage: ' @@ -1696,8 +1703,7 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool, if len(clusters) != 1: with ux_utils.print_exception_no_traceback(): plural = 's' if len(clusters) > 1 else '' - cluster_num = (str(len(clusters)) - if len(clusters) > 0 else 'No') + cluster_num = (str(len(clusters)) if clusters else 'No') cause = 'a single' if len(clusters) > 1 else 'an existing' raise ValueError( _STATUS_PROPERTY_CLUSTER_NUM_ERROR_MESSAGE.format( @@ -1722,9 +1728,8 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool, with ux_utils.print_exception_no_traceback(): plural = 's' if len(cluster_records) > 1 else '' cluster_num = (str(len(cluster_records)) - if len(cluster_records) > 0 else - f'{clusters[0]!r}') - verb = 'found' if len(cluster_records) > 0 else 'not found' + if cluster_records else f'{clusters[0]!r}') + verb = 'found' if cluster_records else 'not found' cause = 'a single' if len(clusters) > 1 else 'an existing' raise ValueError( _STATUS_PROPERTY_CLUSTER_NUM_ERROR_MESSAGE.format( @@ -2470,7 +2475,7 @@ def start( '(see `sky status`), or the -a/--all flag.') if all: - if len(clusters) > 0: + if clusters: click.echo('Both --all and cluster(s) specified for sky start. ' 'Letting --all take effect.') @@ -2800,7 +2805,7 @@ def _down_or_stop_clusters( option_str = '{stop,down}' operation = f'{verb} auto{option_str} on' - if len(names) > 0: + if names: controllers = [ name for name in names if controller_utils.Controllers.from_name(name) is not None @@ -2814,7 +2819,7 @@ def _down_or_stop_clusters( # Make sure the controllers are explicitly specified without other # normal clusters. if controllers: - if len(names) != 0: + if names: names_str = ', '.join(map(repr, names)) raise click.UsageError( f'{operation} controller(s) ' @@ -2867,7 +2872,7 @@ def _down_or_stop_clusters( if apply_to_all: all_clusters = global_user_state.get_clusters() - if len(names) > 0: + if names: click.echo( f'Both --all and cluster(s) specified for `sky {command}`. ' 'Letting --all take effect.') @@ -2894,7 +2899,7 @@ def _down_or_stop_clusters( click.echo('Cluster(s) not found (tip: see `sky status`).') return - if not no_confirm and len(clusters) > 0: + if not no_confirm and clusters: cluster_str = 'clusters' if len(clusters) > 1 else 'cluster' cluster_list = ', '.join(clusters) click.confirm( @@ -3003,7 +3008,7 @@ def check(clouds: Tuple[str], verbose: bool): # Check only specific clouds - AWS and GCP. sky check aws gcp """ - clouds_arg = clouds if len(clouds) > 0 else None + clouds_arg = clouds if clouds else None sky_check.check(verbose=verbose, clouds=clouds_arg) @@ -3138,7 +3143,7 @@ def _get_kubernetes_realtime_gpu_table( f'capacity ({list(capacity.keys())}), ' f'and available ({list(available.keys())}) ' 'must be same.') - if len(counts) == 0: + if not counts: err_msg = 'No GPUs found in Kubernetes cluster. ' debug_msg = 'To further debug, run: sky check ' if name_filter is not None: @@ -3282,7 +3287,7 @@ def _output(): for tpu in service_catalog.get_tpus(): if tpu in result: tpu_table.add_row([tpu, _list_to_str(result.pop(tpu))]) - if len(tpu_table.get_string()) > 0: + if tpu_table.get_string(): yield '\n\n' yield from tpu_table.get_string() @@ -3393,7 +3398,7 @@ def _output(): yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' f'Cloud GPUs{colorama.Style.RESET_ALL}\n') - if len(result) == 0: + if not result: quantity_str = (f' with requested quantity {quantity}' if quantity else '') cloud_str = f' on {cloud_obj}.' if cloud_name else ' in cloud catalogs.' @@ -3522,7 +3527,7 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r # Delete all storage objects. sky storage delete -a """ - if sum([len(names) > 0, all]) != 1: + if sum([bool(names), all]) != 1: raise click.UsageError('Either --all or a name must be specified.') if all: storages = sky.storage_ls() @@ -3881,8 +3886,8 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool): exit_if_not_accessible=True) job_id_str = ','.join(map(str, job_ids)) - if sum([len(job_ids) > 0, name is not None, all]) != 1: - argument_str = f'--job-ids {job_id_str}' if len(job_ids) > 0 else '' + if sum([bool(job_ids), name is not None, all]) != 1: + argument_str = f'--job-ids {job_id_str}' if job_ids else '' argument_str += f' --name {name}' if name is not None else '' argument_str += ' --all' if all else '' raise click.UsageError( @@ -3928,17 +3933,29 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool): required=False, help='Query the latest job logs, restarting the jobs controller if stopped.' ) +@click.option('--sync-down', + '-s', + default=False, + is_flag=True, + required=False, + help='Download logs for all jobs shown in the queue.') @click.argument('job_id', required=False, type=int) @usage_lib.entrypoint def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool, - controller: bool, refresh: bool): - """Tail the log of a managed job.""" + controller: bool, refresh: bool, sync_down: bool): + """Tail or sync down the log of a managed job.""" try: - managed_jobs.tail_logs(name=name, - job_id=job_id, - follow=follow, - controller=controller, - refresh=refresh) + if sync_down: + managed_jobs.sync_down_logs(name=name, + job_id=job_id, + controller=controller, + refresh=refresh) + else: + managed_jobs.tail_logs(name=name, + job_id=job_id, + follow=follow, + controller=controller, + refresh=refresh) except exceptions.ClusterNotUpError: with ux_utils.print_exception_no_traceback(): raise @@ -4523,9 +4540,9 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool, # Forcefully tear down a specific replica, even in failed status. sky serve down my-service --replica-id 1 --purge """ - if sum([len(service_names) > 0, all]) != 1: - argument_str = f'SERVICE_NAMES={",".join(service_names)}' if len( - service_names) > 0 else '' + if sum([bool(service_names), all]) != 1: + argument_str = (f'SERVICE_NAMES={",".join(service_names)}' + if service_names else '') argument_str += ' --all' if all else '' raise click.UsageError( 'Can only specify one of SERVICE_NAMES or --all. ' @@ -4898,7 +4915,7 @@ def benchmark_launch( if idle_minutes_to_autostop is None: idle_minutes_to_autostop = 5 commandline_args['idle-minutes-to-autostop'] = idle_minutes_to_autostop - if len(env) > 0: + if env: commandline_args['env'] = [f'{k}={v}' for k, v in env] # Launch the benchmarking clusters in detach mode in parallel. @@ -5177,7 +5194,7 @@ def benchmark_delete(benchmarks: Tuple[str], all: Optional[bool], raise click.BadParameter( 'Either specify benchmarks or use --all to delete all benchmarks.') to_delete = [] - if len(benchmarks) > 0: + if benchmarks: for benchmark in benchmarks: record = benchmark_state.get_benchmark_from_name(benchmark) if record is None: @@ -5186,7 +5203,7 @@ def benchmark_delete(benchmarks: Tuple[str], all: Optional[bool], to_delete.append(record) if all: to_delete = benchmark_state.get_benchmarks() - if len(benchmarks) > 0: + if benchmarks: print('Both --all and benchmark(s) specified ' 'for sky bench delete. Letting --all take effect.') @@ -5288,7 +5305,7 @@ def _deploy_local_cluster(gpus: bool): run_command = shlex.split(run_command) # Setup logging paths - run_timestamp = backend_utils.get_run_timestamp() + run_timestamp = sky_logging.get_run_timestamp() log_path = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp, 'local_up.log') tail_cmd = 'tail -n100 -f ' + log_path @@ -5402,7 +5419,7 @@ def _deploy_remote_cluster(ip_file: str, ssh_user: str, ssh_key_path: str, deploy_command = shlex.split(deploy_command) # Setup logging paths - run_timestamp = backend_utils.get_run_timestamp() + run_timestamp = sky_logging.get_run_timestamp() log_path = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp, 'local_up.log') tail_cmd = 'tail -n100 -f ' + log_path @@ -5517,7 +5534,7 @@ def local_down(): run_command = shlex.split(down_script_path) # Setup logging paths - run_timestamp = backend_utils.get_run_timestamp() + run_timestamp = sky_logging.get_run_timestamp() log_path = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp, 'local_down.log') tail_cmd = 'tail -n100 -f ' + log_path diff --git a/sky/cloud_stores.py b/sky/cloud_stores.py index ee1b051d32b..e24c4f3ad03 100644 --- a/sky/cloud_stores.py +++ b/sky/cloud_stores.py @@ -7,6 +7,7 @@ * Better interface. * Better implementation (e.g., fsspec, smart_open, using each cloud's SDK). """ +import os import shlex import subprocess import time @@ -18,6 +19,7 @@ from sky.adaptors import azure from sky.adaptors import cloudflare from sky.adaptors import ibm +from sky.adaptors import oci from sky.clouds import gcp from sky.data import data_utils from sky.data.data_utils import Rclone @@ -111,8 +113,16 @@ class GcsCloudStorage(CloudStorage): @property def _gsutil_command(self): gsutil_alias, alias_gen = data_utils.get_gsutil_command() - return (f'{alias_gen}; GOOGLE_APPLICATION_CREDENTIALS=' - f'{gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH} {gsutil_alias}') + return ( + f'{alias_gen}; GOOGLE_APPLICATION_CREDENTIALS=' + f'{gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH}; ' + # Explicitly activate service account. Unlike the gcp packages + # and other GCP commands, gsutil does not automatically pick up + # the default credential keys when it is a service account. + 'gcloud auth activate-service-account ' + '--key-file=$GOOGLE_APPLICATION_CREDENTIALS ' + '2> /dev/null || true; ' + f'{gsutil_alias}') def is_directory(self, url: str) -> bool: """Returns whether 'url' is a directory. @@ -133,7 +143,7 @@ def is_directory(self, url: str) -> bool: # If is a bucket root, then we only need `gsutil` to succeed # to make sure the bucket exists. It is already a directory. _, key = data_utils.split_gcs_path(url) - if len(key) == 0: + if not key: return True # Otherwise, gsutil ls -d url will return: # --> url.rstrip('/') if url is not a directory @@ -470,6 +480,64 @@ def make_sync_file_command(self, source: str, destination: str) -> str: return self.make_sync_dir_command(source, destination) +class OciCloudStorage(CloudStorage): + """OCI Cloud Storage.""" + + def is_directory(self, url: str) -> bool: + """Returns whether OCI 'url' is a directory. + In cloud object stores, a "directory" refers to a regular object whose + name is a prefix of other objects. + """ + bucket_name, path = data_utils.split_oci_path(url) + + client = oci.get_object_storage_client() + namespace = client.get_namespace( + compartment_id=oci.get_oci_config()['tenancy']).data + + objects = client.list_objects(namespace_name=namespace, + bucket_name=bucket_name, + prefix=path).data.objects + + if len(objects) == 0: + # A directory with few or no items + return True + + if len(objects) > 1: + # A directory with more than 1 items + return True + + object_name = objects[0].name + if path.endswith(object_name): + # An object path + return False + + # A directory with only 1 item + return True + + @oci.with_oci_env + def make_sync_dir_command(self, source: str, destination: str) -> str: + """Downloads using OCI CLI.""" + bucket_name, path = data_utils.split_oci_path(source) + + download_via_ocicli = (f'oci os object sync --no-follow-symlinks ' + f'--bucket-name {bucket_name} ' + f'--prefix "{path}" --dest-dir "{destination}"') + + return download_via_ocicli + + @oci.with_oci_env + def make_sync_file_command(self, source: str, destination: str) -> str: + """Downloads a file using OCI CLI.""" + bucket_name, path = data_utils.split_oci_path(source) + filename = os.path.basename(path) + destination = os.path.join(destination, filename) + + download_via_ocicli = (f'oci os object get --bucket-name {bucket_name} ' + f'--name "{path}" --file "{destination}"') + + return download_via_ocicli + + def get_storage_from_path(url: str) -> CloudStorage: """Returns a CloudStorage by identifying the scheme:// in a URL.""" result = urllib.parse.urlsplit(url) @@ -485,6 +553,7 @@ def get_storage_from_path(url: str) -> CloudStorage: 's3': S3CloudStorage(), 'r2': R2CloudStorage(), 'cos': IBMCosCloudStorage(), + 'oci': OciCloudStorage(), # TODO: This is a hack, as Azure URL starts with https://, we should # refactor the registry to be able to take regex, so that Azure blob can # be identified with `https://(.*?)\.blob\.core\.windows\.net` diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index c4d46e93adf..24b805fe8bc 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -15,6 +15,7 @@ from sky.clouds.aws import AWS from sky.clouds.azure import Azure from sky.clouds.cudo import Cudo +from sky.clouds.do import DO from sky.clouds.fluidstack import Fluidstack from sky.clouds.gcp import GCP from sky.clouds.ibm import IBM @@ -34,6 +35,7 @@ 'Cudo', 'GCP', 'Lambda', + 'DO', 'Paperspace', 'SCP', 'RunPod', diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index cafc789c5be..b37992e97c3 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -2,6 +2,8 @@ import enum import fnmatch import functools +import hashlib +import json import os import re import subprocess @@ -16,6 +18,7 @@ from sky import skypilot_config from sky.adaptors import aws from sky.clouds import service_catalog +from sky.clouds.service_catalog import common as catalog_common from sky.clouds.utils import aws_utils from sky.skylet import constants from sky.utils import common_utils @@ -92,6 +95,10 @@ class AWSIdentityType(enum.Enum): CONTAINER_ROLE = 'container-role' + CUSTOM_PROCESS = 'custom-process' + + ASSUME_ROLE = 'assume-role' + # Name Value Type Location # ---- ----- ---- -------- # profile None None @@ -100,6 +107,24 @@ class AWSIdentityType(enum.Enum): # region us-east-1 config-file ~/.aws/config SHARED_CREDENTIALS_FILE = 'shared-credentials-file' + def can_credential_expire(self) -> bool: + """Check if the AWS identity type can expire. + + SSO,IAM_ROLE and CONTAINER_ROLE are temporary credentials and refreshed + automatically. ENV and SHARED_CREDENTIALS_FILE are short-lived + credentials without refresh. + IAM ROLE: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + SSO/Container-role refresh token: + https://docs.aws.amazon.com/solutions/latest/dea-api/auth-refreshtoken.html + """ + # TODO(hong): Add a CLI based check for the expiration of the temporary + # credentials + expirable_types = { + AWSIdentityType.ENV, AWSIdentityType.SHARED_CREDENTIALS_FILE + } + return self in expirable_types + @clouds.CLOUD_REGISTRY.register class AWS(clouds.Cloud): @@ -593,10 +618,27 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: hints = f'AWS IAM role is set.{single_cloud_hint}' elif identity_type == AWSIdentityType.CONTAINER_ROLE: # Similar to the IAM ROLE, an ECS container may not store credentials - # in the~/.aws/credentials file. So we don't check for the existence of + # in the ~/.aws/credentials file. So we don't check for the existence of # the file. i.e. the container will be assigned the IAM role of the # task: skypilot-v1. hints = f'AWS container-role is set.{single_cloud_hint}' + elif identity_type == AWSIdentityType.CUSTOM_PROCESS: + # Similar to the IAM ROLE, a custom process may not store credentials + # in the ~/.aws/credentials file. So we don't check for the existence of + # the file. i.e. the custom process will be assigned the IAM role of the + # task: skypilot-v1. + hints = f'AWS custom-process is set.{single_cloud_hint}' + elif identity_type == AWSIdentityType.ASSUME_ROLE: + # When using ASSUME ROLE, the credentials are coming from a different + # source profile. So we don't check for the existence of ~/.aws/credentials. + # i.e. the assumed role will be assigned the IAM role of the + # task: skypilot-v1. + hints = f'AWS assume-role is set.{single_cloud_hint}' + elif identity_type == AWSIdentityType.ENV: + # When using ENV vars, the credentials are coming from the environment + # variables. So we don't check for the existence of ~/.aws/credentials. + # i.e. the identity is not determined by the file. + hints = f'AWS env is set.{single_cloud_hint}' else: # This file is required because it is required by the VMs launched on # other clouds to access private s3 buckets and resources like EC2. @@ -624,14 +666,10 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: @classmethod def _current_identity_type(cls) -> Optional[AWSIdentityType]: - proc = subprocess.run('aws configure list', - shell=True, - check=False, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - if proc.returncode != 0: + stdout = cls._aws_configure_list() + if stdout is None: return None - stdout = proc.stdout.decode() + output = stdout.decode() # We determine the identity type by looking at the output of # `aws configure list`. The output looks like: @@ -646,55 +684,32 @@ def _current_identity_type(cls) -> Optional[AWSIdentityType]: def _is_access_key_of_type(type_str: str) -> bool: # The dot (.) does not match line separators. - results = re.findall(fr'access_key.*{type_str}', stdout) + results = re.findall(fr'access_key.*{type_str}', output) if len(results) > 1: raise RuntimeError( - f'Unexpected `aws configure list` output:\n{stdout}') + f'Unexpected `aws configure list` output:\n{output}') return len(results) == 1 - if _is_access_key_of_type(AWSIdentityType.SSO.value): - return AWSIdentityType.SSO - elif _is_access_key_of_type(AWSIdentityType.IAM_ROLE.value): - return AWSIdentityType.IAM_ROLE - elif _is_access_key_of_type(AWSIdentityType.CONTAINER_ROLE.value): - return AWSIdentityType.CONTAINER_ROLE - elif _is_access_key_of_type(AWSIdentityType.ENV.value): - return AWSIdentityType.ENV - else: - return AWSIdentityType.SHARED_CREDENTIALS_FILE + for identity_type in AWSIdentityType: + if _is_access_key_of_type(identity_type.value): + return identity_type + return AWSIdentityType.SHARED_CREDENTIALS_FILE @classmethod - @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. - def get_user_identities(cls) -> Optional[List[List[str]]]: - """Returns a [UserId, Account] list that uniquely identifies the user. - - These fields come from `aws sts get-caller-identity`. We permit the same - actual user to: - - - switch between different root accounts (after which both elements - of the list will be different) and have their clusters owned by - each account be protected; or - - - within the same root account, switch between different IAM - users, and treat [user_id=1234, account=A] and - [user_id=4567, account=A] to be the *same*. Namely, switching - between these IAM roles within the same root account will cause - the first element of the returned list to differ, and will allow - the same actual user to continue to interact with their clusters. - Note: this is not 100% safe, since the IAM users can have very - specific permissions, that disallow them to access the clusters - but it is a reasonable compromise as that could be rare. - - Returns: - A list of strings that uniquely identifies the user on this cloud. - For identity check, we will fallback through the list of strings - until we find a match, and print a warning if we fail for the - first string. + @functools.lru_cache(maxsize=1) + def _aws_configure_list(cls) -> Optional[bytes]: + proc = subprocess.run('aws configure list', + shell=True, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + if proc.returncode != 0: + return None + return proc.stdout - Raises: - exceptions.CloudUserIdentityError: if the user identity cannot be - retrieved. - """ + @classmethod + @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. + def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]: try: sts = aws.client('sts') # The caller identity contains 3 fields: UserId, Account, Arn. @@ -773,6 +788,72 @@ def get_user_identities(cls) -> Optional[List[List[str]]]: # automatic switching for AWS. Currently we only support one identity. return [user_ids] + @classmethod + @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. + def get_user_identities(cls) -> Optional[List[List[str]]]: + """Returns a [UserId, Account] list that uniquely identifies the user. + + These fields come from `aws sts get-caller-identity` and are cached + locally by `aws configure list` output. The identities are assumed to + be stable for the duration of the `sky` process. Modifying the + credentials while the `sky` process is running will not affect the + identity returned by this function. + + We permit the same actual user to: + + - switch between different root accounts (after which both elements + of the list will be different) and have their clusters owned by + each account be protected; or + + - within the same root account, switch between different IAM + users, and treat [user_id=1234, account=A] and + [user_id=4567, account=A] to be the *same*. Namely, switching + between these IAM roles within the same root account will cause + the first element of the returned list to differ, and will allow + the same actual user to continue to interact with their clusters. + Note: this is not 100% safe, since the IAM users can have very + specific permissions, that disallow them to access the clusters + but it is a reasonable compromise as that could be rare. + + Returns: + A list of strings that uniquely identifies the user on this cloud. + For identity check, we will fallback through the list of strings + until we find a match, and print a warning if we fail for the + first string. + + Raises: + exceptions.CloudUserIdentityError: if the user identity cannot be + retrieved. + """ + stdout = cls._aws_configure_list() + if stdout is None: + # `aws configure list` is not available, possible reasons: + # - awscli is not installed but credentials are valid, e.g. run from + # an EC2 instance with IAM role + # - aws credentials are not set, proceed anyway to get unified error + # message for users + return cls._sts_get_caller_identity() + config_hash = hashlib.md5(stdout).hexdigest()[:8] + # Getting aws identity cost ~1s, so we cache the result with the output of + # `aws configure list` as cache key. Different `aws configure list` output + # can have same aws identity, our assumption is the output would be stable + # in real world, so the number of cache files would be limited. + # TODO(aylei): consider using a more stable cache key and evalute eviction. + cache_path = catalog_common.get_catalog_path( + f'aws/.cache/user-identity-{config_hash}.txt') + if os.path.exists(cache_path): + try: + with open(cache_path, 'r', encoding='utf-8') as f: + return json.loads(f.read()) + except json.JSONDecodeError: + # cache is invalid, ignore it and fetch identity again + pass + + result = cls._sts_get_caller_identity() + with open(cache_path, 'w', encoding='utf-8') as f: + f.write(json.dumps(result)) + return result + @classmethod def get_active_user_identity_str(cls) -> Optional[str]: user_identity = cls.get_active_user_identity() @@ -812,6 +893,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]: if os.path.exists(os.path.expanduser(f'~/.aws/{filename}')) } + @functools.lru_cache(maxsize=1) + def can_credential_expire(self) -> bool: + identity_type = self._current_identity_type() + return identity_type is not None and identity_type.can_credential_expire( + ) + def instance_type_exists(self, instance_type): return service_catalog.instance_type_exists(instance_type, clouds='aws') diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 455baeaf5d9..2cb45ca14fc 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -536,6 +536,10 @@ def get_credential_file_mounts(self) -> Dict[str, str]: """ raise NotImplementedError + def can_credential_expire(self) -> bool: + """Returns whether the cloud credential can expire.""" + return False + @classmethod def get_image_size(cls, image_id: str, region: Optional[str]) -> float: """Check the image size from the cloud. diff --git a/sky/clouds/do.py b/sky/clouds/do.py new file mode 100644 index 00000000000..3a64ead3ad0 --- /dev/null +++ b/sky/clouds/do.py @@ -0,0 +1,303 @@ +""" Digital Ocean Cloud. """ + +import json +import typing +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from sky import clouds +from sky.adaptors import do +from sky.clouds import service_catalog +from sky.provision.do import utils as do_utils +from sky.utils import resources_utils + +if typing.TYPE_CHECKING: + from sky import resources as resources_lib + +_CREDENTIAL_FILE = 'config.yaml' + + +@clouds.CLOUD_REGISTRY.register(aliases=['digitalocean']) +class DO(clouds.Cloud): + """Digital Ocean Cloud""" + + _REPR = 'DO' + _CLOUD_UNSUPPORTED_FEATURES = { + clouds.CloudImplementationFeatures.CLONE_DISK_FROM_CLUSTER: + 'Migrating ' + f'disk is not supported in {_REPR}.', + clouds.CloudImplementationFeatures.SPOT_INSTANCE: + 'Spot instances are ' + f'not supported in {_REPR}.', + clouds.CloudImplementationFeatures.CUSTOM_DISK_TIER: + 'Custom disk tiers' + f' is not supported in {_REPR}.', + } + # DO maximum node name length defined as <= 255 + # https://docs.digitalocean.com/reference/api/api-reference/#operation/droplets_create + # 255 - 8 = 247 characters since + # our provisioner adds additional `-worker`. + _MAX_CLUSTER_NAME_LEN_LIMIT = 247 + _regions: List[clouds.Region] = [] + + # Using the latest SkyPilot provisioner API to provision and check status. + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT + STATUS_VERSION = clouds.StatusVersion.SKYPILOT + + @classmethod + def _unsupported_features_for_resources( + cls, resources: 'resources_lib.Resources' + ) -> Dict[clouds.CloudImplementationFeatures, str]: + """The features not supported based on the resources provided. + + This method is used by check_features_are_supported() to check if the + cloud implementation supports all the requested features. + + Returns: + A dict of {feature: reason} for the features not supported by the + cloud implementation. + """ + del resources # unused + return cls._CLOUD_UNSUPPORTED_FEATURES + + @classmethod + def _max_cluster_name_length(cls) -> Optional[int]: + return cls._MAX_CLUSTER_NAME_LEN_LIMIT + + @classmethod + def regions_with_offering( + cls, + instance_type: str, + accelerators: Optional[Dict[str, int]], + use_spot: bool, + region: Optional[str], + zone: Optional[str], + ) -> List[clouds.Region]: + assert zone is None, 'DO does not support zones.' + del accelerators, zone # unused + if use_spot: + return [] + regions = service_catalog.get_region_zones_for_instance_type( + instance_type, use_spot, 'DO') + if region is not None: + regions = [r for r in regions if r.name == region] + return regions + + @classmethod + def get_vcpus_mem_from_instance_type( + cls, + instance_type: str, + ) -> Tuple[Optional[float], Optional[float]]: + return service_catalog.get_vcpus_mem_from_instance_type(instance_type, + clouds='DO') + + @classmethod + def zones_provision_loop( + cls, + *, + region: str, + num_nodes: int, + instance_type: str, + accelerators: Optional[Dict[str, int]] = None, + use_spot: bool = False, + ) -> Iterator[None]: + del num_nodes # unused + regions = cls.regions_with_offering(instance_type, + accelerators, + use_spot, + region=region, + zone=None) + for r in regions: + assert r.zones is None, r + yield r.zones + + def instance_type_to_hourly_cost( + self, + instance_type: str, + use_spot: bool, + region: Optional[str] = None, + zone: Optional[str] = None, + ) -> float: + return service_catalog.get_hourly_cost( + instance_type, + use_spot=use_spot, + region=region, + zone=zone, + clouds='DO', + ) + + def accelerators_to_hourly_cost( + self, + accelerators: Dict[str, int], + use_spot: bool, + region: Optional[str] = None, + zone: Optional[str] = None, + ) -> float: + """Returns the hourly cost of the accelerators, in dollars/hour.""" + # the acc price is include in the instance price. + del accelerators, use_spot, region, zone # unused + return 0.0 + + def get_egress_cost(self, num_gigabytes: float) -> float: + return 0.0 + + def __repr__(self): + return self._REPR + + @classmethod + def get_default_instance_type( + cls, + cpus: Optional[str] = None, + memory: Optional[str] = None, + disk_tier: Optional[resources_utils.DiskTier] = None, + ) -> Optional[str]: + """Returns the default instance type for DO.""" + return service_catalog.get_default_instance_type(cpus=cpus, + memory=memory, + disk_tier=disk_tier, + clouds='DO') + + @classmethod + def get_accelerators_from_instance_type( + cls, instance_type: str) -> Optional[Dict[str, Union[int, float]]]: + return service_catalog.get_accelerators_from_instance_type( + instance_type, clouds='DO') + + @classmethod + def get_zone_shell_cmd(cls) -> Optional[str]: + return None + + def make_deploy_resources_variables( + self, + resources: 'resources_lib.Resources', + cluster_name: resources_utils.ClusterName, + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], + num_nodes: int, + dryrun: bool = False) -> Dict[str, Optional[str]]: + del zones, dryrun, cluster_name + + r = resources + acc_dict = self.get_accelerators_from_instance_type(r.instance_type) + if acc_dict is not None: + custom_resources = json.dumps(acc_dict, separators=(',', ':')) + else: + custom_resources = None + image_id = None + if (resources.image_id is not None and + resources.extract_docker_image() is None): + if None in resources.image_id: + image_id = resources.image_id[None] + else: + assert region.name in resources.image_id + image_id = resources.image_id[region.name] + return { + 'instance_type': resources.instance_type, + 'custom_resources': custom_resources, + 'region': region.name, + **({ + 'image_id': image_id + } if image_id else {}) + } + + def _get_feasible_launchable_resources( + self, resources: 'resources_lib.Resources' + ) -> resources_utils.FeasibleResources: + """Returns a list of feasible resources for the given resources.""" + if resources.use_spot: + # TODO: Add hints to all return values in this method to help + # users understand why the resources are not launchable. + return resources_utils.FeasibleResources([], [], None) + if resources.instance_type is not None: + assert resources.is_launchable(), resources + resources = resources.copy(accelerators=None) + return resources_utils.FeasibleResources([resources], [], None) + + def _make(instance_list): + resource_list = [] + for instance_type in instance_list: + r = resources.copy( + cloud=DO(), + instance_type=instance_type, + accelerators=None, + cpus=None, + ) + resource_list.append(r) + return resource_list + + # Currently, handle a filter on accelerators only. + accelerators = resources.accelerators + if accelerators is None: + # Return a default instance type + default_instance_type = DO.get_default_instance_type( + cpus=resources.cpus, + memory=resources.memory, + disk_tier=resources.disk_tier) + return resources_utils.FeasibleResources( + _make([default_instance_type]), [], None) + + assert len(accelerators) == 1, resources + acc, acc_count = list(accelerators.items())[0] + (instance_list, fuzzy_candidate_list) = ( + service_catalog.get_instance_type_for_accelerator( + acc, + acc_count, + use_spot=resources.use_spot, + cpus=resources.cpus, + memory=resources.memory, + region=resources.region, + zone=resources.zone, + clouds='DO', + )) + if instance_list is None: + return resources_utils.FeasibleResources([], fuzzy_candidate_list, + None) + return resources_utils.FeasibleResources(_make(instance_list), + fuzzy_candidate_list, None) + + @classmethod + def check_credentials(cls) -> Tuple[bool, Optional[str]]: + """Verify that the user has valid credentials for DO.""" + try: + # attempt to make a CURL request for listing instances + do_utils.client().droplets.list() + except do.exceptions().HttpResponseError as err: + return False, str(err) + except do_utils.DigitalOceanError as err: + return False, str(err) + + return True, None + + def get_credential_file_mounts(self) -> Dict[str, str]: + try: + do_utils.client() + return { + f'~/.config/doctl/{_CREDENTIAL_FILE}': do_utils.CREDENTIALS_PATH + } + except do_utils.DigitalOceanError: + return {} + + @classmethod + def get_current_user_identity(cls) -> Optional[List[str]]: + # NOTE: used for very advanced SkyPilot functionality + # Can implement later if desired + return None + + @classmethod + def get_image_size(cls, image_id: str, region: Optional[str]) -> float: + del region + try: + response = do_utils.client().images.get(image_id=image_id) + return response['image']['size_gigabytes'] + except do.exceptions().HttpResponseError as err: + raise do_utils.DigitalOceanError( + 'HTTP error while retrieving size of ' + f'image_id {response}: {err.error.message}') from err + except KeyError as err: + raise do_utils.DigitalOceanError( + f'No image_id `{image_id}` found') from err + + def instance_type_exists(self, instance_type: str) -> bool: + return service_catalog.instance_type_exists(instance_type, 'DO') + + def validate_region_zone(self, region: Optional[str], zone: Optional[str]): + return service_catalog.validate_region_zone(region, zone, clouds='DO') diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index c0f22cc860b..3502fee8e1c 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -132,6 +132,9 @@ class GCPIdentityType(enum.Enum): SHARED_CREDENTIALS_FILE = '' + def can_credential_expire(self) -> bool: + return self == GCPIdentityType.SHARED_CREDENTIALS_FILE + @clouds.CLOUD_REGISTRY.register class GCP(clouds.Cloud): @@ -830,7 +833,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: ret_permissions = request.execute().get('permissions', []) diffs = set(gcp_minimal_permissions).difference(set(ret_permissions)) - if len(diffs) > 0: + if diffs: identity_str = identity[0] if identity else None return False, ( 'The following permissions are not enabled for the current ' @@ -863,6 +866,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]: pass return credentials + @functools.lru_cache(maxsize=1) + def can_credential_expire(self) -> bool: + identity_type = self._get_identity_type() + return identity_type is not None and identity_type.can_credential_expire( + ) + @classmethod def _get_identity_type(cls) -> Optional[GCPIdentityType]: try: diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 65b50042aba..f9242bd77aa 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -139,7 +139,7 @@ def _existing_allowed_contexts(cls) -> List[str]: use the service account mounted in the pod. """ all_contexts = kubernetes_utils.get_all_kube_context_names() - if len(all_contexts) == 0: + if not all_contexts: return [] all_contexts = set(all_contexts) diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index d4ae6f298d2..b0234e2802c 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -232,6 +232,14 @@ def make_deploy_resources_variables( listing_id = None res_ver = None + os_type = None + if ':' in image_id: + # OS type provided in the --image-id. This is the case where + # custom image's ocid provided in the --image-id parameter. + # - ocid1.image...aaa:oraclelinux (os type is oraclelinux) + # - ocid1.image...aaa (OS not provided) + image_id, os_type = image_id.replace(' ', '').split(':') + cpus = resources.cpus instance_type_arr = resources.instance_type.split( oci_utils.oci_config.INSTANCE_TYPE_RES_SPERATOR) @@ -297,15 +305,18 @@ def make_deploy_resources_variables( cpus=None if cpus is None else float(cpus), disk_tier=resources.disk_tier) - image_str = self._get_image_str(image_id=resources.image_id, - instance_type=resources.instance_type, - region=region.name) - - # pylint: disable=import-outside-toplevel - from sky.clouds.service_catalog import oci_catalog - os_type = oci_catalog.get_image_os_from_tag(tag=image_str, - region=region.name) - logger.debug(f'OS type for the image {image_str} is {os_type}') + if os_type is None: + # OS type is not determined yet. So try to get it from vms.csv + image_str = self._get_image_str( + image_id=resources.image_id, + instance_type=resources.instance_type, + region=region.name) + + # pylint: disable=import-outside-toplevel + from sky.clouds.service_catalog import oci_catalog + os_type = oci_catalog.get_image_os_from_tag(tag=image_str, + region=region.name) + logger.debug(f'OS type for the image {image_id} is {os_type}') return { 'instance_type': instance_type, diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index d28b530ff06..3aad5a0b7fd 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -10,6 +10,7 @@ from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL from sky.utils import resources_utils +from sky.utils import subprocess_utils if typing.TYPE_CHECKING: from sky.clouds import cloud @@ -31,8 +32,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): if single: clouds = [clouds] # type: ignore - results = [] - for cloud in clouds: + def _execute_catalog_method(cloud: str): try: cloud_module = importlib.import_module( f'sky.clouds.service_catalog.{cloud.lower()}_catalog') @@ -46,7 +46,11 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): raise AttributeError( f'Module "{cloud}_catalog" does not ' f'implement the "{method_name}" method') from None - results.append(method(*args, **kwargs)) + return method(*args, **kwargs) + + results = subprocess_utils.run_in_parallel(_execute_catalog_method, + args=list(clouds), + num_threads=len(clouds)) if single: return results[0] return results diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index 29df92d7535..0fce7c25f6a 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -270,9 +270,10 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str: candidate_loc = difflib.get_close_matches(loc, all_loc, n=5, cutoff=0.9) candidate_loc = sorted(candidate_loc) candidate_strs = '' - if len(candidate_loc) > 0: + if candidate_loc: candidate_strs = ', '.join(candidate_loc) candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?' + return candidate_strs def _get_all_supported_regions_str() -> str: @@ -286,7 +287,7 @@ def _get_all_supported_regions_str() -> str: filter_df = df if region is not None: filter_df = _filter_region_zone(filter_df, region, zone=None) - if len(filter_df) == 0: + if filter_df.empty: with ux_utils.print_exception_no_traceback(): error_msg = (f'Invalid region {region!r}') candidate_strs = _get_candidate_str( @@ -310,7 +311,7 @@ def _get_all_supported_regions_str() -> str: if zone is not None: maybe_region_df = filter_df filter_df = filter_df[filter_df['AvailabilityZone'] == zone] - if len(filter_df) == 0: + if filter_df.empty: region_str = f' for region {region!r}' if region else '' df = maybe_region_df if region else df with ux_utils.print_exception_no_traceback(): @@ -378,7 +379,7 @@ def get_vcpus_mem_from_instance_type_impl( instance_type: str, ) -> Tuple[Optional[float], Optional[float]]: df = _get_instance_type(df, instance_type, None) - if len(df) == 0: + if df.empty: with ux_utils.print_exception_no_traceback(): raise ValueError(f'No instance type {instance_type} found.') assert len(set(df['vCPUs'])) == 1, ('Cannot determine the number of vCPUs ' @@ -484,7 +485,7 @@ def get_accelerators_from_instance_type_impl( instance_type: str, ) -> Optional[Dict[str, Union[int, float]]]: df = _get_instance_type(df, instance_type, None) - if len(df) == 0: + if df.empty: with ux_utils.print_exception_no_traceback(): raise ValueError(f'No instance type {instance_type} found.') row = df.iloc[0] @@ -518,7 +519,7 @@ def get_instance_type_for_accelerator_impl( result = df[(df['AcceleratorName'].str.fullmatch(acc_name, case=False)) & (abs(df['AcceleratorCount'] - acc_count) <= 0.01)] result = _filter_region_zone(result, region, zone) - if len(result) == 0: + if result.empty: fuzzy_result = df[ (df['AcceleratorName'].str.contains(acc_name, case=False)) & (df['AcceleratorCount'] >= acc_count)] @@ -527,7 +528,7 @@ def get_instance_type_for_accelerator_impl( fuzzy_result = fuzzy_result[['AcceleratorName', 'AcceleratorCount']].drop_duplicates() fuzzy_candidate_list = [] - if len(fuzzy_result) > 0: + if not fuzzy_result.empty: for _, row in fuzzy_result.iterrows(): acc_cnt = float(row['AcceleratorCount']) acc_count_display = (int(acc_cnt) if acc_cnt.is_integer() else @@ -539,7 +540,7 @@ def get_instance_type_for_accelerator_impl( result = _filter_with_cpus(result, cpus) result = _filter_with_mem(result, memory) result = _filter_region_zone(result, region, zone) - if len(result) == 0: + if result.empty: return ([], []) # Current strategy: choose the cheapest instance @@ -680,7 +681,7 @@ def get_image_id_from_tag_impl(df: 'pd.DataFrame', tag: str, df = _filter_region_zone(df, region, zone=None) assert len(df) <= 1, ('Multiple images found for tag ' f'{tag} in region {region}') - if len(df) == 0: + if df.empty: return None image_id = df['ImageId'].iloc[0] if pd.isna(image_id): @@ -694,4 +695,4 @@ def is_image_tag_valid_impl(df: 'pd.DataFrame', tag: str, df = df[df['Tag'] == tag] df = _filter_region_zone(df, region, zone=None) df = df.dropna(subset=['ImageId']) - return len(df) > 0 + return not df.empty diff --git a/sky/clouds/service_catalog/constants.py b/sky/clouds/service_catalog/constants.py index a125258ac35..945152582f6 100644 --- a/sky/clouds/service_catalog/constants.py +++ b/sky/clouds/service_catalog/constants.py @@ -4,4 +4,4 @@ CATALOG_DIR = '~/.sky/catalogs' ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci', 'kubernetes', 'runpod', 'vsphere', 'cudo', 'fluidstack', - 'paperspace') + 'paperspace', 'do') diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py index 4aef41f9c90..00768d5c6bb 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py @@ -134,7 +134,7 @@ def get_pricing_df(region: Optional[str] = None) -> 'pd.DataFrame': content_str = r.content.decode('ascii') content = json.loads(content_str) items = content.get('Items', []) - if len(items) == 0: + if not items: break all_items += items url = content.get('NextPageLink') diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py index 570bc773d2e..b3a71e9514a 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py @@ -47,10 +47,6 @@ TPU_V4_ZONES = ['us-central2-b'] # TPU v3 pods are available in us-east1-d, but hidden in the skus. # We assume the TPU prices are the same as us-central1. -# TPU v6e's pricing info is not available on the SKUs. However, in -# https://cloud.google.com/tpu/pricing, it listed the price for 4 regions: -# us-east1, us-east5, europe-west4, and asia-northeast1. We hardcode them here -# and filtered out the other regions (us-central{1,2}, us-south1). HIDDEN_TPU_DF = pd.read_csv( io.StringIO( textwrap.dedent("""\ @@ -62,49 +58,10 @@ ,tpu-v3-512,1,,,tpu-v3-512,512.0,153.6,us-east1,us-east1-d ,tpu-v3-1024,1,,,tpu-v3-1024,1024.0,307.2,us-east1,us-east1-d ,tpu-v3-2048,1,,,tpu-v3-2048,2048.0,614.4,us-east1,us-east1-d - ,tpu-v6e-1,1,,,tpu-v6e-1,2.7,,us-east5,us-east5-b - ,tpu-v6e-1,1,,,tpu-v6e-1,2.7,,us-east5,us-east5-c - ,tpu-v6e-1,1,,,tpu-v6e-1,2.97,,europe-west4,europe-west4-a - ,tpu-v6e-1,1,,,tpu-v6e-1,3.24,,asia-northeast1,asia-northeast1-b - ,tpu-v6e-1,1,,,tpu-v6e-1,2.7,,us-east1,us-east1-d - ,tpu-v6e-4,1,,,tpu-v6e-4,10.8,,us-east5,us-east5-b - ,tpu-v6e-4,1,,,tpu-v6e-4,10.8,,us-east5,us-east5-c - ,tpu-v6e-4,1,,,tpu-v6e-4,11.88,,europe-west4,europe-west4-a - ,tpu-v6e-4,1,,,tpu-v6e-4,12.96,,asia-northeast1,asia-northeast1-b - ,tpu-v6e-4,1,,,tpu-v6e-4,10.8,,us-east1,us-east1-d - ,tpu-v6e-8,1,,,tpu-v6e-8,21.6,,us-east5,us-east5-b - ,tpu-v6e-8,1,,,tpu-v6e-8,21.6,,us-east5,us-east5-c - ,tpu-v6e-8,1,,,tpu-v6e-8,23.76,,europe-west4,europe-west4-a - ,tpu-v6e-8,1,,,tpu-v6e-8,25.92,,asia-northeast1,asia-northeast1-b - ,tpu-v6e-8,1,,,tpu-v6e-8,21.6,,us-east1,us-east1-d - ,tpu-v6e-16,1,,,tpu-v6e-16,43.2,,us-east5,us-east5-b - ,tpu-v6e-16,1,,,tpu-v6e-16,43.2,,us-east5,us-east5-c - ,tpu-v6e-16,1,,,tpu-v6e-16,47.52,,europe-west4,europe-west4-a - ,tpu-v6e-16,1,,,tpu-v6e-16,51.84,,asia-northeast1,asia-northeast1-b - ,tpu-v6e-16,1,,,tpu-v6e-16,43.2,,us-east1,us-east1-d - ,tpu-v6e-32,1,,,tpu-v6e-32,86.4,,us-east5,us-east5-b - ,tpu-v6e-32,1,,,tpu-v6e-32,86.4,,us-east5,us-east5-c - ,tpu-v6e-32,1,,,tpu-v6e-32,95.04,,europe-west4,europe-west4-a - ,tpu-v6e-32,1,,,tpu-v6e-32,103.68,,asia-northeast1,asia-northeast1-b - ,tpu-v6e-32,1,,,tpu-v6e-32,86.4,,us-east1,us-east1-d - ,tpu-v6e-64,1,,,tpu-v6e-64,172.8,,us-east5,us-east5-b - ,tpu-v6e-64,1,,,tpu-v6e-64,172.8,,us-east5,us-east5-c - ,tpu-v6e-64,1,,,tpu-v6e-64,190.08,,europe-west4,europe-west4-a - ,tpu-v6e-64,1,,,tpu-v6e-64,207.36,,asia-northeast1,asia-northeast1-b - ,tpu-v6e-64,1,,,tpu-v6e-64,172.8,,us-east1,us-east1-d - ,tpu-v6e-128,1,,,tpu-v6e-128,345.6,,us-east5,us-east5-b - ,tpu-v6e-128,1,,,tpu-v6e-128,345.6,,us-east5,us-east5-c - ,tpu-v6e-128,1,,,tpu-v6e-128,380.16,,europe-west4,europe-west4-a - ,tpu-v6e-128,1,,,tpu-v6e-128,414.72,,asia-northeast1,asia-northeast1-b - ,tpu-v6e-128,1,,,tpu-v6e-128,345.6,,us-east1,us-east1-d - ,tpu-v6e-256,1,,,tpu-v6e-256,691.2,,us-east5,us-east5-b - ,tpu-v6e-256,1,,,tpu-v6e-256,691.2,,us-east5,us-east5-c - ,tpu-v6e-256,1,,,tpu-v6e-256,760.32,,europe-west4,europe-west4-a - ,tpu-v6e-256,1,,,tpu-v6e-256,829.44,,asia-northeast1,asia-northeast1-b - ,tpu-v6e-256,1,,,tpu-v6e-256,691.2,,us-east1,us-east1-d """))) -TPU_V6E_MISSING_REGIONS = ['us-central1', 'us-central2', 'us-south1'] +# TPU V6e price for us-central2 is missing in the SKUs. +TPU_V6E_MISSING_REGIONS = ['us-central2'] # TPU V5 is not visible in specific zones. We hardcode the missing zones here. # NOTE(dev): Keep the zones and the df in sync. @@ -670,6 +627,8 @@ def _get_tpu_description_str(tpu_version: str) -> str: return 'TpuV5p' assert tpu_version == 'v5litepod', tpu_version return 'TpuV5e' + if tpu_version.startswith('v6e'): + return 'TpuV6e' return f'Tpu-{tpu_version}' def get_tpu_price(row: pd.Series, spot: bool) -> Optional[float]: @@ -684,10 +643,10 @@ def get_tpu_price(row: pd.Series, spot: bool) -> Optional[float]: # whether the TPU is a single device or a pod. # For TPU-v4, the pricing is uniform, and thus the pricing API # only provides the price of TPU-v4 pods. - # The price shown for v5 TPU is per chip hour, so there is no 'Pod' - # keyword in the description. + # The price shown for v5 & v6e TPU is per chip hour, so there is + # no 'Pod' keyword in the description. is_pod = ((num_cores > 8 or tpu_version == 'v4') and - not tpu_version.startswith('v5')) + not tpu_version.startswith('v5') and tpu_version != 'v6e') for sku in gce_skus + tpu_skus: if tpu_region not in sku['serviceRegions']: @@ -718,7 +677,9 @@ def get_tpu_price(row: pd.Series, spot: bool) -> Optional[float]: # for v5e. Reference here: # https://cloud.google.com/tpu/docs/v5p#using-accelerator-type # https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config - core_per_sku = (1 if tpu_version == 'v5litepod' else + # v6e is also per chip price. Reference here: + # https://cloud.google.com/tpu/docs/v6e#configurations + core_per_sku = (1 if tpu_version in ['v5litepod', 'v6e'] else 2 if tpu_version == 'v5p' else 8) tpu_core_price = tpu_device_price / core_per_sku tpu_price = num_cores * tpu_core_price @@ -738,8 +699,6 @@ def get_tpu_price(row: pd.Series, spot: bool) -> Optional[float]: spot_str = 'spot ' if spot else '' print(f'The {spot_str}price of {tpu_name} in {tpu_region} is ' 'not found in SKUs or hidden TPU price DF.') - # TODO(tian): Hack. Should investigate how to retrieve the price - # for TPU-v6e. if (tpu_name.startswith('tpu-v6e') and tpu_region in TPU_V6E_MISSING_REGIONS): if not spot: diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py b/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py index 216e8ed9b4f..c08a56955a0 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py @@ -534,7 +534,7 @@ def initialize_images_csv(csv_saving_path: str, vc_object, gpu_name = tag_name.split('-')[1] if gpu_name not in gpu_tags: gpu_tags.append(gpu_name) - if len(gpu_tags) > 0: + if gpu_tags: gpu_tags_str = str(gpu_tags).replace('\'', '\"') f.write(f'{item.id},{vcenter_name},{item_cpu},{item_memory}' f',,,\'{gpu_tags_str}\'\n') diff --git a/sky/clouds/service_catalog/do_catalog.py b/sky/clouds/service_catalog/do_catalog.py new file mode 100644 index 00000000000..40a53aa6bc4 --- /dev/null +++ b/sky/clouds/service_catalog/do_catalog.py @@ -0,0 +1,111 @@ +"""Digital ocean service catalog. + +This module loads the service catalog file and can be used to +query instance types and pricing information for digital ocean. +""" + +import typing +from typing import Dict, List, Optional, Tuple, Union + +from sky.clouds.service_catalog import common +from sky.utils import ux_utils + +if typing.TYPE_CHECKING: + from sky.clouds import cloud + +_df = common.read_catalog('do/vms.csv') + + +def instance_type_exists(instance_type: str) -> bool: + return common.instance_type_exists_impl(_df, instance_type) + + +def validate_region_zone( + region: Optional[str], + zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('DO does not support zones.') + return common.validate_region_zone_impl('DO', _df, region, zone) + + +def get_hourly_cost( + instance_type: str, + use_spot: bool = False, + region: Optional[str] = None, + zone: Optional[str] = None, +) -> float: + """Returns the cost, or the cheapest cost among all zones for spot.""" + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('DO does not support zones.') + return common.get_hourly_cost_impl(_df, instance_type, use_spot, region, + zone) + + +def get_vcpus_mem_from_instance_type( + instance_type: str,) -> Tuple[Optional[float], Optional[float]]: + return common.get_vcpus_mem_from_instance_type_impl(_df, instance_type) + + +def get_default_instance_type( + cpus: Optional[str] = None, + memory: Optional[str] = None, + disk_tier: Optional[str] = None, +) -> Optional[str]: + # NOTE: After expanding catalog to multiple entries, you may + # want to specify a default instance type or family. + del disk_tier # unused + return common.get_instance_type_for_cpus_mem_impl(_df, cpus, memory) + + +def get_accelerators_from_instance_type( + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: + return common.get_accelerators_from_instance_type_impl(_df, instance_type) + + +def get_instance_type_for_accelerator( + acc_name: str, + acc_count: int, + cpus: Optional[str] = None, + memory: Optional[str] = None, + use_spot: bool = False, + region: Optional[str] = None, + zone: Optional[str] = None, +) -> Tuple[Optional[List[str]], List[str]]: + """Returns a list of instance types that have the given accelerator.""" + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('DO does not support zones.') + return common.get_instance_type_for_accelerator_impl( + df=_df, + acc_name=acc_name, + acc_count=acc_count, + cpus=cpus, + memory=memory, + use_spot=use_spot, + region=region, + zone=zone, + ) + + +def get_region_zones_for_instance_type(instance_type: str, + use_spot: bool) -> List['cloud.Region']: + df = _df[_df['InstanceType'] == instance_type] + return common.get_region_zones(df, use_spot) + + +def list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + quantity_filter: Optional[int], + case_sensitive: bool = True, + all_regions: bool = False, + require_price: bool = True, +) -> Dict[str, List[common.InstanceTypeInfo]]: + """Returns all instance types in DO offering GPUs.""" + del require_price # unused + return common.list_accelerators_impl('DO', _df, gpus_only, name_filter, + region_filter, quantity_filter, + case_sensitive, all_regions) diff --git a/sky/clouds/service_catalog/kubernetes_catalog.py b/sky/clouds/service_catalog/kubernetes_catalog.py index 655b3b54a66..c6becef4750 100644 --- a/sky/clouds/service_catalog/kubernetes_catalog.py +++ b/sky/clouds/service_catalog/kubernetes_catalog.py @@ -115,6 +115,16 @@ def _list_accelerators( If the user does not have sufficient permissions to list pods in all namespaces, the function will return free GPUs as -1. + + Returns: + A tuple of three dictionaries: + - qtys_map: Dict mapping accelerator names to lists of InstanceTypeInfo + objects with quantity information. + - total_accelerators_capacity: Dict mapping accelerator names to their + total capacity in the cluster. + - total_accelerators_available: Dict mapping accelerator names to their + current availability. Returns -1 for each accelerator if + realtime=False or if insufficient permissions. """ # TODO(romilb): This should be refactored to use get_kubernetes_node_info() # function from kubernetes_utils. @@ -243,6 +253,10 @@ def _list_accelerators( accelerators_available = accelerator_count - allocated_qty + # Initialize the entry if it doesn't exist yet + if accelerator_name not in total_accelerators_available: + total_accelerators_available[accelerator_name] = 0 + if accelerators_available >= min_quantity_filter: quantized_availability = min_quantity_filter * ( accelerators_available // min_quantity_filter) diff --git a/sky/clouds/utils/oci_utils.py b/sky/clouds/utils/oci_utils.py index 0cd4f33e647..46d4454d866 100644 --- a/sky/clouds/utils/oci_utils.py +++ b/sky/clouds/utils/oci_utils.py @@ -6,6 +6,12 @@ configuration. - Hysun He (hysun.he@oracle.com) @ Nov.12, 2024: Add the constant SERVICE_PORT_RULE_TAG + - Hysun He (hysun.he@oracle.com) @ Jan.01, 2025: Set the default image + from ubuntu 20.04 to ubuntu 22.04, including: + - GPU: skypilot:gpu-ubuntu-2004 -> skypilot:gpu-ubuntu-2204 + - CPU: skypilot:cpu-ubuntu-2004 -> skypilot:cpu-ubuntu-2204 + - Hysun He (hysun.he@oracle.com) @ Jan.01, 2025: Support reuse existing + VCN for SkyServe. """ import os @@ -105,8 +111,15 @@ def get_compartment(cls, region): ('oci', region, 'compartment_ocid'), default_compartment_ocid) return compartment + @classmethod + def get_vcn_ocid(cls, region): + # Will reuse the regional VCN if specified. + vcn = skypilot_config.get_nested(('oci', region, 'vcn_ocid'), None) + return vcn + @classmethod def get_vcn_subnet(cls, region): + # Will reuse the subnet if specified. vcn = skypilot_config.get_nested(('oci', region, 'vcn_subnet'), None) return vcn @@ -117,7 +130,7 @@ def get_default_gpu_image_tag(cls) -> str: # the sky's user-config file (if not specified, use the hardcode one at # last) return skypilot_config.get_nested(('oci', 'default', 'image_tag_gpu'), - 'skypilot:gpu-ubuntu-2004') + 'skypilot:gpu-ubuntu-2204') @classmethod def get_default_image_tag(cls) -> str: @@ -125,7 +138,7 @@ def get_default_image_tag(cls) -> str: # set the default image tag in the sky's user-config file. (if not # specified, use the hardcode one at last) return skypilot_config.get_nested( - ('oci', 'default', 'image_tag_general'), 'skypilot:cpu-ubuntu-2004') + ('oci', 'default', 'image_tag_general'), 'skypilot:cpu-ubuntu-2204') @classmethod def get_sky_user_config_file(cls) -> str: diff --git a/sky/clouds/utils/scp_utils.py b/sky/clouds/utils/scp_utils.py index 3e91e22e6d9..4efc79313c5 100644 --- a/sky/clouds/utils/scp_utils.py +++ b/sky/clouds/utils/scp_utils.py @@ -65,7 +65,7 @@ def __setitem__(self, instance_id: str, value: Optional[Dict[str, if value is None: if instance_id in metadata: metadata.pop(instance_id) # del entry - if len(metadata) == 0: + if not metadata: if os.path.exists(self.path): os.remove(self.path) return @@ -84,7 +84,7 @@ def refresh(self, instance_ids: List[str]) -> None: for instance_id in list(metadata.keys()): if instance_id not in instance_ids: del metadata[instance_id] - if len(metadata) == 0: + if not metadata: os.remove(self.path) return with open(self.path, 'w', encoding='utf-8') as f: @@ -410,7 +410,7 @@ def list_security_groups(self, vpc_id=None, sg_name=None): parameter.append('vpcId=' + vpc_id) if sg_name is not None: parameter.append('securityGroupName=' + sg_name) - if len(parameter) > 0: + if parameter: url = url + '?' + '&'.join(parameter) return self._get(url) diff --git a/sky/core.py b/sky/core.py index 9f1288d7fb6..36b3d45b849 100644 --- a/sky/core.py +++ b/sky/core.py @@ -732,7 +732,7 @@ def cancel( f'{colorama.Fore.YELLOW}' f'Cancelling latest running job on cluster {cluster_name!r}...' f'{colorama.Style.RESET_ALL}') - elif len(job_ids): + elif job_ids: # all = False, len(job_ids) > 0 => cancel the specified jobs. jobs_str = ', '.join(map(str, job_ids)) sky_logging.print( @@ -817,7 +817,7 @@ def download_logs( backend = backend_utils.get_backend_from_handle(handle) assert isinstance(backend, backends.CloudVmRayBackend), backend - if job_ids is not None and len(job_ids) == 0: + if job_ids is not None and not job_ids: return {} usage_lib.record_cluster_name_for_current_operation(cluster_name) @@ -866,7 +866,7 @@ def job_status(cluster_name: str, f'of type {backend.__class__.__name__!r}.') assert isinstance(handle, backends.CloudVmRayResourceHandle), handle - if job_ids is not None and len(job_ids) == 0: + if job_ids is not None and not job_ids: return {} sky_logging.print(f'{colorama.Fore.YELLOW}' diff --git a/sky/data/data_transfer.py b/sky/data/data_transfer.py index 374871031cb..3ccc6f8fc0e 100644 --- a/sky/data/data_transfer.py +++ b/sky/data/data_transfer.py @@ -200,3 +200,40 @@ def _add_bucket_iam_member(bucket_name: str, role: str, member: str) -> None: bucket.set_iam_policy(policy) logger.debug(f'Added {member} with role {role} to {bucket_name}.') + + +def s3_to_oci(s3_bucket_name: str, oci_bucket_name: str) -> None: + """Creates a one-time transfer from Amazon S3 to OCI Object Storage. + Args: + s3_bucket_name: str; Name of the Amazon S3 Bucket + oci_bucket_name: str; Name of the OCI Bucket + """ + # TODO(HysunHe): Implement sync with other clouds (s3, gs) + raise NotImplementedError('Moving data directly from S3 to OCI bucket ' + 'is currently not supported. Please specify ' + 'a local source for the storage object.') + + +def gcs_to_oci(gs_bucket_name: str, oci_bucket_name: str) -> None: + """Creates a one-time transfer from Google Cloud Storage to + OCI Object Storage. + Args: + gs_bucket_name: str; Name of the Google Cloud Storage Bucket + oci_bucket_name: str; Name of the OCI Bucket + """ + # TODO(HysunHe): Implement sync with other clouds (s3, gs) + raise NotImplementedError('Moving data directly from GCS to OCI bucket ' + 'is currently not supported. Please specify ' + 'a local source for the storage object.') + + +def r2_to_oci(r2_bucket_name: str, oci_bucket_name: str) -> None: + """Creates a one-time transfer from Cloudflare R2 to OCI Bucket. + Args: + r2_bucket_name: str; Name of the Cloudflare R2 Bucket + oci_bucket_name: str; Name of the OCI Bucket + """ + raise NotImplementedError( + 'Moving data directly from Cloudflare R2 to OCI ' + 'bucket is currently not supported. Please specify ' + 'a local source for the storage object.') diff --git a/sky/data/data_utils.py b/sky/data/data_utils.py index 0c8fd64ddea..e8dcaa83017 100644 --- a/sky/data/data_utils.py +++ b/sky/data/data_utils.py @@ -20,6 +20,7 @@ from sky.adaptors import cloudflare from sky.adaptors import gcp from sky.adaptors import ibm +from sky.skylet import log_lib from sky.utils import common_utils from sky.utils import ux_utils @@ -430,6 +431,7 @@ def _group_files_by_dir( def parallel_upload(source_path_list: List[str], filesync_command_generator: Callable[[str, List[str]], str], dirsync_command_generator: Callable[[str, str], str], + log_path: str, bucket_name: str, access_denied_message: str, create_dirs: bool = False, @@ -445,6 +447,7 @@ def parallel_upload(source_path_list: List[str], for a list of files belonging to the same dir. dirsync_command_generator: Callable that generates rsync command for a directory. + log_path: Path to the log file. access_denied_message: Message to intercept from the underlying upload utility when permissions are insufficient. Used in exception handling. @@ -477,7 +480,7 @@ def parallel_upload(source_path_list: List[str], p.starmap( run_upload_cli, zip(commands, [access_denied_message] * len(commands), - [bucket_name] * len(commands))) + [bucket_name] * len(commands), [log_path] * len(commands))) def get_gsutil_command() -> Tuple[str, str]: @@ -518,37 +521,31 @@ def get_gsutil_command() -> Tuple[str, str]: return gsutil_alias, alias_gen -def run_upload_cli(command: str, access_denied_message: str, bucket_name: str): - # TODO(zhwu): Use log_lib.run_with_log() and redirect the output - # to a log file. - with subprocess.Popen(command, - stderr=subprocess.PIPE, - stdout=subprocess.DEVNULL, - shell=True) as process: - stderr = [] - assert process.stderr is not None # for mypy - while True: - line = process.stderr.readline() - if not line: - break - str_line = line.decode('utf-8') - stderr.append(str_line) - if access_denied_message in str_line: - process.kill() - with ux_utils.print_exception_no_traceback(): - raise PermissionError( - 'Failed to upload files to ' - 'the remote bucket. The bucket does not have ' - 'write permissions. It is possible that ' - 'the bucket is public.') - returncode = process.wait() - if returncode != 0: - stderr_str = '\n'.join(stderr) - with ux_utils.print_exception_no_traceback(): - logger.error(stderr_str) - raise exceptions.StorageUploadError( - f'Upload to bucket failed for store {bucket_name}. ' - 'Please check the logs.') +def run_upload_cli(command: str, access_denied_message: str, bucket_name: str, + log_path: str): + returncode, stdout, stderr = log_lib.run_with_log( + command, + log_path, + shell=True, + require_outputs=True, + # We need to use bash as some of the cloud commands uses bash syntax, + # such as [[ ... ]] + executable='/bin/bash') + if access_denied_message in stderr: + with ux_utils.print_exception_no_traceback(): + raise PermissionError('Failed to upload files to ' + 'the remote bucket. The bucket does not have ' + 'write permissions. It is possible that ' + 'the bucket is public.') + if returncode != 0: + with ux_utils.print_exception_no_traceback(): + logger.error(stderr) + raise exceptions.StorageUploadError( + f'Upload to bucket failed for store {bucket_name}. ' + f'Please check the logs: {log_path}') + if not stdout: + logger.debug('No file uploaded. This could be due to an error or ' + 'because all files already exist on the cloud.') def get_cos_regions() -> List[str]: @@ -737,3 +734,14 @@ def _remove_bucket_profile_rclone(bucket_name: str, lines_to_keep.append(line) return lines_to_keep + + +def split_oci_path(oci_path: str) -> Tuple[str, str]: + """Splits OCI Path into Bucket name and Relative Path to Bucket + Args: + oci_path: str; OCI Path, e.g. oci://imagenet/train/ + """ + path_parts = oci_path.replace('oci://', '').split('/') + bucket = path_parts.pop(0) + key = '/'.join(path_parts) + return bucket, key diff --git a/sky/data/mounting_utils.py b/sky/data/mounting_utils.py index 22b26c372c4..a00a73aa9e0 100644 --- a/sky/data/mounting_utils.py +++ b/sky/data/mounting_utils.py @@ -19,6 +19,7 @@ _BLOBFUSE_CACHE_ROOT_DIR = '~/.sky/blobfuse2_cache' _BLOBFUSE_CACHE_DIR = ('~/.sky/blobfuse2_cache/' '{storage_account_name}_{container_name}') +RCLONE_VERSION = 'v1.68.2' def get_s3_mount_install_cmd() -> str: @@ -30,12 +31,19 @@ def get_s3_mount_install_cmd() -> str: return install_cmd -def get_s3_mount_cmd(bucket_name: str, mount_path: str) -> str: +# pylint: disable=invalid-name +def get_s3_mount_cmd(bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an S3 bucket using goofys.""" + if _bucket_sub_path is None: + _bucket_sub_path = '' + else: + _bucket_sub_path = f':{_bucket_sub_path}' mount_cmd = ('goofys -o allow_other ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' - f'{bucket_name} {mount_path}') + f'{bucket_name}{_bucket_sub_path} {mount_path}') return mount_cmd @@ -49,15 +57,20 @@ def get_gcs_mount_install_cmd() -> str: return install_cmd -def get_gcs_mount_cmd(bucket_name: str, mount_path: str) -> str: +# pylint: disable=invalid-name +def get_gcs_mount_cmd(bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount a GCS bucket using gcsfuse.""" - + bucket_sub_path_arg = f'--only-dir {_bucket_sub_path} '\ + if _bucket_sub_path else '' mount_cmd = ('gcsfuse -o allow_other ' '--implicit-dirs ' f'--stat-cache-capacity {_STAT_CACHE_CAPACITY} ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' f'--rename-dir-limit {_RENAME_DIR_LIMIT} ' + f'{bucket_sub_path_arg}' f'{bucket_name} {mount_path}') return mount_cmd @@ -78,10 +91,12 @@ def get_az_mount_install_cmd() -> str: return install_cmd +# pylint: disable=invalid-name def get_az_mount_cmd(container_name: str, storage_account_name: str, mount_path: str, - storage_account_key: Optional[str] = None) -> str: + storage_account_key: Optional[str] = None, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an AZ Container using blobfuse2. Args: @@ -90,6 +105,7 @@ def get_az_mount_cmd(container_name: str, belongs to. mount_path: Path where the container will be mounting. storage_account_key: Access key for the given storage account. + _bucket_sub_path: Sub path of the mounting container. Returns: str: Command used to mount AZ container with blobfuse2. @@ -106,25 +122,44 @@ def get_az_mount_cmd(container_name: str, cache_path = _BLOBFUSE_CACHE_DIR.format( storage_account_name=storage_account_name, container_name=container_name) + # The line below ensures the cache directory is new before mounting to + # avoid "config error in file_cache [temp directory not empty]" error, which + # can occur after stopping and starting the same cluster on Azure. + # This helps ensure a clean state for blobfuse2 operations. + remote_boot_time_cmd = 'date +%s -d "$(uptime -s)"' + if _bucket_sub_path is None: + bucket_sub_path_arg = '' + else: + bucket_sub_path_arg = f'--subdirectory={_bucket_sub_path}/ ' + # TODO(zpoint): clear old cache that has been created in the previous boot. mount_cmd = (f'AZURE_STORAGE_ACCOUNT={storage_account_name} ' f'{key_env_var} ' f'blobfuse2 {mount_path} --allow-other --no-symlinks ' '-o umask=022 -o default_permissions ' - f'--tmp-path {cache_path} ' + f'--tmp-path {cache_path}_$({remote_boot_time_cmd}) ' + f'{bucket_sub_path_arg}' f'--container-name {container_name}') return mount_cmd -def get_r2_mount_cmd(r2_credentials_path: str, r2_profile_name: str, - endpoint_url: str, bucket_name: str, - mount_path: str) -> str: +# pylint: disable=invalid-name +def get_r2_mount_cmd(r2_credentials_path: str, + r2_profile_name: str, + endpoint_url: str, + bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to install R2 mount utility goofys.""" + if _bucket_sub_path is None: + _bucket_sub_path = '' + else: + _bucket_sub_path = f':{_bucket_sub_path}' mount_cmd = (f'AWS_SHARED_CREDENTIALS_FILE={r2_credentials_path} ' f'AWS_PROFILE={r2_profile_name} goofys -o allow_other ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' f'--endpoint {endpoint_url} ' - f'{bucket_name} {mount_path}') + f'{bucket_name}{_bucket_sub_path} {mount_path}') return mount_cmd @@ -136,9 +171,12 @@ def get_cos_mount_install_cmd() -> str: return install_cmd -def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str, - bucket_rclone_profile: str, bucket_name: str, - mount_path: str) -> str: +def get_cos_mount_cmd(rclone_config_data: str, + rclone_config_path: str, + bucket_rclone_profile: str, + bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an IBM COS bucket using rclone.""" # creates a fusermount soft link on older (<22) Ubuntu systems for # rclone's mount utility. @@ -150,14 +188,60 @@ def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str, 'mkdir -p ~/.config/rclone/ && ' f'echo "{rclone_config_data}" >> ' f'{rclone_config_path}') + if _bucket_sub_path is None: + sub_path_arg = f'{bucket_name}/{_bucket_sub_path}' + else: + sub_path_arg = f'/{bucket_name}' # --daemon will keep the mounting process running in the background. mount_cmd = (f'{configure_rclone_profile} && ' 'rclone mount ' - f'{bucket_rclone_profile}:{bucket_name} {mount_path} ' + f'{bucket_rclone_profile}:{sub_path_arg} {mount_path} ' '--daemon') return mount_cmd +def get_rclone_install_cmd() -> str: + """ RClone installation for both apt-get and rpm. + This would be common command. + """ + # pylint: disable=line-too-long + install_cmd = ( + f'(which dpkg > /dev/null 2>&1 && (which rclone > /dev/null || (cd ~ > /dev/null' + f' && curl -O https://downloads.rclone.org/{RCLONE_VERSION}/rclone-{RCLONE_VERSION}-linux-amd64.deb' + f' && sudo dpkg -i rclone-{RCLONE_VERSION}-linux-amd64.deb' + f' && rm -f rclone-{RCLONE_VERSION}-linux-amd64.deb)))' + f' || (which rclone > /dev/null || (cd ~ > /dev/null' + f' && curl -O https://downloads.rclone.org/{RCLONE_VERSION}/rclone-{RCLONE_VERSION}-linux-amd64.rpm' + f' && sudo yum --nogpgcheck install rclone-{RCLONE_VERSION}-linux-amd64.rpm -y' + f' && rm -f rclone-{RCLONE_VERSION}-linux-amd64.rpm))') + return install_cmd + + +def get_oci_mount_cmd(mount_path: str, store_name: str, region: str, + namespace: str, compartment: str, config_file: str, + config_profile: str) -> str: + """ OCI specific RClone mount command for oci object storage. """ + # pylint: disable=line-too-long + mount_cmd = ( + f'sudo chown -R `whoami` {mount_path}' + f' && rclone config create oos_{store_name} oracleobjectstorage' + f' provider user_principal_auth namespace {namespace}' + f' compartment {compartment} region {region}' + f' oci-config-file {config_file}' + f' oci-config-profile {config_profile}' + f' && sed -i "s/oci-config-file/config_file/g;' + f' s/oci-config-profile/config_profile/g" ~/.config/rclone/rclone.conf' + f' && ([ ! -f /bin/fusermount3 ] && sudo ln -s /bin/fusermount /bin/fusermount3 || true)' + f' && (grep -q {mount_path} /proc/mounts || rclone mount oos_{store_name}:{store_name} {mount_path} --daemon --allow-non-empty)' + ) + return mount_cmd + + +def get_rclone_version_check_cmd() -> str: + """ RClone version check. This would be common command. """ + return f'rclone --version | grep -q {RCLONE_VERSION}' + + def _get_mount_binary(mount_cmd: str) -> str: """Returns mounting binary in string given as the mount command. @@ -209,7 +293,7 @@ def get_mounting_script( script = textwrap.dedent(f""" #!/usr/bin/env bash set -e - + {command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD} MOUNT_PATH={mount_path} diff --git a/sky/data/storage.py b/sky/data/storage.py index d3d18a9d18f..018cb2797ca 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -24,6 +24,7 @@ from sky.adaptors import cloudflare from sky.adaptors import gcp from sky.adaptors import ibm +from sky.adaptors import oci from sky.data import data_transfer from sky.data import data_utils from sky.data import mounting_utils @@ -54,7 +55,9 @@ str(clouds.AWS()), str(clouds.GCP()), str(clouds.Azure()), - str(clouds.IBM()), cloudflare.NAME + str(clouds.IBM()), + str(clouds.OCI()), + cloudflare.NAME, ] # Maximum number of concurrent rsync upload processes @@ -72,6 +75,8 @@ 'Bucket {bucket_name!r} does not exist. ' 'It may have been deleted externally.') +_STORAGE_LOG_FILE_NAME = 'storage_sync.log' + def get_cached_enabled_storage_clouds_or_refresh( raise_if_no_cloud_access: bool = False) -> List[str]: @@ -113,6 +118,7 @@ class StoreType(enum.Enum): AZURE = 'AZURE' R2 = 'R2' IBM = 'IBM' + OCI = 'OCI' @classmethod def from_cloud(cls, cloud: str) -> 'StoreType': @@ -126,6 +132,8 @@ def from_cloud(cls, cloud: str) -> 'StoreType': return StoreType.R2 elif cloud.lower() == str(clouds.Azure()).lower(): return StoreType.AZURE + elif cloud.lower() == str(clouds.OCI()).lower(): + return StoreType.OCI elif cloud.lower() == str(clouds.Lambda()).lower(): with ux_utils.print_exception_no_traceback(): raise ValueError('Lambda Cloud does not provide cloud storage.') @@ -147,6 +155,8 @@ def from_store(cls, store: 'AbstractStore') -> 'StoreType': return StoreType.R2 elif isinstance(store, IBMCosStore): return StoreType.IBM + elif isinstance(store, OciStore): + return StoreType.OCI else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {store}') @@ -163,6 +173,8 @@ def store_prefix(self) -> str: return 'r2://' elif self == StoreType.IBM: return 'cos://' + elif self == StoreType.OCI: + return 'oci://' else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {self}') @@ -188,6 +200,45 @@ def get_endpoint_url(cls, store: 'AbstractStore', path: str) -> str: bucket_endpoint_url = f'{store_type.store_prefix()}{path}' return bucket_endpoint_url + @classmethod + def get_fields_from_store_url( + cls, store_url: str + ) -> Tuple['StoreType', Type['AbstractStore'], str, str, Optional[str], + Optional[str]]: + """Returns the store type, store class, bucket name, and sub path from + a store URL, and the storage account name and region if applicable. + + Args: + store_url: str; The store URL. + """ + # The full path from the user config of IBM COS contains the region, + # and Azure Blob Storage contains the storage account name, we need to + # pass these information to the store constructor. + storage_account_name = None + region = None + for store_type in StoreType: + if store_url.startswith(store_type.store_prefix()): + if store_type == StoreType.AZURE: + storage_account_name, bucket_name, sub_path = \ + data_utils.split_az_path(store_url) + store_cls: Type['AbstractStore'] = AzureBlobStore + elif store_type == StoreType.IBM: + bucket_name, sub_path, region = data_utils.split_cos_path( + store_url) + store_cls = IBMCosStore + elif store_type == StoreType.R2: + bucket_name, sub_path = data_utils.split_r2_path(store_url) + store_cls = R2Store + elif store_type == StoreType.GCS: + bucket_name, sub_path = data_utils.split_gcs_path(store_url) + store_cls = GcsStore + elif store_type == StoreType.S3: + bucket_name, sub_path = data_utils.split_s3_path(store_url) + store_cls = S3Store + return store_type, store_cls,bucket_name, \ + sub_path, storage_account_name, region + raise ValueError(f'Unknown store URL: {store_url}') + class StorageMode(enum.Enum): MOUNT = 'MOUNT' @@ -214,25 +265,29 @@ def __init__(self, name: str, source: Optional[SourceType], region: Optional[str] = None, - is_sky_managed: Optional[bool] = None): + is_sky_managed: Optional[bool] = None, + _bucket_sub_path: Optional[str] = None): self.name = name self.source = source self.region = region self.is_sky_managed = is_sky_managed + self._bucket_sub_path = _bucket_sub_path def __repr__(self): return (f'StoreMetadata(' f'\n\tname={self.name},' f'\n\tsource={self.source},' f'\n\tregion={self.region},' - f'\n\tis_sky_managed={self.is_sky_managed})') + f'\n\tis_sky_managed={self.is_sky_managed},' + f'\n\t_bucket_sub_path={self._bucket_sub_path})') def __init__(self, name: str, source: Optional[SourceType], region: Optional[str] = None, is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): # pylint: disable=invalid-name """Initialize AbstractStore Args: @@ -246,7 +301,11 @@ def __init__(self, there. This is set to false when the Storage object is created not for direct use, e.g. for 'sky storage delete', or the storage is being re-used, e.g., for `sky start` on a stopped cluster. - + _bucket_sub_path: str; The prefix of the bucket directory to be + created in the store, e.g. if _bucket_sub_path=my-dir, the files + will be uploaded to s3:///my-dir/. + This only works if source is a local directory. + # TODO(zpoint): Add support for non-local source. Raises: StorageBucketCreateError: If bucket creation fails StorageBucketGetError: If fetching existing bucket fails @@ -257,10 +316,29 @@ def __init__(self, self.region = region self.is_sky_managed = is_sky_managed self.sync_on_reconstruction = sync_on_reconstruction + + # To avoid mypy error + self._bucket_sub_path: Optional[str] = None + # Trigger the setter to strip any leading/trailing slashes. + self.bucket_sub_path = _bucket_sub_path # Whether sky is responsible for the lifecycle of the Store. self._validate() self.initialize() + @property + def bucket_sub_path(self) -> Optional[str]: + """Get the bucket_sub_path.""" + return self._bucket_sub_path + + @bucket_sub_path.setter + # pylint: disable=invalid-name + def bucket_sub_path(self, bucket_sub_path: Optional[str]) -> None: + """Set the bucket_sub_path, stripping any leading/trailing slashes.""" + if bucket_sub_path is not None: + self._bucket_sub_path = bucket_sub_path.strip('/') + else: + self._bucket_sub_path = None + @classmethod def from_metadata(cls, metadata: StoreMetadata, **override_args): """Create a Store from a StoreMetadata object. @@ -268,19 +346,26 @@ def from_metadata(cls, metadata: StoreMetadata, **override_args): Used when reconstructing Storage and Store objects from global_user_state. """ - return cls(name=override_args.get('name', metadata.name), - source=override_args.get('source', metadata.source), - region=override_args.get('region', metadata.region), - is_sky_managed=override_args.get('is_sky_managed', - metadata.is_sky_managed), - sync_on_reconstruction=override_args.get( - 'sync_on_reconstruction', True)) + return cls( + name=override_args.get('name', metadata.name), + source=override_args.get('source', metadata.source), + region=override_args.get('region', metadata.region), + is_sky_managed=override_args.get('is_sky_managed', + metadata.is_sky_managed), + sync_on_reconstruction=override_args.get('sync_on_reconstruction', + True), + # backward compatibility + _bucket_sub_path=override_args.get( + '_bucket_sub_path', + metadata._bucket_sub_path # pylint: disable=protected-access + ) if hasattr(metadata, '_bucket_sub_path') else None) def get_metadata(self) -> StoreMetadata: return self.StoreMetadata(name=self.name, source=self.source, region=self.region, - is_sky_managed=self.is_sky_managed) + is_sky_managed=self.is_sky_managed, + _bucket_sub_path=self._bucket_sub_path) def initialize(self): """Initializes the Store object on the cloud. @@ -308,7 +393,11 @@ def upload(self) -> None: raise NotImplementedError def delete(self) -> None: - """Removes the Storage object from the cloud.""" + """Removes the Storage from the cloud.""" + raise NotImplementedError + + def _delete_sub_path(self) -> None: + """Removes objects from the sub path in the bucket.""" raise NotImplementedError def get_handle(self) -> StorageHandle: @@ -452,13 +541,19 @@ def remove_store(self, store: AbstractStore) -> None: if storetype in self.sky_stores: del self.sky_stores[storetype] - def __init__(self, - name: Optional[str] = None, - source: Optional[SourceType] = None, - stores: Optional[Dict[StoreType, AbstractStore]] = None, - persistent: Optional[bool] = True, - mode: StorageMode = StorageMode.MOUNT, - sync_on_reconstruction: bool = True) -> None: + def __init__( + self, + name: Optional[str] = None, + source: Optional[SourceType] = None, + stores: Optional[Dict[StoreType, AbstractStore]] = None, + persistent: Optional[bool] = True, + mode: StorageMode = StorageMode.MOUNT, + sync_on_reconstruction: bool = True, + # pylint: disable=invalid-name + _is_sky_managed: Optional[bool] = None, + # pylint: disable=invalid-name + _bucket_sub_path: Optional[str] = None + ) -> None: """Initializes a Storage object. Three fields are required: the name of the storage, the source @@ -496,6 +591,18 @@ def __init__(self, there. This is set to false when the Storage object is created not for direct use, e.g. for 'sky storage delete', or the storage is being re-used, e.g., for `sky start` on a stopped cluster. + _is_sky_managed: Optional[bool]; Indicates if the storage is managed + by Sky. Without this argument, the controller's behavior differs + from the local machine. For example, if a bucket does not exist: + Local Machine (is_sky_managed=True) → + Controller (is_sky_managed=False). + With this argument, the controller aligns with the local machine, + ensuring it retains the is_sky_managed information from the YAML. + During teardown, if is_sky_managed is True, the controller should + delete the bucket. Otherwise, it might mistakenly delete only the + sub-path, assuming is_sky_managed is False. + _bucket_sub_path: Optional[str]; The subdirectory to use for the + storage object. """ self.name: str self.source = source @@ -503,6 +610,8 @@ def __init__(self, self.mode = mode assert mode in StorageMode self.sync_on_reconstruction = sync_on_reconstruction + self._is_sky_managed = _is_sky_managed + self._bucket_sub_path = _bucket_sub_path # TODO(romilb, zhwu): This is a workaround to support storage deletion # for spot. Once sky storage supports forced management for external @@ -562,6 +671,14 @@ def __init__(self, self.add_store(StoreType.R2) elif self.source.startswith('cos://'): self.add_store(StoreType.IBM) + elif self.source.startswith('oci://'): + self.add_store(StoreType.OCI) + + def get_bucket_sub_path_prefix(self, blob_path: str) -> str: + """Adds the bucket sub path prefix to the blob path.""" + if self._bucket_sub_path is not None: + return f'{blob_path}/{self._bucket_sub_path}' + return blob_path @staticmethod def _validate_source( @@ -642,7 +759,7 @@ def _validate_local_source(local_source): 'using a bucket by writing : ' f'{source} in the file_mounts section of your YAML') is_local_source = True - elif split_path.scheme in ['s3', 'gs', 'https', 'r2', 'cos']: + elif split_path.scheme in ['s3', 'gs', 'https', 'r2', 'cos', 'oci']: is_local_source = False # Storage mounting does not support mounting specific files from # cloud store - ensure path points to only a directory @@ -666,7 +783,7 @@ def _validate_local_source(local_source): with ux_utils.print_exception_no_traceback(): raise exceptions.StorageSourceError( f'Supported paths: local, s3://, gs://, https://, ' - f'r2://, cos://. Got: {source}') + f'r2://, cos://, oci://. Got: {source}') return source, is_local_source def _validate_storage_spec(self, name: Optional[str]) -> None: @@ -681,7 +798,7 @@ def validate_name(name): """ prefix = name.split('://')[0] prefix = prefix.lower() - if prefix in ['s3', 'gs', 'https', 'r2', 'cos']: + if prefix in ['s3', 'gs', 'https', 'r2', 'cos', 'oci']: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageNameError( 'Prefix detected: `name` cannot start with ' @@ -773,29 +890,40 @@ def _add_store_from_metadata( store = S3Store.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.GCS: store = GcsStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.AZURE: assert isinstance(s_metadata, AzureBlobStore.AzureBlobStoreMetadata) store = AzureBlobStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.R2: store = R2Store.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.IBM: store = IBMCosStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) + elif s_type == StoreType.OCI: + store = OciStore.from_metadata( + s_metadata, + source=self.source, + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {s_type}') @@ -815,7 +943,6 @@ def _add_store_from_metadata( 'to be reconstructed while the corresponding ' 'bucket was externally deleted.') continue - self._add_store(store, is_reconstructed=True) @classmethod @@ -871,6 +998,7 @@ def add_store(self, f'storage account {storage_account_name!r}.') else: logger.info(f'Storage type {store_type} already exists.') + return self.stores[store_type] store_cls: Type[AbstractStore] @@ -884,25 +1012,30 @@ def add_store(self, store_cls = R2Store elif store_type == StoreType.IBM: store_cls = IBMCosStore + elif store_type == StoreType.OCI: + store_cls = OciStore else: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageSpecError( f'{store_type} not supported as a Store.') - - # Initialize store object and get/create bucket try: store = store_cls( name=self.name, source=self.source, region=region, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + is_sky_managed=self._is_sky_managed, + _bucket_sub_path=self._bucket_sub_path) except exceptions.StorageBucketCreateError: # Creation failed, so this must be sky managed store. Add failure # to state. logger.error(f'Could not create {store_type} store ' f'with name {self.name}.') - global_user_state.set_storage_status(self.name, - StorageStatus.INIT_FAILED) + try: + global_user_state.set_storage_status(self.name, + StorageStatus.INIT_FAILED) + except ValueError as e: + logger.error(f'Error setting storage status: {e}') raise except exceptions.StorageBucketGetError: # Bucket get failed, so this is not sky managed. Do not update state @@ -1018,12 +1151,15 @@ def warn_for_git_dir(source: str): def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': common_utils.validate_schema(config, schemas.get_storage_schema(), 'Invalid storage YAML: ') - name = config.pop('name', None) source = config.pop('source', None) store = config.pop('store', None) mode_str = config.pop('mode', None) force_delete = config.pop('_force_delete', None) + # pylint: disable=invalid-name + _is_sky_managed = config.pop('_is_sky_managed', None) + # pylint: disable=invalid-name + _bucket_sub_path = config.pop('_bucket_sub_path', None) if force_delete is None: force_delete = False @@ -1043,7 +1179,9 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': storage_obj = cls(name=name, source=source, persistent=persistent, - mode=mode) + mode=mode, + _is_sky_managed=_is_sky_managed, + _bucket_sub_path=_bucket_sub_path) if store is not None: storage_obj.add_store(StoreType(store.upper())) @@ -1051,7 +1189,7 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': storage_obj.force_delete = force_delete return storage_obj - def to_yaml_config(self) -> Dict[str, str]: + def to_yaml_config(self) -> Dict[str, Any]: config = {} def add_if_not_none(key: str, value: Optional[Any]): @@ -1067,13 +1205,18 @@ def add_if_not_none(key: str, value: Optional[Any]): add_if_not_none('source', self.source) stores = None - if len(self.stores) > 0: + is_sky_managed = self._is_sky_managed + if self.stores: stores = ','.join([store.value for store in self.stores]) + is_sky_managed = list(self.stores.values())[0].is_sky_managed add_if_not_none('store', stores) + add_if_not_none('_is_sky_managed', is_sky_managed) add_if_not_none('persistent', self.persistent) add_if_not_none('mode', self.mode.value) if self.force_delete: config['_force_delete'] = True + if self._bucket_sub_path is not None: + config['_bucket_sub_path'] = self._bucket_sub_path return config @@ -1095,7 +1238,8 @@ def __init__(self, source: str, region: Optional[str] = _DEFAULT_REGION, is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.client: 'boto3.client.Client' self.bucket: 'StorageHandle' # TODO(romilb): This is purely a stopgap fix for @@ -1108,7 +1252,7 @@ def __init__(self, f'{self._DEFAULT_REGION} for bucket {name!r}.') region = self._DEFAULT_REGION super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -1147,6 +1291,9 @@ def _validate(self): assert data_utils.verify_ibm_cos_bucket(self.name), ( f'Source specified as {self.source}, a COS bucket. ', 'COS Bucket should exist.') + elif self.source.startswith('oci://'): + raise NotImplementedError( + 'Moving data from OCI to S3 is currently not supported.') # Validate name self.name = self.validate_name(self.name) @@ -1258,6 +1405,8 @@ def upload(self): self._transfer_to_s3() elif self.source.startswith('r2://'): self._transfer_to_s3() + elif self.source.startswith('oci://'): + self._transfer_to_s3() else: self.batch_aws_rsync([self.source]) except exceptions.StorageUploadError: @@ -1267,6 +1416,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_s3_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted S3 bucket {self.name}.' @@ -1276,6 +1428,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_s3_bucket_sub_path( + self.name, self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Removed objects from S3 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'Failed to remove objects from S3 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return aws.resource('s3').Bucket(self.name) @@ -1306,9 +1471,11 @@ def get_file_sync_command(base_dir_path, file_names): for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('aws s3 sync --no-follow-symlinks --exclude="*" ' f'{includes} {base_dir_path} ' - f's3://{self.name}') + f's3://{self.name}{sub_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name): @@ -1320,9 +1487,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): for file_name in excluded_list ]) src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'aws s3 sync --no-follow-symlinks {excludes} ' f'{src_dir_path} ' - f's3://{self.name}/{dest_dir_name}') + f's3://{self.name}{sub_path}/{dest_dir_name}') return sync_command # Generate message for upload @@ -1331,17 +1500,24 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): else: source_message = source_path_list[0] + log_path = sky_logging.generate_tmp_logging_file_path( + _STORAGE_LOG_FILE_NAME) + sync_path = f'{source_message} -> s3://{self.name}/' with rich_utils.safe_status( - ux_utils.spinner_message(f'Syncing {source_message} -> ' - f's3://{self.name}/')): + ux_utils.spinner_message(f'Syncing {sync_path}', + log_path=log_path)): data_utils.parallel_upload( source_path_list, get_file_sync_command, get_dir_sync_command, + log_path, self.name, self._ACCESS_DENIED_MESSAGE, create_dirs=create_dirs, max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS) + logger.info( + ux_utils.finishing_message(f'Storage synced: {sync_path}', + log_path)) def _transfer_to_s3(self) -> None: assert isinstance(self.source, str), self.source @@ -1433,7 +1609,8 @@ def mount_command(self, mount_path: str) -> str: """ install_cmd = mounting_utils.get_s3_mount_install_cmd() mount_cmd = mounting_utils.get_s3_mount_cmd(self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -1483,6 +1660,27 @@ def _create_s3_bucket(self, ) from e return aws.resource('s3').Bucket(bucket_name) + def _execute_s3_remove_command(self, command: str, bucket_name: str, + hint_operating: str, + hint_failed: str) -> bool: + try: + with rich_utils.safe_status( + ux_utils.spinner_message(hint_operating)): + subprocess.check_output(command.split(' '), + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + if 'NoSuchBucket' in e.output.decode('utf-8'): + logger.debug( + _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( + bucket_name=bucket_name)) + return False + else: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'{hint_failed}' + f'Detailed error: {e.output}') + return True + def _delete_s3_bucket(self, bucket_name: str) -> bool: """Deletes S3 bucket, including all objects in bucket @@ -1500,29 +1698,28 @@ def _delete_s3_bucket(self, bucket_name: str) -> bool: # The fastest way to delete is to run `aws s3 rb --force`, # which removes the bucket by force. remove_command = f'aws s3 rb s3://{bucket_name} --force' - try: - with rich_utils.safe_status( - ux_utils.spinner_message( - f'Deleting S3 bucket [green]{bucket_name}')): - subprocess.check_output(remove_command.split(' '), - stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - if 'NoSuchBucket' in e.output.decode('utf-8'): - logger.debug( - _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( - bucket_name=bucket_name)) - return False - else: - with ux_utils.print_exception_no_traceback(): - raise exceptions.StorageBucketDeleteError( - f'Failed to delete S3 bucket {bucket_name}.' - f'Detailed error: {e.output}') + success = self._execute_s3_remove_command( + remove_command, bucket_name, + f'Deleting S3 bucket [green]{bucket_name}[/]', + f'Failed to delete S3 bucket {bucket_name}.') + if not success: + return False # Wait until bucket deletion propagates on AWS servers while data_utils.verify_s3_bucket(bucket_name): time.sleep(0.1) return True + def _delete_s3_bucket_sub_path(self, bucket_name: str, + sub_path: str) -> bool: + """Deletes the sub path from the bucket.""" + remove_command = f'aws s3 rm s3://{bucket_name}/{sub_path}/ --recursive' + return self._execute_s3_remove_command( + remove_command, bucket_name, f'Removing objects from S3 bucket ' + f'[green]{bucket_name}/{sub_path}[/]', + f'Failed to remove objects from S3 bucket {bucket_name}/{sub_path}.' + ) + class GcsStore(AbstractStore): """GcsStore inherits from Storage Object and represents the backend @@ -1536,11 +1733,12 @@ def __init__(self, source: str, region: Optional[str] = 'us-central1', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): self.client: 'storage.Client' self.bucket: StorageHandle super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -1579,6 +1777,9 @@ def _validate(self): assert data_utils.verify_ibm_cos_bucket(self.name), ( f'Source specified as {self.source}, a COS bucket. ', 'COS Bucket should exist.') + elif self.source.startswith('oci://'): + raise NotImplementedError( + 'Moving data from OCI to GCS is currently not supported.') # Validate name self.name = self.validate_name(self.name) # Check if the storage is enabled @@ -1687,6 +1888,8 @@ def upload(self): self._transfer_to_gcs() elif self.source.startswith('r2://'): self._transfer_to_gcs() + elif self.source.startswith('oci://'): + self._transfer_to_gcs() else: # If a single directory is specified in source, upload # contents to root of bucket by suffixing /*. @@ -1698,6 +1901,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_gcs_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted GCS bucket {self.name}.' @@ -1707,6 +1913,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_gcs_bucket(self.name, + self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Deleted objects in GCS bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'GCS bucket {self.name} may have ' \ + 'been deleted externally.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return self.client.get_bucket(self.name) @@ -1741,13 +1960,19 @@ def batch_gsutil_cp(self, gsutil_alias, alias_gen = data_utils.get_gsutil_command() sync_command = (f'{alias_gen}; echo "{copy_list}" | {gsutil_alias} ' f'cp -e -n -r -I gs://{self.name}') - + log_path = sky_logging.generate_tmp_logging_file_path( + _STORAGE_LOG_FILE_NAME) + sync_path = f'{source_message} -> gs://{self.name}/' with rich_utils.safe_status( - ux_utils.spinner_message(f'Syncing {source_message} -> ' - f'gs://{self.name}/')): + ux_utils.spinner_message(f'Syncing {sync_path}', + log_path=log_path)): data_utils.run_upload_cli(sync_command, self._ACCESS_DENIED_MESSAGE, - bucket_name=self.name) + bucket_name=self.name, + log_path=log_path) + logger.info( + ux_utils.finishing_message(f'Storage synced: {sync_path}', + log_path)) def batch_gsutil_rsync(self, source_path_list: List[Path], @@ -1774,9 +1999,11 @@ def get_file_sync_command(base_dir_path, file_names): sync_format = '|'.join(file_names) gsutil_alias, alias_gen = data_utils.get_gsutil_command() base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -x \'^(?!{sync_format}$).*\' ' - f'{base_dir_path} gs://{self.name}') + f'{base_dir_path} gs://{self.name}{sub_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name): @@ -1786,9 +2013,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): excludes = '|'.join(excluded_list) gsutil_alias, alias_gen = data_utils.get_gsutil_command() src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -r -x \'({excludes})\' {src_dir_path} ' - f'gs://{self.name}/{dest_dir_name}') + f'gs://{self.name}{sub_path}/{dest_dir_name}') return sync_command # Generate message for upload @@ -1797,17 +2026,24 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): else: source_message = source_path_list[0] + log_path = sky_logging.generate_tmp_logging_file_path( + _STORAGE_LOG_FILE_NAME) + sync_path = f'{source_message} -> gs://{self.name}/' with rich_utils.safe_status( - ux_utils.spinner_message(f'Syncing {source_message} -> ' - f'gs://{self.name}/')): + ux_utils.spinner_message(f'Syncing {sync_path}', + log_path=log_path)): data_utils.parallel_upload( source_path_list, get_file_sync_command, get_dir_sync_command, + log_path, self.name, self._ACCESS_DENIED_MESSAGE, create_dirs=create_dirs, max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS) + logger.info( + ux_utils.finishing_message(f'Storage synced: {sync_path}', + log_path)) def _transfer_to_gcs(self) -> None: if isinstance(self.source, str) and self.source.startswith('s3://'): @@ -1886,7 +2122,8 @@ def mount_command(self, mount_path: str) -> str: """ install_cmd = mounting_utils.get_gcs_mount_install_cmd() mount_cmd = mounting_utils.get_gcs_mount_cmd(self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) version_check_cmd = ( f'gcsfuse --version | grep -q {mounting_utils.GCSFUSE_VERSION}') return mounting_utils.get_mounting_command(mount_path, install_cmd, @@ -1926,19 +2163,33 @@ def _create_gcs_bucket(self, f'{new_bucket.storage_class}{colorama.Style.RESET_ALL}') return new_bucket - def _delete_gcs_bucket(self, bucket_name: str) -> bool: - """Deletes GCS bucket, including all objects in bucket + def _delete_gcs_bucket( + self, + bucket_name: str, + # pylint: disable=invalid-name + _bucket_sub_path: Optional[str] = None + ) -> bool: + """Deletes objects in GCS bucket Args: bucket_name: str; Name of bucket + _bucket_sub_path: str; Sub path in the bucket, if provided only + objects in the sub path will be deleted, else the whole bucket will + be deleted Returns: bool; True if bucket was deleted, False if it was deleted externally. """ - + if _bucket_sub_path is not None: + command_suffix = f'/{_bucket_sub_path}' + hint_text = 'objects in ' + else: + command_suffix = '' + hint_text = '' with rich_utils.safe_status( ux_utils.spinner_message( - f'Deleting GCS bucket [green]{bucket_name}')): + f'Deleting {hint_text}GCS bucket ' + f'[green]{bucket_name}{command_suffix}[/]')): try: self.client.get_bucket(bucket_name) except gcp.forbidden_exception() as e: @@ -1956,8 +2207,9 @@ def _delete_gcs_bucket(self, bucket_name: str) -> bool: return False try: gsutil_alias, alias_gen = data_utils.get_gsutil_command() - remove_obj_command = (f'{alias_gen};{gsutil_alias} ' - f'rm -r gs://{bucket_name}') + remove_obj_command = ( + f'{alias_gen};{gsutil_alias} ' + f'rm -r gs://{bucket_name}{command_suffix}') subprocess.check_output(remove_obj_command, stderr=subprocess.STDOUT, shell=True, @@ -1966,7 +2218,8 @@ def _delete_gcs_bucket(self, bucket_name: str) -> bool: except subprocess.CalledProcessError as e: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageBucketDeleteError( - f'Failed to delete GCS bucket {bucket_name}.' + f'Failed to delete {hint_text}GCS bucket ' + f'{bucket_name}{command_suffix}.' f'Detailed error: {e.output}') @@ -2018,7 +2271,8 @@ def __init__(self, storage_account_name: str = '', region: Optional[str] = 'eastus', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.storage_client: 'storage.Client' self.resource_client: 'storage.Client' self.container_name: str @@ -2030,7 +2284,7 @@ def __init__(self, if region is None: region = 'eastus' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) @classmethod def from_metadata(cls, metadata: AbstractStore.StoreMetadata, @@ -2100,6 +2354,9 @@ def _validate(self): assert data_utils.verify_ibm_cos_bucket(self.name), ( f'Source specified as {self.source}, a COS bucket. ', 'COS Bucket should exist.') + elif self.source.startswith('oci://'): + raise NotImplementedError( + 'Moving data from OCI to AZureBlob is not supported.') # Validate name self.name = self.validate_name(self.name) @@ -2177,6 +2434,17 @@ def initialize(self): """ self.storage_client = data_utils.create_az_client('storage') self.resource_client = data_utils.create_az_client('resource') + self._update_storage_account_name_and_resource() + + self.container_name, is_new_bucket = self._get_bucket() + if self.is_sky_managed is None: + # If is_sky_managed is not specified, then this is a new storage + # object (i.e., did not exist in global_user_state) and we should + # set the is_sky_managed property. + # If is_sky_managed is specified, then we take no action. + self.is_sky_managed = is_new_bucket + + def _update_storage_account_name_and_resource(self): self.storage_account_name, self.resource_group_name = ( self._get_storage_account_and_resource_group()) @@ -2187,13 +2455,13 @@ def initialize(self): self.storage_account_name, self.resource_group_name, self.storage_client, self.resource_client) - self.container_name, is_new_bucket = self._get_bucket() - if self.is_sky_managed is None: - # If is_sky_managed is not specified, then this is a new storage - # object (i.e., did not exist in global_user_state) and we should - # set the is_sky_managed property. - # If is_sky_managed is specified, then we take no action. - self.is_sky_managed = is_new_bucket + def update_storage_attributes(self, **kwargs: Dict[str, Any]): + assert 'storage_account_name' in kwargs, ( + 'only storage_account_name supported') + assert isinstance(kwargs['storage_account_name'], + str), ('storage_account_name must be a string') + self.storage_account_name = kwargs['storage_account_name'] + self._update_storage_account_name_and_resource() @staticmethod def get_default_storage_account_name(region: Optional[str]) -> str: @@ -2452,6 +2720,8 @@ def upload(self): raise NotImplementedError(error_message.format('R2')) elif self.source.startswith('cos://'): raise NotImplementedError(error_message.format('IBM COS')) + elif self.source.startswith('oci://'): + raise NotImplementedError(error_message.format('OCI')) else: self.batch_az_blob_sync([self.source]) except exceptions.StorageUploadError: @@ -2462,6 +2732,9 @@ def upload(self): def delete(self) -> None: """Deletes the storage.""" + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_az_bucket(self.name) if deleted_by_skypilot: msg_str = (f'Deleted AZ Container {self.name!r} under storage ' @@ -2472,6 +2745,32 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + try: + container_url = data_utils.AZURE_CONTAINER_URL.format( + storage_account_name=self.storage_account_name, + container_name=self.name) + container_client = data_utils.create_az_client( + client_type='container', + container_url=container_url, + storage_account_name=self.storage_account_name, + resource_group_name=self.resource_group_name) + # List and delete blobs in the specified directory + blobs = container_client.list_blobs( + name_starts_with=self._bucket_sub_path + '/') + for blob in blobs: + container_client.delete_blob(blob.name) + logger.info( + f'Deleted objects from sub path {self._bucket_sub_path} ' + f'in container {self.name}.') + except Exception as e: # pylint: disable=broad-except + logger.error( + f'Failed to delete objects from sub path ' + f'{self._bucket_sub_path} in container {self.name}. ' + f'Details: {common_utils.format_exception(e, use_bracket=True)}' + ) + def get_handle(self) -> StorageHandle: """Returns the Storage Handle object.""" return self.storage_client.blob_containers.get( @@ -2498,13 +2797,15 @@ def get_file_sync_command(base_dir_path, file_names) -> str: includes_list = ';'.join(file_names) includes = f'--include-pattern "{includes_list}"' base_dir_path = shlex.quote(base_dir_path) + container_path = (f'{self.container_name}/{self._bucket_sub_path}' + if self._bucket_sub_path else self.container_name) sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' f'{includes} ' '--delete-destination false ' f'--source {base_dir_path} ' - f'--container {self.container_name}') + f'--container {container_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: @@ -2515,8 +2816,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: [file_name.rstrip('*') for file_name in excluded_list]) excludes = f'--exclude-path "{excludes_list}"' src_dir_path = shlex.quote(src_dir_path) - container_path = (f'{self.container_name}/{dest_dir_name}' - if dest_dir_name else self.container_name) + container_path = (f'{self.container_name}/{self._bucket_sub_path}' + if self._bucket_sub_path else + f'{self.container_name}') + if dest_dir_name: + container_path = f'{container_path}/{dest_dir_name}' sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' @@ -2535,17 +2839,24 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: container_endpoint = data_utils.AZURE_CONTAINER_URL.format( storage_account_name=self.storage_account_name, container_name=self.name) + log_path = sky_logging.generate_tmp_logging_file_path( + _STORAGE_LOG_FILE_NAME) + sync_path = f'{source_message} -> {container_endpoint}/' with rich_utils.safe_status( - ux_utils.spinner_message( - f'Syncing {source_message} -> {container_endpoint}/')): + ux_utils.spinner_message(f'Syncing {sync_path}', + log_path=log_path)): data_utils.parallel_upload( source_path_list, get_file_sync_command, get_dir_sync_command, + log_path, self.name, self._ACCESS_DENIED_MESSAGE, create_dirs=create_dirs, max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS) + logger.info( + ux_utils.finishing_message(f'Storage synced: {sync_path}', + log_path)) def _get_bucket(self) -> Tuple[str, bool]: """Obtains the AZ Container. @@ -2632,6 +2943,7 @@ def _get_bucket(self) -> Tuple[str, bool]: f'{self.storage_account_name!r}.' 'Details: ' f'{common_utils.format_exception(e, use_bracket=True)}') + # If the container cannot be found in both private and public settings, # the container is to be created by Sky. However, creation is skipped # if Store object is being reconstructed for deletion or re-mount with @@ -2662,7 +2974,8 @@ def mount_command(self, mount_path: str) -> str: mount_cmd = mounting_utils.get_az_mount_cmd(self.container_name, self.storage_account_name, mount_path, - self.storage_account_key) + self.storage_account_key, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -2761,11 +3074,12 @@ def __init__(self, source: str, region: Optional[str] = 'auto', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): self.client: 'boto3.client.Client' self.bucket: 'StorageHandle' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -2804,6 +3118,10 @@ def _validate(self): assert data_utils.verify_ibm_cos_bucket(self.name), ( f'Source specified as {self.source}, a COS bucket. ', 'COS Bucket should exist.') + elif self.source.startswith('oci://'): + raise NotImplementedError( + 'Moving data from OCI to R2 is currently not supported.') + # Validate name self.name = S3Store.validate_name(self.name) # Check if the storage is enabled @@ -2855,6 +3173,8 @@ def upload(self): self._transfer_to_r2() elif self.source.startswith('r2://'): pass + elif self.source.startswith('oci://'): + self._transfer_to_r2() else: self.batch_aws_rsync([self.source]) except exceptions.StorageUploadError: @@ -2864,6 +3184,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_r2_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted R2 bucket {self.name}.' @@ -2873,6 +3196,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_r2_bucket_sub_path( + self.name, self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Removed objects from R2 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'Failed to remove objects from R2 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return cloudflare.resource('s3').Bucket(self.name) @@ -2904,11 +3240,13 @@ def get_file_sync_command(base_dir_path, file_names): ]) endpoint_url = cloudflare.create_endpoint() base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' 'aws s3 sync --no-follow-symlinks --exclude="*" ' f'{includes} {base_dir_path} ' - f's3://{self.name} ' + f's3://{self.name}{sub_path} ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') return sync_command @@ -2923,11 +3261,13 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): ]) endpoint_url = cloudflare.create_endpoint() src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' f'aws s3 sync --no-follow-symlinks {excludes} ' f'{src_dir_path} ' - f's3://{self.name}/{dest_dir_name} ' + f's3://{self.name}{sub_path}/{dest_dir_name} ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') return sync_command @@ -2938,17 +3278,24 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): else: source_message = source_path_list[0] + log_path = sky_logging.generate_tmp_logging_file_path( + _STORAGE_LOG_FILE_NAME) + sync_path = f'{source_message} -> r2://{self.name}/' with rich_utils.safe_status( - ux_utils.spinner_message( - f'Syncing {source_message} -> r2://{self.name}/')): + ux_utils.spinner_message(f'Syncing {sync_path}', + log_path=log_path)): data_utils.parallel_upload( source_path_list, get_file_sync_command, get_dir_sync_command, + log_path, self.name, self._ACCESS_DENIED_MESSAGE, create_dirs=create_dirs, max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS) + logger.info( + ux_utils.finishing_message(f'Storage synced: {sync_path}', + log_path)) def _transfer_to_r2(self) -> None: assert isinstance(self.source, str), self.source @@ -3051,11 +3398,9 @@ def mount_command(self, mount_path: str) -> str: endpoint_url = cloudflare.create_endpoint() r2_credential_path = cloudflare.R2_CREDENTIALS_PATH r2_profile_name = cloudflare.R2_PROFILE_NAME - mount_cmd = mounting_utils.get_r2_mount_cmd(r2_credential_path, - r2_profile_name, - endpoint_url, - self.bucket.name, - mount_path) + mount_cmd = mounting_utils.get_r2_mount_cmd( + r2_credential_path, r2_profile_name, endpoint_url, self.bucket.name, + mount_path, self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -3088,6 +3433,43 @@ def _create_r2_bucket(self, f'{self.name} but failed.') from e return cloudflare.resource('s3').Bucket(bucket_name) + def _execute_r2_remove_command(self, command: str, bucket_name: str, + hint_operating: str, + hint_failed: str) -> bool: + try: + with rich_utils.safe_status( + ux_utils.spinner_message(hint_operating)): + subprocess.check_output(command.split(' '), + stderr=subprocess.STDOUT, + shell=True) + except subprocess.CalledProcessError as e: + if 'NoSuchBucket' in e.output.decode('utf-8'): + logger.debug( + _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( + bucket_name=bucket_name)) + return False + else: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'{hint_failed}' + f'Detailed error: {e.output}') + return True + + def _delete_r2_bucket_sub_path(self, bucket_name: str, + sub_path: str) -> bool: + """Deletes the sub path from the bucket.""" + endpoint_url = cloudflare.create_endpoint() + remove_command = ( + f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} ' + f'aws s3 rm s3://{bucket_name}/{sub_path}/ --recursive ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + return self._execute_r2_remove_command( + remove_command, bucket_name, + f'Removing objects from R2 bucket {bucket_name}/{sub_path}', + f'Failed to remove objects from R2 bucket {bucket_name}/{sub_path}.' + ) + def _delete_r2_bucket(self, bucket_name: str) -> bool: """Deletes R2 bucket, including all objects in bucket @@ -3110,24 +3492,12 @@ def _delete_r2_bucket(self, bucket_name: str) -> bool: f'aws s3 rb s3://{bucket_name} --force ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') - try: - with rich_utils.safe_status( - ux_utils.spinner_message( - f'Deleting R2 bucket {bucket_name}')): - subprocess.check_output(remove_command, - stderr=subprocess.STDOUT, - shell=True) - except subprocess.CalledProcessError as e: - if 'NoSuchBucket' in e.output.decode('utf-8'): - logger.debug( - _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( - bucket_name=bucket_name)) - return False - else: - with ux_utils.print_exception_no_traceback(): - raise exceptions.StorageBucketDeleteError( - f'Failed to delete R2 bucket {bucket_name}.' - f'Detailed error: {e.output}') + + success = self._execute_r2_remove_command( + remove_command, bucket_name, f'Deleting R2 bucket {bucket_name}', + f'Failed to delete R2 bucket {bucket_name}.') + if not success: + return False # Wait until bucket deletion propagates on AWS servers while data_utils.verify_r2_bucket(bucket_name): @@ -3146,11 +3516,12 @@ def __init__(self, source: str, region: Optional[str] = 'us-east', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.client: 'storage.Client' self.bucket: 'StorageHandle' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) self.bucket_rclone_profile = \ Rclone.generate_rclone_bucket_profile_name( self.name, Rclone.RcloneClouds.IBM) @@ -3295,10 +3666,22 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + self._delete_cos_bucket() logger.info(f'{colorama.Fore.GREEN}Deleted COS bucket {self.name}.' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + bucket = self.s3_resource.Bucket(self.name) + try: + self._delete_cos_bucket_objects(bucket, self._bucket_sub_path + '/') + except ibm.ibm_botocore.exceptions.ClientError as e: + if e.__class__.__name__ == 'NoSuchBucket': + logger.debug('bucket already removed') + def get_handle(self) -> StorageHandle: return self.s3_resource.Bucket(self.name) @@ -3339,10 +3722,13 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: # .git directory is excluded from the sync # wrapping src_dir_path with "" to support path with spaces src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ( 'rclone copy --exclude ".git/*" ' f'{src_dir_path} ' - f'{self.bucket_rclone_profile}:{self.name}/{dest_dir_name}') + f'{self.bucket_rclone_profile}:{self.name}{sub_path}' + f'/{dest_dir_name}') return sync_command def get_file_sync_command(base_dir_path, file_names) -> str: @@ -3368,9 +3754,12 @@ def get_file_sync_command(base_dir_path, file_names) -> str: for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) - sync_command = ('rclone copy ' - f'{includes} {base_dir_path} ' - f'{self.bucket_rclone_profile}:{self.name}') + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') + sync_command = ( + 'rclone copy ' + f'{includes} {base_dir_path} ' + f'{self.bucket_rclone_profile}:{self.name}{sub_path}') return sync_command # Generate message for upload @@ -3379,17 +3768,24 @@ def get_file_sync_command(base_dir_path, file_names) -> str: else: source_message = source_path_list[0] + log_path = sky_logging.generate_tmp_logging_file_path( + _STORAGE_LOG_FILE_NAME) + sync_path = f'{source_message} -> cos://{self.region}/{self.name}/' with rich_utils.safe_status( - ux_utils.spinner_message(f'Syncing {source_message} -> ' - f'cos://{self.region}/{self.name}/')): + ux_utils.spinner_message(f'Syncing {sync_path}', + log_path=log_path)): data_utils.parallel_upload( source_path_list, get_file_sync_command, get_dir_sync_command, + log_path, self.name, self._ACCESS_DENIED_MESSAGE, create_dirs=create_dirs, max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS) + logger.info( + ux_utils.finishing_message(f'Storage synced: {sync_path}', + log_path)) def _get_bucket(self) -> Tuple[StorageHandle, bool]: """returns IBM COS bucket object if exists, otherwise creates it. @@ -3448,6 +3844,7 @@ def _get_bucket(self) -> Tuple[StorageHandle, bool]: Rclone.RcloneClouds.IBM, self.region, # type: ignore ) + if not bucket_region and self.sync_on_reconstruction: # bucket doesn't exist return self._create_cos_bucket(self.name, self.region), True @@ -3494,7 +3891,8 @@ def mount_command(self, mount_path: str) -> str: Rclone.RCLONE_CONFIG_PATH, self.bucket_rclone_profile, self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -3532,18 +3930,442 @@ def _create_cos_bucket(self, return self.bucket - def _delete_cos_bucket(self): - bucket = self.s3_resource.Bucket(self.name) - try: - bucket_versioning = self.s3_resource.BucketVersioning(self.name) - if bucket_versioning.status == 'Enabled': + def _delete_cos_bucket_objects(self, + bucket: Any, + prefix: Optional[str] = None): + bucket_versioning = self.s3_resource.BucketVersioning(bucket.name) + if bucket_versioning.status == 'Enabled': + if prefix is not None: + res = list( + bucket.object_versions.filter(Prefix=prefix).delete()) + else: res = list(bucket.object_versions.delete()) + else: + if prefix is not None: + res = list(bucket.objects.filter(Prefix=prefix).delete()) else: res = list(bucket.objects.delete()) - logger.debug(f'Deleted bucket\'s content:\n{res}') + logger.debug(f'Deleted bucket\'s content:\n{res}, prefix: {prefix}') + + def _delete_cos_bucket(self): + bucket = self.s3_resource.Bucket(self.name) + try: + self._delete_cos_bucket_objects(bucket) bucket.delete() bucket.wait_until_not_exists() except ibm.ibm_botocore.exceptions.ClientError as e: if e.__class__.__name__ == 'NoSuchBucket': logger.debug('bucket already removed') Rclone.delete_rclone_bucket_profile(self.name, Rclone.RcloneClouds.IBM) + + +class OciStore(AbstractStore): + """OciStore inherits from Storage Object and represents the backend + for OCI buckets. + """ + + _ACCESS_DENIED_MESSAGE = 'AccessDeniedException' + + def __init__(self, + name: str, + source: str, + region: Optional[str] = None, + is_sky_managed: Optional[bool] = None, + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): + self.client: Any + self.bucket: StorageHandle + self.oci_config_file: str + self.config_profile: str + self.compartment: str + self.namespace: str + + # Bucket region should be consistence with the OCI config file + region = oci.get_oci_config()['region'] + + super().__init__(name, source, region, is_sky_managed, + sync_on_reconstruction, _bucket_sub_path) + # TODO(zpoint): add _bucket_sub_path to the sync/mount/delete commands + + def _validate(self): + if self.source is not None and isinstance(self.source, str): + if self.source.startswith('oci://'): + assert self.name == data_utils.split_oci_path(self.source)[0], ( + 'OCI Bucket is specified as path, the name should be ' + 'the same as OCI bucket.') + elif not re.search(r'^\w+://', self.source): + # Treat it as local path. + pass + else: + raise NotImplementedError( + f'Moving data from {self.source} to OCI is not supported.') + + # Validate name + self.name = self.validate_name(self.name) + # Check if the storage is enabled + if not _is_storage_cloud_enabled(str(clouds.OCI())): + with ux_utils.print_exception_no_traceback(): + raise exceptions.ResourcesUnavailableError( + 'Storage \'store: oci\' specified, but ' \ + 'OCI access is disabled. To fix, enable '\ + 'OCI by running `sky check`. '\ + 'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long + ) + + @classmethod + def validate_name(cls, name) -> str: + """Validates the name of the OCI store. + + Source for rules: https://docs.oracle.com/en-us/iaas/Content/Object/Tasks/managingbuckets.htm#Managing_Buckets # pylint: disable=line-too-long + """ + + def _raise_no_traceback_name_error(err_str): + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageNameError(err_str) + + if name is not None and isinstance(name, str): + # Check for overall length + if not 1 <= len(name) <= 256: + _raise_no_traceback_name_error( + f'Invalid store name: name {name} must contain 1-256 ' + 'characters.') + + # Check for valid characters and start/end with a number or letter + pattern = r'^[A-Za-z0-9-._]+$' + if not re.match(pattern, name): + _raise_no_traceback_name_error( + f'Invalid store name: name {name} can only contain ' + 'upper or lower case letters, numeric characters, hyphens ' + '(-), underscores (_), and dots (.). Spaces are not ' + 'allowed. Names must start and end with a number or ' + 'letter.') + else: + _raise_no_traceback_name_error('Store name must be specified.') + return name + + def initialize(self): + """Initializes the OCI store object on the cloud. + + Initialization involves fetching bucket if exists, or creating it if + it does not. + + Raises: + StorageBucketCreateError: If bucket creation fails + StorageBucketGetError: If fetching existing bucket fails + StorageInitError: If general initialization fails. + """ + # pylint: disable=import-outside-toplevel + from sky.clouds.utils import oci_utils + from sky.provision.oci.query_utils import query_helper + + self.oci_config_file = oci.get_config_file() + self.config_profile = oci_utils.oci_config.get_profile() + + ## pylint: disable=line-too-long + # What's compartment? See thttps://docs.oracle.com/en/cloud/foundation/cloud_architecture/governance/compartments.html + self.compartment = query_helper.find_compartment(self.region) + self.client = oci.get_object_storage_client(region=self.region, + profile=self.config_profile) + self.namespace = self.client.get_namespace( + compartment_id=oci.get_oci_config()['tenancy']).data + + self.bucket, is_new_bucket = self._get_bucket() + if self.is_sky_managed is None: + # If is_sky_managed is not specified, then this is a new storage + # object (i.e., did not exist in global_user_state) and we should + # set the is_sky_managed property. + # If is_sky_managed is specified, then we take no action. + self.is_sky_managed = is_new_bucket + + def upload(self): + """Uploads source to store bucket. + + Upload must be called by the Storage handler - it is not called on + Store initialization. + + Raises: + StorageUploadError: if upload fails. + """ + try: + if isinstance(self.source, list): + self.batch_oci_rsync(self.source, create_dirs=True) + elif self.source is not None: + if self.source.startswith('oci://'): + pass + else: + self.batch_oci_rsync([self.source]) + except exceptions.StorageUploadError: + raise + except Exception as e: + raise exceptions.StorageUploadError( + f'Upload failed for store {self.name}') from e + + def delete(self) -> None: + deleted_by_skypilot = self._delete_oci_bucket(self.name) + if deleted_by_skypilot: + msg_str = f'Deleted OCI bucket {self.name}.' + else: + msg_str = (f'OCI bucket {self.name} may have been deleted ' + f'externally. Removing from local state.') + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + + def get_handle(self) -> StorageHandle: + return self.client.get_bucket(namespace_name=self.namespace, + bucket_name=self.name).data + + def batch_oci_rsync(self, + source_path_list: List[Path], + create_dirs: bool = False) -> None: + """Invokes oci sync to batch upload a list of local paths to Bucket + + Use OCI bulk operation to batch process the file upload + + Args: + source_path_list: List of paths to local files or directories + create_dirs: If the local_path is a directory and this is set to + False, the contents of the directory are directly uploaded to + root of the bucket. If the local_path is a directory and this is + set to True, the directory is created in the bucket root and + contents are uploaded to it. + """ + + @oci.with_oci_env + def get_file_sync_command(base_dir_path, file_names): + includes = ' '.join( + [f'--include "{file_name}"' for file_name in file_names]) + sync_command = ( + 'oci os object bulk-upload --no-follow-symlinks --overwrite ' + f'--bucket-name {self.name} --namespace-name {self.namespace} ' + f'--src-dir "{base_dir_path}" {includes}') + + return sync_command + + @oci.with_oci_env + def get_dir_sync_command(src_dir_path, dest_dir_name): + if dest_dir_name and not str(dest_dir_name).endswith('/'): + dest_dir_name = f'{dest_dir_name}/' + + excluded_list = storage_utils.get_excluded_files(src_dir_path) + excluded_list.append('.git/*') + excludes = ' '.join([ + f'--exclude {shlex.quote(file_name)}' + for file_name in excluded_list + ]) + + # we exclude .git directory from the sync + sync_command = ( + 'oci os object bulk-upload --no-follow-symlinks --overwrite ' + f'--bucket-name {self.name} --namespace-name {self.namespace} ' + f'--object-prefix "{dest_dir_name}" --src-dir "{src_dir_path}" ' + f'{excludes} ') + + return sync_command + + # Generate message for upload + if len(source_path_list) > 1: + source_message = f'{len(source_path_list)} paths' + else: + source_message = source_path_list[0] + + log_path = sky_logging.generate_tmp_logging_file_path( + _STORAGE_LOG_FILE_NAME) + sync_path = f'{source_message} -> oci://{self.name}/' + with rich_utils.safe_status( + ux_utils.spinner_message(f'Syncing {sync_path}', + log_path=log_path)): + data_utils.parallel_upload( + source_path_list=source_path_list, + filesync_command_generator=get_file_sync_command, + dirsync_command_generator=get_dir_sync_command, + log_path=log_path, + bucket_name=self.name, + access_denied_message=self._ACCESS_DENIED_MESSAGE, + create_dirs=create_dirs, + max_concurrent_uploads=1) + + logger.info( + ux_utils.finishing_message(f'Storage synced: {sync_path}', + log_path)) + + def _get_bucket(self) -> Tuple[StorageHandle, bool]: + """Obtains the OCI bucket. + If the bucket exists, this method will connect to the bucket. + + If the bucket does not exist, there are three cases: + 1) Raise an error if the bucket source starts with oci:// + 2) Return None if bucket has been externally deleted and + sync_on_reconstruction is False + 3) Create and return a new bucket otherwise + + Return tuple (Bucket, Boolean): The first item is the bucket + json payload from the OCI API call, the second item indicates + if this is a new created bucket(True) or an existing bucket(False). + + Raises: + StorageBucketCreateError: If creating the bucket fails + StorageBucketGetError: If fetching a bucket fails + """ + try: + get_bucket_response = self.client.get_bucket( + namespace_name=self.namespace, bucket_name=self.name) + bucket = get_bucket_response.data + return bucket, False + except oci.service_exception() as e: + if e.status == 404: # Not Found + if isinstance(self.source, + str) and self.source.startswith('oci://'): + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketGetError( + 'Attempted to connect to a non-existent bucket: ' + f'{self.source}') from e + else: + # If bucket cannot be found (i.e., does not exist), it is + # to be created by Sky. However, creation is skipped if + # Store object is being reconstructed for deletion. + if self.sync_on_reconstruction: + bucket = self._create_oci_bucket(self.name) + return bucket, True + else: + return None, False + elif e.status == 401: # Unauthorized + # AccessDenied error for buckets that are private and not + # owned by user. + command = ( + f'oci os object list --namespace-name {self.namespace} ' + f'--bucket-name {self.name}') + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketGetError( + _BUCKET_FAIL_TO_CONNECT_MESSAGE.format(name=self.name) + + f' To debug, consider running `{command}`.') from e + else: + # Unknown / unexpected error happened. This might happen when + # Object storage service itself functions not normal (e.g. + # maintainance event causes internal server error or request + # timeout, etc). + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketGetError( + f'Failed to connect to OCI bucket {self.name}') from e + + def mount_command(self, mount_path: str) -> str: + """Returns the command to mount the bucket to the mount_path. + + Uses Rclone to mount the bucket. + + Args: + mount_path: str; Path to mount the bucket to. + """ + install_cmd = mounting_utils.get_rclone_install_cmd() + mount_cmd = mounting_utils.get_oci_mount_cmd( + mount_path=mount_path, + store_name=self.name, + region=str(self.region), + namespace=self.namespace, + compartment=self.bucket.compartment_id, + config_file=self.oci_config_file, + config_profile=self.config_profile) + version_check_cmd = mounting_utils.get_rclone_version_check_cmd() + + return mounting_utils.get_mounting_command(mount_path, install_cmd, + mount_cmd, version_check_cmd) + + def _download_file(self, remote_path: str, local_path: str) -> None: + """Downloads file from remote to local on OCI bucket + + Args: + remote_path: str; Remote path on OCI bucket + local_path: str; Local path on user's device + """ + if remote_path.startswith(f'/{self.name}'): + # If the remote path is /bucket_name, we need to + # remove the leading / + remote_path = remote_path.lstrip('/') + + filename = os.path.basename(remote_path) + if not local_path.endswith(filename): + local_path = os.path.join(local_path, filename) + + @oci.with_oci_env + def get_file_download_command(remote_path, local_path): + download_command = (f'oci os object get --bucket-name {self.name} ' + f'--namespace-name {self.namespace} ' + f'--name {remote_path} --file {local_path}') + + return download_command + + download_command = get_file_download_command(remote_path, local_path) + + try: + with rich_utils.safe_status( + f'[bold cyan]Downloading: {remote_path} -> {local_path}[/]' + ): + subprocess.check_output(download_command, + stderr=subprocess.STDOUT, + shell=True) + except subprocess.CalledProcessError as e: + logger.error(f'Download failed: {remote_path} -> {local_path}.\n' + f'Detail errors: {e.output}') + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'Failed download file {self.name}:{remote_path}.') from e + + def _create_oci_bucket(self, bucket_name: str) -> StorageHandle: + """Creates OCI bucket with specific name in specific region + + Args: + bucket_name: str; Name of bucket + region: str; Region name, e.g. us-central1, us-west1 + """ + logger.debug(f'_create_oci_bucket: {bucket_name}') + try: + create_bucket_response = self.client.create_bucket( + namespace_name=self.namespace, + create_bucket_details=oci.oci.object_storage.models. + CreateBucketDetails( + name=bucket_name, + compartment_id=self.compartment, + )) + bucket = create_bucket_response.data + return bucket + except oci.service_exception() as e: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketCreateError( + f'Failed to create OCI bucket: {self.name}') from e + + def _delete_oci_bucket(self, bucket_name: str) -> bool: + """Deletes OCI bucket, including all objects in bucket + + Args: + bucket_name: str; Name of bucket + + Returns: + bool; True if bucket was deleted, False if it was deleted externally. + """ + logger.debug(f'_delete_oci_bucket: {bucket_name}') + + @oci.with_oci_env + def get_bucket_delete_command(bucket_name): + remove_command = (f'oci os bucket delete --bucket-name ' + f'{bucket_name} --empty --force') + + return remove_command + + remove_command = get_bucket_delete_command(bucket_name) + + try: + with rich_utils.safe_status( + f'[bold cyan]Deleting OCI bucket {bucket_name}[/]'): + subprocess.check_output(remove_command.split(' '), + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + if 'BucketNotFound' in e.output.decode('utf-8'): + logger.debug( + _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( + bucket_name=bucket_name)) + return False + else: + logger.error(e.output) + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'Failed to delete OCI bucket {bucket_name}.') + return True diff --git a/sky/jobs/__init__.py b/sky/jobs/__init__.py index 5688ca7c7a2..5f52a863e36 100644 --- a/sky/jobs/__init__.py +++ b/sky/jobs/__init__.py @@ -9,6 +9,7 @@ from sky.jobs.core import launch from sky.jobs.core import queue from sky.jobs.core import queue_from_kubernetes_pod +from sky.jobs.core import sync_down_logs from sky.jobs.core import tail_logs from sky.jobs.recovery_strategy import DEFAULT_RECOVERY_STRATEGY from sky.jobs.recovery_strategy import RECOVERY_STRATEGIES @@ -37,6 +38,7 @@ 'queue', 'queue_from_kubernetes_pod', 'tail_logs', + 'sync_down_logs', # utils 'ManagedJobCodeGen', 'format_job_table', diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 1348441a5bd..3cb67daba94 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -347,8 +347,8 @@ def cancel(name: Optional[str] = None, stopped_message='All managed jobs should have finished.') job_id_str = ','.join(map(str, job_ids)) - if sum([len(job_ids) > 0, name is not None, all]) != 1: - argument_str = f'job_ids={job_id_str}' if len(job_ids) > 0 else '' + if sum([bool(job_ids), name is not None, all]) != 1: + argument_str = f'job_ids={job_id_str}' if job_ids else '' argument_str += f' name={name}' if name is not None else '' argument_str += ' all' if all else '' with ux_utils.print_exception_no_traceback(): @@ -427,6 +427,52 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool, controller=controller) +@usage_lib.entrypoint +def sync_down_logs( + name: Optional[str], + job_id: Optional[int], + refresh: bool, + controller: bool, + local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) -> None: + """Sync down logs of managed jobs. + + Please refer to sky.cli.job_logs for documentation. + + Raises: + ValueError: invalid arguments. + sky.exceptions.ClusterNotUpError: the jobs controller is not up. + """ + # TODO(zhwu): Automatically restart the jobs controller + if name is not None and job_id is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Cannot specify both name and job_id.') + + jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER + job_name_or_id_str = '' + if job_id is not None: + job_name_or_id_str = str(job_id) + elif name is not None: + job_name_or_id_str = f'-n {name}' + else: + job_name_or_id_str = '' + handle = _maybe_restart_controller( + refresh, + stopped_message=( + f'{jobs_controller_type.value.name.capitalize()} is stopped. To ' + f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs ' + f'-r --sync-down {job_name_or_id_str}{colorama.Style.RESET_ALL}'), + spinner_message='Retrieving job logs') + + backend = backend_utils.get_backend_from_handle(handle) + assert isinstance(backend, backends.CloudVmRayBackend), backend + + backend.sync_down_managed_job_logs(handle, + job_id=job_id, + job_name=name, + controller=controller, + local_dir=local_dir) + + spot_launch = common_utils.deprecated_function( launch, name='sky.jobs.launch', diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 9a5ab4b3cad..5da807b8bbb 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -564,6 +564,33 @@ def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]: return job_ids +def get_all_job_ids_by_name(name: Optional[str]) -> List[int]: + """Get all job ids by name.""" + name_filter = '' + field_values = [] + if name is not None: + # We match the job name from `job_info` for the jobs submitted after + # #1982, and from `spot` for the jobs submitted before #1982, whose + # job_info is not available. + name_filter = ('WHERE (job_info.name=(?) OR ' + '(job_info.name IS NULL AND spot.task_name=(?)))') + field_values = [name, name] + + # Left outer join is used here instead of join, because the job_info does + # not contain the managed jobs submitted before #1982. + with db_utils.safe_cursor(_DB_PATH) as cursor: + rows = cursor.execute( + f"""\ + SELECT DISTINCT spot.spot_job_id + FROM spot + LEFT OUTER JOIN job_info + ON spot.spot_job_id=job_info.spot_job_id + {name_filter} + ORDER BY spot.spot_job_id DESC""", field_values).fetchall() + job_ids = [row[0] for row in rows if row[0] is not None] + return job_ids + + def _get_all_task_ids_statuses( job_id: int) -> List[Tuple[int, ManagedJobStatus]]: with db_utils.safe_cursor(_DB_PATH) as cursor: @@ -591,7 +618,7 @@ def get_latest_task_id_status( If the job_id does not exist, (None, None) will be returned. """ id_statuses = _get_all_task_ids_statuses(job_id) - if len(id_statuses) == 0: + if not id_statuses: return None, None task_id, status = id_statuses[-1] for task_id, status in id_statuses: @@ -617,7 +644,7 @@ def get_failure_reason(job_id: int) -> Optional[str]: WHERE spot_job_id=(?) ORDER BY task_id ASC""", (job_id,)).fetchall() reason = [r[0] for r in reason if r[0] is not None] - if len(reason) == 0: + if not reason: return None return reason[0] diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 267c205285b..b044e31bda6 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -234,11 +234,11 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: if job_ids is None: job_ids = managed_job_state.get_nonterminal_job_ids_by_name(None) job_ids = list(set(job_ids)) - if len(job_ids) == 0: + if not job_ids: return 'No job to cancel.' job_id_str = ', '.join(map(str, job_ids)) logger.info(f'Cancelling jobs {job_id_str}.') - cancelled_job_ids = [] + cancelled_job_ids: List[int] = [] for job_id in job_ids: # Check the status of the managed job status. If it is in # terminal state, we can safely skip it. @@ -268,7 +268,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: shutil.copy(str(signal_file), str(legacy_signal_file)) cancelled_job_ids.append(job_id) - if len(cancelled_job_ids) == 0: + if not cancelled_job_ids: return 'No job to cancel.' identity_str = f'Job with ID {cancelled_job_ids[0]} is' if len(cancelled_job_ids) > 1: @@ -281,7 +281,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: def cancel_job_by_name(job_name: str) -> str: """Cancel a job by name.""" job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name) - if len(job_ids) == 0: + if not job_ids: return f'No running job found with name {job_name!r}.' if len(job_ids) > 1: return (f'{colorama.Fore.RED}Multiple running jobs found ' @@ -515,7 +515,7 @@ def stream_logs(job_id: Optional[int], for job in managed_jobs if job['job_name'] == job_name } - if len(managed_job_ids) == 0: + if not managed_job_ids: return f'No managed job found with name {job_name!r}.' if len(managed_job_ids) > 1: job_ids_str = ', '.join( @@ -541,7 +541,7 @@ def stream_logs(job_id: Optional[int], if job_id is None: assert job_name is not None job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name) - if len(job_ids) == 0: + if not job_ids: return f'No running managed job found with name {job_name!r}.' if len(job_ids) > 1: raise ValueError( @@ -855,6 +855,15 @@ def cancel_job_by_name(cls, job_name: str) -> str: """) return cls._build(code) + @classmethod + def get_all_job_ids_by_name(cls, job_name: str) -> str: + code = textwrap.dedent(f"""\ + from sky.utils import common_utils + job_id = managed_job_state.get_all_job_ids_by_name({job_name!r}) + print(common_utils.encode_payload(job_id), end="", flush=True) + """) + return cls._build(code) + @classmethod def stream_logs(cls, job_name: Optional[str], diff --git a/sky/optimizer.py b/sky/optimizer.py index 2f70dd39429..5aab31d7750 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -188,7 +188,7 @@ def _remove_dummy_source_sink_nodes(dag: 'dag_lib.Dag'): """Removes special Source and Sink nodes.""" source = [t for t in dag.tasks if t.name == _DUMMY_SOURCE_NAME] sink = [t for t in dag.tasks if t.name == _DUMMY_SINK_NAME] - if len(source) == len(sink) == 0: + if not source and not sink: return assert len(source) == len(sink) == 1, dag.tasks dag.remove(source[0]) @@ -1293,12 +1293,15 @@ def _fill_in_launchable_resources( if resources.cloud is not None else enabled_clouds) # If clouds provide hints, store them for later printing. hints: Dict[clouds.Cloud, str] = {} - for cloud in clouds_list: - feasible_resources = cloud.get_feasible_launchable_resources( - resources, num_nodes=task.num_nodes) + + feasible_list = subprocess_utils.run_in_parallel( + lambda cloud, r=resources, n=task.num_nodes: + (cloud, cloud.get_feasible_launchable_resources(r, n)), + clouds_list) + for cloud, feasible_resources in feasible_list: if feasible_resources.hint is not None: hints[cloud] = feasible_resources.hint - if len(feasible_resources.resources_list) > 0: + if feasible_resources.resources_list: # Assume feasible_resources is sorted by prices. Guaranteed by # the implementation of get_feasible_launchable_resources and # the underlying service_catalog filtering @@ -1310,7 +1313,7 @@ def _fill_in_launchable_resources( else: all_fuzzy_candidates.update( feasible_resources.fuzzy_candidate_list) - if len(launchable[resources]) == 0: + if not launchable[resources]: clouds_str = str(clouds_list) if len(clouds_list) > 1 else str( clouds_list[0]) num_node_str = '' diff --git a/sky/provision/aws/config.py b/sky/provision/aws/config.py index 6a8c77eafed..acc6fcb0e56 100644 --- a/sky/provision/aws/config.py +++ b/sky/provision/aws/config.py @@ -279,7 +279,7 @@ def _has_igw_route(route_tables): logger.debug(f'subnet {subnet_id} route tables: {route_tables}') if _has_igw_route(route_tables): return True - if len(route_tables) > 0: + if route_tables: return False # Handle the case that a "main" route table is implicitly associated with @@ -383,10 +383,13 @@ def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]: raise exc if not subnets: + vpc_msg = (f'Does a default VPC exist in region ' + f'{ec2.meta.client.meta.region_name}? ') if ( + vpc_id_of_sg is None) else '' _skypilot_log_error_and_exit_for_failover( - 'No usable subnets found, try ' - 'manually creating an instance in your specified region to ' - 'populate the list of subnets and trying this again. ' + f'No usable subnets found. {vpc_msg}' + 'Try manually creating an instance in your specified region to ' + 'populate the list of subnets and try again. ' 'Note that the subnet must map public IPs ' 'on instance launch unless you set `use_internal_ips: true` in ' 'the `provider` config.') @@ -454,7 +457,7 @@ def _vpc_id_from_security_group_ids(ec2, sg_ids: List[str]) -> Any: no_sg_msg = ('Failed to detect a security group with id equal to any of ' 'the configured SecurityGroupIds.') - assert len(vpc_ids) > 0, no_sg_msg + assert vpc_ids, no_sg_msg return vpc_ids[0] @@ -495,6 +498,11 @@ def _get_subnet_and_vpc_id(ec2, security_group_ids: Optional[List[str]], vpc_id_of_sg = None all_subnets = list(ec2.subnets.all()) + # If no VPC is specified, use the default VPC. + # We filter only for default VPCs to avoid using subnets that users may + # not want SkyPilot to use. + if vpc_id_of_sg is None: + all_subnets = [s for s in all_subnets if s.vpc.is_default] subnets, vpc_id = _usable_subnets( ec2, user_specified_subnets=None, diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 229d7361e22..4e461375a14 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -343,7 +343,7 @@ def create_single_instance(vm_i): _create_vm(compute_client, vm_name, node_tags, provider_config, node_config, network_interface.id) - subprocess_utils.run_in_parallel(create_single_instance, range(count)) + subprocess_utils.run_in_parallel(create_single_instance, list(range(count))) # Update disk performance tier performance_tier = node_config.get('disk_performance_tier', None) diff --git a/sky/provision/do/__init__.py b/sky/provision/do/__init__.py new file mode 100644 index 00000000000..75502d3cb05 --- /dev/null +++ b/sky/provision/do/__init__.py @@ -0,0 +1,11 @@ +"""DO provisioner for SkyPilot.""" + +from sky.provision.do.config import bootstrap_instances +from sky.provision.do.instance import cleanup_ports +from sky.provision.do.instance import get_cluster_info +from sky.provision.do.instance import open_ports +from sky.provision.do.instance import query_instances +from sky.provision.do.instance import run_instances +from sky.provision.do.instance import stop_instances +from sky.provision.do.instance import terminate_instances +from sky.provision.do.instance import wait_instances diff --git a/sky/provision/do/config.py b/sky/provision/do/config.py new file mode 100644 index 00000000000..0b10f7f7698 --- /dev/null +++ b/sky/provision/do/config.py @@ -0,0 +1,14 @@ +"""Paperspace configuration bootstrapping.""" + +from sky import sky_logging +from sky.provision import common + +logger = sky_logging.init_logger(__name__) + + +def bootstrap_instances( + region: str, cluster_name: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: + """Bootstraps instances for the given cluster.""" + del region, cluster_name + return config diff --git a/sky/provision/do/constants.py b/sky/provision/do/constants.py new file mode 100644 index 00000000000..0010646f873 --- /dev/null +++ b/sky/provision/do/constants.py @@ -0,0 +1,10 @@ +"""DO cloud constants +""" + +POLL_INTERVAL = 5 +WAIT_DELETE_VOLUMES = 5 + +GPU_IMAGES = { + 'gpu-h100x1-80gb': 'gpu-h100x1-base', + 'gpu-h100x8-640gb': 'gpu-h100x8-base', +} diff --git a/sky/provision/do/instance.py b/sky/provision/do/instance.py new file mode 100644 index 00000000000..098ef6e0595 --- /dev/null +++ b/sky/provision/do/instance.py @@ -0,0 +1,287 @@ +"""DigitalOcean instance provisioning.""" + +import time +from typing import Any, Dict, List, Optional +import uuid + +from sky import sky_logging +from sky import status_lib +from sky.provision import common +from sky.provision.do import constants +from sky.provision.do import utils + +# The maximum number of times to poll for the status of an operation +MAX_POLLS = 60 // constants.POLL_INTERVAL +# Stopping instances can take several minutes, so we increase the timeout +MAX_POLLS_FOR_UP_OR_STOP = MAX_POLLS * 8 + +logger = sky_logging.init_logger(__name__) + + +def _get_head_instance( + instances: Dict[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]: + for instance_name, instance_meta in instances.items(): + if instance_name.endswith('-head'): + return instance_meta + return None + + +def run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + """Runs instances for the given cluster.""" + + pending_status = ['new'] + newly_started_instances = utils.filter_instances(cluster_name_on_cloud, + pending_status + ['off']) + while True: + instances = utils.filter_instances(cluster_name_on_cloud, + pending_status) + if not instances: + break + instance_statuses = [ + instance['status'] for instance in instances.values() + ] + logger.info(f'Waiting for {len(instances)} instances to be ready: ' + f'{instance_statuses}') + time.sleep(constants.POLL_INTERVAL) + + exist_instances = utils.filter_instances(cluster_name_on_cloud, + status_filters=pending_status + + ['active', 'off']) + if len(exist_instances) > config.count: + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, but {config.count} are required.') + + stopped_instances = utils.filter_instances(cluster_name_on_cloud, + status_filters=['off']) + for instance in stopped_instances.values(): + utils.start_instance(instance) + for _ in range(MAX_POLLS_FOR_UP_OR_STOP): + instances = utils.filter_instances(cluster_name_on_cloud, ['off']) + if len(instances) == 0: + break + num_stopped_instances = len(stopped_instances) + num_restarted_instances = num_stopped_instances - len(instances) + logger.info( + f'Waiting for {num_restarted_instances}/{num_stopped_instances} ' + 'stopped instances to be restarted.') + time.sleep(constants.POLL_INTERVAL) + else: + msg = ('run_instances: Failed to restart all' + 'instances possibly due to to capacity issue.') + logger.warning(msg) + raise RuntimeError(msg) + + exist_instances = utils.filter_instances(cluster_name_on_cloud, + status_filters=['active']) + head_instance = _get_head_instance(exist_instances) + to_start_count = config.count - len(exist_instances) + if to_start_count < 0: + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, but {config.count} are required.') + if to_start_count == 0: + if head_instance is None: + head_instance = list(exist_instances.values())[0] + utils.rename_instance( + head_instance, + f'{cluster_name_on_cloud}-{uuid.uuid4().hex[:4]}-head') + assert head_instance is not None, ('`head_instance` should not be None') + logger.info(f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, no need to start more.') + return common.ProvisionRecord( + provider_name='do', + cluster_name=cluster_name_on_cloud, + region=region, + zone=None, + head_instance_id=head_instance['name'], + resumed_instance_ids=list(newly_started_instances.keys()), + created_instance_ids=[], + ) + + created_instances: List[Dict[str, Any]] = [] + for _ in range(to_start_count): + instance_type = 'head' if head_instance is None else 'worker' + instance = utils.create_instance( + region=region, + cluster_name_on_cloud=cluster_name_on_cloud, + instance_type=instance_type, + config=config) + logger.info(f'Launched instance {instance["name"]}.') + created_instances.append(instance) + if head_instance is None: + head_instance = instance + + # Wait for instances to be ready. + for _ in range(MAX_POLLS_FOR_UP_OR_STOP): + instances = utils.filter_instances(cluster_name_on_cloud, + status_filters=['active']) + logger.info('Waiting for instances to be ready: ' + f'({len(instances)}/{config.count}).') + if len(instances) == config.count: + break + + time.sleep(constants.POLL_INTERVAL) + else: + # Failed to launch config.count of instances after max retries + msg = 'run_instances: Failed to create the instances' + logger.warning(msg) + raise RuntimeError(msg) + assert head_instance is not None, 'head_instance should not be None' + return common.ProvisionRecord( + provider_name='do', + cluster_name=cluster_name_on_cloud, + region=region, + zone=None, + head_instance_id=head_instance['name'], + resumed_instance_ids=list(stopped_instances.keys()), + created_instance_ids=[ + instance['name'] for instance in created_instances + ], + ) + + +def wait_instances(region: str, cluster_name_on_cloud: str, + state: Optional[status_lib.ClusterStatus]) -> None: + del region, cluster_name_on_cloud, state # unused + # We already wait on ready state in `run_instances` no need + + +def stop_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + del provider_config # unused + all_instances = utils.filter_instances(cluster_name_on_cloud, + status_filters=None) + num_instances = len(all_instances) + + # Request a stop on all instances + for instance_name, instance_meta in all_instances.items(): + if worker_only and instance_name.endswith('-head'): + num_instances -= 1 + continue + utils.stop_instance(instance_meta) + + # Wait for instances to stop + for _ in range(MAX_POLLS_FOR_UP_OR_STOP): + all_instances = utils.filter_instances(cluster_name_on_cloud, ['off']) + if len(all_instances) >= num_instances: + break + time.sleep(constants.POLL_INTERVAL) + else: + raise RuntimeError(f'Maximum number of polls: ' + f'{MAX_POLLS_FOR_UP_OR_STOP} reached. ' + f'Instance {all_instances} is still not in ' + 'STOPPED status.') + + +def terminate_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + """See sky/provision/__init__.py""" + del provider_config # unused + instances = utils.filter_instances(cluster_name_on_cloud, + status_filters=None) + for instance_name, instance_meta in instances.items(): + logger.debug(f'Terminating instance {instance_name}') + if worker_only and instance_name.endswith('-head'): + continue + utils.down_instance(instance_meta) + + for _ in range(MAX_POLLS_FOR_UP_OR_STOP): + instances = utils.filter_instances(cluster_name_on_cloud, + status_filters=None) + if len(instances) == 0 or len(instances) <= 1 and worker_only: + break + time.sleep(constants.POLL_INTERVAL) + else: + msg = ('Failed to delete all instances') + logger.warning(msg) + raise RuntimeError(msg) + + +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, +) -> common.ClusterInfo: + del region # unused + running_instances = utils.filter_instances(cluster_name_on_cloud, + ['active']) + instances: Dict[str, List[common.InstanceInfo]] = {} + head_instance: Optional[str] = None + for instance_name, instance_meta in running_instances.items(): + if instance_name.endswith('-head'): + head_instance = instance_name + for net in instance_meta['networks']['v4']: + if net['type'] == 'public': + instance_ip = net['ip_address'] + break + instances[instance_name] = [ + common.InstanceInfo( + instance_id=instance_meta['name'], + internal_ip=instance_ip, + external_ip=instance_ip, + ssh_port=22, + tags={}, + ) + ] + + assert head_instance is not None, 'no head instance found' + return common.ClusterInfo( + instances=instances, + head_instance_id=head_instance, + provider_name='do', + provider_config=provider_config, + ) + + +def query_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """See sky/provision/__init__.py""" + # terminated instances are not retrieved by the + # API making `non_terminated_only` argument moot. + del non_terminated_only + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + instances = utils.filter_instances(cluster_name_on_cloud, + status_filters=None) + + status_map = { + 'new': status_lib.ClusterStatus.INIT, + 'archive': status_lib.ClusterStatus.INIT, + 'active': status_lib.ClusterStatus.UP, + 'off': status_lib.ClusterStatus.STOPPED, + } + statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {} + for instance_meta in instances.values(): + status = status_map[instance_meta['status']] + statuses[instance_meta['name']] = status + return statuses + + +def open_ports( + cluster_name_on_cloud: str, + ports: List[str], + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + """See sky/provision/__init__.py""" + logger.debug( + f'Skip opening ports {ports} for DigitalOcean instances, as all ' + 'ports are open by default.') + del cluster_name_on_cloud, provider_config, ports + + +def cleanup_ports( + cluster_name_on_cloud: str, + ports: List[str], + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + del cluster_name_on_cloud, provider_config, ports diff --git a/sky/provision/do/utils.py b/sky/provision/do/utils.py new file mode 100644 index 00000000000..ebc1b4ac389 --- /dev/null +++ b/sky/provision/do/utils.py @@ -0,0 +1,306 @@ +"""DigitalOcean API client wrapper for SkyPilot. + +Example usage of `pydo` client library was mostly taken from here: +https://github.com/digitalocean/pydo/blob/main/examples/poc_droplets_volumes_sshkeys.py +""" + +import copy +import os +import typing +from typing import Any, Dict, List, Optional +import urllib +import uuid + +from sky import sky_logging +from sky.adaptors import do +from sky.provision import common +from sky.provision import constants as provision_constants +from sky.provision.do import constants +from sky.utils import common_utils + +if typing.TYPE_CHECKING: + from sky import resources + from sky import status_lib + +logger = sky_logging.init_logger(__name__) + +POSSIBLE_CREDENTIALS_PATHS = [ + os.path.expanduser( + '~/Library/Application Support/doctl/config.yaml'), # OS X + os.path.expanduser( + os.path.join(os.getenv('XDG_CONFIG_HOME', '~/.config/'), + 'doctl/config.yaml')), # Linux +] +INITIAL_BACKOFF_SECONDS = 10 +MAX_BACKOFF_FACTOR = 10 +MAX_ATTEMPTS = 6 +SSH_KEY_NAME_ON_DO = f'sky-key-{common_utils.get_user_hash()}' + +CREDENTIALS_PATH = '~/.config/doctl/config.yaml' +_client = None +_ssh_key_id = None + + +class DigitalOceanError(Exception): + pass + + +def _init_client(): + global _client, CREDENTIALS_PATH + assert _client is None + CREDENTIALS_PATH = None + credentials_found = 0 + for path in POSSIBLE_CREDENTIALS_PATHS: + if os.path.exists(path): + CREDENTIALS_PATH = path + credentials_found += 1 + logger.debug(f'Digital Ocean credential path found at {path}') + if not credentials_found > 1: + logger.debug('more than 1 credential file found') + if CREDENTIALS_PATH is None: + raise DigitalOceanError( + 'no credentials file found from ' + f'the following paths {POSSIBLE_CREDENTIALS_PATHS}') + + # attempt default context + credentials = common_utils.read_yaml(CREDENTIALS_PATH) + default_token = credentials.get('access-token', None) + if default_token is not None: + try: + test_client = do.pydo.Client(token=default_token) + test_client.droplets.list() + logger.debug('trying `default` context') + _client = test_client + return _client + except do.exceptions().HttpResponseError: + pass + + auth_contexts = credentials.get('auth-contexts', None) + if auth_contexts is not None: + for context, api_token in auth_contexts.items(): + try: + test_client = do.pydo.Client(token=api_token) + test_client.droplets.list() + logger.debug(f'using {context} context') + _client = test_client + break + except do.exceptions().HttpResponseError: + continue + else: + raise DigitalOceanError( + 'no valid api tokens found try ' + 'setting a new API token with `doctl auth init`') + return _client + + +def client(): + global _client + if _client is None: + _client = _init_client() + return _client + + +def ssh_key_id(public_key: str): + global _ssh_key_id + if _ssh_key_id is None: + page = 1 + paginated = True + while paginated: + try: + resp = client().ssh_keys.list(per_page=50, page=page) + for ssh_key in resp['ssh_keys']: + if ssh_key['public_key'] == public_key: + _ssh_key_id = ssh_key + return _ssh_key_id + except do.exceptions().HttpResponseError as err: + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: ' + f'{err.error.message}') from err + + pages = resp['links'] + if 'pages' in pages and 'next' in pages['pages']: + pages = pages['pages'] + parsed_url = urllib.parse.urlparse(pages['next']) + page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0]) + else: + paginated = False + + request = { + 'public_key': public_key, + 'name': SSH_KEY_NAME_ON_DO, + } + _ssh_key_id = client().ssh_keys.create(body=request)['ssh_key'] + return _ssh_key_id + + +def _create_volume(request: Dict[str, Any]) -> Dict[str, Any]: + try: + resp = client().volumes.create(body=request) + volume = resp['volume'] + except do.exceptions().HttpResponseError as err: + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: {err.error.message}' + ) from err + else: + return volume + + +def _create_droplet(request: Dict[str, Any]) -> Dict[str, Any]: + try: + resp = client().droplets.create(body=request) + droplet_id = resp['droplet']['id'] + + get_resp = client().droplets.get(droplet_id) + droplet = get_resp['droplet'] + except do.exceptions().HttpResponseError as err: + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: {err.error.message}' + ) from err + return droplet + + +def create_instance(region: str, cluster_name_on_cloud: str, instance_type: str, + config: common.ProvisionConfig) -> Dict[str, Any]: + """Creates a instance and mounts the requested block storage + + Args: + region (str): instance region + instance_name (str): name of instance + config (common.ProvisionConfig): provisioner configuration + + Returns: + Dict[str, Any]: instance metadata + """ + # sort tags by key to support deterministic unit test stubbing + tags = dict(sorted(copy.deepcopy(config.tags).items())) + tags = { + 'Name': cluster_name_on_cloud, + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud, + **tags + } + tags = [f'{key}:{value}' for key, value in tags.items()] + default_image = constants.GPU_IMAGES.get( + config.node_config['InstanceType'], + 'gpu-h100x1-base', + ) + image_id = config.node_config['ImageId'] + image_id = image_id if image_id is not None else default_image + instance_name = (f'{cluster_name_on_cloud}-' + f'{uuid.uuid4().hex[:4]}-{instance_type}') + instance_request = { + 'name': instance_name, + 'region': region, + 'size': config.node_config['InstanceType'], + 'image': image_id, + 'ssh_keys': [ + ssh_key_id( + config.authentication_config['ssh_public_key'])['fingerprint'] + ], + 'tags': tags, + } + instance = _create_droplet(instance_request) + + volume_request = { + 'size_gigabytes': config.node_config['DiskSize'], + 'name': instance_name, + 'region': region, + 'filesystem_type': 'ext4', + 'tags': tags + } + volume = _create_volume(volume_request) + + attach_request = {'type': 'attach', 'droplet_id': instance['id']} + try: + client().volume_actions.post_by_id(volume['id'], attach_request) + except do.exceptions().HttpResponseError as err: + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: {err.error.message}' + ) from err + logger.debug(f'{instance_name} created') + return instance + + +def start_instance(instance: Dict[str, Any]): + try: + client().droplet_actions.post(droplet_id=instance['id'], + body={'type': 'power_on'}) + except do.exceptions().HttpResponseError as err: + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: {err.error.message}' + ) from err + + +def stop_instance(instance: Dict[str, Any]): + try: + client().droplet_actions.post( + droplet_id=instance['id'], + body={'type': 'shutdown'}, + ) + except do.exceptions().HttpResponseError as err: + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: {err.error.message}' + ) from err + + +def down_instance(instance: Dict[str, Any]): + # We use dangerous destroy to atomically delete + # block storage and instance for autodown + try: + client().droplets.destroy_with_associated_resources_dangerous( + droplet_id=instance['id'], x_dangerous=True) + except do.exceptions().HttpResponseError as err: + if 'a destroy is already in progress' in err.error.message: + return + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: {err.error.message}' + ) from err + + +def rename_instance(instance: Dict[str, Any], new_name: str): + try: + client().droplet_actions.rename(droplet=instance['id'], + body={ + 'type': 'rename', + 'name': new_name + }) + except do.exceptions().HttpResponseError as err: + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: {err.error.message}' + ) from err + + +def filter_instances( + cluster_name_on_cloud: str, + status_filters: Optional[List[str]] = None) -> Dict[str, Any]: + """Returns Dict mapping instance name + to instance metadata filtered by status + """ + + filtered_instances: Dict[str, Any] = {} + page = 1 + paginated = True + while paginated: + try: + resp = client().droplets.list( + tag_name=f'{provision_constants.TAG_SKYPILOT_CLUSTER_NAME}:' + f'{cluster_name_on_cloud}', + per_page=50, + page=page) + for instance in resp['droplets']: + if status_filters is None or instance[ + 'status'] in status_filters: + filtered_instances[instance['name']] = instance + except do.exceptions().HttpResponseError as err: + raise DigitalOceanError( + f'Error: {err.status_code} {err.reason}: {err.error.message}' + ) from err + + pages = resp['links'] + if 'pages' in pages and 'next' in pages['pages']: + pages = pages['pages'] + parsed_url = urllib.parse.urlparse(pages['next']) + page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0]) + else: + paginated = False + return filtered_instances diff --git a/sky/provision/docker_utils.py b/sky/provision/docker_utils.py index c55508ab41a..0aadcc55335 100644 --- a/sky/provision/docker_utils.py +++ b/sky/provision/docker_utils.py @@ -38,6 +38,13 @@ class DockerLoginConfig: password: str server: str + def format_image(self, image: str) -> str: + """Format the image name with the server prefix.""" + server_prefix = f'{self.server}/' + if not image.startswith(server_prefix): + return f'{server_prefix}{image}' + return image + @classmethod def from_env_vars(cls, d: Dict[str, str]) -> 'DockerLoginConfig': return cls( @@ -220,9 +227,7 @@ def initialize(self) -> str: wait_for_docker_daemon=True) # We automatically add the server prefix to the image name if # the user did not add it. - server_prefix = f'{docker_login_config.server}/' - if not specific_image.startswith(server_prefix): - specific_image = f'{server_prefix}{specific_image}' + specific_image = docker_login_config.format_image(specific_image) if self.docker_config.get('pull_before_run', True): assert specific_image, ('Image must be included in config if ' + @@ -338,14 +343,20 @@ def _check_docker_installed(self): no_exist = 'NoExist' # SkyPilot: Add the current user to the docker group first (if needed), # before checking if docker is installed to avoid permission issues. - cleaned_output = self._run( - 'id -nG $USER | grep -qw docker || ' - 'sudo usermod -aG docker $USER > /dev/null 2>&1;' - f'command -v {self.docker_cmd} || echo {no_exist!r}') - if no_exist in cleaned_output or 'docker' not in cleaned_output: - logger.error( - f'{self.docker_cmd.capitalize()} not installed. Please use an ' - f'image with {self.docker_cmd.capitalize()} installed.') + docker_cmd = ('id -nG $USER | grep -qw docker || ' + 'sudo usermod -aG docker $USER > /dev/null 2>&1;' + f'command -v {self.docker_cmd} || echo {no_exist!r}') + cleaned_output = self._run(docker_cmd) + timeout = 60 * 10 # 10 minute timeout + start = time.time() + while no_exist in cleaned_output or 'docker' not in cleaned_output: + if time.time() - start > timeout: + logger.error( + f'{self.docker_cmd.capitalize()} not installed. Please use ' + f'an image with {self.docker_cmd.capitalize()} installed.') + return + time.sleep(5) + cleaned_output = self._run(docker_cmd) def _check_container_status(self): if self.initialized: diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index a8292669a7c..a99267eb0b9 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -397,7 +397,7 @@ def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str, operation = compute.networks().getEffectiveFirewalls(project=project_id, network=vpc_name) response = operation.execute() - if len(response) == 0: + if not response: return False effective_rules = response['firewalls'] @@ -515,7 +515,7 @@ def _create_rules(project_id: str, compute, rules, vpc_name): rule_list = _list_firewall_rules(project_id, compute, filter=f'(name={rule_name})') - if len(rule_list) > 0: + if rule_list: _delete_firewall_rule(project_id, compute, rule_name) body = rule.copy() @@ -624,7 +624,7 @@ def get_usable_vpc_and_subnet( vpc_list = _list_vpcnets(project_id, compute, filter=f'name={constants.SKYPILOT_VPC_NAME}') - if len(vpc_list) == 0: + if not vpc_list: body = constants.VPC_TEMPLATE.copy() body['name'] = body['name'].format(VPC_NAME=constants.SKYPILOT_VPC_NAME) body['selfLink'] = body['selfLink'].format( diff --git a/sky/provision/kubernetes/config.py b/sky/provision/kubernetes/config.py index 370430720f0..0fe920be9d6 100644 --- a/sky/provision/kubernetes/config.py +++ b/sky/provision/kubernetes/config.py @@ -232,7 +232,7 @@ def _get_resource(container_resources: Dict[str, Any], resource_name: str, # Look for keys containing the resource_name. For example, # the key 'nvidia.com/gpu' contains the key 'gpu'. matching_keys = [key for key in resources if resource_name in key.lower()] - if len(matching_keys) == 0: + if not matching_keys: return float('inf') if len(matching_keys) > 1: # Should have only one match -- mostly relevant for gpu. @@ -265,7 +265,7 @@ def _configure_autoscaler_service_account( field_selector = f'metadata.name={name}' accounts = (kubernetes.core_api(context).list_namespaced_service_account( namespace, field_selector=field_selector).items) - if len(accounts) > 0: + if accounts: assert len(accounts) == 1 # Nothing to check for equality and patch here, # since the service_account.metadata.name is the only important @@ -308,7 +308,7 @@ def _configure_autoscaler_role(namespace: str, context: Optional[str], field_selector = f'metadata.name={name}' roles = (kubernetes.auth_api(context).list_namespaced_role( namespace, field_selector=field_selector).items) - if len(roles) > 0: + if roles: assert len(roles) == 1 existing_role = roles[0] # Convert to k8s object to compare @@ -374,7 +374,7 @@ def _configure_autoscaler_role_binding( field_selector = f'metadata.name={name}' role_bindings = (kubernetes.auth_api(context).list_namespaced_role_binding( rb_namespace, field_selector=field_selector).items) - if len(role_bindings) > 0: + if role_bindings: assert len(role_bindings) == 1 existing_binding = role_bindings[0] new_rb = kubernetes_utils.dict_to_k8s_object(binding, 'V1RoleBinding') @@ -415,7 +415,7 @@ def _configure_autoscaler_cluster_role(namespace, context, field_selector = f'metadata.name={name}' cluster_roles = (kubernetes.auth_api(context).list_cluster_role( field_selector=field_selector).items) - if len(cluster_roles) > 0: + if cluster_roles: assert len(cluster_roles) == 1 existing_cr = cluster_roles[0] new_cr = kubernetes_utils.dict_to_k8s_object(role, 'V1ClusterRole') @@ -460,7 +460,7 @@ def _configure_autoscaler_cluster_role_binding( field_selector = f'metadata.name={name}' cr_bindings = (kubernetes.auth_api(context).list_cluster_role_binding( field_selector=field_selector).items) - if len(cr_bindings) > 0: + if cr_bindings: assert len(cr_bindings) == 1 existing_binding = cr_bindings[0] new_binding = kubernetes_utils.dict_to_k8s_object( @@ -639,7 +639,7 @@ def _configure_services(namespace: str, context: Optional[str], field_selector = f'metadata.name={name}' services = (kubernetes.core_api(context).list_namespaced_service( namespace, field_selector=field_selector).items) - if len(services) > 0: + if services: assert len(services) == 1 existing_service = services[0] # Convert to k8s object to compare diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index c431b023ab9..a849dfc3044 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -804,7 +804,8 @@ def _create_pod_thread(i: int): # Create pods in parallel pods = subprocess_utils.run_in_parallel(_create_pod_thread, - range(to_start_count), _NUM_THREADS) + list(range(to_start_count)), + _NUM_THREADS) # Process created pods for pod in pods: diff --git a/sky/provision/kubernetes/network_utils.py b/sky/provision/kubernetes/network_utils.py index b16482e5072..29fcf181edd 100644 --- a/sky/provision/kubernetes/network_utils.py +++ b/sky/provision/kubernetes/network_utils.py @@ -230,7 +230,7 @@ def get_ingress_external_ip_and_ports( namespace, _request_timeout=kubernetes.API_TIMEOUT).items if item.metadata.name == 'ingress-nginx-controller' ] - if len(ingress_services) == 0: + if not ingress_services: return (None, None) ingress_service = ingress_services[0] diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 4c23a41161a..14b6b42aa58 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -340,14 +340,15 @@ def get_accelerator_from_label_value(cls, value: str) -> str: """ canonical_gpu_names = [ 'A100-80GB', 'A100', 'A10G', 'H100', 'K80', 'M60', 'T4g', 'T4', - 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L4' + 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L40', 'L4' ] for canonical_name in canonical_gpu_names: # A100-80G accelerator is A100-SXM-80GB or A100-PCIE-80GB if canonical_name == 'A100-80GB' and re.search( r'A100.*-80GB', value): return canonical_name - elif canonical_name in value: + # Use word boundary matching to prevent substring matches + elif re.search(rf'\b{re.escape(canonical_name)}\b', value): return canonical_name # If we didn't find a canonical name: @@ -583,7 +584,7 @@ def check_tpu_fits(candidate_instance_type: 'KubernetesInstanceType', node for node in nodes if gpu_label_key in node.metadata.labels and node.metadata.labels[gpu_label_key] == gpu_label_val ] - assert len(gpu_nodes) > 0, 'GPU nodes not found' + assert gpu_nodes, 'GPU nodes not found' if is_tpu_on_gke(acc_type): # If requested accelerator is a TPU type, check if the cluster # has sufficient TPU resource to meet the requirement. @@ -892,6 +893,52 @@ def check_credentials(context: Optional[str], return True, None +def check_pod_config(pod_config: dict) \ + -> Tuple[bool, Optional[str]]: + """Check if the pod_config is a valid pod config + + Using deserialize api to check the pod_config is valid or not. + + Returns: + bool: True if pod_config is valid. + str: Error message about why the pod_config is invalid, None otherwise. + """ + errors = [] + # This api_client won't be used to send any requests, so there is no need to + # load kubeconfig + api_client = kubernetes.kubernetes.client.ApiClient() + + # Used for kubernetes api_client deserialize function, the function will use + # data attr, the detail ref: + # https://github.com/kubernetes-client/python/blob/master/kubernetes/client/api_client.py#L244 + class InnerResponse(): + + def __init__(self, data: dict): + self.data = json.dumps(data) + + try: + # Validate metadata if present + if 'metadata' in pod_config: + try: + value = InnerResponse(pod_config['metadata']) + api_client.deserialize( + value, kubernetes.kubernetes.client.V1ObjectMeta) + except ValueError as e: + errors.append(f'Invalid metadata: {str(e)}') + # Validate spec if present + if 'spec' in pod_config: + try: + value = InnerResponse(pod_config['spec']) + api_client.deserialize(value, + kubernetes.kubernetes.client.V1PodSpec) + except ValueError as e: + errors.append(f'Invalid spec: {str(e)}') + return len(errors) == 0, '.'.join(errors) + except Exception as e: # pylint: disable=broad-except + errors.append(f'Validation error: {str(e)}') + return False, '.'.join(errors) + + def is_kubeconfig_exec_auth( context: Optional[str] = None) -> Tuple[bool, Optional[str]]: """Checks if the kubeconfig file uses exec-based authentication @@ -1526,7 +1573,7 @@ def clean_zombie_ssh_jump_pod(namespace: str, context: Optional[str], def find(l, predicate): """Utility function to find element in given list""" results = [x for x in l if predicate(x)] - return results[0] if len(results) > 0 else None + return results[0] if results else None # Get the SSH jump pod name from the head pod try: diff --git a/sky/provision/lambda_cloud/lambda_utils.py b/sky/provision/lambda_cloud/lambda_utils.py index 4d8e6246b6d..cfd8e02ad23 100644 --- a/sky/provision/lambda_cloud/lambda_utils.py +++ b/sky/provision/lambda_cloud/lambda_utils.py @@ -50,7 +50,7 @@ def set(self, instance_id: str, value: Optional[Dict[str, Any]]) -> None: if value is None: if instance_id in metadata: metadata.pop(instance_id) # del entry - if len(metadata) == 0: + if not metadata: if os.path.exists(self.path): os.remove(self.path) return @@ -69,7 +69,7 @@ def refresh(self, instance_ids: List[str]) -> None: for instance_id in list(metadata.keys()): if instance_id not in instance_ids: del metadata[instance_id] - if len(metadata) == 0: + if not metadata: os.remove(self.path) return with open(self.path, 'w', encoding='utf-8') as f: @@ -150,7 +150,7 @@ def create_instances( ['regions_with_capacity_available']) available_regions = [reg['name'] for reg in available_regions] if region not in available_regions: - if len(available_regions) > 0: + if available_regions: aval_reg = ' '.join(available_regions) else: aval_reg = 'None' diff --git a/sky/provision/oci/query_utils.py b/sky/provision/oci/query_utils.py index 47a0438cb21..3f545aca4ba 100644 --- a/sky/provision/oci/query_utils.py +++ b/sky/provision/oci/query_utils.py @@ -7,6 +7,8 @@ find_compartment: allow search subtree when find a compartment. - Hysun He (hysun.he@oracle.com) @ Nov.12, 2024: Add methods to Add/remove security rules: create_nsg_rules & remove_nsg + - Hysun He (hysun.he@oracle.com) @ Jan.01, 2025: Support reuse existing + VCN for SkyServe. """ from datetime import datetime import functools @@ -17,7 +19,6 @@ import typing from typing import List, Optional, Tuple -from sky import exceptions from sky import sky_logging from sky.adaptors import common as adaptors_common from sky.adaptors import oci as oci_adaptor @@ -248,7 +249,7 @@ def find_compartment(cls, region) -> str: limit=1) compartments = list_compartments_response.data - if len(compartments) > 0: + if compartments: skypilot_compartment = compartments[0].id return skypilot_compartment @@ -274,7 +275,7 @@ def find_create_vcn_subnet(cls, region) -> Optional[str]: display_name=oci_utils.oci_config.VCN_NAME, lifecycle_state='AVAILABLE') vcns = list_vcns_response.data - if len(vcns) > 0: + if vcns: # Found the VCN. skypilot_vcn = vcns[0].id list_subnets_response = net_client.list_subnets( @@ -359,7 +360,7 @@ def create_vcn_subnet(cls, net_client, if str(s.cidr_block).startswith('all-') and str(s.cidr_block). endswith('-services-in-oracle-services-network') ] - if len(services) > 0: + if services: # Create service gateway for regional services. create_sg_response = net_client.create_service_gateway( create_service_gateway_details=oci_adaptor.oci.core.models. @@ -496,23 +497,25 @@ def find_nsg(cls, region: str, nsg_name: str, compartment = cls.find_compartment(region) - list_vcns_resp = net_client.list_vcns( - compartment_id=compartment, - display_name=oci_utils.oci_config.VCN_NAME, - lifecycle_state='AVAILABLE', - ) + vcn_id = oci_utils.oci_config.get_vcn_ocid(region) + if vcn_id is None: + list_vcns_resp = net_client.list_vcns( + compartment_id=compartment, + display_name=oci_utils.oci_config.VCN_NAME, + lifecycle_state='AVAILABLE', + ) - if not list_vcns_resp: - raise exceptions.ResourcesUnavailableError( - 'The VCN is not available') + # Get the primary vnic. The vnic might be an empty list for the + # corner case when the cluster was exited during provision. + if not list_vcns_resp.data: + return None - # Get the primary vnic. - assert len(list_vcns_resp.data) > 0 - vcn = list_vcns_resp.data[0] + vcn = list_vcns_resp.data[0] + vcn_id = vcn.id list_nsg_resp = net_client.list_network_security_groups( compartment_id=compartment, - vcn_id=vcn.id, + vcn_id=vcn_id, limit=1, display_name=nsg_name, ) @@ -529,7 +532,7 @@ def find_nsg(cls, region: str, nsg_name: str, create_network_security_group_details=oci_adaptor.oci.core.models. CreateNetworkSecurityGroupDetails( compartment_id=compartment, - vcn_id=vcn.id, + vcn_id=vcn_id, display_name=nsg_name, )) get_nsg_resp = net_client.get_network_security_group( diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index cc2ca73e1dc..8f2142df273 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -415,7 +415,6 @@ def _post_provision_setup( f'{json.dumps(dataclasses.asdict(provision_record), indent=2)}\n' 'Cluster info:\n' f'{json.dumps(dataclasses.asdict(cluster_info), indent=2)}') - head_instance = cluster_info.get_head_instance() if head_instance is None: e = RuntimeError(f'Provision failed for cluster {cluster_name!r}. ' diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index 8f992f569d9..9e57887c3f1 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -83,7 +83,8 @@ def run_instances(region: str, cluster_name_on_cloud: str, node_type = 'head' if head_instance_id is None else 'worker' try: instance_id = utils.launch( - name=f'{cluster_name_on_cloud}-{node_type}', + cluster_name=cluster_name_on_cloud, + node_type=node_type, instance_type=config.node_config['InstanceType'], region=region, disk_size=config.node_config['DiskSize'], @@ -92,6 +93,8 @@ def run_instances(region: str, cluster_name_on_cloud: str, public_key=config.node_config['PublicKey'], preemptible=config.node_config['Preemptible'], bid_per_gpu=config.node_config['BidPerGPU'], + docker_login_config=config.provider_config.get( + 'docker_login_config'), ) except Exception as e: # pylint: disable=broad-except logger.warning(f'run_instances error: {e}') @@ -145,6 +148,8 @@ def terminate_instances( """See sky/provision/__init__.py""" del provider_config # unused instances = _filter_instances(cluster_name_on_cloud, None) + template_name, registry_auth_id = utils.get_registry_auth_resources( + cluster_name_on_cloud) for inst_id, inst in instances.items(): logger.debug(f'Terminating instance {inst_id}: {inst}') if worker_only and inst['name'].endswith('-head'): @@ -157,6 +162,10 @@ def terminate_instances( f'Failed to terminate instance {inst_id}: ' f'{common_utils.format_exception(e, use_bracket=False)}' ) from e + if template_name is not None: + utils.delete_pod_template(template_name) + if registry_auth_id is not None: + utils.delete_register_auth(registry_auth_id) def get_cluster_info( diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index d0a06b026b3..6600cfd6198 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -2,10 +2,11 @@ import base64 import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from sky import sky_logging from sky.adaptors import runpod +from sky.provision import docker_utils import sky.provision.runpod.api.commands as runpod_commands from sky.skylet import constants from sky.utils import common_utils @@ -47,6 +48,11 @@ } +def _construct_docker_login_template_name(cluster_name: str) -> str: + """Constructs the registry auth template name.""" + return f'{cluster_name}-docker-login-template' + + def retry(func): """Decorator to retry a function.""" @@ -66,9 +72,83 @@ def wrapper(*args, **kwargs): return wrapper +# Adapted from runpod.api.queries.pods.py::QUERY_POD. +# Adding containerRegistryAuthId to the query. +_QUERY_POD = """ +query myPods { + myself { + pods { + id + containerDiskInGb + containerRegistryAuthId + costPerHr + desiredStatus + dockerArgs + dockerId + env + gpuCount + imageName + lastStatusChange + machineId + memoryInGb + name + podType + port + ports + uptimeSeconds + vcpuCount + volumeInGb + volumeMountPath + runtime { + ports{ + ip + isIpPublic + privatePort + publicPort + type + } + } + machine { + gpuDisplayName + } + } + } +} +""" + + +def _sky_get_pods() -> dict: + """List all pods with extra registry auth information. + + Adapted from runpod.get_pods() to include containerRegistryAuthId. + """ + raw_return = runpod.runpod.api.graphql.run_graphql_query(_QUERY_POD) + cleaned_return = raw_return['data']['myself']['pods'] + return cleaned_return + + +_QUERY_POD_TEMPLATE_WITH_REGISTRY_AUTH = """ +query myself { + myself { + podTemplates { + name + containerRegistryAuthId + } + } +} +""" + + +def _list_pod_templates_with_container_registry() -> dict: + """List all pod templates.""" + raw_return = runpod.runpod.api.graphql.run_graphql_query( + _QUERY_POD_TEMPLATE_WITH_REGISTRY_AUTH) + return raw_return['data']['myself']['podTemplates'] + + def list_instances() -> Dict[str, Dict[str, Any]]: """Lists instances associated with API key.""" - instances = runpod.runpod.get_pods() + instances = _sky_get_pods() instance_dict: Dict[str, Dict[str, Any]] = {} for instance in instances: @@ -100,14 +180,75 @@ def list_instances() -> Dict[str, Dict[str, Any]]: return instance_dict -def launch(name: str, instance_type: str, region: str, disk_size: int, - image_name: str, ports: Optional[List[int]], public_key: str, - preemptible: Optional[bool], bid_per_gpu: float) -> str: +def delete_pod_template(template_name: str) -> None: + """Deletes a pod template.""" + try: + runpod.runpod.api.graphql.run_graphql_query( + f'mutation {{deleteTemplate(templateName: "{template_name}")}}') + except runpod.runpod.error.QueryError as e: + logger.warning(f'Failed to delete template {template_name}: {e}' + 'Please delete it manually.') + + +def delete_register_auth(registry_auth_id: str) -> None: + """Deletes a registry auth.""" + try: + runpod.runpod.delete_container_registry_auth(registry_auth_id) + except runpod.runpod.error.QueryError as e: + logger.warning(f'Failed to delete registry auth {registry_auth_id}: {e}' + 'Please delete it manually.') + + +def _create_template_for_docker_login( + cluster_name: str, + image_name: str, + docker_login_config: Optional[Dict[str, str]], +) -> Tuple[str, Optional[str]]: + """Creates a template for the given image with the docker login config. + + Returns: + formatted_image_name: The formatted image name. + template_id: The template ID. None for no docker login config. + """ + if docker_login_config is None: + return image_name, None + login_config = docker_utils.DockerLoginConfig(**docker_login_config) + container_registry_auth_name = f'{cluster_name}-registry-auth' + container_template_name = _construct_docker_login_template_name( + cluster_name) + # The `name` argument is only for display purpose and the registry server + # will be splitted from the docker image name (Tested with AWS ECR). + # Here we only need the username and password to create the registry auth. + # TODO(tian): Now we create a template and a registry auth for each cluster. + # Consider create one for each server and reuse them. Challenges including + # calculate the reference count and delete them when no longer needed. + create_auth_resp = runpod.runpod.create_container_registry_auth( + name=container_registry_auth_name, + username=login_config.username, + password=login_config.password, + ) + registry_auth_id = create_auth_resp['id'] + create_template_resp = runpod.runpod.create_template( + name=container_template_name, + image_name=None, + registry_auth_id=registry_auth_id, + ) + return login_config.format_image(image_name), create_template_resp['id'] + + +def launch(cluster_name: str, node_type: str, instance_type: str, region: str, + disk_size: int, image_name: str, ports: Optional[List[int]], + public_key: str, preemptible: Optional[bool], bid_per_gpu: float, + docker_login_config: Optional[Dict[str, str]]) -> str: """Launches an instance with the given parameters. Converts the instance_type to the RunPod GPU name, finds the specs for the GPU, and launches the instance. + + Returns: + instance_id: The instance ID. """ + name = f'{cluster_name}-{node_type}' gpu_type = GPU_NAME_MAP[instance_type.split('_')[1]] gpu_quantity = int(instance_type.split('_')[0].replace('x', '')) cloud_type = instance_type.split('_')[2] @@ -139,21 +280,24 @@ def launch(name: str, instance_type: str, region: str, disk_size: int, # Use base64 to deal with the tricky quoting issues caused by runpod API. encoded = base64.b64encode(setup_cmd.encode('utf-8')).decode('utf-8') + docker_args = (f'bash -c \'echo {encoded} | base64 --decode > init.sh; ' + f'bash init.sh\'') + # Port 8081 is occupied for nginx in the base image. custom_ports_str = '' if ports is not None: custom_ports_str = ''.join([f'{p}/tcp,' for p in ports]) + ports_str = (f'22/tcp,' + f'{custom_ports_str}' + f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,' + f'{constants.SKY_REMOTE_RAY_PORT}/http') - docker_args = (f'bash -c \'echo {encoded} | base64 --decode > init.sh; ' - f'bash init.sh\'') - ports = (f'22/tcp,' - f'{custom_ports_str}' - f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,' - f'{constants.SKY_REMOTE_RAY_PORT}/http') + image_name_formatted, template_id = _create_template_for_docker_login( + cluster_name, image_name, docker_login_config) params = { 'name': name, - 'image_name': image_name, + 'image_name': image_name_formatted, 'gpu_type_id': gpu_type, 'cloud_type': cloud_type, 'container_disk_in_gb': disk_size, @@ -161,9 +305,10 @@ def launch(name: str, instance_type: str, region: str, disk_size: int, 'min_memory_in_gb': gpu_specs['memoryInGb'] * gpu_quantity, 'gpu_count': gpu_quantity, 'country_code': region, - 'ports': ports, + 'ports': ports_str, 'support_public_ip': True, 'docker_args': docker_args, + 'template_id': template_id, } if preemptible is None or not preemptible: @@ -177,6 +322,18 @@ def launch(name: str, instance_type: str, region: str, disk_size: int, return new_instance['id'] +def get_registry_auth_resources( + cluster_name: str) -> Tuple[Optional[str], Optional[str]]: + """Gets the registry auth resources.""" + container_registry_auth_name = _construct_docker_login_template_name( + cluster_name) + for template in _list_pod_templates_with_container_registry(): + if template['name'] == container_registry_auth_name: + return container_registry_auth_name, template[ + 'containerRegistryAuthId'] + return None, None + + def remove(instance_id: str) -> None: """Terminates the given instance.""" runpod.runpod.terminate_pod(instance_id) diff --git a/sky/provision/vsphere/common/vim_utils.py b/sky/provision/vsphere/common/vim_utils.py index 33c02db8feb..bde1bc25cf0 100644 --- a/sky/provision/vsphere/common/vim_utils.py +++ b/sky/provision/vsphere/common/vim_utils.py @@ -56,7 +56,7 @@ def get_hosts_by_cluster_names(content, vcenter_name, cluster_name_dicts=None): 'name': cluster.name } for cluster in cluster_view.view] cluster_view.Destroy() - if len(cluster_name_dicts) == 0: + if not cluster_name_dicts: logger.warning(f'vCenter \'{vcenter_name}\' has no clusters') # Retrieve all cluster names from the cluster_name_dicts diff --git a/sky/provision/vsphere/instance.py b/sky/provision/vsphere/instance.py index 787d8c97f62..2075cdb9c36 100644 --- a/sky/provision/vsphere/instance.py +++ b/sky/provision/vsphere/instance.py @@ -162,7 +162,7 @@ def _create_instances( if not gpu_instance: # Find an image for CPU images_df = images_df[images_df['GpuTags'] == '\'[]\''] - if len(images_df) == 0: + if not images_df: logger.error( f'Can not find an image for instance type: {instance_type}.') raise Exception( @@ -185,7 +185,7 @@ def _create_instances( image_instance_mapping_df = image_instance_mapping_df[ image_instance_mapping_df['InstanceType'] == instance_type] - if len(image_instance_mapping_df) == 0: + if not image_instance_mapping_df: raise Exception(f"""There is no image can match instance type named {instance_type} If you are using CPU-only instance, assign an image with tag @@ -218,10 +218,9 @@ def _create_instances( hosts_df = hosts_df[(hosts_df['AvailableCPUs'] / hosts_df['cpuMhz']) >= cpus_needed] hosts_df = hosts_df[hosts_df['AvailableMemory(MB)'] >= memory_needed] - assert len(hosts_df) > 0, ( - f'There is no host available to create the instance ' - f'{vms_item["InstanceType"]}, at least {cpus_needed} ' - f'cpus and {memory_needed}MB memory are required.') + assert hosts_df, (f'There is no host available to create the instance ' + f'{vms_item["InstanceType"]}, at least {cpus_needed} ' + f'cpus and {memory_needed}MB memory are required.') # Sort the hosts df by AvailableCPUs to get the compatible host with the # least resource @@ -365,7 +364,7 @@ def _choose_vsphere_cluster_name(config: common.ProvisionConfig, region: str, skypilot framework-optimized availability_zones""" vsphere_cluster_name = None vsphere_cluster_name_str = config.provider_config['availability_zone'] - if len(vc_object.clusters) > 0: + if vc_object.clusters: for optimized_cluster_name in vsphere_cluster_name_str.split(','): if optimized_cluster_name in [ item['name'] for item in vc_object.clusters diff --git a/sky/provision/vsphere/vsphere_utils.py b/sky/provision/vsphere/vsphere_utils.py index faec5d54930..51f284b0fc6 100644 --- a/sky/provision/vsphere/vsphere_utils.py +++ b/sky/provision/vsphere/vsphere_utils.py @@ -257,7 +257,7 @@ def get_skypilot_profile_id(self): # hard code here. should support configure later. profile_name = 'skypilot_policy' storage_profile_id = None - if len(profile_ids) > 0: + if profile_ids: profiles = pm.PbmRetrieveContent(profileIds=profile_ids) for profile in profiles: if profile_name in profile.name: diff --git a/sky/resources.py b/sky/resources.py index 5184278e02e..68d1b6f9ea8 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -661,7 +661,7 @@ def _validate_and_set_region_zone(self, region: Optional[str], continue valid_clouds.append(cloud) - if len(valid_clouds) == 0: + if not valid_clouds: if len(enabled_clouds) == 1: cloud_str = f'for cloud {enabled_clouds[0]}' else: @@ -773,7 +773,7 @@ def _try_validate_instance_type(self) -> None: for cloud in enabled_clouds: if cloud.instance_type_exists(self._instance_type): valid_clouds.append(cloud) - if len(valid_clouds) == 0: + if not valid_clouds: if len(enabled_clouds) == 1: cloud_str = f'for cloud {enabled_clouds[0]}' else: @@ -1008,7 +1008,7 @@ def _try_validate_labels(self) -> None: f'Label rejected due to {cloud}: {err_msg}' ]) break - if len(invalid_table.rows) > 0: + if invalid_table.rows: with ux_utils.print_exception_no_traceback(): raise ValueError( 'The following labels are invalid:' @@ -1283,7 +1283,7 @@ def copy(self, **override) -> 'Resources': _cluster_config_overrides=override.pop( '_cluster_config_overrides', self._cluster_config_overrides), ) - assert len(override) == 0 + assert not override return resources def valid_on_region_zones(self, region: str, zones: List[str]) -> bool: diff --git a/sky/serve/autoscalers.py b/sky/serve/autoscalers.py index bcfac54f4c2..570e47669db 100644 --- a/sky/serve/autoscalers.py +++ b/sky/serve/autoscalers.py @@ -220,8 +220,8 @@ def _select_outdated_replicas_to_scale_down( """Select outdated replicas to scale down.""" if self.update_mode == serve_utils.UpdateMode.ROLLING: - latest_ready_replicas = [] - old_nonterminal_replicas = [] + latest_ready_replicas: List['replica_managers.ReplicaInfo'] = [] + old_nonterminal_replicas: List['replica_managers.ReplicaInfo'] = [] for info in replica_infos: if info.version == self.latest_version: if info.is_ready: diff --git a/sky/serve/core.py b/sky/serve/core.py index 561314bcbe0..f71c60b2fef 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -360,7 +360,7 @@ def update( raise RuntimeError(e.error_msg) from e service_statuses = serve_utils.load_service_status(serve_status_payload) - if len(service_statuses) == 0: + if not service_statuses: with ux_utils.print_exception_no_traceback(): raise RuntimeError(f'Cannot find service {service_name!r}.' f'To spin up a service, use {ux_utils.BOLD}' @@ -491,9 +491,9 @@ def down( stopped_message='All services should have terminated.') service_names_str = ','.join(service_names) - if sum([len(service_names) > 0, all]) != 1: - argument_str = f'service_names={service_names_str}' if len( - service_names) > 0 else '' + if sum([bool(service_names), all]) != 1: + argument_str = (f'service_names={service_names_str}' + if service_names else '') argument_str += ' all' if all else '' raise ValueError('Can only specify one of service_names or all. ' f'Provided {argument_str!r}.') diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index c0e5220e779..5f92dda0e2f 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -172,7 +172,7 @@ def _get_resources_ports(task_yaml: str) -> str: """Get the resources ports used by the task.""" task = sky.Task.from_yaml(task_yaml) # Already checked all ports are the same in sky.serve.core.up - assert len(task.resources) >= 1, task + assert task.resources, task task_resources: 'resources.Resources' = list(task.resources)[0] # Already checked the resources have and only have one port # before upload the task yaml. diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index 983e17d00ae..f3e8fbf1e53 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -226,7 +226,7 @@ def from_replica_statuses( for status in ReplicaStatus.failed_statuses()) > 0: return cls.FAILED # When min_replicas = 0, there is no (provisioning) replica. - if len(replica_statuses) == 0: + if not replica_statuses: return cls.NO_REPLICA return cls.REPLICA_INIT diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 7e665929d66..35d2c25ff40 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -110,7 +110,7 @@ class UpdateMode(enum.Enum): class ThreadSafeDict(Generic[KeyType, ValueType]): """A thread-safe dict.""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self._dict: Dict[KeyType, ValueType] = dict(*args, **kwargs) self._lock = threading.Lock() @@ -383,7 +383,7 @@ def _get_service_status( def get_service_status_encoded(service_names: Optional[List[str]]) -> str: - service_statuses = [] + service_statuses: List[Dict[str, str]] = [] if service_names is None: # Get all service names service_names = serve_state.get_glob_service_names(None) @@ -400,7 +400,7 @@ def get_service_status_encoded(service_names: Optional[List[str]]) -> str: def load_service_status(payload: str) -> List[Dict[str, Any]]: service_statuses_encoded = common_utils.decode_payload(payload) - service_statuses = [] + service_statuses: List[Dict[str, Any]] = [] for service_status in service_statuses_encoded: service_statuses.append({ k: pickle.loads(base64.b64decode(v)) @@ -432,7 +432,7 @@ def _terminate_failed_services( A message indicating potential resource leak (if any). If no resource leak is detected, return None. """ - remaining_replica_clusters = [] + remaining_replica_clusters: List[str] = [] # The controller should have already attempted to terminate those # replicas, so we don't need to try again here. for replica_info in serve_state.get_replica_infos(service_name): @@ -459,8 +459,8 @@ def _terminate_failed_services( def terminate_services(service_names: Optional[List[str]], purge: bool) -> str: service_names = serve_state.get_glob_service_names(service_names) - terminated_service_names = [] - messages = [] + terminated_service_names: List[str] = [] + messages: List[str] = [] for service_name in service_names: service_status = _get_service_status(service_name, with_replica_info=False) @@ -506,7 +506,7 @@ def terminate_services(service_names: Optional[List[str]], purge: bool) -> str: f.write(UserSignal.TERMINATE.value) f.flush() terminated_service_names.append(f'{service_name!r}') - if len(terminated_service_names) == 0: + if not terminated_service_names: messages.append('No service to terminate.') else: identity_str = f'Service {terminated_service_names[0]} is' @@ -784,9 +784,9 @@ def get_endpoint(service_record: Dict[str, Any]) -> str: # Don't use backend_utils.is_controller_accessible since it is too slow. handle = global_user_state.get_handle_from_cluster_name( SKY_SERVE_CONTROLLER_NAME) - assert isinstance(handle, backends.CloudVmRayResourceHandle) if handle is None: return '-' + assert isinstance(handle, backends.CloudVmRayResourceHandle) load_balancer_port = service_record['load_balancer_port'] if load_balancer_port is None: return '-' @@ -816,7 +816,7 @@ def format_service_table(service_records: List[Dict[str, Any]], ]) service_table = log_utils.create_table(service_columns) - replica_infos = [] + replica_infos: List[Dict[str, Any]] = [] for record in service_records: for replica in record['replica_info']: replica['service_name'] = record['name'] @@ -888,7 +888,8 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], region = '-' zone = '-' - replica_handle: 'backends.CloudVmRayResourceHandle' = record['handle'] + replica_handle: Optional['backends.CloudVmRayResourceHandle'] = record[ + 'handle'] if replica_handle is not None: resources_str = resources_utils.get_readable_resources_repr( replica_handle, simplify=not show_all) diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index fbbca5bc0dd..41de54cf806 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -2,7 +2,7 @@ import json import os import textwrap -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml @@ -186,9 +186,12 @@ def from_yaml(yaml_path: str) -> 'SkyServiceSpec': return SkyServiceSpec.from_yaml_config(config['service']) def to_yaml_config(self) -> Dict[str, Any]: - config = dict() + config: Dict[str, Any] = {} - def add_if_not_none(section, key, value, no_empty: bool = False): + def add_if_not_none(section: str, + key: Optional[str], + value: Any, + no_empty: bool = False): if no_empty and not value: return if value is not None: @@ -231,8 +234,8 @@ def probe_str(self): ' with custom headers') return f'{method}{headers}' - def spot_policy_str(self): - policy_strs = [] + def spot_policy_str(self) -> str: + policy_strs: List[str] = [] if (self.dynamic_ondemand_fallback is not None and self.dynamic_ondemand_fallback): policy_strs.append('Dynamic on-demand fallback') diff --git a/sky/setup_files/dependencies.py b/sky/setup_files/dependencies.py index 18d2f5cdc08..13b99770e5b 100644 --- a/sky/setup_files/dependencies.py +++ b/sky/setup_files/dependencies.py @@ -123,10 +123,13 @@ 'oci': ['oci'] + local_ray, 'kubernetes': ['kubernetes>=20.0.0'], 'remote': remote, - 'runpod': ['runpod>=1.5.1'], + # For the container registry auth api. Reference: + # https://github.com/runpod/runpod-python/releases/tag/1.6.1 + 'runpod': ['runpod>=1.6.1'], 'fluidstack': [], # No dependencies needed for fluidstack 'cudo': ['cudo-compute>=0.1.10'], 'paperspace': [], # No dependencies needed for paperspace + 'do': ['pydo>=0.3.0', 'azure-core>=1.24.0', 'azure-common'], 'vsphere': [ 'pyvmomi==8.0.1.0.2', # vsphere-automation-sdk is also required, but it does not have diff --git a/sky/sky_logging.py b/sky/sky_logging.py index effeab310d8..8c6ac6911d9 100644 --- a/sky/sky_logging.py +++ b/sky/sky_logging.py @@ -1,12 +1,15 @@ """Logging utilities.""" import builtins import contextlib +from datetime import datetime import logging +import os import sys import threading import colorama +from sky.skylet import constants from sky.utils import env_options from sky.utils import rich_utils @@ -113,7 +116,7 @@ def reload_logger(): _setup_logger() -def init_logger(name: str): +def init_logger(name: str) -> logging.Logger: return logging.getLogger(name) @@ -161,3 +164,16 @@ def is_silent(): # threads. _logging_config.is_silent = False return _logging_config.is_silent + + +def get_run_timestamp() -> str: + return 'sky-' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f') + + +def generate_tmp_logging_file_path(file_name: str) -> str: + """Generate an absolute path of a tmp file for logging.""" + run_timestamp = get_run_timestamp() + log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp) + log_path = os.path.expanduser(os.path.join(log_dir, file_name)) + + return log_path diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 0b2a5b08e1b..96651eddc39 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -268,12 +268,16 @@ # Used for translate local file mounts to cloud storage. Please refer to # sky/execution.py::_maybe_translate_local_file_mounts_and_sync_up for # more details. -WORKDIR_BUCKET_NAME = 'skypilot-workdir-{username}-{id}' -FILE_MOUNTS_BUCKET_NAME = 'skypilot-filemounts-folder-{username}-{id}' -FILE_MOUNTS_FILE_ONLY_BUCKET_NAME = 'skypilot-filemounts-files-{username}-{id}' +FILE_MOUNTS_BUCKET_NAME = 'skypilot-filemounts-{username}-{id}' FILE_MOUNTS_LOCAL_TMP_DIR = 'skypilot-filemounts-files-{id}' FILE_MOUNTS_REMOTE_TMP_DIR = '/tmp/sky-{}-filemounts-files' +# Used when an managed jobs are created and +# files are synced up to the cloud. +FILE_MOUNTS_WORKDIR_SUBPATH = 'job-{run_id}/workdir' +FILE_MOUNTS_SUBPATH = 'job-{run_id}/local-file-mounts/{i}' +FILE_MOUNTS_TMP_SUBPATH = 'job-{run_id}/tmp-files' + # The default idle timeout for SkyPilot controllers. This include spot # controller and sky serve controller. # TODO(tian): Refactor to controller_utils. Current blocker: circular import. diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index dfd8332b019..65311688fb4 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -586,7 +586,7 @@ def update_job_status(job_ids: List[int], This function should only be run on the remote instance with ray>=2.4.0. """ echo = logger.info if not silent else logger.debug - if len(job_ids) == 0: + if not job_ids: return [] statuses = [] diff --git a/sky/skylet/providers/command_runner.py b/sky/skylet/providers/command_runner.py index 4f66ef54383..16dbc4d2668 100644 --- a/sky/skylet/providers/command_runner.py +++ b/sky/skylet/providers/command_runner.py @@ -25,7 +25,7 @@ def docker_start_cmds( docker_cmd, ): """Generating docker start command without --rm. - + The code is borrowed from `ray.autoscaler._private.docker`. Changes we made: @@ -159,19 +159,17 @@ def run_init(self, *, as_head: bool, file_mounts: Dict[str, str], return True # SkyPilot: Docker login if user specified a private docker registry. - if "docker_login_config" in self.docker_config: + if 'docker_login_config' in self.docker_config: # TODO(tian): Maybe support a command to get the login password? - docker_login_config: docker_utils.DockerLoginConfig = self.docker_config[ - "docker_login_config"] + docker_login_config: docker_utils.DockerLoginConfig = ( + self.docker_config['docker_login_config']) self._run_with_retry( f'{self.docker_cmd} login --username ' f'{docker_login_config.username} --password ' f'{docker_login_config.password} {docker_login_config.server}') # We automatically add the server prefix to the image name if # the user did not add it. - server_prefix = f'{docker_login_config.server}/' - if not specific_image.startswith(server_prefix): - specific_image = f'{server_prefix}{specific_image}' + specific_image = docker_login_config.format_image(specific_image) if self.docker_config.get('pull_before_run', True): assert specific_image, ('Image must be included in config if ' diff --git a/sky/skylet/providers/ibm/node_provider.py b/sky/skylet/providers/ibm/node_provider.py index 5e2a2d64493..44622369e92 100644 --- a/sky/skylet/providers/ibm/node_provider.py +++ b/sky/skylet/providers/ibm/node_provider.py @@ -377,7 +377,7 @@ def non_terminated_nodes(self, tag_filters) -> List[str]: node["id"], nic_id ).get_result() floating_ips = res["floating_ips"] - if len(floating_ips) == 0: + if not floating_ips: # not adding a node that's yet/failed to # to get a floating ip provisioned continue @@ -485,7 +485,7 @@ def _get_instance_data(self, name): """Returns instance (node) information matching the specified name""" instances_data = self.ibm_vpc_client.list_instances(name=name).get_result() - if len(instances_data["instances"]) > 0: + if instances_data["instances"]: return instances_data["instances"][0] return None diff --git a/sky/skylet/providers/scp/config.py b/sky/skylet/providers/scp/config.py index c20b1837f26..d19744e7322 100644 --- a/sky/skylet/providers/scp/config.py +++ b/sky/skylet/providers/scp/config.py @@ -107,7 +107,7 @@ def get_vcp_subnets(self): for item in subnet_contents if item['subnetState'] == 'ACTIVE' and item["vpcId"] == vpc ] - if len(subnet_list) > 0: + if subnet_list: vpc_subnets[vpc] = subnet_list return vpc_subnets diff --git a/sky/skylet/providers/scp/node_provider.py b/sky/skylet/providers/scp/node_provider.py index 004eaac3830..f99b477ab06 100644 --- a/sky/skylet/providers/scp/node_provider.py +++ b/sky/skylet/providers/scp/node_provider.py @@ -259,7 +259,7 @@ def _config_security_group(self, zone_id, vpc, cluster_name): for sg in sg_contents if sg["securityGroupId"] == sg_id ] - if len(sg) != 0 and sg[0] == "ACTIVE": + if sg and sg[0] == "ACTIVE": break time.sleep(5) @@ -282,16 +282,16 @@ def _del_security_group(self, sg_id): for sg in sg_contents if sg["securityGroupId"] == sg_id ] - if len(sg) == 0: + if not sg: break def _refresh_security_group(self, vms): - if len(vms) > 0: + if vms: return # remove security group if vm does not exist keys = self.metadata.keys() security_group_id = self.metadata[ - keys[0]]['creation']['securityGroupId'] if len(keys) > 0 else None + keys[0]]['creation']['securityGroupId'] if keys else None if security_group_id: try: self._del_security_group(security_group_id) @@ -308,7 +308,7 @@ def _del_vm(self, vm_id): for vm in vm_contents if vm["virtualServerId"] == vm_id ] - if len(vms) == 0: + if not vms: break def _del_firwall_rules(self, firewall_id, rule_ids): @@ -391,7 +391,7 @@ def _create_instance_sequence(self, vpc, instance_config): return None, None, None, None def _undo_funcs(self, undo_func_list): - while len(undo_func_list) > 0: + while undo_func_list: func = undo_func_list.pop() func() @@ -468,7 +468,7 @@ def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str], zone_config = ZoneConfig(self.scp_client, node_config) vpc_subnets = zone_config.get_vcp_subnets() - if (len(vpc_subnets) == 0): + if not vpc_subnets: raise SCPError("This region/zone does not have available VPCs.") instance_config = zone_config.bootstrap_instance_config(node_config) diff --git a/sky/task.py b/sky/task.py index cebc616dc6d..bbf6d59b2ae 100644 --- a/sky/task.py +++ b/sky/task.py @@ -948,19 +948,31 @@ def _get_preferred_store( store_type = storage_lib.StoreType.from_cloud(storage_cloud_str) return store_type, storage_region - def sync_storage_mounts(self) -> None: + def sync_storage_mounts(self, force_sync: bool = False) -> None: """(INTERNAL) Eagerly syncs storage mounts to cloud storage. After syncing up, COPY-mode storage mounts are translated into regular file_mounts of the form ``{ /remote/path: {s3,gs,..}:// }``. + + Args: + force_sync: If True, forces the synchronization of storage mounts. + If the store object is added via storage.add_store(), + the sync will happen automatically via add_store. + However, if it is passed via the construction function + of storage, it is usually because the user passed an + intermediate bucket name in the config and we need to + construct from the user config. In this case, set + force_sync to True. """ for storage in self.storage_mounts.values(): - if len(storage.stores) == 0: + if not storage.stores: store_type, store_region = self._get_preferred_store() self.storage_plans[storage] = store_type storage.add_store(store_type, store_region) else: + if force_sync: + storage.sync_all_stores() # We will download the first store that is added to remote. self.storage_plans[storage] = list(storage.stores.keys())[0] @@ -977,6 +989,7 @@ def sync_storage_mounts(self) -> None: else: assert storage.name is not None, storage blob_path = 's3://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -987,6 +1000,7 @@ def sync_storage_mounts(self) -> None: else: assert storage.name is not None, storage blob_path = 'gs://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1005,6 +1019,7 @@ def sync_storage_mounts(self) -> None: blob_path = data_utils.AZURE_CONTAINER_URL.format( storage_account_name=storage_account_name, container_name=storage.name) + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1015,6 +1030,7 @@ def sync_storage_mounts(self) -> None: blob_path = storage.source else: blob_path = 'r2://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1030,7 +1046,18 @@ def sync_storage_mounts(self) -> None: cos_region = data_utils.Rclone.get_region_from_rclone( storage.name, data_utils.Rclone.RcloneClouds.IBM) blob_path = f'cos://{cos_region}/{storage.name}' + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({mnt_path: blob_path}) + elif store_type is storage_lib.StoreType.OCI: + if storage.source is not None and not isinstance( + storage.source, + list) and storage.source.startswith('oci://'): + blob_path = storage.source + else: + blob_path = 'oci://' + storage.name + self.update_file_mounts({ + mnt_path: blob_path, + }) else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Storage Type {store_type} ' diff --git a/sky/templates/do-ray.yml.j2 b/sky/templates/do-ray.yml.j2 new file mode 100644 index 00000000000..ea9db59398e --- /dev/null +++ b/sky/templates/do-ray.yml.j2 @@ -0,0 +1,98 @@ +cluster_name: {{cluster_name_on_cloud}} + +# The maximum number of workers nodes to launch in addition to the head node. +max_workers: {{num_nodes - 1}} +upscaling_speed: {{num_nodes - 1}} +idle_timeout_minutes: 60 + +{%- if docker_image is not none %} +docker: + image: {{docker_image}} + container_name: {{docker_container_name}} + run_options: + - --ulimit nofile=1048576:1048576 + {%- for run_option in docker_run_options %} + - {{run_option}} + {%- endfor %} + {%- if docker_login_config is not none %} + docker_login_config: + username: |- + {{docker_login_config.username}} + password: |- + {{docker_login_config.password}} + server: |- + {{docker_login_config.server}} + {%- endif %} +{%- endif %} + +provider: + type: external + module: sky.provision.do + region: "{{region}}" + +auth: + ssh_user: root + ssh_private_key: {{ssh_private_key}} + ssh_public_key: |- + skypilot:ssh_public_key_content + +available_node_types: + ray_head_default: + resources: {} + node_config: + InstanceType: {{instance_type}} + DiskSize: {{disk_size}} + {%- if image_id is not none %} + ImageId: {{image_id}} + {%- endif %} + +head_node_type: ray_head_default + +# Format: `REMOTE_PATH : LOCAL_PATH` +file_mounts: { + "{{sky_ray_yaml_remote_path}}": "{{sky_ray_yaml_local_path}}", + "{{sky_remote_path}}/{{sky_wheel_hash}}": "{{sky_local_path}}", +{%- for remote_path, local_path in credentials.items() %} + "{{remote_path}}": "{{local_path}}", +{%- endfor %} +} + +rsync_exclude: [] + +initialization_commands: [] + +# List of shell commands to run to set up nodes. +# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH +# connection, which is expensive. Try your best to co-locate commands into fewer +# items! +# +# Increment the following for catching performance bugs easier: +# current num items (num SSH connections): 1 +setup_commands: + # Disable `unattended-upgrades` to prevent apt-get from hanging. It should be called at the beginning before the process started to avoid being blocked. (This is a temporary fix.) + # Create ~/.ssh/config file in case the file does not exist in the image. + # Line 'rm ..': there is another installation of pip. + # Line 'sudo bash ..': set the ulimit as suggested by ray docs for performance. https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html#system-configuration + # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. + # Line 'mkdir -p ..': disable host key check + # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + - {%- for initial_setup_command in initial_setup_commands %} + {{ initial_setup_command }} + {%- endfor %} + sudo systemctl stop unattended-upgrades || true; + sudo systemctl disable unattended-upgrades || true; + sudo sed -i 's/Unattended-Upgrade "1"/Unattended-Upgrade "0"/g' /etc/apt/apt.conf.d/20auto-upgrades || true; + sudo kill -9 `sudo lsof /var/lib/dpkg/lock-frontend | awk '{print $2}' | tail -n 1` || true; + sudo pkill -9 apt-get; + sudo pkill -9 dpkg; + sudo dpkg --configure -a; + mkdir -p ~/.ssh; touch ~/.ssh/config; + {{ conda_installation_commands }} + {{ ray_skypilot_installation_commands }} + sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf'; + sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; + mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; + [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); + +# Command to start ray clusters are now placed in `sky.provision.instance_setup`. +# We do not need to list it here anymore. diff --git a/sky/templates/jobs-controller.yaml.j2 b/sky/templates/jobs-controller.yaml.j2 index 45cdb5141d4..71c808fdd0f 100644 --- a/sky/templates/jobs-controller.yaml.j2 +++ b/sky/templates/jobs-controller.yaml.j2 @@ -26,10 +26,40 @@ setup: | echo 'export SKYPILOT_DEV=1' >> ~/.bashrc {% endif %} - # Dashboard. - ps aux | grep -v nohup | grep -v grep | grep -- "-m sky.spot.dashboard" | awk '{print $2}' | xargs kill > /dev/null 2>&1 || true - pip list | grep flask > /dev/null 2>&1 || pip install flask 2>&1 > /dev/null - ((ps aux | grep -v nohup | grep -v grep | grep -q -- "-m sky.jobs.dashboard.dashboard") || (nohup {{ sky_python_cmd }} -m sky.jobs.dashboard.dashboard >> ~/.sky/job-dashboard.log 2>&1 &)); + # Create systemd service file + mkdir -p ~/.config/systemd/user/ + + # Create systemd user service file + cat << EOF > ~/.config/systemd/user/skypilot-dashboard.service + [Unit] + Description=SkyPilot Jobs Dashboard + After=network.target + + [Service] + Environment="PATH={{ sky_python_env_path }}:\$PATH" + Environment="SKYPILOT_USER_ID={{controller_envs.SKYPILOT_USER_ID}}" + Environment="SKYPILOT_USER={{controller_envs.SKYPILOT_USER}}" + Restart=always + StandardOutput=append:/home/$USER/.sky/job-dashboard.log + StandardError=append:/home/$USER/.sky/job-dashboard.log + ExecStart={{ sky_python_cmd }} -m sky.jobs.dashboard.dashboard + + [Install] + WantedBy=default.target + EOF + + if command -v systemctl &>/dev/null && systemctl --user show &>/dev/null; then + systemctl --user daemon-reload + systemctl --user enable --now skypilot-dashboard + else + echo "Systemd user services not found. Setting up SkyPilot dashboard manually." + # Kill any old dashboard processes + ps aux | grep -v nohup | grep -v grep | grep -- '-m sky.jobs.dashboard.dashboard' \ + | awk '{print $2}' | xargs kill > /dev/null 2>&1 || true + # Launch the dashboard in the background if not already running + (ps aux | grep -v nohup | grep -v grep | grep -q -- '-m sky.jobs.dashboard.dashboard') || \ + (nohup {{ sky_python_cmd }} -m sky.jobs.dashboard.dashboard >> ~/.sky/job-dashboard.log 2>&1 &) + fi run: | {{ sky_activate_python_env }} diff --git a/sky/templates/runpod-ray.yml.j2 b/sky/templates/runpod-ray.yml.j2 index 853b9142037..ea57c9ac808 100644 --- a/sky/templates/runpod-ray.yml.j2 +++ b/sky/templates/runpod-ray.yml.j2 @@ -10,6 +10,19 @@ provider: module: sky.provision.runpod region: "{{region}}" disable_launch_config_check: true + # For RunPod, we directly set the image id for the docker as runtime environment + # support, thus we need to avoid the DockerInitializer detects the docker field + # and performs the initialization. Therefore we put the docker login config in + # the provider config here. + {%- if docker_login_config is not none %} + docker_login_config: + username: |- + {{docker_login_config.username}} + password: |- + {{docker_login_config.password}} + server: |- + {{docker_login_config.server}} + {%- endif %} auth: ssh_user: root diff --git a/sky/usage/usage_lib.py b/sky/usage/usage_lib.py index 07867939ee5..3cc630b3a98 100644 --- a/sky/usage/usage_lib.py +++ b/sky/usage/usage_lib.py @@ -3,7 +3,6 @@ import contextlib import datetime import enum -import inspect import json import os import time @@ -12,19 +11,28 @@ from typing import Any, Callable, Dict, List, Optional, Union import click -import requests import sky from sky import sky_logging +from sky.adaptors import common as adaptors_common from sky.usage import constants from sky.utils import common_utils from sky.utils import env_options from sky.utils import ux_utils if typing.TYPE_CHECKING: + import inspect + + import requests + from sky import resources as resources_lib from sky import status_lib from sky import task as task_lib +else: + # requests and inspect cost ~100ms to load, which can be postponed to + # collection phase or skipped if user specifies no collection + requests = adaptors_common.LazyImport('requests') + inspect = adaptors_common.LazyImport('inspect') logger = sky_logging.init_logger(__name__) diff --git a/sky/utils/accelerator_registry.py b/sky/utils/accelerator_registry.py index 78a708efb91..8134fb9e7d1 100644 --- a/sky/utils/accelerator_registry.py +++ b/sky/utils/accelerator_registry.py @@ -3,6 +3,7 @@ from typing import Optional from sky.clouds import service_catalog +from sky.utils import rich_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -88,14 +89,17 @@ def canonicalize_accelerator_name(accelerator: str, if accelerator.lower() in mapping: return mapping[accelerator.lower()] - # _ACCELERATORS may not be comprehensive. - # Users may manually add new accelerators to the catalogs, or download new - # catalogs (that have new accelerators) without upgrading SkyPilot. - # To cover such cases, we should search the accelerator name - # in the service catalog. - searched = service_catalog.list_accelerators(name_filter=accelerator, - case_sensitive=False, - clouds=cloud_str) + # Listing accelerators can be time-consuming since canonicalizing usually + # involves catalog reading with cache not warmed up. + with rich_utils.safe_status('Listing accelerators...'): + # _ACCELERATORS may not be comprehensive. + # Users may manually add new accelerators to the catalogs, or download + # new catalogs (that have new accelerators) without upgrading SkyPilot. + # To cover such cases, we should search the accelerator name + # in the service catalog. + searched = service_catalog.list_accelerators(name_filter=accelerator, + case_sensitive=False, + clouds=cloud_str) names = list(searched.keys()) # Exact match. @@ -106,7 +110,7 @@ def canonicalize_accelerator_name(accelerator: str, return names[0] # Do not print an error message here. Optimizer will handle it. - if len(names) == 0: + if not names: return accelerator # Currently unreachable. diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 3fcdd24e505..ee8f5cf7bec 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -633,7 +633,7 @@ def get_cleaned_username(username: str = '') -> str: return username -def fill_template(template_name: str, variables: Dict, +def fill_template(template_name: str, variables: Dict[str, Any], output_path: str) -> None: """Create a file from a Jinja template and return the filename.""" assert template_name.endswith('.j2'), template_name diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 0166a16ff16..acb636893a5 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -206,6 +206,9 @@ def _get_cloud_dependencies_installation_commands( # installed, so we don't check that. python_packages: Set[str] = set() + # add flask to the controller dependencies for dashboard + python_packages.add('flask') + step_prefix = prefix_str.replace('', str(len(commands) + 1)) commands.append(f'echo -en "\\r{step_prefix}uv{empty_str}" &&' f'{constants.SKY_UV_INSTALL_CMD} >/dev/null 2>&1') @@ -649,10 +652,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', still sync up any storage mounts with local source paths (which do not undergo translation). """ + # ================================================================ # Translate the workdir and local file mounts to cloud file mounts. # ================================================================ + def _sub_path_join(sub_path: Optional[str], path: str) -> str: + if sub_path is None: + return path + return os.path.join(sub_path, path).strip('/') + + def assert_no_bucket_creation(store: storage_lib.AbstractStore) -> None: + if store.is_sky_managed: + # Bucket was created, this should not happen since use configured + # the bucket and we assumed it already exists. + store.delete() + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketCreateError( + f'Jobs bucket {store.name!r} does not exist. ' + 'Please check jobs.bucket configuration in ' + 'your SkyPilot config.') + run_id = common_utils.get_usage_run_id()[:8] original_file_mounts = task.file_mounts if task.file_mounts else {} original_storage_mounts = task.storage_mounts if task.storage_mounts else {} @@ -679,11 +699,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', ux_utils.spinner_message( f'Translating {msg} to SkyPilot Storage...')) + # Get the bucket name for the workdir and file mounts, + # we store all these files in same bucket from config. + bucket_wth_prefix = skypilot_config.get_nested(('jobs', 'bucket'), None) + store_kwargs: Dict[str, Any] = {} + if bucket_wth_prefix is None: + store_type = store_cls = sub_path = None + storage_account_name = region = None + bucket_name = constants.FILE_MOUNTS_BUCKET_NAME.format( + username=common_utils.get_cleaned_username(), id=run_id) + else: + store_type, store_cls, bucket_name, sub_path, storage_account_name, \ + region = storage_lib.StoreType.get_fields_from_store_url( + bucket_wth_prefix) + if storage_account_name is not None: + store_kwargs['storage_account_name'] = storage_account_name + if region is not None: + store_kwargs['region'] = region + # Step 1: Translate the workdir to SkyPilot storage. new_storage_mounts = {} if task.workdir is not None: - bucket_name = constants.WORKDIR_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), id=run_id) workdir = task.workdir task.workdir = None if (constants.SKY_REMOTE_WORKDIR in original_file_mounts or @@ -691,14 +727,28 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', raise ValueError( f'Cannot mount {constants.SKY_REMOTE_WORKDIR} as both the ' 'workdir and file_mounts contains it as the target.') - new_storage_mounts[ - constants. - SKY_REMOTE_WORKDIR] = storage_lib.Storage.from_yaml_config({ - 'name': bucket_name, - 'source': workdir, - 'persistent': False, - 'mode': 'COPY', - }) + bucket_sub_path = _sub_path_join( + sub_path, + constants.FILE_MOUNTS_WORKDIR_SUBPATH.format(run_id=run_id)) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + stores = { + store_type: store_cls(name=bucket_name, + source=workdir, + _bucket_sub_path=bucket_sub_path, + **store_kwargs) + } + assert_no_bucket_creation(stores[store_type]) + + storage_obj = storage_lib.Storage(name=bucket_name, + source=workdir, + persistent=False, + mode=storage_lib.StorageMode.COPY, + stores=stores, + _bucket_sub_path=bucket_sub_path) + new_storage_mounts[constants.SKY_REMOTE_WORKDIR] = storage_obj # Check of the existence of the workdir in file_mounts is done in # the task construction. logger.info(f' {colorama.Style.DIM}Workdir: {workdir!r} ' @@ -716,27 +766,37 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', if os.path.isfile(os.path.abspath(os.path.expanduser(src))): copy_mounts_with_file_in_src[dst] = src continue - bucket_name = constants.FILE_MOUNTS_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), - id=f'{run_id}-{i}', - ) - new_storage_mounts[dst] = storage_lib.Storage.from_yaml_config({ - 'name': bucket_name, - 'source': src, - 'persistent': False, - 'mode': 'COPY', - }) + bucket_sub_path = _sub_path_join( + sub_path, constants.FILE_MOUNTS_SUBPATH.format(i=i, run_id=run_id)) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + store = store_cls(name=bucket_name, + source=src, + _bucket_sub_path=bucket_sub_path, + **store_kwargs) + + stores = {store_type: store} + assert_no_bucket_creation(stores[store_type]) + storage_obj = storage_lib.Storage(name=bucket_name, + source=src, + persistent=False, + mode=storage_lib.StorageMode.COPY, + stores=stores, + _bucket_sub_path=bucket_sub_path) + new_storage_mounts[dst] = storage_obj logger.info(f' {colorama.Style.DIM}Folder : {src!r} ' f'-> storage: {bucket_name!r}.{colorama.Style.RESET_ALL}') # Step 3: Translate local file mounts with file in src to SkyPilot storage. # Hard link the files in src to a temporary directory, and upload folder. + file_mounts_tmp_subpath = _sub_path_join( + sub_path, constants.FILE_MOUNTS_TMP_SUBPATH.format(run_id=run_id)) local_fm_path = os.path.join( tempfile.gettempdir(), constants.FILE_MOUNTS_LOCAL_TMP_DIR.format(id=run_id)) os.makedirs(local_fm_path, exist_ok=True) - file_bucket_name = constants.FILE_MOUNTS_FILE_ONLY_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), id=run_id) file_mount_remote_tmp_dir = constants.FILE_MOUNTS_REMOTE_TMP_DIR.format( path) if copy_mounts_with_file_in_src: @@ -745,14 +805,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', src_to_file_id[src] = i os.link(os.path.abspath(os.path.expanduser(src)), os.path.join(local_fm_path, f'file-{i}')) - - new_storage_mounts[ - file_mount_remote_tmp_dir] = storage_lib.Storage.from_yaml_config({ - 'name': file_bucket_name, - 'source': local_fm_path, - 'persistent': False, - 'mode': 'MOUNT', - }) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + stores = { + store_type: store_cls( + name=bucket_name, + source=local_fm_path, + _bucket_sub_path=file_mounts_tmp_subpath, + **store_kwargs) + } + assert_no_bucket_creation(stores[store_type]) + storage_obj = storage_lib.Storage( + name=bucket_name, + source=local_fm_path, + persistent=False, + mode=storage_lib.StorageMode.MOUNT, + stores=stores, + _bucket_sub_path=file_mounts_tmp_subpath) + + new_storage_mounts[file_mount_remote_tmp_dir] = storage_obj if file_mount_remote_tmp_dir in original_storage_mounts: with ux_utils.print_exception_no_traceback(): raise ValueError( @@ -762,8 +835,9 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', sources = list(src_to_file_id.keys()) sources_str = '\n '.join(sources) logger.info(f' {colorama.Style.DIM}Files (listed below) ' - f' -> storage: {file_bucket_name}:' + f' -> storage: {bucket_name}:' f'\n {sources_str}{colorama.Style.RESET_ALL}') + rich_utils.force_update_status( ux_utils.spinner_message('Uploading translated local files/folders')) task.update_storage_mounts(new_storage_mounts) @@ -779,7 +853,7 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', ux_utils.spinner_message('Uploading local sources to storage[/] ' '[dim]View storages: sky storage ls')) try: - task.sync_storage_mounts() + task.sync_storage_mounts(force_sync=bucket_wth_prefix is not None) except (ValueError, exceptions.NoCloudAccessError) as e: if 'No enabled cloud for storage' in str(e) or isinstance( e, exceptions.NoCloudAccessError): @@ -809,10 +883,11 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', # file_mount_remote_tmp_dir will only exist when there are files in # the src for copy mounts. storage_obj = task.storage_mounts[file_mount_remote_tmp_dir] - store_type = list(storage_obj.stores.keys())[0] - store_object = storage_obj.stores[store_type] + curr_store_type = list(storage_obj.stores.keys())[0] + store_object = storage_obj.stores[curr_store_type] bucket_url = storage_lib.StoreType.get_endpoint_url( - store_object, file_bucket_name) + store_object, bucket_name) + bucket_url += f'/{file_mounts_tmp_subpath}' for dst, src in copy_mounts_with_file_in_src.items(): file_id = src_to_file_id[src] new_file_mounts[dst] = bucket_url + f'/file-{file_id}' @@ -829,8 +904,8 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', store_types = list(storage_obj.stores.keys()) assert len(store_types) == 1, ( 'We only support one store type for now.', storage_obj.stores) - store_type = store_types[0] - store_object = storage_obj.stores[store_type] + curr_store_type = store_types[0] + store_object = storage_obj.stores[curr_store_type] storage_obj.source = storage_lib.StoreType.get_endpoint_url( store_object, storage_obj.name) storage_obj.force_delete = True @@ -847,8 +922,8 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', store_types = list(storage_obj.stores.keys()) assert len(store_types) == 1, ( 'We only support one store type for now.', storage_obj.stores) - store_type = store_types[0] - store_object = storage_obj.stores[store_type] + curr_store_type = store_types[0] + store_object = storage_obj.stores[curr_store_type] source = storage_lib.StoreType.get_endpoint_url( store_object, storage_obj.name) new_storage = storage_lib.Storage.from_yaml_config({ diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index 3229f86abf9..d0eb03d46ea 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -89,7 +89,7 @@ def load_chain_dag_from_yaml( elif len(configs) == 1: dag_name = configs[0].get('name') - if len(configs) == 0: + if not configs: # YAML has only `name: xxx`. Still instantiate a task. configs = [{'name': dag_name}] diff --git a/sky/utils/db_utils.py b/sky/utils/db_utils.py index b74af27340b..09218aea87d 100644 --- a/sky/utils/db_utils.py +++ b/sky/utils/db_utils.py @@ -4,11 +4,27 @@ import threading from typing import Any, Callable, Optional +# This parameter (passed to sqlite3.connect) controls how long we will wait to +# obtains a database lock (not necessarily during connection, but whenever it is +# needed). It is not a connection timeout. +# Even in WAL mode, only a single writer is allowed at a time. Other writers +# will block until the write lock can be obtained. This behavior is described in +# the SQLite documentation for WAL: https://www.sqlite.org/wal.html +# Python's default timeout is 5s. In normal usage, lock contention is very low, +# and this is more than sufficient. However, in some highly concurrent cases, +# such as a jobs controller suddenly recovering thousands of jobs at once, we +# can see a small number of processes that take much longer to obtain the lock. +# In contrived highly contentious cases, around 0.1% of transactions will take +# >30s to take the lock. We have not seen cases that take >60s. For cases up to +# 1000x parallelism, this is thus thought to be a conservative setting. +# For more info, see the PR description for #4552. +_DB_TIMEOUT_S = 60 + @contextlib.contextmanager def safe_cursor(db_path: str): """A newly created, auto-committing, auto-closing cursor.""" - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path, timeout=_DB_TIMEOUT_S) cursor = conn.cursor() try: yield cursor @@ -79,8 +95,6 @@ class SQLiteConn(threading.local): def __init__(self, db_path: str, create_table: Callable): super().__init__() self.db_path = db_path - # NOTE: We use a timeout of 10 seconds to avoid database locked - # errors. This is a hack, but it works. - self.conn = sqlite3.connect(db_path, timeout=10) + self.conn = sqlite3.connect(db_path, timeout=_DB_TIMEOUT_S) self.cursor = self.conn.cursor() create_table(self.cursor, self.conn) diff --git a/sky/utils/kubernetes/gpu_labeler.py b/sky/utils/kubernetes/gpu_labeler.py index 6877c94a2a8..9f5a11cea42 100644 --- a/sky/utils/kubernetes/gpu_labeler.py +++ b/sky/utils/kubernetes/gpu_labeler.py @@ -139,7 +139,7 @@ def label(): # Create the job for this node` batch_v1.create_namespaced_job(namespace, job_manifest) print(f'Created GPU labeler job for node {node_name}') - if len(gpu_nodes) == 0: + if not gpu_nodes: print('No GPU nodes found in the cluster. If you have GPU nodes, ' 'please ensure that they have the label ' f'`{kubernetes_utils.get_gpu_resource_key()}: `') diff --git a/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py b/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py index 380c82f8c88..a764fb6e5e4 100644 --- a/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +++ b/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py @@ -126,7 +126,7 @@ def manage_lifecycle(): f'error: {e}\n') raise - if len(ret.items) == 0: + if not ret.items: sys.stdout.write( f'[Lifecycle] Did not find pods with label ' f'"{label_selector}" in namespace {current_namespace}\n') diff --git a/sky/utils/log_utils.py b/sky/utils/log_utils.py index a5884333609..5a7d8cfd5f7 100644 --- a/sky/utils/log_utils.py +++ b/sky/utils/log_utils.py @@ -5,6 +5,8 @@ from typing import Callable, Iterator, List, Optional, TextIO, Type import colorama +# slow due to https://github.com/python-pendulum/pendulum/issues/808 +# FIXME(aylei): bump pendulum if it get fixed import pendulum import prettytable diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 851e77a57fc..3194dc79da5 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -299,6 +299,12 @@ def get_storage_schema(): mode.value for mode in storage.StorageMode ] }, + '_is_sky_managed': { + 'type': 'boolean', + }, + '_bucket_sub_path': { + 'type': 'string', + }, '_force_delete': { 'type': 'boolean', } @@ -721,6 +727,11 @@ def get_config_schema(): 'resources': resources_schema, } }, + 'bucket': { + 'type': 'string', + 'pattern': '^(https|s3|gs|r2|cos)://.+', + 'required': [], + } } } cloud_configs = { @@ -875,6 +886,9 @@ def get_config_schema(): 'image_tag_gpu': { 'type': 'string', }, + 'vcn_ocid': { + 'type': 'string', + }, 'vcn_subnet': { 'type': 'string', }, diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 992c6bbe3ff..88d351632a3 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -5,7 +5,7 @@ import resource import subprocess import time -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import colorama import psutil @@ -97,7 +97,7 @@ def get_parallel_threads(cloud_str: Optional[str] = None) -> int: def run_in_parallel(func: Callable, - args: Iterable[Any], + args: List[Any], num_threads: Optional[int] = None) -> List[Any]: """Run a function in parallel on a list of arguments. @@ -113,6 +113,11 @@ def run_in_parallel(func: Callable, A list of the return values of the function func, in the same order as the arguments. """ + if len(args) == 0: + return [] + # Short-circuit for single element + if len(args) == 1: + return [func(args[0])] # Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long processes = num_threads if num_threads is not None else get_parallel_threads( ) diff --git a/tests/backward_compatibility_tests.sh b/tests/backward_compatibility_tests.sh index d32e1e9e224..941cd455e64 100644 --- a/tests/backward_compatibility_tests.sh +++ b/tests/backward_compatibility_tests.sh @@ -50,6 +50,14 @@ uv pip install --prerelease=allow "azure-cli>=2.65.0" uv pip install -e ".[all]" +clear_resources() { + sky down ${CLUSTER_NAME}* -y + sky jobs cancel -n ${MANAGED_JOB_JOB_NAME}* -y +} + +# Set trap to call cleanup on script exit +trap clear_resources EXIT + # exec + launch if [ "$start_from" -le 1 ]; then conda activate sky-back-compat-master @@ -193,6 +201,3 @@ echo "$s" echo "$s" | grep "SUCCEEDED" | wc -l | grep 2 || exit 1 echo "$s" | grep "CANCELLING\|CANCELLED" | wc -l | grep 1 || exit 1 fi - -sky down ${CLUSTER_NAME}* -y -sky jobs cancel -n ${MANAGED_JOB_JOB_NAME}* -y diff --git a/tests/conftest.py b/tests/conftest.py index ee5caf062b9..af6367fdac6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ # To only run tests for managed jobs (without generic tests), use # --managed-jobs. all_clouds_in_smoke_tests = [ - 'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci', + 'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci', 'do', 'kubernetes', 'vsphere', 'cudo', 'fluidstack', 'paperspace', 'runpod' ] default_clouds_to_run = ['aws', 'azure'] @@ -43,6 +43,7 @@ 'fluidstack': 'fluidstack', 'cudo': 'cudo', 'paperspace': 'paperspace', + 'do': 'do', 'runpod': 'runpod' } @@ -120,6 +121,12 @@ def _get_cloud_to_run(config) -> List[str]: def pytest_collection_modifyitems(config, items): + if config.option.collectonly: + for item in items: + full_name = item.nodeid + marks = [mark.name for mark in item.iter_markers()] + print(f"Collected {full_name} with marks: {marks}") + skip_marks = {} skip_marks['slow'] = pytest.mark.skip(reason='need --runslow option to run') skip_marks['managed_jobs'] = pytest.mark.skip( diff --git a/tests/skyserve/update/bump_version_after.yaml b/tests/skyserve/update/bump_version_after.yaml index 6e845f54b9e..0f2c6925bc6 100644 --- a/tests/skyserve/update/bump_version_after.yaml +++ b/tests/skyserve/update/bump_version_after.yaml @@ -16,7 +16,7 @@ service: replicas: 3 resources: - ports: 8081 + ports: 8080 cpus: 2+ setup: | diff --git a/tests/skyserve/update/bump_version_before.yaml b/tests/skyserve/update/bump_version_before.yaml index c9fd957e41a..de922b66434 100644 --- a/tests/skyserve/update/bump_version_before.yaml +++ b/tests/skyserve/update/bump_version_before.yaml @@ -16,7 +16,7 @@ service: replicas: 2 resources: - ports: 8081 + ports: 8080 cpus: 2+ setup: | diff --git a/tests/smoke_tests/test_basic.py b/tests/smoke_tests/test_basic.py index e8dffe53846..30576d3272f 100644 --- a/tests/smoke_tests/test_basic.py +++ b/tests/smoke_tests/test_basic.py @@ -422,6 +422,7 @@ def test_load_dump_yaml_config_equivalent(self): # ---------- Testing Multiple Accelerators ---------- @pytest.mark.no_fluidstack # Fluidstack does not support K80 gpus for now @pytest.mark.no_paperspace # Paperspace does not support K80 gpus +@pytest.mark.no_do # DO does not support K80s def test_multiple_accelerators_ordered(): name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( @@ -438,6 +439,7 @@ def test_multiple_accelerators_ordered(): @pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs @pytest.mark.no_paperspace # Paperspace does not support T4 GPUs +@pytest.mark.no_do # DO does not have multiple accelerators def test_multiple_accelerators_ordered_with_default(): name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( @@ -454,6 +456,7 @@ def test_multiple_accelerators_ordered_with_default(): @pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs @pytest.mark.no_paperspace # Paperspace does not support T4 GPUs +@pytest.mark.no_do # DO does not have multiple accelerators def test_multiple_accelerators_unordered(): name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( @@ -469,6 +472,7 @@ def test_multiple_accelerators_unordered(): @pytest.mark.no_fluidstack # Fluidstack has low availability for T4 GPUs @pytest.mark.no_paperspace # Paperspace does not support T4 GPUs +@pytest.mark.no_do # DO does not support multiple accelerators def test_multiple_accelerators_unordered_with_default(): name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( @@ -502,6 +506,7 @@ def test_multiple_resources(): @pytest.mark.no_paperspace # Requires other clouds to be enabled @pytest.mark.no_kubernetes @pytest.mark.aws # SkyBenchmark requires S3 access +@pytest.mark.no_do # requires other clouds to be enabled def test_sky_bench(generic_cloud: str): name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( @@ -566,7 +571,7 @@ def test_kubernetes_context_failover(): 'kubectl get namespaces --context kind-skypilot | grep test-namespace || ' '{ echo "Should set the namespace to test-namespace for kind-skypilot. Check the instructions in ' 'tests/test_smoke.py::test_kubernetes_context_failover." && exit 1; }', - 'sky show-gpus --cloud kubernetes --region kind-skypilot | grep H100 | grep "1, 2, 3, 4, 5, 6, 7, 8"', + 'sky show-gpus --cloud kubernetes --region kind-skypilot | grep H100 | grep "1, 2, 4, 8"', # Get contexts and set current context to the other cluster that is not kind-skypilot f'kubectl config use-context {context}', # H100 should not in the current context diff --git a/tests/smoke_tests/test_cluster_job.py b/tests/smoke_tests/test_cluster_job.py index 18b82c649e7..1fbb1b3d875 100644 --- a/tests/smoke_tests/test_cluster_job.py +++ b/tests/smoke_tests/test_cluster_job.py @@ -22,6 +22,7 @@ import pathlib import tempfile import textwrap +from typing import Dict import jinja2 import pytest @@ -43,15 +44,17 @@ @pytest.mark.no_scp # SCP does not have T4 gpus. Run test_scp_job_queue instead @pytest.mark.no_paperspace # Paperspace does not have T4 gpus. @pytest.mark.no_oci # OCI does not have T4 gpus -def test_job_queue(generic_cloud: str): +@pytest.mark.parametrize('accelerator', [{'do': 'H100'}]) +def test_job_queue(generic_cloud: str, accelerator: Dict[str, str]): + accelerator = accelerator.get(generic_cloud, 'T4') name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( 'job_queue', [ - f'sky launch -y -c {name} --cloud {generic_cloud} examples/job_queue/cluster.yaml', - f'sky exec {name} -n {name}-1 -d examples/job_queue/job.yaml', - f'sky exec {name} -n {name}-2 -d examples/job_queue/job.yaml', - f'sky exec {name} -n {name}-3 -d examples/job_queue/job.yaml', + f'sky launch -y -c {name} --cloud {generic_cloud} --gpus {accelerator} examples/job_queue/cluster.yaml', + f'sky exec {name} -n {name}-1 -d --gpus {accelerator}:0.5 examples/job_queue/job.yaml', + f'sky exec {name} -n {name}-2 -d --gpus {accelerator}:0.5 examples/job_queue/job.yaml', + f'sky exec {name} -n {name}-3 -d --gpus {accelerator}:0.5 examples/job_queue/job.yaml', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep RUNNING', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep RUNNING', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep PENDING', @@ -59,8 +62,8 @@ def test_job_queue(generic_cloud: str): 'sleep 5', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep RUNNING', f'sky cancel -y {name} 3', - f'sky exec {name} --gpus T4:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', - f'sky exec {name} --gpus T4:1 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', + f'sky exec {name} --gpus {accelerator}:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', + f'sky exec {name} --gpus {accelerator}:1 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', f'sky logs {name} 4 --status', f'sky logs {name} 5 --status', ], @@ -77,6 +80,7 @@ def test_job_queue(generic_cloud: str): @pytest.mark.no_scp # Doesn't support SCP for now @pytest.mark.no_oci # Doesn't support OCI for now @pytest.mark.no_kubernetes # Doesn't support Kubernetes for now +@pytest.mark.parametrize('accelerator', [{'do': 'H100'}]) @pytest.mark.parametrize( 'image_id', [ @@ -93,17 +97,19 @@ def test_job_queue(generic_cloud: str): # 2. python>=3.12 works with SkyPilot runtime. 'docker:winglian/axolotl:main-latest' ]) -def test_job_queue_with_docker(generic_cloud: str, image_id: str): +def test_job_queue_with_docker(generic_cloud: str, image_id: str, + accelerator: Dict[str, str]): + accelerator = accelerator.get(generic_cloud, 'T4') name = smoke_tests_utils.get_cluster_name() + image_id[len('docker:'):][:4] total_timeout_minutes = 40 if generic_cloud == 'azure' else 15 time_to_sleep = 300 if generic_cloud == 'azure' else 180 test = smoke_tests_utils.Test( 'job_queue_with_docker', [ - f'sky launch -y -c {name} --cloud {generic_cloud} --image-id {image_id} examples/job_queue/cluster_docker.yaml', - f'sky exec {name} -n {name}-1 -d --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml', - f'sky exec {name} -n {name}-2 -d --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml', - f'sky exec {name} -n {name}-3 -d --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml', + f'sky launch -y -c {name} --cloud {generic_cloud} --gpus {accelerator} --image-id {image_id} examples/job_queue/cluster_docker.yaml', + f'sky exec {name} -n {name}-1 -d --gpus {accelerator}:0.5 --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml', + f'sky exec {name} -n {name}-2 -d --gpus {accelerator}:0.5 --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml', + f'sky exec {name} -n {name}-3 -d --gpus {accelerator}:0.5 --image-id {image_id} --env TIME_TO_SLEEP={time_to_sleep} examples/job_queue/job_docker.yaml', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep RUNNING', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep RUNNING', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep PENDING', @@ -112,7 +118,7 @@ def test_job_queue_with_docker(generic_cloud: str, image_id: str): f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep RUNNING', f'sky cancel -y {name} 3', # Make sure the GPU is still visible to the container. - f'sky exec {name} --image-id {image_id} nvidia-smi | grep "Tesla T4"', + f'sky exec {name} --image-id {image_id} nvidia-smi | grep -i "{accelerator}"', f'sky logs {name} 4 --status', f'sky stop -y {name}', # Make sure the job status preserve after stop and start the @@ -122,12 +128,12 @@ def test_job_queue_with_docker(generic_cloud: str, image_id: str): f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-1 | grep FAILED', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-2 | grep CANCELLED', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep CANCELLED', - f'sky exec {name} --gpus T4:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', - f'sky exec {name} --gpus T4:1 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', + f'sky exec {name} --gpus {accelerator}:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', + f'sky exec {name} --gpus {accelerator}:1 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', f'sky logs {name} 5 --status', f'sky logs {name} 6 --status', # Make sure it is still visible after an stop & start cycle. - f'sky exec {name} --image-id {image_id} nvidia-smi | grep "Tesla T4"', + f'sky exec {name} --image-id {image_id} nvidia-smi | grep -i "{accelerator}"', f'sky logs {name} 7 --status' ], f'sky down -y {name}', @@ -214,16 +220,18 @@ def test_scp_job_queue(): @pytest.mark.no_scp # SCP does not support num_nodes > 1 yet @pytest.mark.no_oci # OCI Cloud does not have T4 gpus. @pytest.mark.no_kubernetes # Kubernetes not support num_nodes > 1 yet -def test_job_queue_multinode(generic_cloud: str): +@pytest.mark.parametrize('accelerator', [{'do': 'H100'}]) +def test_job_queue_multinode(generic_cloud: str, accelerator: Dict[str, str]): + accelerator = accelerator.get(generic_cloud, 'T4') name = smoke_tests_utils.get_cluster_name() total_timeout_minutes = 30 if generic_cloud == 'azure' else 15 test = smoke_tests_utils.Test( 'job_queue_multinode', [ - f'sky launch -y -c {name} --cloud {generic_cloud} examples/job_queue/cluster_multinode.yaml', - f'sky exec {name} -n {name}-1 -d examples/job_queue/job_multinode.yaml', - f'sky exec {name} -n {name}-2 -d examples/job_queue/job_multinode.yaml', - f'sky launch -c {name} -n {name}-3 --detach-setup -d examples/job_queue/job_multinode.yaml', + f'sky launch -y -c {name} --cloud {generic_cloud} --gpus {accelerator} examples/job_queue/cluster_multinode.yaml', + f'sky exec {name} -n {name}-1 -d --gpus {accelerator}:0.5 examples/job_queue/job_multinode.yaml', + f'sky exec {name} -n {name}-2 -d --gpus {accelerator}:0.5 examples/job_queue/job_multinode.yaml', + f'sky launch -c {name} -n {name}-3 --detach-setup -d --gpus {accelerator}:0.5 examples/job_queue/job_multinode.yaml', f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-1 | grep RUNNING)', f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-2 | grep RUNNING)', f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-3 | grep PENDING)', @@ -232,16 +240,16 @@ def test_job_queue_multinode(generic_cloud: str): 'sleep 5', f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep {name}-3 | grep SETTING_UP', f'sky cancel -y {name} 1 2 3', - f'sky launch -c {name} -n {name}-4 --detach-setup -d examples/job_queue/job_multinode.yaml', + f'sky launch -c {name} -n {name}-4 --detach-setup -d --gpus {accelerator} examples/job_queue/job_multinode.yaml', # Test the job status is correctly set to SETTING_UP, during the setup is running, # and the job can be cancelled during the setup. 'sleep 5', f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-4 | grep SETTING_UP)', f'sky cancel -y {name} 4', f's=$(sky queue {name}) && echo "$s" && (echo "$s" | grep {name}-4 | grep CANCELLED)', - f'sky exec {name} --gpus T4:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', - f'sky exec {name} --gpus T4:0.2 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', - f'sky exec {name} --gpus T4:1 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', + f'sky exec {name} --gpus {accelerator}:0.2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', + f'sky exec {name} --gpus {accelerator}:0.2 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', + f'sky exec {name} --gpus {accelerator}:1 --num-nodes 2 "[[ \$SKYPILOT_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"', f'sky logs {name} 5 --status', f'sky logs {name} 6 --status', f'sky logs {name} 7 --status', @@ -385,6 +393,7 @@ def test_docker_preinstalled_package(generic_cloud: str): @pytest.mark.no_ibm # IBM Cloud does not have T4 gpus @pytest.mark.no_scp # SCP does not support num_nodes > 1 yet @pytest.mark.no_oci # OCI Cloud does not have T4 gpus +@pytest.mark.no_do # DO does not have T4 gpus def test_multi_echo(generic_cloud: str): name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( @@ -427,14 +436,16 @@ def test_multi_echo(generic_cloud: str): @pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 gpus @pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA @pytest.mark.no_scp # SCP does not have V100 (16GB) GPUs. Run test_scp_huggingface instead. -def test_huggingface(generic_cloud: str): +@pytest.mark.parametrize('accelerator', [{'do': 'H100'}]) +def test_huggingface(generic_cloud: str, accelerator: Dict[str, str]): + accelerator = accelerator.get(generic_cloud, 'T4') name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( 'huggingface_glue_imdb_app', [ - f'sky launch -y -c {name} --cloud {generic_cloud} examples/huggingface_glue_imdb_app.yaml', + f'sky launch -y -c {name} --cloud {generic_cloud} --gpus {accelerator} examples/huggingface_glue_imdb_app.yaml', f'sky logs {name} 1 --status', # Ensure the job succeeded. - f'sky exec {name} examples/huggingface_glue_imdb_app.yaml', + f'sky exec {name} --gpus {accelerator} examples/huggingface_glue_imdb_app.yaml', f'sky logs {name} 2 --status', # Ensure the job succeeded. ], f'sky down -y {name}', @@ -556,6 +567,7 @@ def test_tpu_vm_pod(): # ---------- TPU Pod Slice on GKE. ---------- +@pytest.mark.requires_gke @pytest.mark.kubernetes def test_tpu_pod_slice_gke(): name = smoke_tests_utils.get_cluster_name() @@ -865,6 +877,7 @@ def test_add_and_remove_pod_annotations_with_autostop(): # ---------- Container logs from task on Kubernetes ---------- +@pytest.mark.requires_gke @pytest.mark.kubernetes def test_container_logs_multinode_kubernetes(): name = smoke_tests_utils.get_cluster_name() @@ -953,6 +966,7 @@ def test_container_logs_two_simultaneous_jobs_kubernetes(): @pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 gpus @pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA @pytest.mark.no_scp # SCP does not support num_nodes > 1 yet +@pytest.mark.no_dos # DO does not have V100 gpus @pytest.mark.skip( reason= 'The resnet_distributed_tf_app is flaky, due to it failing to detect GPUs.') @@ -1228,12 +1242,14 @@ def test_cancel_azure(): @pytest.mark.no_ibm # IBM cloud currently doesn't provide public image with CUDA @pytest.mark.no_paperspace # Paperspace has `gnome-shell` on nvidia-smi @pytest.mark.no_scp # SCP does not support num_nodes > 1 yet -def test_cancel_pytorch(generic_cloud: str): +@pytest.mark.parametrize('accelerator', [{'do': 'H100'}]) +def test_cancel_pytorch(generic_cloud: str, accelerator: Dict[str, str]): + accelerator = accelerator.get(generic_cloud, 'T4') name = smoke_tests_utils.get_cluster_name() test = smoke_tests_utils.Test( 'cancel-pytorch', [ - f'sky launch -c {name} --cloud {generic_cloud} examples/resnet_distributed_torch.yaml -y -d', + f'sky launch -c {name} --cloud {generic_cloud} --gpus {accelerator} examples/resnet_distributed_torch.yaml -y -d', # Wait the GPU process to start. 'sleep 90', f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep python) || ' @@ -1283,6 +1299,7 @@ def test_cancel_ibm(): @pytest.mark.no_ibm # IBM Cloud does not support spot instances @pytest.mark.no_scp # SCP does not support spot instances @pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances +@pytest.mark.no_do def test_use_spot(generic_cloud: str): """Test use-spot and sky exec.""" name = smoke_tests_utils.get_cluster_name() @@ -1414,6 +1431,7 @@ def test_aws_custom_image(): smoke_tests_utils.run_one_test(test) +@pytest.mark.requires_gke @pytest.mark.kubernetes @pytest.mark.parametrize( 'image_id', @@ -1555,7 +1573,7 @@ def test_azure_disk_tier(): name = smoke_tests_utils.get_cluster_name() + '-' + disk_tier.value name_on_cloud = common_utils.make_cluster_name_on_cloud( name, sky.Azure.max_cluster_name_length()) - region = 'westus2' + region = 'eastus2' test = smoke_tests_utils.Test( 'azure-disk-tier-' + disk_tier.value, [ @@ -1577,7 +1595,7 @@ def test_azure_best_tier_failover(): name = smoke_tests_utils.get_cluster_name() name_on_cloud = common_utils.make_cluster_name_on_cloud( name, sky.Azure.max_cluster_name_length()) - region = 'westus2' + region = 'eastus2' test = smoke_tests_utils.Test( 'azure-best-tier-failover', [ diff --git a/tests/smoke_tests/test_images.py b/tests/smoke_tests/test_images.py index 27d6a693ae6..14769161675 100644 --- a/tests/smoke_tests/test_images.py +++ b/tests/smoke_tests/test_images.py @@ -19,6 +19,9 @@ # Change cloud for generic tests to aws # > pytest tests/smoke_tests/test_images.py --generic-cloud aws +import os +import subprocess + import pytest from smoke_tests import smoke_tests_utils @@ -345,6 +348,21 @@ def test_gcp_mig(): @pytest.mark.gcp def test_gcp_force_enable_external_ips(): name = smoke_tests_utils.get_cluster_name() + + # Command to check if the instance is on GCP + is_on_gcp_command = ( + 'curl -s -H "Metadata-Flavor: Google" ' + '"http://metadata.google.internal/computeMetadata/v1/instance/name"') + + # Run the GCP check + is_on_gcp = subprocess.run(f'{is_on_gcp_command}', + shell=True, + check=False, + text=True, + capture_output=True).stdout.strip() + if not is_on_gcp: + pytest.skip('Not on GCP, skipping test') + test_commands = [ f'sky launch -y -c {name} --cloud gcp --cpus 2 tests/test_yamls/minimal.yaml', # Check network of vm is "default" diff --git a/tests/smoke_tests/test_managed_job.py b/tests/smoke_tests/test_managed_job.py index c8ef5c1a502..12d1654a773 100644 --- a/tests/smoke_tests/test_managed_job.py +++ b/tests/smoke_tests/test_managed_job.py @@ -23,6 +23,7 @@ # > pytest tests/smoke_tests/test_managed_job.py --generic-cloud aws import pathlib +import re import tempfile import time @@ -94,6 +95,7 @@ def test_managed_jobs(generic_cloud: str): @pytest.mark.no_scp # SCP does not support spot instances @pytest.mark.no_paperspace # Paperspace does not support spot instances @pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances +@pytest.mark.no_do # DO does not support spot instances @pytest.mark.managed_jobs def test_job_pipeline(generic_cloud: str): """Test a job pipeline.""" @@ -135,6 +137,7 @@ def test_job_pipeline(generic_cloud: str): @pytest.mark.no_scp # SCP does not support spot instances @pytest.mark.no_paperspace # Paperspace does not support spot instances @pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances +@pytest.mark.no_do # DO does not support spot instances @pytest.mark.managed_jobs def test_managed_jobs_failed_setup(generic_cloud: str): """Test managed job with failed setup.""" @@ -362,7 +365,7 @@ def test_managed_jobs_pipeline_recovery_gcp(): # separated by `-`. (f'MANAGED_JOB_ID=`cat /tmp/{name}-run-id | rev | ' f'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`; {terminate_cmd}'), - smoke_tests_utils.zJOB_WAIT_NOT_RUNNING.format(job_name=name), + smoke_tests_utils.JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{smoke_tests_utils.GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', smoke_tests_utils. get_cmd_wait_until_managed_job_status_contains_matching_job_name( @@ -386,6 +389,7 @@ def test_managed_jobs_pipeline_recovery_gcp(): @pytest.mark.no_scp # SCP does not support spot instances @pytest.mark.no_paperspace # Paperspace does not support spot instances @pytest.mark.no_kubernetes # Kubernetes does not have a notion of spot instances +@pytest.mark.no_do # DO does not have spot instances @pytest.mark.managed_jobs def test_managed_jobs_recovery_default_resources(generic_cloud: str): """Test managed job recovery for default resources.""" @@ -657,6 +661,7 @@ def test_managed_jobs_cancellation_gcp(): @pytest.mark.no_ibm # IBM Cloud does not support spot instances @pytest.mark.no_paperspace # Paperspace does not support spot instances @pytest.mark.no_scp # SCP does not support spot instances +@pytest.mark.no_do # DO does not support spot instances @pytest.mark.managed_jobs def test_managed_jobs_storage(generic_cloud: str): """Test storage with managed job""" @@ -693,7 +698,7 @@ def test_managed_jobs_storage(generic_cloud: str): storage_lib.StoreType.GCS, output_storage_name, 'output.txt') output_check_cmd = f'{gcs_check_file_count} | grep 1' elif generic_cloud == 'azure': - region = 'westus2' + region = 'centralus' region_flag = f' --region {region}' storage_account_name = ( storage_lib.AzureBlobStore.get_default_storage_account_name(region)) @@ -742,14 +747,70 @@ def test_managed_jobs_storage(generic_cloud: str): # Check if file was written to the mounted output bucket output_check_cmd ], - (f'sky jobs cancel -y -n {name}', - f'; sky storage delete {output_storage_name} || true'), + (f'sky jobs cancel -y -n {name}' + f'; sky storage delete {output_storage_name} -y || true'), # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=20 * 60, ) smoke_tests_utils.run_one_test(test) +@pytest.mark.aws +def test_managed_jobs_intermediate_storage(generic_cloud: str): + """Test storage with managed job""" + name = smoke_tests_utils.get_cluster_name() + yaml_str = pathlib.Path( + 'examples/managed_job_with_storage.yaml').read_text() + timestamp = int(time.time()) + storage_name = f'sky-test-{timestamp}' + output_storage_name = f'sky-test-output-{timestamp}' + + yaml_str_user_config = pathlib.Path( + 'tests/test_yamls/use_intermediate_bucket_config.yaml').read_text() + intermediate_storage_name = f'intermediate-smoke-test-{timestamp}' + + yaml_str = yaml_str.replace('sky-workdir-zhwu', storage_name) + yaml_str = yaml_str.replace('sky-output-bucket', output_storage_name) + yaml_str_user_config = re.sub(r'bucket-jobs-[\w\d]+', + intermediate_storage_name, + yaml_str_user_config) + + with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f_user_config: + f_user_config.write(yaml_str_user_config) + f_user_config.flush() + user_config_path = f_user_config.name + with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f_task: + f_task.write(yaml_str) + f_task.flush() + file_path = f_task.name + + test = smoke_tests_utils.Test( + 'managed_jobs_intermediate_storage', + [ + *smoke_tests_utils.STORAGE_SETUP_COMMANDS, + # Verify command fails with correct error - run only once + f'err=$(sky jobs launch -n {name} --cloud {generic_cloud} {file_path} -y 2>&1); ret=$?; echo "$err" ; [ $ret -eq 0 ] || ! echo "$err" | grep "StorageBucketCreateError: Jobs bucket \'{intermediate_storage_name}\' does not exist. Please check jobs.bucket configuration in your SkyPilot config." > /dev/null && exit 1 || exit 0', + f'aws s3api create-bucket --bucket {intermediate_storage_name}', + f'sky jobs launch -n {name} --cloud {generic_cloud} {file_path} -y', + # fail because the bucket does not exist + smoke_tests_utils. + get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.SUCCEEDED], + timeout=60 + smoke_tests_utils.BUMP_UP_SECONDS), + # check intermediate bucket exists, it won't be deletd if its user specific + f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{intermediate_storage_name}\')].Name" --output text | wc -l) -eq 1 ]', + ], + (f'sky jobs cancel -y -n {name}' + f'; aws s3 rb s3://{intermediate_storage_name} --force' + f'; sky storage delete {output_storage_name} -y || true'), + env={'SKYPILOT_CONFIG': user_config_path}, + # Increase timeout since sky jobs queue -r can be blocked by other spot tests. + timeout=20 * 60, + ) + smoke_tests_utils.run_one_test(test) + + # ---------- Testing spot TPU ---------- @pytest.mark.gcp @pytest.mark.managed_jobs @@ -810,3 +871,26 @@ def test_managed_jobs_inline_env(generic_cloud: str): timeout=20 * 60, ) smoke_tests_utils.run_one_test(test) + + +@pytest.mark.managed_jobs +def test_managed_jobs_logs_sync_down(): + name = smoke_tests_utils.get_cluster_name() + test = smoke_tests_utils.Test( + 'test-managed-jobs-logs-sync-down', + [ + f'sky jobs launch -n {name} -y examples/managed_job.yaml -d', + smoke_tests_utils. + get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + smoke_tests_utils.BUMP_UP_SECONDS), + f'sky jobs logs --controller 1 --sync-down', + f'sky jobs logs 1 --sync-down', + f'sky jobs logs --controller --name minimal --sync-down', + f'sky jobs logs --name minimal --sync-down', + ], + f'sky jobs cancel -y -n {name}', + timeout=20 * 60, + ) + smoke_tests_utils.run_one_test(test) diff --git a/tests/smoke_tests/test_mount_and_storage.py b/tests/smoke_tests/test_mount_and_storage.py index 93a5f22c274..3f2ddb16c57 100644 --- a/tests/smoke_tests/test_mount_and_storage.py +++ b/tests/smoke_tests/test_mount_and_storage.py @@ -19,6 +19,7 @@ # Change cloud for generic tests to aws # > pytest tests/smoke_tests/test_mount_and_storage.py --generic-cloud aws +import json import os import pathlib import shlex @@ -37,6 +38,7 @@ import sky from sky import global_user_state from sky import skypilot_config +from sky.adaptors import azure from sky.adaptors import cloudflare from sky.adaptors import ibm from sky.data import data_utils @@ -85,6 +87,23 @@ def test_scp_file_mounts(): smoke_tests_utils.run_one_test(test) +@pytest.mark.oci # For OCI object storage mounts and file mounts. +def test_oci_mounts(): + name = smoke_tests_utils.get_cluster_name() + test_commands = [ + *smoke_tests_utils.STORAGE_SETUP_COMMANDS, + f'sky launch -y -c {name} --cloud oci --num-nodes 2 examples/oci/oci-mounts.yaml', + f'sky logs {name} 1 --status', # Ensure the job succeeded. + ] + test = smoke_tests_utils.Test( + 'oci_mounts', + test_commands, + f'sky down -y {name}', + timeout=20 * 60, # 20 mins + ) + smoke_tests_utils.run_one_test(test) + + @pytest.mark.no_fluidstack # Requires GCP to be enabled def test_using_file_mounts_with_env_vars(generic_cloud: str): name = smoke_tests_utils.get_cluster_name() @@ -612,21 +631,69 @@ def cli_delete_cmd(store_type, bucket_name, Rclone.RcloneClouds.IBM) return f'rclone purge {bucket_rclone_profile}:{bucket_name} && rclone config delete {bucket_rclone_profile}' + @classmethod + def list_all_files(cls, store_type, bucket_name): + cmd = cls.cli_ls_cmd(store_type, bucket_name, recursive=True) + if store_type == storage_lib.StoreType.GCS: + try: + out = subprocess.check_output(cmd, + shell=True, + stderr=subprocess.PIPE) + files = [line[5:] for line in out.decode('utf-8').splitlines()] + except subprocess.CalledProcessError as e: + error_output = e.stderr.decode('utf-8') + if "One or more URLs matched no objects" in error_output: + files = [] + else: + raise + elif store_type == storage_lib.StoreType.AZURE: + out = subprocess.check_output(cmd, shell=True) + try: + blobs = json.loads(out.decode('utf-8')) + files = [blob['name'] for blob in blobs] + except json.JSONDecodeError: + files = [] + elif store_type == storage_lib.StoreType.IBM: + # rclone ls format: " 1234 path/to/file" + out = subprocess.check_output(cmd, shell=True) + files = [] + for line in out.decode('utf-8').splitlines(): + # Skip empty lines + if not line.strip(): + continue + # Split by whitespace and get the file path (last column) + parts = line.strip().split( + None, 1) # Split into max 2 parts (size and path) + if len(parts) == 2: + files.append(parts[1]) + else: + out = subprocess.check_output(cmd, shell=True) + files = [ + line.split()[-1] for line in out.decode('utf-8').splitlines() + ] + return files + @staticmethod - def cli_ls_cmd(store_type, bucket_name, suffix=''): + def cli_ls_cmd(store_type, bucket_name, suffix='', recursive=False): if store_type == storage_lib.StoreType.S3: if suffix: url = f's3://{bucket_name}/{suffix}' else: url = f's3://{bucket_name}' - return f'aws s3 ls {url}' + cmd = f'aws s3 ls {url}' + if recursive: + cmd += ' --recursive' + return cmd if store_type == storage_lib.StoreType.GCS: if suffix: url = f'gs://{bucket_name}/{suffix}' else: url = f'gs://{bucket_name}' + if recursive: + url = f'"{url}/**"' return f'gsutil ls {url}' if store_type == storage_lib.StoreType.AZURE: + # azure isrecursive by default default_region = 'eastus' config_storage_account = skypilot_config.get_nested( ('azure', 'storage_account'), None) @@ -648,8 +715,10 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''): url = f's3://{bucket_name}/{suffix}' else: url = f's3://{bucket_name}' - return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls {url} --endpoint {endpoint_url} --profile=r2' + recursive_flag = '--recursive' if recursive else '' + return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls {url} --endpoint {endpoint_url} --profile=r2 {recursive_flag}' if store_type == storage_lib.StoreType.IBM: + # rclone ls is recursive by default bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name( bucket_name, Rclone.RcloneClouds.IBM) return f'rclone ls {bucket_rclone_profile}:{bucket_name}/{suffix}' @@ -747,6 +816,12 @@ def tmp_source(self, tmp_path): circle_link.symlink_to(tmp_dir, target_is_directory=True) yield str(tmp_dir) + @pytest.fixture + def tmp_sub_path(self): + tmp_dir1 = uuid.uuid4().hex[:8] + tmp_dir2 = uuid.uuid4().hex[:8] + yield "/".join([tmp_dir1, tmp_dir2]) + @staticmethod def generate_bucket_name(): # Creates a temporary bucket name @@ -766,13 +841,15 @@ def yield_storage_object( stores: Optional[Dict[storage_lib.StoreType, storage_lib.AbstractStore]] = None, persistent: Optional[bool] = True, - mode: storage_lib.StorageMode = storage_lib.StorageMode.MOUNT): + mode: storage_lib.StorageMode = storage_lib.StorageMode.MOUNT, + _bucket_sub_path: Optional[str] = None): # Creates a temporary storage object. Stores must be added in the test. storage_obj = storage_lib.Storage(name=name, source=source, stores=stores, persistent=persistent, - mode=mode) + mode=mode, + _bucket_sub_path=_bucket_sub_path) yield storage_obj handle = global_user_state.get_handle_from_storage_name( storage_obj.name) @@ -839,6 +916,15 @@ def tmp_local_storage_obj(self, tmp_bucket_name, tmp_source): yield from self.yield_storage_object(name=tmp_bucket_name, source=tmp_source) + @pytest.fixture + def tmp_local_storage_obj_with_sub_path(self, tmp_bucket_name, tmp_source, + tmp_sub_path): + # Creates a temporary storage object with sub. Stores must be added in the test. + list_source = [tmp_source, tmp_source + '/tmp-file'] + yield from self.yield_storage_object(name=tmp_bucket_name, + source=list_source, + _bucket_sub_path=tmp_sub_path) + @pytest.fixture def tmp_local_list_storage_obj(self, tmp_bucket_name, tmp_source): # Creates a temp storage object which uses a list of paths as source. @@ -997,6 +1083,59 @@ def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj, out = subprocess.check_output(['sky', 'storage', 'ls']) assert tmp_local_storage_obj.name not in out.decode('utf-8') + @pytest.mark.no_fluidstack + @pytest.mark.parametrize('store_type', [ + pytest.param(storage_lib.StoreType.S3, marks=pytest.mark.aws), + pytest.param(storage_lib.StoreType.GCS, marks=pytest.mark.gcp), + pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure), + pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm), + pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare) + ]) + def test_bucket_sub_path(self, tmp_local_storage_obj_with_sub_path, + store_type): + # Creates a new bucket with a local source, uploads files to it + # and deletes it. + tmp_local_storage_obj_with_sub_path.add_store(store_type) + + # Check files under bucket and filter by prefix + files = self.list_all_files(store_type, + tmp_local_storage_obj_with_sub_path.name) + assert len(files) > 0 + if store_type == storage_lib.StoreType.GCS: + assert all([ + file.startswith( + tmp_local_storage_obj_with_sub_path.name + '/' + + tmp_local_storage_obj_with_sub_path._bucket_sub_path) + for file in files + ]) + else: + assert all([ + file.startswith( + tmp_local_storage_obj_with_sub_path._bucket_sub_path) + for file in files + ]) + + # Check bucket is empty, all files under sub directory should be deleted + store = tmp_local_storage_obj_with_sub_path.stores[store_type] + store.is_sky_managed = False + if store_type == storage_lib.StoreType.AZURE: + azure.assign_storage_account_iam_role( + storage_account_name=store.storage_account_name, + resource_group_name=store.resource_group_name) + store.delete() + files = self.list_all_files(store_type, + tmp_local_storage_obj_with_sub_path.name) + assert len(files) == 0 + + # Now, delete the entire bucket + store.is_sky_managed = True + tmp_local_storage_obj_with_sub_path.delete() + + # Run sky storage ls to check if storage object is deleted + out = subprocess.check_output(['sky', 'storage', 'ls']) + assert tmp_local_storage_obj_with_sub_path.name not in out.decode( + 'utf-8') + @pytest.mark.no_fluidstack @pytest.mark.xdist_group('multiple_bucket_deletion') @pytest.mark.parametrize('store_type', [ @@ -1466,8 +1605,8 @@ def test_aws_regions(self, tmp_local_storage_obj, region): 'europe-west8', 'europe-west9', 'europe-west10', 'europe-west12', 'asia-east1', 'asia-east2', 'asia-northeast1', 'asia-northeast2', 'asia-northeast3', 'asia-southeast1', 'asia-south1', 'asia-south2', - 'asia-southeast2', 'me-central1', 'me-central2', 'me-west1', - 'australia-southeast1', 'australia-southeast2', 'africa-south1' + 'asia-southeast2', 'me-central1', 'me-west1', 'australia-southeast1', + 'australia-southeast2', 'africa-south1' ]) def test_gcs_regions(self, tmp_local_storage_obj, region): # This tests creation and upload to bucket in all GCS regions diff --git a/tests/smoke_tests/test_sky_serve.py b/tests/smoke_tests/test_sky_serve.py index 5f34eba8728..3ba36d8a092 100644 --- a/tests/smoke_tests/test_sky_serve.py +++ b/tests/smoke_tests/test_sky_serve.py @@ -25,7 +25,7 @@ import inspect import json import shlex -from typing import List, Tuple +from typing import Dict, List, Tuple import pytest from smoke_tests import smoke_tests_utils @@ -180,6 +180,7 @@ def test_skyserve_azure_http(): @pytest.mark.kubernetes @pytest.mark.serve +@pytest.mark.requires_gke def test_skyserve_kubernetes_http(): """Test skyserve on Kubernetes""" name = _get_service_name() @@ -197,9 +198,11 @@ def test_skyserve_oci_http(): @pytest.mark.no_fluidstack # Fluidstack does not support T4 gpus for now +@pytest.mark.parametrize('accelerator', [{'do': 'H100'}]) @pytest.mark.serve -def test_skyserve_llm(generic_cloud: str): +def test_skyserve_llm(generic_cloud: str, accelerator: Dict[str, str]): """Test skyserve with real LLM usecase""" + accelerator = accelerator.get(generic_cloud, 'T4') name = _get_service_name() def generate_llm_test_command(prompt: str, expected_output: str) -> str: @@ -217,7 +220,7 @@ def generate_llm_test_command(prompt: str, expected_output: str) -> str: test = smoke_tests_utils.Test( f'test-skyserve-llm', [ - f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/llm/service.yaml', + f'sky serve up -n {name} --cloud {generic_cloud} --gpus {accelerator} -y tests/skyserve/llm/service.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), *[ generate_llm_test_command(prompt, output) @@ -257,6 +260,7 @@ def test_skyserve_spot_recovery(): @pytest.mark.no_fluidstack # Fluidstack does not support spot instances @pytest.mark.serve @pytest.mark.no_kubernetes +@pytest.mark.no_do def test_skyserve_base_ondemand_fallback(generic_cloud: str): name = _get_service_name() test = smoke_tests_utils.Test( @@ -321,6 +325,7 @@ def test_skyserve_dynamic_ondemand_fallback(): # TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs @pytest.mark.no_fluidstack +@pytest.mark.no_do # DO does not support `--cpus 2` @pytest.mark.serve def test_skyserve_user_bug_restart(generic_cloud: str): """Tests that we restart the service after user bug.""" @@ -507,6 +512,7 @@ def test_skyserve_large_readiness_timeout(generic_cloud: str): # TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs @pytest.mark.no_fluidstack +@pytest.mark.no_do # DO does not support `--cpus 2` @pytest.mark.serve def test_skyserve_update(generic_cloud: str): """Test skyserve with update""" @@ -537,6 +543,7 @@ def test_skyserve_update(generic_cloud: str): # TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs @pytest.mark.no_fluidstack +@pytest.mark.no_do # DO does not support `--cpus 2` @pytest.mark.serve def test_skyserve_rolling_update(generic_cloud: str): """Test skyserve with rolling update""" @@ -654,6 +661,7 @@ def test_skyserve_update_autoscale(generic_cloud: str): @pytest.mark.no_fluidstack # Spot instances are note supported by Fluidstack @pytest.mark.serve @pytest.mark.no_kubernetes # Spot instances are not supported in Kubernetes +@pytest.mark.no_do # Spot instances not on DO @pytest.mark.parametrize('mode', ['rolling', 'blue_green']) def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): """Test skyserve with update that changes autoscaler""" @@ -717,6 +725,7 @@ def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): # TODO: fluidstack does not support `--cpus 2`, but the check for services in this test is based on CPUs @pytest.mark.no_fluidstack +@pytest.mark.no_do # DO does not support `--cpus 2` @pytest.mark.serve def test_skyserve_failures(generic_cloud: str): """Test replica failure statuses""" diff --git a/tests/test_config.py b/tests/test_config.py index 5789214dc61..d3eaeb261bc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,6 +7,7 @@ import sky from sky import skypilot_config +import sky.exceptions from sky.skylet import constants from sky.utils import common_utils from sky.utils import kubernetes_enums @@ -99,6 +100,29 @@ def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: """)) +def _create_invalid_config_yaml_file(task_file_path: pathlib.Path) -> None: + task_file_path.write_text( + textwrap.dedent("""\ + experimental: + config_overrides: + kubernetes: + pod_config: + metadata: + labels: + test-key: test-value + annotations: + abc: def + spec: + containers: + - name: + imagePullSecrets: + - name: my-secret-2 + + setup: echo 'Setting up...' + run: echo 'Running...' + """)) + + def test_nested_config(monkeypatch) -> None: """Test that the nested config works.""" config = skypilot_config.Config() @@ -335,6 +359,28 @@ def test_k8s_config_with_override(monkeypatch, tmp_path, assert cluster_pod_config['spec']['runtimeClassName'] == 'nvidia' +def test_k8s_config_with_invalid_config(monkeypatch, tmp_path, + enable_all_clouds) -> None: + config_path = tmp_path / 'config.yaml' + _create_config_file(config_path) + monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) + + _reload_config() + task_path = tmp_path / 'task.yaml' + _create_invalid_config_yaml_file(task_path) + task = sky.Task.from_yaml(task_path) + + # Test Kubernetes pod_config invalid + cluster_name = 'test_k8s_config_with_invalid_config' + task.set_resources_override({'cloud': sky.Kubernetes()}) + exception_occurred = False + try: + sky.launch(task, cluster_name=cluster_name, dryrun=True) + except sky.exceptions.ResourcesUnavailableError: + exception_occurred = True + assert exception_occurred + + def test_gcp_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: config_path = tmp_path / 'config.yaml' diff --git a/tests/test_yaml_parser.py b/tests/test_yaml_parser.py index 7d304b60633..a9fad1b4b83 100644 --- a/tests/test_yaml_parser.py +++ b/tests/test_yaml_parser.py @@ -96,8 +96,8 @@ def test_empty_fields_storage(tmp_path): storage = task.storage_mounts['/mystorage'] assert storage.name == 'sky-dataset' assert storage.source is None - assert len(storage.stores) == 0 - assert storage.persistent is True + assert not storage.stores + assert storage.persistent def test_invalid_fields_storage(tmp_path): diff --git a/tests/test_yamls/intermediate_bucket.yaml b/tests/test_yamls/intermediate_bucket.yaml new file mode 100644 index 00000000000..fe9aafd0675 --- /dev/null +++ b/tests/test_yamls/intermediate_bucket.yaml @@ -0,0 +1,21 @@ +name: intermediate-bucket + +file_mounts: + /setup.py: ./setup.py + /sky: . + /train-00001-of-01024: gs://cloud-tpu-test-datasets/fake_imagenet/train-00001-of-01024 + +workdir: . + + +setup: | + echo "running setup" + +run: | + echo "listing workdir" + ls . + echo "listing file_mounts" + ls /setup.py + ls /sky + ls /train-00001-of-01024 + echo "task run finish" diff --git a/tests/test_yamls/use_intermediate_bucket_config.yaml b/tests/test_yamls/use_intermediate_bucket_config.yaml new file mode 100644 index 00000000000..cdfb5fbabc1 --- /dev/null +++ b/tests/test_yamls/use_intermediate_bucket_config.yaml @@ -0,0 +1,2 @@ +jobs: + bucket: "s3://bucket-jobs-s3" diff --git a/tests/unit_tests/kubernetes/test_gpu_label_formatters.py b/tests/unit_tests/kubernetes/test_gpu_label_formatters.py new file mode 100644 index 00000000000..cd7337dc7a1 --- /dev/null +++ b/tests/unit_tests/kubernetes/test_gpu_label_formatters.py @@ -0,0 +1,22 @@ +"""Tests for GPU label formatting in Kubernetes integration. + +Tests verify correct GPU detection from Kubernetes labels. +""" +import pytest + +from sky.provision.kubernetes.utils import GFDLabelFormatter + + +def test_gfd_label_formatter(): + """Test word boundary regex matching in GFDLabelFormatter.""" + # Test various GPU name patterns + test_cases = [ + ('NVIDIA-L4-24GB', 'L4'), + ('NVIDIA-L40-48GB', 'L40'), + ('NVIDIA-L400', 'L400'), # Should not match L4 or L40 + ('NVIDIA-L4', 'L4'), + ('L40-GPU', 'L40'), + ] + for input_value, expected in test_cases: + result = GFDLabelFormatter.get_accelerator_from_label_value(input_value) + assert result == expected, f'Failed for {input_value}' diff --git a/tests/unit_tests/test_storage_utils.py b/tests/unit_tests/test_storage_utils.py index cd1e436390b..6edb5abf2f5 100644 --- a/tests/unit_tests/test_storage_utils.py +++ b/tests/unit_tests/test_storage_utils.py @@ -7,7 +7,7 @@ def test_get_excluded_files_from_skyignore_no_file(): excluded_files = storage_utils.get_excluded_files_from_skyignore('.') - assert len(excluded_files) == 0 + assert not excluded_files def test_get_excluded_files_from_skyignore():