diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js index 3f7689e4d6d..5ae47b7b7be 100644 --- a/docs/source/_static/custom.js +++ b/docs/source/_static/custom.js @@ -7,7 +7,7 @@ document.addEventListener('DOMContentLoaded', function () { script.setAttribute('data-project-logo', 'https://avatars.githubusercontent.com/u/109387420?s=100&v=4'); script.setAttribute('data-modal-disclaimer', 'Results are automatically generated and may be inaccurate or contain inappropriate information. Do not include any sensitive information in your query.\n**To get further assistance, you can chat directly with the development team** by joining the [SkyPilot Slack](https://slack.skypilot.co/).'); script.setAttribute('data-modal-title', 'SkyPilot Docs AI - Ask a Question.'); - script.setAttribute('data-button-position-bottom', '85px'); + script.setAttribute('data-button-position-bottom', '100px'); script.async = true; document.head.appendChild(script); }); diff --git a/docs/source/examples/managed-jobs.rst b/docs/source/examples/managed-jobs.rst index 018a993f588..61c33b5c43e 100644 --- a/docs/source/examples/managed-jobs.rst +++ b/docs/source/examples/managed-jobs.rst @@ -78,9 +78,9 @@ We can launch it with the following: .. code-block:: console + $ git clone https://github.com/huggingface/transformers.git ~/transformers -b v4.30.1 $ sky jobs launch -n bert-qa bert_qa.yaml - .. code-block:: yaml # bert_qa.yaml @@ -88,39 +88,37 @@ We can launch it with the following: resources: accelerators: V100:1 - # Use spot instances to save cost. - use_spot: true - - # Assume your working directory is under `~/transformers`. - # To make this example work, please run the following command: - # git clone https://github.com/huggingface/transformers.git ~/transformers -b v4.30.1 - workdir: ~/transformers + use_spot: true # Use spot instances to save cost. - setup: | + envs: # Fill in your wandb key: copy from https://wandb.ai/authorize # Alternatively, you can use `--env WANDB_API_KEY=$WANDB_API_KEY` # to pass the key in the command line, during `sky jobs launch`. - echo export WANDB_API_KEY=[YOUR-WANDB-API-KEY] >> ~/.bashrc + WANDB_API_KEY: + + # Assume your working directory is under `~/transformers`. + workdir: ~/transformers + setup: | pip install -e . cd examples/pytorch/question-answering/ pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 pip install wandb run: | - cd ./examples/pytorch/question-answering/ + cd examples/pytorch/question-answering/ python run_qa.py \ - --model_name_or_path bert-base-uncased \ - --dataset_name squad \ - --do_train \ - --do_eval \ - --per_device_train_batch_size 12 \ - --learning_rate 3e-5 \ - --num_train_epochs 50 \ - --max_seq_length 384 \ - --doc_stride 128 \ - --report_to wandb - + --model_name_or_path bert-base-uncased \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size 12 \ + --learning_rate 3e-5 \ + --num_train_epochs 50 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --report_to wandb \ + --output_dir /tmp/bert_qa/ .. note:: @@ -162,55 +160,52 @@ An End-to-End Example Below we show an `example `_ for fine-tuning a BERT model on a question-answering task with HuggingFace. .. code-block:: yaml - :emphasize-lines: 13-16,42-45 + :emphasize-lines: 8-11,41-44 # bert_qa.yaml name: bert-qa resources: accelerators: V100:1 - use_spot: true - - # Assume your working directory is under `~/transformers`. - # To make this example work, please run the following command: - # git clone https://github.com/huggingface/transformers.git ~/transformers -b v4.30.1 - workdir: ~/transformers + use_spot: true # Use spot instances to save cost. file_mounts: /checkpoint: name: # NOTE: Fill in your bucket name mode: MOUNT - setup: | + envs: # Fill in your wandb key: copy from https://wandb.ai/authorize # Alternatively, you can use `--env WANDB_API_KEY=$WANDB_API_KEY` # to pass the key in the command line, during `sky jobs launch`. - echo export WANDB_API_KEY=[YOUR-WANDB-API-KEY] >> ~/.bashrc + WANDB_API_KEY: + + # Assume your working directory is under `~/transformers`. + workdir: ~/transformers + setup: | pip install -e . cd examples/pytorch/question-answering/ - pip install -r requirements.txt + pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 pip install wandb run: | - cd ./examples/pytorch/question-answering/ + cd examples/pytorch/question-answering/ python run_qa.py \ - --model_name_or_path bert-base-uncased \ - --dataset_name squad \ - --do_train \ - --do_eval \ - --per_device_train_batch_size 12 \ - --learning_rate 3e-5 \ - --num_train_epochs 50 \ - --max_seq_length 384 \ - --doc_stride 128 \ - --report_to wandb \ - --run_name $SKYPILOT_TASK_ID \ - --output_dir /checkpoint/bert_qa/ \ - --save_total_limit 10 \ - --save_steps 1000 - - + --model_name_or_path bert-base-uncased \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size 12 \ + --learning_rate 3e-5 \ + --num_train_epochs 50 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --report_to wandb \ + --output_dir /checkpoint/bert_qa/ \ + --run_name $SKYPILOT_TASK_ID \ + --save_total_limit 10 \ + --save_steps 1000 As HuggingFace has built-in support for periodically checkpointing, we only need to pass the highlighted arguments for setting up the output directory and frequency of checkpointing (see more diff --git a/docs/source/running-jobs/environment-variables.rst b/docs/source/running-jobs/environment-variables.rst index f7138af95fa..d88424359d1 100644 --- a/docs/source/running-jobs/environment-variables.rst +++ b/docs/source/running-jobs/environment-variables.rst @@ -16,7 +16,7 @@ User-specified environment variables User-specified environment variables are useful for passing secrets and any arguments or configurations needed for your tasks. They are made available in ``file_mounts``, ``setup``, and ``run``. -You can specify environment variables to be made available to a task in two ways: +You can specify environment variables to be made available to a task in several ways: - ``envs`` field (dict) in a :ref:`task YAML `: @@ -24,7 +24,18 @@ You can specify environment variables to be made available to a task in two ways envs: MYVAR: val - + + +- ``--env-file`` flag in ``sky launch/exec`` :ref:`CLI `, which is a path to a `dotenv` file (takes precedence over the above): + + .. code-block:: text + + # sky launch example.yaml --env-file my_app.env + # cat my_app.env + MYVAR=val + WANDB_API_KEY=MY_WANDB_API_KEY + HF_TOKEN=MY_HF_TOKEN + - ``--env`` flag in ``sky launch/exec`` :ref:`CLI ` (takes precedence over the above) .. tip:: @@ -145,9 +156,9 @@ Environment variables for ``setup`` - 0 * - ``SKYPILOT_SETUP_NODE_IPS`` - A string of IP addresses of the nodes in the cluster with the same order as the node ranks, where each line contains one IP address. - + Note that this is not necessarily the same as the nodes in ``run`` stage: the ``setup`` stage runs on all nodes of the cluster, while the ``run`` stage can run on a subset of nodes. - - + - .. code-block:: text 1.2.3.4 @@ -158,19 +169,19 @@ Environment variables for ``setup`` - 2 * - ``SKYPILOT_TASK_ID`` - A unique ID assigned to each task. - - This environment variable is available only when the task is submitted + + This environment variable is available only when the task is submitted with :code:`sky launch --detach-setup`, or run as a managed spot job. - + Refer to the description in the :ref:`environment variables for run `. - sky-2023-07-06-21-18-31-563597_myclus_1 - + For managed spot jobs: sky-managed-2023-07-06-21-18-31-563597_my-job-name_1-0 * - ``SKYPILOT_CLUSTER_INFO`` - A JSON string containing information about the cluster. To access the information, you could parse the JSON string in bash ``echo $SKYPILOT_CLUSTER_INFO | jq .cloud`` or in Python : .. code-block:: python - + import json json.loads( os.environ['SKYPILOT_CLUSTER_INFO'] @@ -200,7 +211,7 @@ Environment variables for ``run`` - 0 * - ``SKYPILOT_NODE_IPS`` - A string of IP addresses of the nodes reserved to execute the task, where each line contains one IP address. Read more :ref:`here `. - - + - .. code-block:: text 1.2.3.4 @@ -221,13 +232,13 @@ Environment variables for ``run`` If a task is run as a :ref:`managed spot job `, then all recoveries of that job will have the same ID value. The ID is in the format "sky-managed-_(_)_-", where ```` will appear when a pipeline is used, i.e., more than one task in a managed spot job. Read more :ref:`here `. - sky-2023-07-06-21-18-31-563597_myclus_1 - + For managed spot jobs: sky-managed-2023-07-06-21-18-31-563597_my-job-name_1-0 * - ``SKYPILOT_CLUSTER_INFO`` - A JSON string containing information about the cluster. To access the information, you could parse the JSON string in bash ``echo $SKYPILOT_CLUSTER_INFO | jq .cloud`` or in Python : .. code-block:: python - + import json json.loads( os.environ['SKYPILOT_CLUSTER_INFO'] diff --git a/examples/oci/serve-http-cpu.yaml b/examples/oci/serve-http-cpu.yaml new file mode 100644 index 00000000000..68e3d18c9e5 --- /dev/null +++ b/examples/oci/serve-http-cpu.yaml @@ -0,0 +1,11 @@ +service: + readiness_probe: / + replicas: 2 + +resources: + cloud: oci + region: us-sanjose-1 + ports: 8080 + cpus: 2+ + +run: python -m http.server 8080 diff --git a/examples/oci/serve-qwen-7b.yaml b/examples/oci/serve-qwen-7b.yaml new file mode 100644 index 00000000000..799e5a7d891 --- /dev/null +++ b/examples/oci/serve-qwen-7b.yaml @@ -0,0 +1,25 @@ +# service.yaml +service: + readiness_probe: /v1/models + replicas: 2 + +# Fields below describe each replica. +resources: + cloud: oci + region: us-sanjose-1 + ports: 8080 + accelerators: {A10:1} + +setup: | + conda create -n vllm python=3.12 -y + conda activate vllm + pip install vllm + pip install vllm-flash-attn + +run: | + conda activate vllm + python -u -m vllm.entrypoints.openai.api_server \ + --host 0.0.0.0 --port 8080 \ + --model Qwen/Qwen2-7B-Instruct \ + --served-model-name Qwen2-7B-Instruct \ + --device=cuda --dtype auto --max-model-len=2048 diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 9d797609571..8daeedc6a96 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -100,6 +100,10 @@ CLUSTER_STATUS_LOCK_PATH = os.path.expanduser('~/.sky/.{}.lock') CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS = 20 +# Time that must elapse since the last status check before we should re-check if +# the cluster has been terminated or autostopped. +_CLUSTER_STATUS_CACHE_DURATION_SECONDS = 2 + # Filelocks for updating cluster's file_mounts. CLUSTER_FILE_MOUNTS_LOCK_PATH = os.path.expanduser( '~/.sky/.{}_file_mounts.lock') @@ -1669,11 +1673,27 @@ def check_can_clone_disk_and_override_task( def _update_cluster_status_no_lock( cluster_name: str) -> Optional[Dict[str, Any]]: - """Updates the status of the cluster. + """Update the cluster status. + + The cluster status is updated by checking ray cluster and real status from + cloud. + + The function will update the cached cluster status in the global state. For + the design of the cluster status and transition, please refer to the + sky/design_docs/cluster_status.md + + Returns: + If the cluster is terminated or does not exist, return None. Otherwise + returns the input record with status and handle potentially updated. Raises: + exceptions.ClusterOwnerIdentityMismatchError: if the current user is + not the same as the user who created the cluster. + exceptions.CloudUserIdentityError: if we fail to get the current user + identity. exceptions.ClusterStatusFetchingError: the cluster status cannot be - fetched from the cloud provider. + fetched from the cloud provider or there are leaked nodes causing + the node number larger than expected. """ record = global_user_state.get_cluster_from_name(cluster_name) if record is None: @@ -1893,52 +1913,22 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: return global_user_state.get_cluster_from_name(cluster_name) -def _update_cluster_status( - cluster_name: str, - acquire_per_cluster_status_lock: bool, - cluster_status_lock_timeout: int = CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS -) -> Optional[Dict[str, Any]]: - """Update the cluster status. +def _must_refresh_cluster_status( + record: Dict[str, Any], + force_refresh_statuses: Optional[Set[status_lib.ClusterStatus]] +) -> bool: + force_refresh_for_cluster = (force_refresh_statuses is not None and + record['status'] in force_refresh_statuses) - The cluster status is updated by checking ray cluster and real status from - cloud. + use_spot = record['handle'].launched_resources.use_spot + has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and + record['autostop'] >= 0) + recently_refreshed = (record['status_updated_at'] is not None and + time.time() - record['status_updated_at'] < + _CLUSTER_STATUS_CACHE_DURATION_SECONDS) + is_stale = (use_spot or has_autostop) and not recently_refreshed - The function will update the cached cluster status in the global state. For - the design of the cluster status and transition, please refer to the - sky/design_docs/cluster_status.md - - Args: - cluster_name: The name of the cluster. - acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock - before updating the status. - cluster_status_lock_timeout: The timeout to acquire the per-cluster - lock. - - Returns: - If the cluster is terminated or does not exist, return None. Otherwise - returns the input record with status and handle potentially updated. - - Raises: - exceptions.ClusterOwnerIdentityMismatchError: if the current user is - not the same as the user who created the cluster. - exceptions.CloudUserIdentityError: if we fail to get the current user - identity. - exceptions.ClusterStatusFetchingError: the cluster status cannot be - fetched from the cloud provider or there are leaked nodes causing - the node number larger than expected. - """ - if not acquire_per_cluster_status_lock: - return _update_cluster_status_no_lock(cluster_name) - - try: - with filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name), - timeout=cluster_status_lock_timeout): - return _update_cluster_status_no_lock(cluster_name) - except filelock.Timeout: - logger.debug('Refreshing status: Failed get the lock for cluster ' - f'{cluster_name!r}. Using the cached status.') - record = global_user_state.get_cluster_from_name(cluster_name) - return record + return force_refresh_for_cluster or is_stale def refresh_cluster_record( @@ -1956,16 +1946,22 @@ def refresh_cluster_record( Args: cluster_name: The name of the cluster. - force_refresh_statuses: if specified, refresh the cluster if it has one of - the specified statuses. Additionally, clusters satisfying the - following conditions will always be refreshed no matter the - argument is specified or not: - 1. is a spot cluster, or - 2. is a non-spot cluster, is not STOPPED, and autostop is set. + force_refresh_statuses: if specified, refresh the cluster if it has one + of the specified statuses. Additionally, clusters satisfying the + following conditions will be refreshed no matter the argument is + specified or not: + - the most latest available status update is more than + _CLUSTER_STATUS_CACHE_DURATION_SECONDS old, and one of: + 1. the cluster is a spot cluster, or + 2. cluster autostop is set and the cluster is not STOPPED. acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock - before updating the status. + before updating the status. Even if this is True, the lock may not be + acquired if the status does not need to be refreshed. cluster_status_lock_timeout: The timeout to acquire the per-cluster - lock. If timeout, the function will use the cached status. + lock. If timeout, the function will use the cached status. If the + value is <0, do not timeout (wait for the lock indefinitely). By + default, this is set to CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS. Warning: + if correctness is required, you must set this to -1. Returns: If the cluster is terminated or does not exist, return None. @@ -1986,19 +1982,58 @@ def refresh_cluster_record( return None check_owner_identity(cluster_name) - handle = record['handle'] - if isinstance(handle, backends.CloudVmRayResourceHandle): - use_spot = handle.launched_resources.use_spot - has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and - record['autostop'] >= 0) - force_refresh_for_cluster = (force_refresh_statuses is not None and - record['status'] in force_refresh_statuses) - if force_refresh_for_cluster or has_autostop or use_spot: - record = _update_cluster_status( - cluster_name, - acquire_per_cluster_status_lock=acquire_per_cluster_status_lock, - cluster_status_lock_timeout=cluster_status_lock_timeout) - return record + if not isinstance(record['handle'], backends.CloudVmRayResourceHandle): + return record + + # The loop logic allows us to notice if the status was updated in the + # global_user_state by another process and stop trying to get the lock. + # The core loop logic is adapted from FileLock's implementation. + lock = filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) + start_time = time.perf_counter() + + # Loop until we have an up-to-date status or until we acquire the lock. + while True: + # Check to see if we can return the cached status. + if not _must_refresh_cluster_status(record, force_refresh_statuses): + return record + + if not acquire_per_cluster_status_lock: + return _update_cluster_status_no_lock(cluster_name) + + # Try to acquire the lock so we can fetch the status. + try: + with lock.acquire(blocking=False): + # Lock acquired. + + # Check the cluster status again, since it could have been + # updated between our last check and acquiring the lock. + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None or not _must_refresh_cluster_status( + record, force_refresh_statuses): + return record + + # Update and return the cluster status. + return _update_cluster_status_no_lock(cluster_name) + except filelock.Timeout: + # lock.acquire() will throw a Timeout exception if the lock is not + # available and we have blocking=False. + pass + + # Logic adapted from FileLock.acquire(). + # If cluster_status_lock_time is <0, we will never hit this. No timeout. + # Otherwise, if we have timed out, return the cached status. This has + # the potential to cause correctness issues, but if so it is the + # caller's responsibility to set the timeout to -1. + if 0 <= cluster_status_lock_timeout < time.perf_counter() - start_time: + logger.debug('Refreshing status: Failed get the lock for cluster ' + f'{cluster_name!r}. Using the cached status.') + return record + time.sleep(0.05) + + # Refresh for next loop iteration. + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None: + return None @timeline.event @@ -2604,15 +2639,18 @@ def check_stale_runtime_on_remote(returncode: int, stderr: str, pattern = re.compile(r'AttributeError: module \'sky\.(.*)\' has no ' r'attribute \'(.*)\'') if returncode != 0: + # TODO(zhwu): Backward compatibility for old SkyPilot runtime version on + # the remote cluster. Remove this after 0.10.0 is released. attribute_error = re.findall(pattern, stderr) - if attribute_error: + if attribute_error or 'SkyPilot runtime is too old' in stderr: with ux_utils.print_exception_no_traceback(): raise RuntimeError( f'{colorama.Fore.RED}SkyPilot runtime needs to be updated ' - 'on the remote cluster. To update, run (existing jobs are ' - f'not interrupted): {colorama.Style.BRIGHT}sky start -f -y ' + f'on the remote cluster: {cluster_name}. To update, run ' + '(existing jobs will not be interrupted): ' + f'{colorama.Style.BRIGHT}sky start -f -y ' f'{cluster_name}{colorama.Style.RESET_ALL}' - f'\n--- Details ---\n{stderr.strip()}\n') + f'\n--- Details ---\n{stderr.strip()}\n') from None def get_endpoints(cluster: str, diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index e338eecb744..d00560ece23 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -276,6 +276,7 @@ def add_prologue(self, job_id: int) -> None: from sky.skylet import constants from sky.skylet import job_lib from sky.utils import log_utils + from sky.utils import subprocess_utils SKY_REMOTE_WORKDIR = {constants.SKY_REMOTE_WORKDIR!r} @@ -3275,14 +3276,13 @@ def _exec_code_on_head( encoded_script = shlex.quote(codegen) create_script_code = (f'{{ echo {encoded_script} > {script_path}; }}') job_submit_cmd = ( - f'RAY_DASHBOARD_PORT=$({constants.SKY_PYTHON_CMD} -c "from sky.skylet import job_lib; print(job_lib.get_job_submission_port())" 2> /dev/null || echo 8265);' # pylint: disable=line-too-long - f'{cd} && {constants.SKY_RAY_CMD} job submit ' - '--address=http://127.0.0.1:$RAY_DASHBOARD_PORT ' - f'--submission-id {job_id}-$(whoami) --no-wait ' - f'"{constants.SKY_PYTHON_CMD} -u {script_path} ' + # JOB_CMD_IDENTIFIER is used for identifying the process retrieved + # with pid is the same driver process. + f'{job_lib.JOB_CMD_IDENTIFIER.format(job_id)} && ' + f'{cd} && {constants.SKY_PYTHON_CMD} -u {script_path}' # Do not use &>, which is not POSIX and may not work. # Note that the order of ">filename 2>&1" matters. - f'> {remote_log_path} 2>&1"') + f'> {remote_log_path} 2>&1') code = job_lib.JobLibCodeGen.queue_job(job_id, job_submit_cmd) job_submit_cmd = ' && '.join([mkdir_code, create_script_code, code]) @@ -3330,6 +3330,10 @@ def _dump_code_to_file(codegen: str) -> None: job_submit_cmd, stream_logs=False, require_outputs=True) + # Happens when someone calls `sky exec` but remote is outdated for + # running a job. Necessitating calling `sky launch`. + backend_utils.check_stale_runtime_on_remote(returncode, stderr, + handle.cluster_name) if returncode == 255 and 'too long' in stdout + stderr: # If the generated script is too long, we retry it with dumping # the script to a file and running it with SSH. We use a general @@ -3344,10 +3348,6 @@ def _dump_code_to_file(codegen: str) -> None: stream_logs=False, require_outputs=True) - # Happens when someone calls `sky exec` but remote is outdated - # necessitating calling `sky launch`. - backend_utils.check_stale_runtime_on_remote(returncode, stdout, - handle.cluster_name) subprocess_utils.handle_returncode(returncode, job_submit_cmd, f'Failed to submit job {job_id}.', @@ -3417,6 +3417,10 @@ def _add_job(self, handle: CloudVmRayResourceHandle, stream_logs=False, require_outputs=True, separate_stderr=True) + # Happens when someone calls `sky exec` but remote is outdated for + # adding a job. Necessitating calling `sky launch`. + backend_utils.check_stale_runtime_on_remote(returncode, stderr, + handle.cluster_name) # TODO(zhwu): this sometimes will unexpectedly fail, we can add # retry for this, after we figure out the reason. subprocess_utils.handle_returncode(returncode, code, @@ -3554,7 +3558,7 @@ def _teardown(self, backend_utils.CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) try: - with filelock.FileLock( + with timeline.FileLockEvent( lock_path, backend_utils.CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS): self.teardown_no_lock( diff --git a/sky/cli.py b/sky/cli.py index 490749d1231..c49b692add1 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -3699,13 +3699,24 @@ def jobs_launch( dag_utils.maybe_infer_and_fill_dag_and_task_names(dag) dag_utils.fill_default_config_in_dag_for_job_launch(dag) - click.secho(f'Managed job {dag.name!r} will be launched on (estimated):', - fg='cyan') dag, _ = admin_policy_utils.apply( dag, use_mutated_config_in_current_request=False) - dag = sky.optimize(dag) - if not yes: + if yes: + # Skip resource preview if -y is set, since we are probably running in + # a script and the user won't have a chance to review it anyway. + # This can save a couple of seconds. + click.secho( + f'Resources for managed job {dag.name!r} will be computed on the ' + 'managed jobs controller, since --yes is set.', + fg='cyan') + + else: + click.secho( + f'Managed job {dag.name!r} will be launched on (estimated):', + fg='cyan') + dag = sky.optimize(dag) + prompt = f'Launching a managed job {dag.name!r}. Proceed?' if prompt is not None: click.confirm(prompt, default=True, abort=True, show_default=True) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 4a9f2d63f35..22e1039f121 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -663,6 +663,7 @@ def _is_access_key_of_type(type_str: str) -> bool: return AWSIdentityType.SHARED_CREDENTIALS_FILE @classmethod + @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. def get_user_identities(cls) -> Optional[List[List[str]]]: """Returns a [UserId, Account] list that uniquely identifies the user. diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index 0ebf44b4d0b..37806ff8349 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -75,8 +75,6 @@ def _unsupported_features_for_resources( (f'Docker image is currently not supported on {cls._REPR}. ' 'You can try running docker command inside the ' '`run` section in task.yaml.'), - clouds.CloudImplementationFeatures.OPEN_PORTS: - (f'Opening ports is currently not supported on {cls._REPR}.'), } if resources.use_spot: features[clouds.CloudImplementationFeatures.STOP] = ( diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 918a4070414..bbd48863755 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -20,6 +20,7 @@ from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import rich_utils +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -100,6 +101,7 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']: return az_mappings +@timeline.event def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame': """Maps zone IDs (use1-az1) to zone names (us-east-1x). diff --git a/sky/clouds/utils/oci_utils.py b/sky/clouds/utils/oci_utils.py index 9398dff1363..0cd4f33e647 100644 --- a/sky/clouds/utils/oci_utils.py +++ b/sky/clouds/utils/oci_utils.py @@ -4,6 +4,8 @@ - Zhanghao Wu @ Oct 2023: Formatting and refactoring - Hysun He (hysun.he@oracle.com) @ Oct, 2024: Add default image OS configuration. + - Hysun He (hysun.he@oracle.com) @ Nov.12, 2024: Add the constant + SERVICE_PORT_RULE_TAG """ import os @@ -42,6 +44,9 @@ class OCIConfig: VCN_CIDR_INTERNET = '0.0.0.0/0' VCN_CIDR = '192.168.0.0/16' VCN_SUBNET_CIDR = '192.168.0.0/18' + SERVICE_PORT_RULE_TAG = 'SkyServe-Service-Port' + # NSG name template + NSG_NAME_TEMPLATE = 'nsg_{cluster_name}' MAX_RETRY_COUNT = 3 RETRY_INTERVAL_BASE_SECONDS = 5 diff --git a/sky/execution.py b/sky/execution.py index df3cdd5efdb..350a482a418 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -11,10 +11,10 @@ from sky import admin_policy from sky import backends from sky import clouds -from sky import exceptions from sky import global_user_state from sky import optimizer from sky import sky_logging +from sky import status_lib from sky.backends import backend_utils from sky.usage import usage_lib from sky.utils import admin_policy_utils @@ -267,6 +267,12 @@ def _execute( # no-credential machine should not enter optimize(), which # would directly error out ('No cloud is enabled...'). Fix # by moving `sky check` checks out of optimize()? + + controller = controller_utils.Controllers.from_name( + cluster_name) + if controller is not None: + logger.info( + f'Choosing resources for {controller.name}...') dag = sky.optimize(dag, minimize=optimize_target) task = dag.tasks[0] # Keep: dag may have been deep-copied. assert task.best_resources is not None, task @@ -463,28 +469,43 @@ def launch( stages = None # Check if cluster exists and we are doing fast provisioning if fast and cluster_name is not None: - maybe_handle = global_user_state.get_handle_from_cluster_name( - cluster_name) - if maybe_handle is not None: - try: - # This will throw if the cluster is not available - backend_utils.check_cluster_available( + cluster_status, maybe_handle = ( + backend_utils.refresh_cluster_status_handle(cluster_name)) + if cluster_status == status_lib.ClusterStatus.INIT: + # If the cluster is INIT, it may be provisioning. We want to prevent + # concurrent calls from queueing up many sequential reprovision + # attempts. Since provisioning will hold the cluster status lock, we + # wait to hold that lock by force refreshing the status. This will + # block until the cluster finishes provisioning, then correctly see + # that it is UP. + # TODO(cooperc): If multiple processes launched in parallel see that + # the cluster is STOPPED or does not exist, they will still all try + # to provision it, since we do not hold the lock continuously from + # the status check until the provision call. Fixing this requires a + # bigger refactor. + cluster_status, maybe_handle = ( + backend_utils.refresh_cluster_status_handle( cluster_name, - operation='executing tasks', - check_cloud_vm_ray_backend=False, - dryrun=dryrun) - handle = maybe_handle - # Get all stages - stages = [ - Stage.SYNC_WORKDIR, - Stage.SYNC_FILE_MOUNTS, - Stage.PRE_EXEC, - Stage.EXEC, - Stage.DOWN, - ] - except exceptions.ClusterNotUpError: - # Proceed with normal provisioning - pass + force_refresh_statuses=[ + # If the cluster is INIT, we want to try to grab the + # status lock, which should block until provisioning is + # finished. + status_lib.ClusterStatus.INIT, + ], + # Wait indefinitely to obtain the lock, so that we don't + # have multiple processes launching the same cluster at + # once. + cluster_status_lock_timeout=-1, + )) + if cluster_status == status_lib.ClusterStatus.UP: + handle = maybe_handle + stages = [ + Stage.SYNC_WORKDIR, + Stage.SYNC_FILE_MOUNTS, + Stage.PRE_EXEC, + Stage.EXEC, + Stage.DOWN, + ] return _execute( entrypoint=entrypoint, diff --git a/sky/global_user_state.py b/sky/global_user_state.py index 7c040ea55fc..e9f15df4f52 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -60,7 +60,8 @@ def create_table(cursor, conn): owner TEXT DEFAULT null, cluster_hash TEXT DEFAULT null, storage_mounts_metadata BLOB DEFAULT null, - cluster_ever_up INTEGER DEFAULT 0)""") + cluster_ever_up INTEGER DEFAULT 0, + status_updated_at INTEGER DEFAULT null)""") # Table for Cluster History # usage_intervals: List[Tuple[int, int]] @@ -130,6 +131,10 @@ def create_table(cursor, conn): # clusters were never really UP, setting it to 1 means they won't be # auto-deleted during any failover. value_to_replace_existing_entries=1) + + db_utils.add_column_to_table(cursor, conn, 'clusters', 'status_updated_at', + 'INTEGER DEFAULT null') + conn.commit() @@ -159,6 +164,7 @@ def add_or_update_cluster(cluster_name: str, status = status_lib.ClusterStatus.INIT if ready: status = status_lib.ClusterStatus.UP + status_updated_at = int(time.time()) # TODO (sumanth): Cluster history table will have multiple entries # when the cluster failover through multiple regions (one entry per region). @@ -191,7 +197,7 @@ def add_or_update_cluster(cluster_name: str, # specified. '(name, launched_at, handle, last_use, status, ' 'autostop, to_down, metadata, owner, cluster_hash, ' - 'storage_mounts_metadata, cluster_ever_up) ' + 'storage_mounts_metadata, cluster_ever_up, status_updated_at) ' 'VALUES (' # name '?, ' @@ -228,7 +234,9 @@ def add_or_update_cluster(cluster_name: str, 'COALESCE(' '(SELECT storage_mounts_metadata FROM clusters WHERE name=?), null), ' # cluster_ever_up - '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?)' + '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?),' + # status_updated_at + '?' ')', ( # name @@ -260,6 +268,8 @@ def add_or_update_cluster(cluster_name: str, # cluster_ever_up cluster_name, int(ready), + # status_updated_at + status_updated_at, )) launched_nodes = getattr(cluster_handle, 'launched_nodes', None) @@ -330,11 +340,13 @@ def remove_cluster(cluster_name: str, terminate: bool) -> None: # stopped VM, which leads to timeout. if hasattr(handle, 'stable_internal_external_ips'): handle.stable_internal_external_ips = None + current_time = int(time.time()) _DB.cursor.execute( - 'UPDATE clusters SET handle=(?), status=(?) ' - 'WHERE name=(?)', ( + 'UPDATE clusters SET handle=(?), status=(?), ' + 'status_updated_at=(?) WHERE name=(?)', ( pickle.dumps(handle), status_lib.ClusterStatus.STOPPED.value, + current_time, cluster_name, )) _DB.conn.commit() @@ -359,10 +371,10 @@ def get_glob_cluster_names(cluster_name: str) -> List[str]: def set_cluster_status(cluster_name: str, status: status_lib.ClusterStatus) -> None: - _DB.cursor.execute('UPDATE clusters SET status=(?) WHERE name=(?)', ( - status.value, - cluster_name, - )) + current_time = int(time.time()) + _DB.cursor.execute( + 'UPDATE clusters SET status=(?), status_updated_at=(?) WHERE name=(?)', + (status.value, current_time, cluster_name)) count = _DB.cursor.rowcount _DB.conn.commit() assert count <= 1, count @@ -570,15 +582,18 @@ def _load_storage_mounts_metadata( def get_cluster_from_name( cluster_name: Optional[str]) -> Optional[Dict[str, Any]]: - rows = _DB.cursor.execute('SELECT * FROM clusters WHERE name=(?)', - (cluster_name,)).fetchall() + rows = _DB.cursor.execute( + 'SELECT name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, status_updated_at FROM clusters WHERE name=(?)', + (cluster_name,)).fetchall() for row in rows: # Explicitly specify the number of fields to unpack, so that # we can add new fields to the database in the future without # breaking the previous code. (name, launched_at, handle, last_use, status, autostop, metadata, - to_down, owner, cluster_hash, storage_mounts_metadata, - cluster_ever_up) = row[:12] + to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, + status_updated_at) = row[:13] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -594,6 +609,7 @@ def get_cluster_from_name( 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'status_updated_at': status_updated_at, } return record return None @@ -601,12 +617,15 @@ def get_cluster_from_name( def get_clusters() -> List[Dict[str, Any]]: rows = _DB.cursor.execute( - 'select * from clusters order by launched_at desc').fetchall() + 'select name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, status_updated_at from clusters ' + 'order by launched_at desc').fetchall() records = [] for row in rows: (name, launched_at, handle, last_use, status, autostop, metadata, - to_down, owner, cluster_hash, storage_mounts_metadata, - cluster_ever_up) = row[:12] + to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, + status_updated_at) = row[:13] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -622,6 +641,7 @@ def get_clusters() -> List[Dict[str, Any]]: 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'status_updated_at': status_updated_at, } records.append(record) diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 5bf3da2d023..f11a556f2d4 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -133,7 +133,6 @@ def launch( controller_task.set_resources(controller_resources) controller_task.managed_job_dag = dag - assert len(controller_task.resources) == 1, controller_task sky_logging.print( f'{colorama.Fore.YELLOW}' diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 896740f6ed6..f82e1132678 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -85,7 +85,8 @@ def get_job_status(backend: 'backends.CloudVmRayBackend', cluster_name: str) -> Optional['job_lib.JobStatus']: """Check the status of the job running on a managed job cluster. - It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_SETUP or CANCELLED. + It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER, + FAILED_SETUP or CANCELLED. """ handle = global_user_state.get_handle_from_cluster_name(cluster_name) assert isinstance(handle, backends.CloudVmRayResourceHandle), handle @@ -866,7 +867,7 @@ def stream_logs(cls, code += inspect.getsource(stream_logs) code += textwrap.dedent(f"""\ - msg = stream_logs({job_id!r}, {job_name!r}, + msg = stream_logs({job_id!r}, {job_name!r}, follow={follow}, controller={controller}) print(msg, flush=True) """) @@ -883,7 +884,7 @@ def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag') -> str: resources_str = backend_utils.get_task_resources_str( task, is_managed_job=True) code += textwrap.dedent(f"""\ - managed_job_state.set_pending({job_id}, {task_id}, + managed_job_state.set_pending({job_id}, {task_id}, {task.name!r}, {resources_str!r}) """) return cls._build(code) diff --git a/sky/provision/oci/instance.py b/sky/provision/oci/instance.py index e909c9d8fdc..811d27d0e21 100644 --- a/sky/provision/oci/instance.py +++ b/sky/provision/oci/instance.py @@ -2,6 +2,8 @@ History: - Hysun He (hysun.he@oracle.com) @ Oct.16, 2024: Initial implementation + - Hysun He (hysun.he@oracle.com) @ Nov.13, 2024: Implement open_ports + and cleanup_ports for supporting SkyServe. """ import copy @@ -292,11 +294,11 @@ def open_ports( provider_config: Optional[Dict[str, Any]] = None, ) -> None: """Open ports for inbound traffic.""" - # OCI ports in security groups are opened while creating the new - # VCN (skypilot_vcn). If user configure to use existing VCN, it is - # intended to let user to manage the ports instead of automatically - # opening ports here. - del cluster_name_on_cloud, ports, provider_config + assert provider_config is not None, cluster_name_on_cloud + region = provider_config['region'] + query_helper.create_nsg_rules(region=region, + cluster_name=cluster_name_on_cloud, + ports=ports) @query_utils.debug_enabled(logger) @@ -306,12 +308,11 @@ def cleanup_ports( provider_config: Optional[Dict[str, Any]] = None, ) -> None: """Delete any opened ports.""" - del cluster_name_on_cloud, ports, provider_config - # OCI ports in security groups are opened while creating the new - # VCN (skypilot_vcn). The VCN will only be created at the first - # time when it is not existed. We'll not automatically delete the - # VCN while teardown clusters. it is intended to let user to decide - # to delete the VCN or not from OCI console, for example. + assert provider_config is not None, cluster_name_on_cloud + region = provider_config['region'] + del ports + query_helper.remove_cluster_nsg(region=region, + cluster_name=cluster_name_on_cloud) @query_utils.debug_enabled(logger) diff --git a/sky/provision/oci/query_utils.py b/sky/provision/oci/query_utils.py index 2fbbaf49853..47a0438cb21 100644 --- a/sky/provision/oci/query_utils.py +++ b/sky/provision/oci/query_utils.py @@ -5,6 +5,8 @@ migrated from the old provisioning API. - Hysun He (hysun.he@oracle.com) @ Oct.18, 2024: Enhancement. find_compartment: allow search subtree when find a compartment. + - Hysun He (hysun.he@oracle.com) @ Nov.12, 2024: Add methods to + Add/remove security rules: create_nsg_rules & remove_nsg """ from datetime import datetime import functools @@ -13,12 +15,15 @@ import time import traceback import typing -from typing import Optional +from typing import List, Optional, Tuple +from sky import exceptions from sky import sky_logging from sky.adaptors import common as adaptors_common from sky.adaptors import oci as oci_adaptor from sky.clouds.utils import oci_utils +from sky.provision import constants +from sky.utils import resources_utils if typing.TYPE_CHECKING: import pandas as pd @@ -81,19 +86,33 @@ def query_instances_by_tags(cls, tag_filters, region): return result_set @classmethod + @debug_enabled(logger) def terminate_instances_by_tags(cls, tag_filters, region) -> int: logger.debug(f'Terminate instance by tags: {tag_filters}') + + cluster_name = tag_filters[constants.TAG_RAY_CLUSTER_NAME] + nsg_name = oci_utils.oci_config.NSG_NAME_TEMPLATE.format( + cluster_name=cluster_name) + nsg_id = cls.find_nsg(region, nsg_name, create_if_not_exist=False) + + core_client = oci_adaptor.get_core_client( + region, oci_utils.oci_config.get_profile()) + insts = cls.query_instances_by_tags(tag_filters, region) fail_count = 0 for inst in insts: inst_id = inst.identifier - logger.debug(f'Got instance(to be terminated): {inst_id}') + logger.debug(f'Terminating instance {inst_id}') try: - oci_adaptor.get_core_client( - region, - oci_utils.oci_config.get_profile()).terminate_instance( - inst_id) + # Release the NSG reference so that the NSG can be + # deleted without waiting the instance being terminated. + if nsg_id is not None: + cls.detach_nsg(region, inst, nsg_id) + + # Terminate the instance + core_client.terminate_instance(inst_id) + except oci_adaptor.oci.exceptions.ServiceError as e: fail_count += 1 logger.error(f'Terminate instance failed: {str(e)}\n: {inst}') @@ -468,5 +487,192 @@ def delete_vcn(cls, net_client, skypilot_vcn, skypilot_subnet, logger.error( f'Delete VCN {oci_utils.oci_config.VCN_NAME} Error: {str(e)}') + @classmethod + @debug_enabled(logger) + def find_nsg(cls, region: str, nsg_name: str, + create_if_not_exist: bool) -> Optional[str]: + net_client = oci_adaptor.get_net_client( + region, oci_utils.oci_config.get_profile()) + + compartment = cls.find_compartment(region) + + list_vcns_resp = net_client.list_vcns( + compartment_id=compartment, + display_name=oci_utils.oci_config.VCN_NAME, + lifecycle_state='AVAILABLE', + ) + + if not list_vcns_resp: + raise exceptions.ResourcesUnavailableError( + 'The VCN is not available') + + # Get the primary vnic. + assert len(list_vcns_resp.data) > 0 + vcn = list_vcns_resp.data[0] + + list_nsg_resp = net_client.list_network_security_groups( + compartment_id=compartment, + vcn_id=vcn.id, + limit=1, + display_name=nsg_name, + ) + + nsgs = list_nsg_resp.data + if nsgs: + assert len(nsgs) == 1 + return nsgs[0].id + elif not create_if_not_exist: + return None + + # Continue to create new NSG if not exists + create_nsg_resp = net_client.create_network_security_group( + create_network_security_group_details=oci_adaptor.oci.core.models. + CreateNetworkSecurityGroupDetails( + compartment_id=compartment, + vcn_id=vcn.id, + display_name=nsg_name, + )) + get_nsg_resp = net_client.get_network_security_group( + network_security_group_id=create_nsg_resp.data.id) + oci_adaptor.oci.wait_until( + net_client, + get_nsg_resp, + 'lifecycle_state', + 'AVAILABLE', + ) + + return get_nsg_resp.data.id + + @classmethod + def get_range_min_max(cls, port_range: str) -> Tuple[int, int]: + range_list = port_range.split('-') + if len(range_list) == 1: + return (int(range_list[0]), int(range_list[0])) + from_port, to_port = range_list + return (int(from_port), int(to_port)) + + @classmethod + @debug_enabled(logger) + def create_nsg_rules(cls, region: str, cluster_name: str, + ports: List[str]) -> None: + """ Create per-cluster NSG with ingress rules """ + if not ports: + return + + net_client = oci_adaptor.get_net_client( + region, oci_utils.oci_config.get_profile()) + + nsg_name = oci_utils.oci_config.NSG_NAME_TEMPLATE.format( + cluster_name=cluster_name) + nsg_id = cls.find_nsg(region, nsg_name, create_if_not_exist=True) + + filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name} + insts = query_helper.query_instances_by_tags(filters, region) + for inst in insts: + vnic = cls.get_instance_primary_vnic( + region=region, + inst_info={ + 'inst_id': inst.identifier, + 'ad': inst.availability_domain, + 'compartment': inst.compartment_id, + }) + nsg_ids = vnic.nsg_ids + if not nsg_ids: + net_client.update_vnic( + vnic_id=vnic.id, + update_vnic_details=oci_adaptor.oci.core.models. + UpdateVnicDetails(nsg_ids=[nsg_id], + skip_source_dest_check=False), + ) + + # pylint: disable=line-too-long + list_nsg_rules_resp = net_client.list_network_security_group_security_rules( + network_security_group_id=nsg_id, + direction='INGRESS', + sort_by='TIMECREATED', + sort_order='DESC', + ) + + ingress_rules: List = list_nsg_rules_resp.data + existing_port_ranges: List[str] = [] + for r in ingress_rules: + if r.tcp_options: + options_range = r.tcp_options.destination_port_range + rule_port_range = f'{options_range.min}-{options_range.max}' + existing_port_ranges.append(rule_port_range) + + new_ports = resources_utils.port_ranges_to_set(ports) + existing_ports = resources_utils.port_ranges_to_set( + existing_port_ranges) + if new_ports.issubset(existing_ports): + # ports already contains in the existing rules, nothing to add. + return + + # Determine the ports to be added, without overlapping. + ports_to_open = new_ports - existing_ports + port_ranges_to_open = resources_utils.port_set_to_ranges(ports_to_open) + + new_rules = [] + for port_range in port_ranges_to_open: + port_range_min, port_range_max = cls.get_range_min_max(port_range) + new_rules.append( + oci_adaptor.oci.core.models.AddSecurityRuleDetails( + direction='INGRESS', + protocol='6', + is_stateless=False, + source=oci_utils.oci_config.VCN_CIDR_INTERNET, + source_type='CIDR_BLOCK', + tcp_options=oci_adaptor.oci.core.models.TcpOptions( + destination_port_range=oci_adaptor.oci.core.models. + PortRange(min=port_range_min, max=port_range_max),), + description=oci_utils.oci_config.SERVICE_PORT_RULE_TAG, + )) + + net_client.add_network_security_group_security_rules( + network_security_group_id=nsg_id, + add_network_security_group_security_rules_details=oci_adaptor.oci. + core.models.AddNetworkSecurityGroupSecurityRulesDetails( + security_rules=new_rules), + ) + + @classmethod + @debug_enabled(logger) + def detach_nsg(cls, region: str, inst, nsg_id: Optional[str]) -> None: + if nsg_id is None: + return + + vnic = cls.get_instance_primary_vnic( + region=region, + inst_info={ + 'inst_id': inst.identifier, + 'ad': inst.availability_domain, + 'compartment': inst.compartment_id, + }) + + # Detatch the NSG before removing it. + oci_adaptor.get_net_client(region, oci_utils.oci_config.get_profile( + )).update_vnic( + vnic_id=vnic.id, + update_vnic_details=oci_adaptor.oci.core.models.UpdateVnicDetails( + nsg_ids=[], skip_source_dest_check=False), + ) + + @classmethod + @debug_enabled(logger) + def remove_cluster_nsg(cls, region: str, cluster_name: str) -> None: + """ Remove NSG of the cluster """ + net_client = oci_adaptor.get_net_client( + region, oci_utils.oci_config.get_profile()) + + nsg_name = oci_utils.oci_config.NSG_NAME_TEMPLATE.format( + cluster_name=cluster_name) + nsg_id = cls.find_nsg(region, nsg_name, create_if_not_exist=False) + if nsg_id is None: + return + + # Delete the NSG + net_client.delete_network_security_group( + network_security_group_id=nsg_id) + query_helper = QueryHelper() diff --git a/sky/serve/core.py b/sky/serve/core.py index abf9bfbc719..f6f6c53ad7b 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -701,6 +701,7 @@ def tail_logs( with ux_utils.print_exception_no_traceback(): raise ValueError(f'`target` must be a string or ' f'sky.serve.ServiceComponent, got {type(target)}.') + if target == serve_utils.ServiceComponent.REPLICA: if replica_id is None: with ux_utils.print_exception_no_traceback(): diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 3be41cc1593..6ab932f278a 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -46,8 +46,14 @@ constants.CONTROLLER_MEMORY_USAGE_GB) _CONTROLLER_URL = 'http://localhost:{CONTROLLER_PORT}' -_SKYPILOT_PROVISION_LOG_PATTERN = r'.*tail -n100 -f (.*provision\.log).*' -_SKYPILOT_LOG_PATTERN = r'.*tail -n100 -f (.*\.log).*' +# NOTE(dev): We assume log paths are either in ~/sky_logs/... or ~/.sky/... +# and always appear after a space. Be careful when changing UX as this +# assumption is used to expand some log files while ignoring others. +_SKYPILOT_LOG_DIRS = r'~/(sky_logs|\.sky)' +_SKYPILOT_PROVISION_LOG_PATTERN = ( + fr'.* ({_SKYPILOT_LOG_DIRS}/.*provision\.log)') +_SKYPILOT_LOG_PATTERN = fr'.* ({_SKYPILOT_LOG_DIRS}/.*\.log)' + # TODO(tian): Find all existing replica id and print here. _FAILED_TO_FIND_REPLICA_MSG = ( f'{colorama.Fore.RED}Failed to find replica ' @@ -591,7 +597,7 @@ def get_latest_version_with_min_replicas( return active_versions[-1] if active_versions else None -def _follow_replica_logs( +def _follow_logs_with_provision_expanding( file: TextIO, cluster_name: str, *, @@ -599,7 +605,7 @@ def _follow_replica_logs( stop_on_eof: bool = False, idle_timeout_seconds: Optional[int] = None, ) -> Iterator[str]: - """Follows logs for a replica, handling nested log files. + """Follows logs and expands any provision.log references found. Args: file: Log file to read from. @@ -610,7 +616,7 @@ def _follow_replica_logs( new content. Yields: - Log lines from the main file and any nested log files. + Log lines, including expanded content from referenced provision logs. """ def cluster_is_up() -> bool: @@ -620,36 +626,35 @@ def cluster_is_up() -> bool: return cluster_record['status'] == status_lib.ClusterStatus.UP def process_line(line: str) -> Iterator[str]: - # Tailing detailed progress for user. All logs in skypilot is - # of format `To view detailed progress: tail -n100 -f *.log`. - # Check if the line is directing users to view logs + # The line might be directing users to view logs, like + # `✓ Cluster launched: new-http. View logs at: *.log` + # We should tail the detailed logs for user. provision_log_prompt = re.match(_SKYPILOT_PROVISION_LOG_PATTERN, line) - other_log_prompt = re.match(_SKYPILOT_LOG_PATTERN, line) + log_prompt = re.match(_SKYPILOT_LOG_PATTERN, line) if provision_log_prompt is not None: nested_log_path = os.path.expanduser(provision_log_prompt.group(1)) - with open(nested_log_path, 'r', newline='', encoding='utf-8') as f: - # We still exit if more than 10 seconds without new content - # to avoid any internal bug that causes the launch to fail - # while cluster status remains INIT. - # Originally, we output the next line first before printing - # the launching logs. Since the next line is always - # `Launching on ()`, we output it first - # to indicate the process is starting. - # TODO(andyl): After refactor #4323, the above logic is broken, - # but coincidentally with the new UX 3.0, the `Cluster launched` - # message is printed first, making the output appear correct. - # Explaining this since it's technically a breaking change - # for this refactor PR #4323. Will remove soon in a fix PR - # for adapting the serve.follow_logs to the new UX. - yield from _follow_replica_logs(f, - cluster_name, - should_stop=cluster_is_up, - stop_on_eof=stop_on_eof, - idle_timeout_seconds=10) + + try: + with open(nested_log_path, 'r', newline='', + encoding='utf-8') as f: + # We still exit if more than 10 seconds without new content + # to avoid any internal bug that causes the launch to fail + # while cluster status remains INIT. + yield from log_utils.follow_logs(f, + should_stop=cluster_is_up, + stop_on_eof=stop_on_eof, + idle_timeout_seconds=10) + except FileNotFoundError: + yield line + + yield (f'{colorama.Fore.YELLOW}{colorama.Style.BRIGHT}' + f'Try to expand log file {nested_log_path} but not ' + f'found. Skipping...{colorama.Style.RESET_ALL}') + pass return - if other_log_prompt is not None: + if log_prompt is not None: # Now we skip other logs (file sync logs) since we lack # utility to determine when these log files are finished # writing. @@ -702,7 +707,7 @@ def _get_replica_status() -> serve_state.ReplicaStatus: replica_provisioned = ( lambda: _get_replica_status() != serve_state.ReplicaStatus.PROVISIONING) with open(launch_log_file_name, 'r', newline='', encoding='utf-8') as f: - for line in _follow_replica_logs( + for line in _follow_logs_with_provision_expanding( f, replica_cluster_name, should_stop=replica_provisioned, diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 91476cf8f6f..77be8119758 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -75,7 +75,7 @@ # cluster yaml is updated. # # TODO(zongheng,zhanghao): make the upgrading of skylet automatic? -SKYLET_VERSION = '8' +SKYLET_VERSION = '9' # The version of the lib files that skylet/jobs use. Whenever there is an API # change for the job_lib or log_lib, we need to bump this version, so that the # user can be notified to update their SkyPilot version on the remote cluster. diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index ee7aee85f36..dfd8332b019 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -8,6 +8,7 @@ import os import pathlib import shlex +import signal import sqlite3 import subprocess import time @@ -27,6 +28,10 @@ _LINUX_NEW_LINE = '\n' _JOB_STATUS_LOCK = '~/.sky/locks/.job_{}.lock' +# JOB_CMD_IDENTIFIER is used for identifying the process retrieved +# with pid is the same driver process to guard against the case where +# the same pid is reused by a different process. +JOB_CMD_IDENTIFIER = 'echo "SKYPILOT_JOB_ID <{}>"' def _get_lock_path(job_id: int) -> str: @@ -46,6 +51,7 @@ class JobInfoLoc(enum.IntEnum): START_AT = 6 END_AT = 7 RESOURCES = 8 + PID = 9 _DB_PATH = os.path.expanduser('~/.sky/jobs.db') @@ -67,6 +73,16 @@ def create_table(cursor, conn): # If the database is locked, it is OK to continue, as the WAL mode # is not critical and is likely to be enabled by other processes. + # Pid column is used for keeping track of the driver process of a job. It + # can be in three states: + # -1: The job was submitted with SkyPilot older than #4318, where we use + # ray job submit to submit the job, i.e. no pid is recorded. This is for + # backward compatibility and should be removed after 0.10.0. + # 0: The job driver process has never been started. When adding a job with + # INIT state, the pid will be set to 0 (the default -1 value is just for + # backward compatibility). + # >=0: The job has been started. The pid is the driver process's pid. + # The driver can be actually running or finished. cursor.execute("""\ CREATE TABLE IF NOT EXISTS jobs ( job_id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -75,7 +91,10 @@ def create_table(cursor, conn): submitted_at FLOAT, status TEXT, run_timestamp TEXT CANDIDATE KEY, - start_at FLOAT DEFAULT -1)""") + start_at FLOAT DEFAULT -1, + end_at FLOAT DEFAULT NULL, + resources TEXT DEFAULT NULL, + pid INTEGER DEFAULT -1)""") cursor.execute("""CREATE TABLE IF NOT EXISTS pending_jobs( job_id INTEGER, @@ -86,7 +105,8 @@ def create_table(cursor, conn): db_utils.add_column_to_table(cursor, conn, 'jobs', 'end_at', 'FLOAT') db_utils.add_column_to_table(cursor, conn, 'jobs', 'resources', 'TEXT') - + db_utils.add_column_to_table(cursor, conn, 'jobs', 'pid', + 'INTEGER DEFAULT -1') conn.commit() @@ -118,6 +138,11 @@ class JobStatus(enum.Enum): # In the 'jobs' table, the `start_at` column will be set to the current # time, when the job is firstly transitioned to RUNNING. RUNNING = 'RUNNING' + # The job driver process failed. This happens when the job driver process + # finishes when the status in job table is still not set to terminal state. + # We should keep this state before the SUCCEEDED, as our job status update + # relies on the order of the statuses to keep the latest status. + FAILED_DRIVER = 'FAILED_DRIVER' # 3 terminal states below: once reached, they do not transition. # The job finished successfully. SUCCEEDED = 'SUCCEEDED' @@ -148,11 +173,16 @@ def colored_str(self): return f'{color}{self.value}{colorama.Style.RESET_ALL}' -# Only update status of the jobs after this many seconds of job submission, -# to avoid race condition with `ray job` to make sure it job has been -# correctly updated. +# We have two steps for job submissions: +# 1. Client reserve a job id from the job table by adding a INIT state job. +# 2. Client updates the job status to PENDING by actually submitting the job's +# command to the scheduler. +# In normal cases, the two steps happens very close to each other through two +# consecutive SSH connections. +# We should update status for INIT job that has been staying in INIT state for +# a while (60 seconds), which likely fails to reach step 2. # TODO(zhwu): This number should be tuned based on heuristics. -_PENDING_SUBMIT_GRACE_PERIOD = 60 +_INIT_SUBMIT_GRACE_PERIOD = 60 _PRE_RESOURCE_STATUSES = [JobStatus.PENDING] @@ -175,7 +205,39 @@ def _run_job(self, job_id: int, run_cmd: str): _CURSOR.execute((f'UPDATE pending_jobs SET submit={int(time.time())} ' f'WHERE job_id={job_id!r}')) _CONN.commit() - subprocess.Popen(run_cmd, shell=True, stdout=subprocess.DEVNULL) + # Use nohup to ensure the job driver process is a separate process tree, + # instead of being a child of the current process. This is important to + # avoid a chain of driver processes (job driver can call schedule_step() + # to submit new jobs, and the new job can also call schedule_step() + # recursively). + # + # echo $! will output the PID of the last background process started + # in the current shell, so we can retrieve it and record in the DB. + # + # TODO(zhwu): A more elegant solution is to use another daemon process + # to be in charge of starting these driver processes, instead of + # starting them in the current process. + wrapped_cmd = (f'nohup bash -c {shlex.quote(run_cmd)} ' + '/dev/null 2>&1 & echo $!') + proc = subprocess.run(wrapped_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.DEVNULL, + start_new_session=True, + check=True, + shell=True, + text=True) + # Get the PID of the detached process + pid = int(proc.stdout.strip()) + + # TODO(zhwu): Backward compatibility, remove this check after 0.10.0. + # This is for the case where the job is submitted with SkyPilot older + # than #4318, using ray job submit. + if 'job submit' in run_cmd: + pid = -1 + _CURSOR.execute((f'UPDATE jobs SET pid={pid} ' + f'WHERE job_id={job_id!r}')) + _CONN.commit() def schedule_step(self, force_update_jobs: bool = False) -> None: if force_update_jobs: @@ -237,59 +299,13 @@ def _get_pending_job_ids(self) -> List[int]: JobStatus.SETTING_UP: colorama.Fore.BLUE, JobStatus.PENDING: colorama.Fore.BLUE, JobStatus.RUNNING: colorama.Fore.GREEN, + JobStatus.FAILED_DRIVER: colorama.Fore.RED, JobStatus.SUCCEEDED: colorama.Fore.GREEN, JobStatus.FAILED: colorama.Fore.RED, JobStatus.FAILED_SETUP: colorama.Fore.RED, JobStatus.CANCELLED: colorama.Fore.YELLOW, } -_RAY_TO_JOB_STATUS_MAP = { - # These are intentionally set this way, because: - # 1. when the ray status indicates the job is PENDING the generated - # python program has been `ray job submit` from the job queue - # and is now PENDING - # 2. when the ray status indicates the job is RUNNING the job can be in - # setup or resources may not be allocated yet, i.e. the job should be - # PENDING. - # For case 2, update_job_status() would compare this mapped PENDING to - # the status in our jobs DB and take the max. This is because the job's - # generated ray program is the only place that can determine a job has - # reserved resources and actually started running: it will set the - # status in the DB to SETTING_UP or RUNNING. - # If there is no setup specified in the task, as soon as it is started - # (ray's status becomes RUNNING), i.e. it will be very rare that the job - # will be set to SETTING_UP by the update_job_status, as our generated - # ray program will set the status to PENDING immediately. - 'PENDING': JobStatus.PENDING, - 'RUNNING': JobStatus.PENDING, - 'SUCCEEDED': JobStatus.SUCCEEDED, - 'FAILED': JobStatus.FAILED, - 'STOPPED': JobStatus.CANCELLED, -} - - -def _create_ray_job_submission_client(): - """Import the ray job submission client.""" - try: - import ray # pylint: disable=import-outside-toplevel - except ImportError: - logger.error('Failed to import ray') - raise - try: - # pylint: disable=import-outside-toplevel - from ray import job_submission - except ImportError: - logger.error( - f'Failed to import job_submission with ray=={ray.__version__}') - raise - port = get_job_submission_port() - return job_submission.JobSubmissionClient( - address=f'http://127.0.0.1:{port}') - - -def make_ray_job_id(sky_job_id: int) -> str: - return f'{sky_job_id}-{getpass.getuser()}' - def make_job_command_with_user_switching(username: str, command: str) -> List[str]: @@ -301,9 +317,10 @@ def add_job(job_name: str, username: str, run_timestamp: str, """Atomically reserve the next available job id for the user.""" job_submitted_at = time.time() # job_id will autoincrement with the null value - _CURSOR.execute('INSERT INTO jobs VALUES (null, ?, ?, ?, ?, ?, ?, null, ?)', - (job_name, username, job_submitted_at, JobStatus.INIT.value, - run_timestamp, None, resources_str)) + _CURSOR.execute( + 'INSERT INTO jobs VALUES (null, ?, ?, ?, ?, ?, ?, null, ?, 0)', + (job_name, username, job_submitted_at, JobStatus.INIT.value, + run_timestamp, None, resources_str)) _CONN.commit() rows = _CURSOR.execute('SELECT job_id FROM jobs WHERE run_timestamp=(?)', (run_timestamp,)) @@ -478,6 +495,7 @@ def _get_records_from_rows(rows) -> List[Dict[str, Any]]: 'start_at': row[JobInfoLoc.START_AT.value], 'end_at': row[JobInfoLoc.END_AT.value], 'resources': row[JobInfoLoc.RESOURCES.value], + 'pid': row[JobInfoLoc.PID.value], }) return records @@ -537,6 +555,23 @@ def _get_pending_job(job_id: int) -> Optional[Dict[str, Any]]: return None +def _is_job_driver_process_running(job_pid: int, job_id: int) -> bool: + """Check if the job driver process is running. + + We check the cmdline to avoid the case where the same pid is reused by a + different process. + """ + if job_pid <= 0: + return False + try: + job_process = psutil.Process(job_pid) + return job_process.is_running() and any( + JOB_CMD_IDENTIFIER.format(job_id) in line + for line in job_process.cmdline()) + except psutil.NoSuchProcess: + return False + + def update_job_status(job_ids: List[int], silent: bool = False) -> List[JobStatus]: """Updates and returns the job statuses matching our `JobStatus` semantics. @@ -554,11 +589,8 @@ def update_job_status(job_ids: List[int], if len(job_ids) == 0: return [] - ray_job_ids = [make_ray_job_id(job_id) for job_id in job_ids] - job_client = _create_ray_job_submission_client() - statuses = [] - for job_id, ray_job_id in zip(job_ids, ray_job_ids): + for job_id in job_ids: # Per-job status lock is required because between the job status # query and the job status update, the job status in the databse # can be modified by the generated ray program. @@ -567,11 +599,13 @@ def update_job_status(job_ids: List[int], job_record = _get_jobs_by_ids([job_id])[0] original_status = job_record['status'] job_submitted_at = job_record['submitted_at'] + job_pid = job_record['pid'] - ray_job_query_time = time.time() + pid_query_time = time.time() + failed_driver_transition_message = None if original_status == JobStatus.INIT: if (job_submitted_at >= psutil.boot_time() and job_submitted_at - >= ray_job_query_time - _PENDING_SUBMIT_GRACE_PERIOD): + >= pid_query_time - _INIT_SUBMIT_GRACE_PERIOD): # The job id is reserved, but the job is not submitted yet. # We should keep it in INIT. status = JobStatus.INIT @@ -582,75 +616,98 @@ def update_job_status(job_ids: List[int], # was killed before the job is submitted. We should set it # to FAILED then. Note, if ray job indicates the job is # running, we will change status to PENDING below. - echo(f'INIT job {job_id} is stale, setting to FAILED') - status = JobStatus.FAILED - - try: - # Querying status within the lock is safer than querying - # outside, as it avoids the race condition when job table is - # updated after the ray job status query. - # Also, getting per-job status is faster than querying all jobs, - # when there are significant number of finished jobs. - # Reference: getting 124 finished jobs takes 0.038s, while - # querying a single job takes 0.006s, 10 jobs takes 0.066s. - # TODO: if too slow, directly query against redis. - ray_job_status = job_client.get_job_status(ray_job_id) - status = _RAY_TO_JOB_STATUS_MAP[ray_job_status.value] - except RuntimeError: - # Job not found. - pass + failed_driver_transition_message = ( + f'INIT job {job_id} is stale, setting to FAILED_DRIVER') + status = JobStatus.FAILED_DRIVER + + # job_pid is 0 if the job is not submitted yet. + # job_pid is -1 if the job is submitted with SkyPilot older than + # #4318, using ray job submit. We skip the checking for those + # jobs. + if job_pid > 0: + if _is_job_driver_process_running(job_pid, job_id): + status = JobStatus.PENDING + else: + # By default, if the job driver process does not exist, + # the actual SkyPilot job is one of the following: + # 1. Still pending to be submitted. + # 2. Submitted and finished. + # 3. Driver failed without correctly setting the job + # status in the job table. + # Although we set the status to FAILED_DRIVER, it can be + # overridden to PENDING if the job is not submitted, or + # any other terminal status if the job driver process + # finished correctly. + failed_driver_transition_message = ( + f'Job {job_id} driver process is not running, but ' + 'the job state is not in terminal states, setting ' + 'it to FAILED_DRIVER') + status = JobStatus.FAILED_DRIVER + elif job_pid < 0: + # TODO(zhwu): Backward compatibility, remove after 0.9.0. + # We set the job status to PENDING instead of actually + # checking ray job status and let the status in job table + # take effect in the later max. + status = JobStatus.PENDING pending_job = _get_pending_job(job_id) if pending_job is not None: if pending_job['created_time'] < psutil.boot_time(): - echo(f'Job {job_id} is stale, setting to FAILED: ' - f'created_time={pending_job["created_time"]}, ' - f'boot_time={psutil.boot_time()}') + failed_driver_transition_message = ( + f'Job {job_id} is stale, setting to FAILED_DRIVER: ' + f'created_time={pending_job["created_time"]}, ' + f'boot_time={psutil.boot_time()}') # The job is stale as it is created before the instance # is booted, e.g. the instance is rebooted. - status = JobStatus.FAILED - # Gives a 60 second grace period between job being submit from - # the pending table until appearing in ray jobs. For jobs - # submitted outside of the grace period, we will consider the - # ray job status. - - if not (pending_job['submit'] > 0 and pending_job['submit'] < - ray_job_query_time - _PENDING_SUBMIT_GRACE_PERIOD): - # Reset the job status to PENDING even though it may not - # appear in the ray jobs, so that it will not be considered - # as stale. + status = JobStatus.FAILED_DRIVER + elif pending_job['submit'] <= 0: + # The job is not submitted (submit <= 0), we set it to + # PENDING. + # For submitted jobs, the driver should have been started, + # because the job_lib.JobScheduler.schedule_step() have + # the submit field and driver process pid set in the same + # job lock. + # The job process check in the above section should + # correctly figured out the status and we don't overwrite + # it here. (Note: the FAILED_DRIVER status will be + # overridden by the actual job terminal status in the table + # if the job driver process finished correctly.) status = JobStatus.PENDING assert original_status is not None, (job_id, status) if status is None: + # The job is submitted but the job driver process pid is not + # set in the database. This is guarding against the case where + # the schedule_step() function is interrupted (e.g., VM stop) + # at the middle of starting a new process and setting the pid. status = original_status if (original_status is not None and not original_status.is_terminal()): - echo(f'Ray job status for job {job_id} is None, ' - 'setting it to FAILED.') - # The job may be stale, when the instance is restarted - # (the ray redis is volatile). We need to reset the - # status of the task to FAILED if its original status - # is RUNNING or PENDING. - status = JobStatus.FAILED + echo(f'Job {job_id} status is None, setting it to ' + 'FAILED_DRIVER.') + # The job may be stale, when the instance is restarted. We + # need to reset the job status to FAILED_DRIVER if its + # original status is in nonterminal_statuses. + echo(f'Job {job_id} is in a unknown state, setting it to ' + 'FAILED_DRIVER') + status = JobStatus.FAILED_DRIVER _set_status_no_lock(job_id, status) - echo(f'Updated job {job_id} status to {status}') else: # Taking max of the status is necessary because: - # 1. It avoids race condition, where the original status has - # already been set to later state by the job. We skip the - # update. - # 2. _RAY_TO_JOB_STATUS_MAP would map `ray job status`'s - # `RUNNING` to our JobStatus.SETTING_UP; if a job has already - # been set to JobStatus.PENDING or JobStatus.RUNNING by the - # generated ray program, `original_status` (job status from our - # DB) would already have that value. So we take the max here to - # keep it at later status. + # 1. The original status has already been set to later + # terminal state by a finished job driver. + # 2. Job driver process check would map any running job process + # to `PENDING`, so we need to take the max to keep it at + # later status for jobs actually started in SETTING_UP or + # RUNNING. status = max(status, original_status) assert status is not None, (job_id, status, original_status) if status != original_status: # Prevents redundant update. _set_status_no_lock(job_id, status) echo(f'Updated job {job_id} status to {status}') + if (status == JobStatus.FAILED_DRIVER and + failed_driver_transition_message is not None): + echo(failed_driver_transition_message) statuses.append(status) return statuses @@ -663,17 +720,13 @@ def fail_all_jobs_in_progress() -> None: f"""\ UPDATE jobs SET status=(?) WHERE status IN ({','.join(['?'] * len(in_progress_status))}) - """, (JobStatus.FAILED.value, *in_progress_status)) + """, (JobStatus.FAILED_DRIVER.value, *in_progress_status)) _CONN.commit() def update_status() -> None: # This will be called periodically by the skylet to update the status # of the jobs in the database, to avoid stale job status. - # NOTE: there might be a INIT job in the database set to FAILED by this - # function, as the ray job status does not exist due to the app - # not submitted yet. It will be then reset to PENDING / RUNNING when the - # app starts. nonterminal_jobs = _get_jobs(username=None, status_list=JobStatus.nonterminal_statuses()) nonterminal_job_ids = [job['job_id'] for job in nonterminal_jobs] @@ -756,6 +809,31 @@ def load_job_queue(payload: str) -> List[Dict[str, Any]]: return jobs +# TODO(zhwu): Backward compatibility for jobs submitted before #4318, remove +# after 0.10.0. +def _create_ray_job_submission_client(): + """Import the ray job submission client.""" + try: + import ray # pylint: disable=import-outside-toplevel + except ImportError: + logger.error('Failed to import ray') + raise + try: + # pylint: disable=import-outside-toplevel + from ray import job_submission + except ImportError: + logger.error( + f'Failed to import job_submission with ray=={ray.__version__}') + raise + port = get_job_submission_port() + return job_submission.JobSubmissionClient( + address=f'http://127.0.0.1:{port}') + + +def _make_ray_job_id(sky_job_id: int) -> str: + return f'{sky_job_id}-{getpass.getuser()}' + + def cancel_jobs_encoded_results(jobs: Optional[List[int]], cancel_all: bool = False) -> str: """Cancel jobs. @@ -783,27 +861,51 @@ def cancel_jobs_encoded_results(jobs: Optional[List[int]], # Cancel jobs with specified IDs. job_records = _get_jobs_by_ids(jobs) - # TODO(zhwu): `job_client.stop_job` will wait for the jobs to be killed, but - # when the memory is not enough, this will keep waiting. - job_client = _create_ray_job_submission_client() cancelled_ids = [] # Sequentially cancel the jobs to avoid the resource number bug caused by # ray cluster (tracked in #1262). - for job in job_records: - job_id = make_ray_job_id(job['job_id']) + for job_record in job_records: + job_id = job_record['job_id'] # Job is locked to ensure that pending queue does not start it while # it is being cancelled - with filelock.FileLock(_get_lock_path(job['job_id'])): - try: - job_client.stop_job(job_id) - except RuntimeError as e: - # If the request to the job server fails, we should not - # set the job to CANCELLED. - if 'does not exist' not in str(e): - logger.warning(str(e)) - continue - + with filelock.FileLock(_get_lock_path(job_id)): + job = _get_jobs_by_ids([job_id])[0] + if _is_job_driver_process_running(job['pid'], job_id): + # Not use process.terminate() as that will only terminate the + # process shell process, not the ray driver process + # under the shell. + # + # We don't kill all the children of the process, like + # subprocess_utils.kill_process_daemon() does, but just the + # process group here, because the underlying job driver can + # start other jobs with `schedule_step`, causing the other job + # driver processes to be children of the current job driver + # process. + # + # Killing the process group is enough as the underlying job + # should be able to clean itself up correctly by ray driver. + # + # The process group pid should be the same as the job pid as we + # use start_new_session=True, but we use os.getpgid() to be + # extra cautious. + job_pgid = os.getpgid(job['pid']) + os.killpg(job_pgid, signal.SIGTERM) + # We don't have to start a daemon to forcefully kill the process + # as our job driver process will clean up the underlying + # child processes. + elif job['pid'] < 0: + try: + # TODO(zhwu): Backward compatibility, remove after 0.9.0. + # The job was submitted with ray job submit before #4318. + job_client = _create_ray_job_submission_client() + job_client.stop_job(_make_ray_job_id(job['job_id'])) + except RuntimeError as e: + # If the request to the job server fails, we should not + # set the job to CANCELLED. + if 'does not exist' not in str(e): + logger.warning(str(e)) + continue # Get the job status again to avoid race condition. job_status = get_status_no_lock(job['job_id']) if job_status in [ @@ -865,10 +967,17 @@ def add_job(cls, job_name: Optional[str], username: str, run_timestamp: str, if job_name is None: job_name = '-' code = [ - 'job_id = job_lib.add_job(' - f'{job_name!r}, ' - f'{username!r}, ' - f'{run_timestamp!r}, ' + # We disallow job submission when SKYLET_VERSION is older than 9, as + # it was using ray job submit before #4318, and switched to raw + # process. Using the old skylet version will cause the job status + # to be stuck in PENDING state or transition to FAILED_DRIVER state. + '\nif int(constants.SKYLET_VERSION) < 9: ' + 'raise RuntimeError("SkyPilot runtime is too old, which does not ' + 'support submitting jobs.")', + '\njob_id = job_lib.add_job(' + f'{job_name!r},' + f'{username!r},' + f'{run_timestamp!r},' f'{resources_str!r})', 'print("Job ID: " + str(job_id), flush=True)', ] @@ -876,9 +985,11 @@ def add_job(cls, job_name: Optional[str], username: str, run_timestamp: str, @classmethod def queue_job(cls, job_id: int, cmd: str) -> str: - code = ['job_lib.scheduler.queue(' - f'{job_id!r},' - f'{cmd!r})'] + code = [ + 'job_lib.scheduler.queue(' + f'{job_id!r},' + f'{cmd!r})', + ] return cls._build(code) @classmethod diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index ac5b9d5ee16..fa3f7f9f3fc 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -183,40 +183,7 @@ def run_with_log( shell=shell, **kwargs) as proc: try: - # The proc can be defunct if the python program is killed. Here we - # open a new subprocess to gracefully kill the proc, SIGTERM - # and then SIGKILL the process group. - # Adapted from ray/dashboard/modules/job/job_manager.py#L154 - parent_pid = os.getpid() - daemon_script = os.path.join( - os.path.dirname(os.path.abspath(job_lib.__file__)), - 'subprocess_daemon.py') - python_path = subprocess.check_output( - constants.SKY_GET_PYTHON_PATH_CMD, - shell=True, - stderr=subprocess.DEVNULL, - encoding='utf-8').strip() - daemon_cmd = [ - python_path, - daemon_script, - '--parent-pid', - str(parent_pid), - '--proc-pid', - str(proc.pid), - ] - - # We do not need to set `start_new_session=True` here, as the - # daemon script will detach itself from the parent process with - # fork to avoid being killed by ray job. See the reason we - # daemonize the process in `sky/skylet/subprocess_daemon.py`. - subprocess.Popen( - daemon_cmd, - # Suppress output - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - # Disable input - stdin=subprocess.DEVNULL, - ) + subprocess_utils.kill_process_daemon(proc.pid) stdout = '' stderr = '' diff --git a/sky/skylet/subprocess_daemon.py b/sky/skylet/subprocess_daemon.py index 1261f4ecf72..55b63d1f9a5 100644 --- a/sky/skylet/subprocess_daemon.py +++ b/sky/skylet/subprocess_daemon.py @@ -15,10 +15,11 @@ def daemonize(): This detachment is crucial in the context of SkyPilot and Ray job. When 'sky cancel' is executed, it uses Ray's stop job API to terminate the job. - Without daemonization, this subprocess_daemon process would be terminated - along with its parent process, ray::task, which is launched with Ray job. - Daemonization ensures this process survives the 'sky cancel' command, - allowing it to prevent orphaned processes of Ray job. + Without daemonization, this subprocess_daemon process will still be a child + of the parent process which would be terminated along with the parent + process, ray::task or the cancel request for jobs, which is launched with + Ray job. Daemonization ensures this process survives the 'sky cancel' + command, allowing it to prevent orphaned processes of Ray job. """ # First fork: Creates a child process identical to the parent if os.fork() > 0: @@ -42,6 +43,15 @@ def daemonize(): parser = argparse.ArgumentParser() parser.add_argument('--parent-pid', type=int, required=True) parser.add_argument('--proc-pid', type=int, required=True) + parser.add_argument( + '--initial-children', + type=str, + default='', + help=( + 'Comma-separated list of initial children PIDs. This is to guard ' + 'against the case where the target process has already terminated, ' + 'while the children are still running.'), + ) args = parser.parse_args() process = None @@ -52,24 +62,34 @@ def daemonize(): except psutil.NoSuchProcess: pass - if process is None: - sys.exit() - + # Initialize children list from arguments children = [] - if parent_process is not None: - # Wait for either parent or target process to exit. + if args.initial_children: + for pid in args.initial_children.split(','): + try: + child = psutil.Process(int(pid)) + children.append(child) + except (psutil.NoSuchProcess, ValueError): + pass + + if process is not None and parent_process is not None: + # Wait for either parent or target process to exit while process.is_running() and parent_process.is_running(): try: - # process.children() must be called while the target process - # is alive, as it will return an empty list if the target - # process has already terminated. tmp_children = process.children(recursive=True) if tmp_children: children = tmp_children except psutil.NoSuchProcess: pass time.sleep(1) - children.append(process) + + if process is not None: + # Kill the target process first to avoid having more children, or fail + # the process due to the children being defunct. + children = [process] + children + + if not children: + sys.exit() for child in children: try: diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 0ab2fd7e117..a6657df960d 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -293,6 +293,13 @@ def _get_cloud_dependencies_installation_commands( 'pip list | grep runpod > /dev/null 2>&1 || ' 'pip install "runpod>=1.5.1" > /dev/null 2>&1') setup_clouds.append(str(cloud)) + elif isinstance(cloud, clouds.OCI): + step_prefix = prefix_str.replace('', + str(len(setup_clouds) + 1)) + commands.append(f'echo -en "\\r{prefix_str}OCI{empty_str}" && ' + 'pip list | grep oci > /dev/null 2>&1 || ' + 'pip install oci > /dev/null 2>&1') + setup_clouds.append(str(cloud)) if controller == Controllers.JOBS_CONTROLLER: if isinstance(cloud, clouds.IBM): step_prefix = prefix_str.replace('', @@ -303,13 +310,6 @@ def _get_cloud_dependencies_installation_commands( 'pip install ibm-cloud-sdk-core ibm-vpc ' 'ibm-platform-services ibm-cos-sdk > /dev/null 2>&1') setup_clouds.append(str(cloud)) - elif isinstance(cloud, clouds.OCI): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) - commands.append(f'echo -en "\\r{prefix_str}OCI{empty_str}" && ' - 'pip list | grep oci > /dev/null 2>&1 || ' - 'pip install oci > /dev/null 2>&1') - setup_clouds.append(str(cloud)) if (cloudflare.NAME in storage_lib.get_cached_enabled_storage_clouds_or_refresh()): step_prefix = prefix_str.replace('', str(len(setup_clouds) + 1)) @@ -818,8 +818,9 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', '[dim]View storages: sky storage ls')) try: task.sync_storage_mounts() - except ValueError as e: - if 'No enabled cloud for storage' in str(e): + except (ValueError, exceptions.NoCloudAccessError) as e: + if 'No enabled cloud for storage' in str(e) or isinstance( + e, exceptions.NoCloudAccessError): data_src = None if has_local_source_paths_file_mounts: data_src = 'file_mounts' diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index acb8fb9f490..28bd2c2ee07 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -11,6 +11,7 @@ from sky import exceptions from sky import sky_logging +from sky.skylet import constants from sky.skylet import log_lib from sky.utils import timeline from sky.utils import ux_utils @@ -198,3 +199,52 @@ def run_with_retries( continue break return returncode, stdout, stderr + + +def kill_process_daemon(process_pid: int) -> None: + """Start a daemon as a safety net to kill the process. + + Args: + process_pid: The PID of the process to kill. + """ + # Get initial children list + try: + process = psutil.Process(process_pid) + initial_children = [p.pid for p in process.children(recursive=True)] + except psutil.NoSuchProcess: + initial_children = [] + + parent_pid = os.getpid() + daemon_script = os.path.join( + os.path.dirname(os.path.abspath(log_lib.__file__)), + 'subprocess_daemon.py') + python_path = subprocess.check_output(constants.SKY_GET_PYTHON_PATH_CMD, + shell=True, + stderr=subprocess.DEVNULL, + encoding='utf-8').strip() + daemon_cmd = [ + python_path, + daemon_script, + '--parent-pid', + str(parent_pid), + '--proc-pid', + str(process_pid), + # We pass the initial children list to avoid the race condition where + # the process_pid is terminated before the daemon starts and gets the + # children list. + '--initial-children', + ','.join(map(str, initial_children)), + ] + + # We do not need to set `start_new_session=True` here, as the + # daemon script will detach itself from the parent process with + # fork to avoid being killed by parent process. See the reason we + # daemonize the process in `sky/skylet/subprocess_daemon.py`. + subprocess.Popen( + daemon_cmd, + # Suppress output + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + # Disable input + stdin=subprocess.DEVNULL, + ) diff --git a/sky/utils/timeline.py b/sky/utils/timeline.py index f7244bd9ab2..4db9bd149b2 100644 --- a/sky/utils/timeline.py +++ b/sky/utils/timeline.py @@ -79,11 +79,9 @@ def event(name_or_fn: Union[str, Callable], message: Optional[str] = None): class FileLockEvent: """Serve both as a file lock and event for the lock.""" - def __init__(self, lockfile: Union[str, os.PathLike]): + def __init__(self, lockfile: Union[str, os.PathLike], timeout: float = -1): self._lockfile = lockfile - # TODO(mraheja): remove pylint disabling when filelock version updated - # pylint: disable=abstract-class-instantiated - self._lock = filelock.FileLock(self._lockfile) + self._lock = filelock.FileLock(self._lockfile, timeout) self._hold_lock_event = Event(f'[FileLock.hold]:{self._lockfile}') def acquire(self): diff --git a/tests/backward_compatibility_tests.sh b/tests/backward_compatibility_tests.sh index 276fda899dd..696b87ff6ad 100644 --- a/tests/backward_compatibility_tests.sh +++ b/tests/backward_compatibility_tests.sh @@ -187,7 +187,7 @@ sky jobs logs -n "${MANAGED_JOB_JOB_NAME}-7-1" || exit 1 s=$(sky jobs queue | grep ${MANAGED_JOB_JOB_NAME}-7) echo "$s" echo "$s" | grep "SUCCEEDED" | wc -l | grep 2 || exit 1 -echo "$s" | grep "CANCELLED" | wc -l | grep 1 || exit 1 +echo "$s" | grep "CANCELLING\|CANCELLED" | wc -l | grep 1 || exit 1 fi sky down ${CLUSTER_NAME}* -y diff --git a/tests/skyserve/http/oci.yaml b/tests/skyserve/http/oci.yaml new file mode 100644 index 00000000000..d7d98c18ab4 --- /dev/null +++ b/tests/skyserve/http/oci.yaml @@ -0,0 +1,10 @@ +service: + readiness_probe: / + replicas: 2 + +resources: + cloud: oci + ports: 8080 + cpus: 2+ + +run: python -m http.server 8080 \ No newline at end of file diff --git a/tests/test_smoke.py b/tests/test_smoke.py index b1ccf0b7d51..ce93c3bfa30 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -465,6 +465,11 @@ def test_aws_with_ssh_proxy_command(): f'sky logs {name} 1 --status', f'export SKYPILOT_CONFIG={f.name}; sky exec {name} echo hi', f'sky logs {name} 2 --status', + # Start a small job to make sure the controller is created. + f'sky jobs launch -n {name}-0 --cloud aws --cpus 2 --use-spot -y echo hi', + # Wait other tests to create the job controller first, so that + # the job controller is not launched with proxy command. + 'timeout 300s bash -c "until sky status sky-jobs-controller* | grep UP; do sleep 1; done"', f'export SKYPILOT_CONFIG={f.name}; sky jobs launch -n {name} --cpus 2 --cloud aws --region us-east-1 -yd echo hi', 'sleep 300', f'{_GET_JOB_QUEUE} | grep {name} | grep "STARTING\|RUNNING\|SUCCEEDED"', @@ -976,7 +981,7 @@ def test_stale_job(generic_cloud: str): 'sleep 100', # Ensure this is large enough, else GCP leaks. f'sky start {name} -y', f'sky logs {name} 1 --status', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED', + f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', ], f'sky down -y {name}', ) @@ -1007,7 +1012,7 @@ def test_aws_stale_job_manual_restart(): f'sky logs {name} 3 --status', # Ensure the skylet updated the stale job status. f'sleep {events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS}', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED', + f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', ], f'sky down -y {name}', ) @@ -1038,7 +1043,7 @@ def test_gcp_stale_job_manual_restart(): f'sky logs {name} 3 --status', # Ensure the skylet updated the stale job status. f'sleep {events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS}', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED', + f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', ], f'sky down -y {name}', ) @@ -2663,7 +2668,7 @@ def test_cancel_pytorch(generic_cloud: str): f'sky launch -c {name} --cloud {generic_cloud} examples/resnet_distributed_torch.yaml -y -d', # Wait the GPU process to start. 'sleep 90', - f'sky exec {name} "(nvidia-smi | grep python) || ' + f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep python) || ' # When run inside container/k8s, nvidia-smi cannot show process ids. # See https://github.com/NVIDIA/nvidia-docker/issues/179 # To work around, we check if GPU utilization is greater than 0. @@ -2671,7 +2676,7 @@ def test_cancel_pytorch(generic_cloud: str): f'sky logs {name} 2 --status', # Ensure the job succeeded. f'sky cancel -y {name} 1', 'sleep 60', - f'sky exec {name} "(nvidia-smi | grep \'No running process\') || ' + f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep \'No running process\') || ' # Ensure Xorg is the only process running. '[ \$(nvidia-smi | grep -A 10 Processes | grep -A 10 === | grep -v Xorg) -eq 2 ]"', f'sky logs {name} 3 --status', # Ensure the job succeeded. @@ -3876,6 +3881,15 @@ def test_skyserve_kubernetes_http(): run_one_test(test) +@pytest.mark.oci +@pytest.mark.serve +def test_skyserve_oci_http(): + """Test skyserve on OCI""" + name = _get_service_name() + test = _get_skyserve_http_test(name, 'oci', 20) + run_one_test(test) + + @pytest.mark.no_fluidstack # Fluidstack does not support T4 gpus for now @pytest.mark.serve def test_skyserve_llm(generic_cloud: str):