diff --git a/.gitignore b/.gitignore index f1dbf59a52f..efa74dd744b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,4 @@ sky_logs/ sky/clouds/service_catalog/data_fetchers/*.csv .vscode/ .idea/ - +.env diff --git a/docs/source/cloud-setup/cloud-permissions/index.rst b/docs/source/cloud-setup/cloud-permissions/index.rst index 873cbf339fc..e2a1aaf16ae 100644 --- a/docs/source/cloud-setup/cloud-permissions/index.rst +++ b/docs/source/cloud-setup/cloud-permissions/index.rst @@ -20,3 +20,4 @@ Table of Contents aws gcp vsphere + kubernetes diff --git a/docs/source/cloud-setup/cloud-permissions/kubernetes.rst b/docs/source/cloud-setup/cloud-permissions/kubernetes.rst new file mode 100644 index 00000000000..5318d76b1a3 --- /dev/null +++ b/docs/source/cloud-setup/cloud-permissions/kubernetes.rst @@ -0,0 +1,234 @@ +.. _cloud-permissions-kubernetes: + +Kubernetes +========== + +When running outside your Kubernetes cluster, SkyPilot uses your local ``~/.kube/config`` file +for authentication and creating resources on your Kubernetes cluster. + +When running inside your Kubernetes cluster (e.g., as a Spot controller or Serve controller), +SkyPilot can operate using either of the following three authentication methods: + +1. **Using your local kubeconfig file**: In this case, SkyPilot will + copy your local ``~/.kube/config`` file to the controller pod and use it for + authentication. This is the default method when running inside the cluster, + and no additional configuration is required. + + .. note:: + + If your cluster uses exec based authentication in your ``~/.kube/config`` file + (e.g., GKE uses exec auth by default), SkyPilot may not be able to authenticate using this method. In this case, + consider using the service account methods below. + +2. **Creating a service account**: SkyPilot can automatically create the service + account and roles for itself to manage resources in the Kubernetes cluster. + To use this method, set ``remote_identity: SERVICE_ACCOUNT`` to your + Kubernetes configuration in the :ref:`~/.sky/config.yaml ` file: + + .. code-block:: yaml + + kubernetes: + remote_identity: SERVICE_ACCOUNT + + For details on the permissions that are granted to the service account, + refer to the `Permissions required for SkyPilot`_ section below. + +3. **Using a custom service account**: If you have a custom service account + with the `necessary permissions `__, you can configure + SkyPilot to use it by adding this to your :ref:`~/.sky/config.yaml ` file: + + .. code-block:: yaml + + kubernetes: + remote_identity: your-service-account-name + +.. note:: + + Service account based authentication applies only when the remote SkyPilot + cluster (including spot and serve controller) is launched inside the + Kubernetes cluster. When running outside the cluster (e.g., on AWS), + SkyPilot will use the local ``~/.kube/config`` file for authentication. + +Below are the permissions required by SkyPilot and an example service account YAML that you can use to create a service account with the necessary permissions. + +.. _k8s-permissions: + +Permissions required for SkyPilot +--------------------------------- + +SkyPilot requires permissions equivalent to the following roles to be able to manage the resources in the Kubernetes cluster: + +.. code-block:: yaml + + # Namespaced role for the service account + # Required for creating pods, services and other necessary resources in the namespace. + # Note these permissions only apply in the namespace where SkyPilot is deployed. + kind: Role + apiVersion: rbac.authorization.k8s.io/v1 + metadata: + name: sky-sa-role + namespace: default + rules: + - apiGroups: ["*"] + resources: ["*"] + verbs: ["*"] + --- + # ClusterRole for accessing cluster-wide resources. Details for each resource below: + kind: ClusterRole + apiVersion: rbac.authorization.k8s.io/v1 + metadata: + name: sky-sa-cluster-role + namespace: default + labels: + parent: skypilot + rules: + - apiGroups: [""] + resources: ["nodes"] # Required for getting node resources. + verbs: ["get", "list", "watch"] + - apiGroups: ["rbac.authorization.k8s.io"] + resources: ["clusterroles", "clusterrolebindings"] # Required for launching more SkyPilot clusters from within the pod. + verbs: ["get", "list", "watch"] + - apiGroups: ["node.k8s.io"] + resources: ["runtimeclasses"] # Required for autodetecting the runtime class of the nodes. + verbs: ["get", "list", "watch"] + --- + # Optional: If using ingresses, role for accessing ingress service IP + apiVersion: rbac.authorization.k8s.io/v1 + kind: Role + metadata: + namespace: ingress-nginx + name: sky-sa-role-ingress-nginx + rules: + - apiGroups: [""] + resources: ["services"] + verbs: ["list", "get"] + +These roles must apply to both the user account configured in the kubeconfig file and the service account used by SkyPilot (if configured). + +.. _k8s-sa-example: + +Example using Custom Service Account +------------------------------------ + +To create a service account that has the necessary permissions for SkyPilot, you can use the following YAML: + +.. code-block:: yaml + + # create-sky-sa.yaml + kind: ServiceAccount + apiVersion: v1 + metadata: + name: sky-sa + namespace: default + labels: + parent: skypilot + --- + # Role for the service account + kind: Role + apiVersion: rbac.authorization.k8s.io/v1 + metadata: + name: sky-sa-role + namespace: default + labels: + parent: skypilot + rules: + - apiGroups: ["*"] # Required for creating pods, services, secrets and other necessary resources in the namespace. + resources: ["*"] + verbs: ["*"] + --- + # RoleBinding for the service account + kind: RoleBinding + apiVersion: rbac.authorization.k8s.io/v1 + metadata: + name: sky-sa-rb + namespace: default + labels: + parent: skypilot + subjects: + - kind: ServiceAccount + name: sky-sa + roleRef: + kind: Role + name: sky-sa-role + apiGroup: rbac.authorization.k8s.io + --- + # Role for accessing ingress resources + apiVersion: rbac.authorization.k8s.io/v1 + kind: Role + metadata: + namespace: ingress-nginx + name: sky-sa-role-ingress-nginx + rules: + - apiGroups: [""] + resources: ["services"] + verbs: ["list", "get", "watch"] + - apiGroups: ["rbac.authorization.k8s.io"] + resources: ["roles", "rolebindings"] + verbs: ["list", "get", "watch"] + --- + # RoleBinding for accessing ingress resources + apiVersion: rbac.authorization.k8s.io/v1 + kind: RoleBinding + metadata: + name: sky-sa-rolebinding-ingress-nginx + namespace: ingress-nginx + subjects: + - kind: ServiceAccount + name: sky-sa + namespace: default + roleRef: + kind: Role + name: sky-sa-role-ingress-nginx + apiGroup: rbac.authorization.k8s.io + --- + # ClusterRole for the service account + kind: ClusterRole + apiVersion: rbac.authorization.k8s.io/v1 + metadata: + name: sky-sa-cluster-role + namespace: default + labels: + parent: skypilot + rules: + - apiGroups: [""] + resources: ["nodes"] # Required for getting node resources. + verbs: ["get", "list", "watch"] + - apiGroups: ["rbac.authorization.k8s.io"] + resources: ["clusterroles", "clusterrolebindings"] # Required for launching more SkyPilot clusters from within the pod. + verbs: ["get", "list", "watch"] + - apiGroups: ["node.k8s.io"] + resources: ["runtimeclasses"] # Required for autodetecting the runtime class of the nodes. + verbs: ["get", "list", "watch"] + - apiGroups: ["networking.k8s.io"] # Required for exposing services. + resources: ["ingressclasses"] + verbs: ["get", "list", "watch"] + --- + # ClusterRoleBinding for the service account + apiVersion: rbac.authorization.k8s.io/v1 + kind: ClusterRoleBinding + metadata: + name: sky-sa-cluster-role-binding + namespace: default + labels: + parent: skypilot + subjects: + - kind: ServiceAccount + name: sky-sa + namespace: default + roleRef: + kind: ClusterRole + name: sky-sa-cluster-role + apiGroup: rbac.authorization.k8s.io + +Create the service account using the following command: + +.. code-block:: bash + + $ kubectl apply -f create-sky-sa.yaml + +After creating the service account, configure SkyPilot to use it through ``~/.sky/config.yaml``: + +.. code-block:: yaml + + kubernetes: + remote_identity: sky-sa # Or your service account name diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 1dfda834ee0..e7594142331 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -109,7 +109,7 @@ Available fields and semantics: # permission to create a security group. security_group_name: my-security-group - # Identity to use for all AWS instances (optional). + # Identity to use for AWS instances (optional). # # LOCAL_CREDENTIALS: The user's local credential files will be uploaded to # AWS instances created by SkyPilot. They are used for accessing cloud @@ -120,6 +120,21 @@ Available fields and semantics: # instances. SkyPilot will auto-create and reuse a service account (IAM # role) for AWS instances. # + # Customized service account (IAM role): or + # - : apply the service account with the specified name to all instances. + # Example: + # remote_identity: my-service-account-name + # - : A list of single-element dict mapping from the cluster name (pattern) + # to the service account name to use. The matching of the cluster name is done in the same order + # as the list. + # NOTE: If none of the wildcard expressions in the dict match the cluster name, LOCAL_CREDENTIALS will be used. + # To specify your default, use "*" as the wildcard expression. + # Example: + # remote_identity: + # - my-cluster-name: my-service-account-1 + # - sky-serve-controller-*: my-service-account-2 + # - "*": my-default-service-account + # # Two caveats of SERVICE_ACCOUNT for multicloud users: # # - This only affects AWS instances. Local AWS credentials will still be @@ -190,21 +205,21 @@ Available fields and semantics: # Reserved capacity (optional). - # + # # Whether to prioritize reserved instance types/locations (considered as 0 # cost) in the optimizer. - # + # # If you have "automatically consumed" reservations in your GCP project: # Setting this to true guarantees the optimizer will pick any matching # reservation and GCP will auto consume your reservation, and setting to # false means optimizer uses regular, non-zero pricing in optimization (if # by chance any matching reservation is selected, GCP still auto consumes # the reservation). - # + # # If you have "specifically targeted" reservations (set by the # `specific_reservations` field below): This field will automatically be set # to true. - # + # # Default: false. prioritize_reservations: false # @@ -283,6 +298,30 @@ Available fields and semantics: # Default: loadbalancer ports: loadbalancer + # Identity to use for all Kubernetes pods (optional). + # + # LOCAL_CREDENTIALS: The user's local ~/.kube/config will be uploaded to the + # Kubernetes pods created by SkyPilot. They are used for authenticating with + # the Kubernetes API server and launching new pods (e.g., for + # spot/serve controllers). + # + # SERVICE_ACCOUNT: Local ~/.kube/config is not uploaded to Kubernetes pods. + # SkyPilot will auto-create and reuse a service account with necessary roles + # in the user's namespace. + # + # : The name of a service account to use for all Kubernetes pods. + # This service account must exist in the user's namespace and have all + # necessary permissions. Refer to https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/kubernetes.html + # for details on the roles required by the service account. + # + # Using SERVICE_ACCOUNT or a custom service account only affects Kubernetes + # instances. Local ~/.kube/config will still be uploaded to non-Kubernetes + # instances (e.g., a serve controller on GCP or AWS may need to provision + # Kubernetes resources). + # + # Default: 'LOCAL_CREDENTIALS'. + remote_identity: my-k8s-service-account + # Attach custom metadata to Kubernetes objects created by SkyPilot # # Uses the same schema as Kubernetes metadata object: https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.26/#objectmeta-v1-meta @@ -313,6 +352,25 @@ Available fields and semantics: # Default: 10 seconds provision_timeout: 10 + # Autoscaler configured in the Kubernetes cluster (optional) + # + # This field informs SkyPilot about the cluster autoscaler used in the + # Kubernetes cluster. Setting this field disables pre-launch checks for + # GPU capacity in the cluster and SkyPilot relies on the autoscaler to + # provision nodes with the required GPU capacity. + # + # Remember to set provision_timeout accordingly when using an autoscaler. + # + # Supported values: gke, karpenter, generic + # gke: uses cloud.google.com/gke-accelerator label to identify GPUs on nodes + # karpenter: uses karpenter.k8s.aws/instance-gpu-name label to identify GPUs on nodes + # generic: uses skypilot.co/accelerator labels to identify GPUs on nodes + # Refer to https://skypilot.readthedocs.io/en/latest/reference/kubernetes/kubernetes-setup.html#setting-up-gpu-support + # for more details on setting up labels for GPU support. + # + # Default: null (no autoscaler, autodetect label format for GPU nodes) + autoscaler: gke + # Additional fields to override the pod fields used by SkyPilot (optional) # # Any key:value pairs added here would get added to the pod spec used to diff --git a/docs/source/reference/kubernetes/kubernetes-setup.rst b/docs/source/reference/kubernetes/kubernetes-setup.rst index fca4d539327..3ed1b8c89f0 100644 --- a/docs/source/reference/kubernetes/kubernetes-setup.rst +++ b/docs/source/reference/kubernetes/kubernetes-setup.rst @@ -382,7 +382,7 @@ To use this mode: # ingress-nginx-controller LoadBalancer 10.24.4.254 35.202.58.117 80:31253/TCP,443:32699/TCP .. note:: - If the ``EXTERNAL-IP`` field is ````, you must manually assign an External IP. + If the ``EXTERNAL-IP`` field is ````, you may manually assign it an External IP. This can be done by patching the service with an IP that can be accessed from outside the cluster. If the service type is ``NodePort``, you can set the ``EXTERNAL-IP`` to any node's IP address: @@ -395,6 +395,22 @@ To use this mode: If the ``EXTERNAL-IP`` field is left as ````, SkyPilot will use ``localhost`` as the external IP for the Ingress, and the endpoint may not be accessible from outside the cluster. +.. note:: + If you cannot update the ``EXTERNAL-IP`` field of the service, you can also + specify the Ingress IP or hostname through the ``skypilot.co/external-ip`` + annotation on the ``ingress-nginx-controller`` service. In this case, + having a valid ``EXTERNAL-IP`` field is not required. + + For example, if your ``ingress-nginx-controller`` service is ``NodePort``: + + .. code-block:: bash + + # Add skypilot.co/external-ip annotation to the nginx ingress service. + # Replace in the following command with the IP you select. + # Can be any node's IP if using NodePort service type. + $ kubectl annotate service ingress-nginx-controller skypilot.co/external-ip= -n ingress-nginx + + 3. Update the :ref:`SkyPilot config ` at :code:`~/.sky/config` to use the ingress mode. .. code-block:: yaml diff --git a/docs/source/serving/sky-serve.rst b/docs/source/serving/sky-serve.rst index c7425e28c14..3ccbed140c0 100644 --- a/docs/source/serving/sky-serve.rst +++ b/docs/source/serving/sky-serve.rst @@ -302,11 +302,12 @@ Let's bring up a real LLM chat service with FastChat + Vicuna. We'll use the `Vi conda activate chatbot echo 'Starting controller...' - python -u -m fastchat.serve.controller > ~/controller.log 2>&1 & + python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 & sleep 10 echo 'Starting model worker...' python -u -m fastchat.serve.model_worker \ --model-path lmsys/vicuna-${MODEL_SIZE}b-v1.3 2>&1 \ + --host 127.0.0.1 \ | tee model_worker.log & echo 'Waiting for model worker to start...' diff --git a/examples/serve/gorilla/gorilla.yaml b/examples/serve/gorilla/gorilla.yaml index ee46aa94568..e3072d816fb 100644 --- a/examples/serve/gorilla/gorilla.yaml +++ b/examples/serve/gorilla/gorilla.yaml @@ -35,11 +35,12 @@ run: | conda activate chatbot echo 'Starting controller...' - python -u -m fastchat.serve.controller > ~/controller.log 2>&1 & + python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 & sleep 10 echo 'Starting model worker...' python -u -m fastchat.serve.model_worker \ --model-path gorilla-llm/gorilla-falcon-7b-hf-v0 2>&1 \ + --host 127.0.0.1 \ | tee model_worker.log & echo 'Waiting for model worker to start...' diff --git a/examples/serve/vicuna-v1.5.yaml b/examples/serve/vicuna-v1.5.yaml index c94115ea3d7..0f659e85697 100644 --- a/examples/serve/vicuna-v1.5.yaml +++ b/examples/serve/vicuna-v1.5.yaml @@ -34,11 +34,12 @@ run: | conda activate chatbot echo 'Starting controller...' - python -u -m fastchat.serve.controller > ~/controller.log 2>&1 & + python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 & sleep 10 echo 'Starting model worker...' python -u -m fastchat.serve.model_worker \ --model-path lmsys/vicuna-${MODEL_SIZE}b-v1.5 2>&1 \ + --host 127.0.0.1 \ | tee model_worker.log & echo 'Waiting for model worker to start...' diff --git a/llm/llama-2/chatbot-hf.yaml b/llm/llama-2/chatbot-hf.yaml index 4c0132e4dd4..992c01346e6 100644 --- a/llm/llama-2/chatbot-hf.yaml +++ b/llm/llama-2/chatbot-hf.yaml @@ -24,12 +24,13 @@ run: | conda activate chatbot echo 'Starting controller...' - python -u -m fastchat.serve.controller > ~/controller.log 2>&1 & + python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 & sleep 10 echo 'Starting model worker...' python -u -m fastchat.serve.model_worker \ --model-path meta-llama/Llama-2-${MODEL_SIZE}b-chat-hf \ --num-gpus $SKYPILOT_NUM_GPUS_PER_NODE 2>&1 \ + --host 127.0.0.1 \ | tee model_worker.log & echo 'Waiting for model worker to start...' diff --git a/llm/vicuna-llama-2/serve.yaml b/llm/vicuna-llama-2/serve.yaml index 0a98dab5d26..69f89f2fc28 100644 --- a/llm/vicuna-llama-2/serve.yaml +++ b/llm/vicuna-llama-2/serve.yaml @@ -27,11 +27,12 @@ run: | conda activate chatbot echo 'Starting controller...' - python -u -m fastchat.serve.controller > ~/controller.log 2>&1 & + python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 & sleep 10 echo 'Starting model worker...' python -u -m fastchat.serve.model_worker \ --model-path /skypilot-vicuna 2>&1 \ + --host 127.0.0.1 \ | tee model_worker.log & echo 'Waiting for model worker to start...' diff --git a/llm/vicuna/serve-openai-api-endpoint.yaml b/llm/vicuna/serve-openai-api-endpoint.yaml index 247043ee3c2..639dfadc6d6 100644 --- a/llm/vicuna/serve-openai-api-endpoint.yaml +++ b/llm/vicuna/serve-openai-api-endpoint.yaml @@ -19,11 +19,12 @@ run: | conda activate chatbot echo 'Starting controller...' - python -u -m fastchat.serve.controller > ~/controller.log 2>&1 & + python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 & sleep 10 echo 'Starting model worker...' python -u -m fastchat.serve.model_worker \ --model-path lmsys/vicuna-${MODEL_SIZE}b-v1.3 2>&1 \ + --host 127.0.0.1 \ | tee model_worker.log & echo 'Waiting for model worker to start...' diff --git a/llm/vicuna/serve.yaml b/llm/vicuna/serve.yaml index d458112a42f..49185fcea20 100644 --- a/llm/vicuna/serve.yaml +++ b/llm/vicuna/serve.yaml @@ -19,11 +19,12 @@ run: | conda activate chatbot echo 'Starting controller...' - python -u -m fastchat.serve.controller > ~/controller.log 2>&1 & + python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 & sleep 10 echo 'Starting model worker...' python -u -m fastchat.serve.model_worker \ --model-path lmsys/vicuna-${MODEL_SIZE}b-v1.3 2>&1 \ + --host 127.0.0.1 \ | tee model_worker.log & echo 'Waiting for model worker to start...' diff --git a/sky/adaptors/kubernetes.py b/sky/adaptors/kubernetes.py index ce6f93a8905..f4c84d3f578 100644 --- a/sky/adaptors/kubernetes.py +++ b/sky/adaptors/kubernetes.py @@ -55,9 +55,9 @@ def _load_config(): ' If you were running a local Kubernetes ' 'cluster, run `sky local up` to start the cluster.') else: - err_str = ( - 'Failed to load Kubernetes configuration. ' - f'Please check if your kubeconfig file is valid.{suffix}') + err_str = ('Failed to load Kubernetes configuration. ' + 'Please check if your kubeconfig file exists at ' + f'~/.kube/config and is valid.{suffix}') err_str += '\nTo disable Kubernetes for SkyPilot: run `sky check`.' with ux_utils.print_exception_no_traceback(): raise ValueError(err_str) from None diff --git a/sky/authentication.py b/sky/authentication.py index 581fdc12c7f..966dad670c5 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -408,7 +408,7 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: # Add the user's public key to the SkyPilot cluster. public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH) secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME - secret_field_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_FIELD_NAME + secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name namespace = kubernetes_utils.get_current_kube_config_context_namespace() k8s = kubernetes.kubernetes with open(public_key_path, 'r', encoding='utf-8') as f: diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index fecbcaad0b8..5a9663b1275 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1,6 +1,7 @@ """Util constants/functions for the backends.""" from datetime import datetime import enum +import fnmatch import functools import os import pathlib @@ -47,7 +48,9 @@ from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import env_options +from sky.utils import resources_utils from sky.utils import rich_utils +from sky.utils import schemas from sky.utils import subprocess_utils from sky.utils import timeline from sky.utils import ux_utils @@ -108,6 +111,9 @@ # Remote dir that holds our runtime files. _REMOTE_RUNTIME_FILES_DIR = '~/.sky/.runtime_files' +_ENDPOINTS_RETRY_MESSAGE = ('If the cluster was recently started, ' + 'please retry after a while.') + # Include the fields that will be used for generating tags that distinguishes # the cluster in ray, to avoid the stopped cluster being discarded due to # updates in the yaml template. @@ -327,7 +333,7 @@ def wrap_file_mount(cls, path: str) -> str: def make_safe_symlink_command(cls, *, source: str, target: str) -> str: """Returns a command that safely symlinks 'source' to 'target'. - All intermediate directories of 'source' will be owned by $USER, + All intermediate directories of 'source' will be owned by $(whoami), excluding the root directory (/). 'source' must be an absolute path; both 'source' and 'target' must not @@ -354,17 +360,17 @@ def make_safe_symlink_command(cls, *, source: str, target: str) -> str: target) # Below, use sudo in case the symlink needs sudo access to create. # Prepare to create the symlink: - # 1. make sure its dir(s) exist & are owned by $USER. + # 1. make sure its dir(s) exist & are owned by $(whoami). dir_of_symlink = os.path.dirname(source) commands = [ # mkdir, then loop over '/a/b/c' as /a, /a/b, /a/b/c. For each, - # chown $USER on it so user can use these intermediate dirs + # chown $(whoami) on it so user can use these intermediate dirs # (excluding /). f'sudo mkdir -p {dir_of_symlink}', # p: path so far ('(p=""; ' f'for w in $(echo {dir_of_symlink} | tr "/" " "); do ' - 'p=${p}/${w}; sudo chown $USER $p; done)') + 'p=${p}/${w}; sudo chown $(whoami) $p; done)') ] # 2. remove any existing symlink (ln -f may throw 'cannot # overwrite directory', if the link exists and points to a @@ -380,7 +386,7 @@ def make_safe_symlink_command(cls, *, source: str, target: str) -> str: # Link. f'sudo ln -s {target} {source}', # chown. -h to affect symlinks only. - f'sudo chown -h $USER {source}', + f'sudo chown -h $(whoami) {source}', ] return ' && '.join(commands) @@ -797,8 +803,14 @@ def write_cluster_config( assert cluster_name is not None excluded_clouds = [] remote_identity = skypilot_config.get_nested( - (str(cloud).lower(), 'remote_identity'), 'LOCAL_CREDENTIALS') - if remote_identity == 'SERVICE_ACCOUNT': + (str(cloud).lower(), 'remote_identity'), + schemas.REMOTE_IDENTITY_DEFAULT) + if remote_identity is not None and not isinstance(remote_identity, str): + for profile in remote_identity: + if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]): + remote_identity = list(profile.values())[0] + break + if remote_identity != schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value: if not cloud.supports_service_account_on_remote(): raise exceptions.InvalidCloudConfigs( 'remote_identity: SERVICE_ACCOUNT is specified in ' @@ -888,6 +900,8 @@ def write_cluster_config( # User-supplied labels. 'labels': labels, + # User-supplied remote_identity + 'remote_identity': remote_identity, # The reservation pools that specified by the user. This is # currently only used by GCP. 'specific_reservations': specific_reservations, @@ -1066,7 +1080,7 @@ def get_ready_nodes_counts(pattern, output): def get_docker_user(ip: str, cluster_config_file: str) -> str: """Find docker container username.""" ssh_credentials = ssh_credential_from_yaml(cluster_config_file) - runner = command_runner.SSHCommandRunner(ip, port=22, **ssh_credentials) + runner = command_runner.SSHCommandRunner(node=(ip, 22), **ssh_credentials) container_name = constants.DEFAULT_DOCKER_CONTAINER_NAME whoami_returncode, whoami_stdout, whoami_stderr = runner.run( f'sudo docker exec {container_name} whoami', @@ -1099,7 +1113,7 @@ def wait_until_ray_cluster_ready( try: head_ip = _query_head_ip_with_retries( cluster_config_file, max_attempts=WAIT_HEAD_NODE_IP_MAX_ATTEMPTS) - except exceptions.FetchIPError as e: + except exceptions.FetchClusterInfoError as e: logger.error(common_utils.format_exception(e)) return False, None # failed @@ -1115,8 +1129,7 @@ def wait_until_ray_cluster_ready( ssh_credentials = ssh_credential_from_yaml(cluster_config_file, docker_user) last_nodes_so_far = 0 start = time.time() - runner = command_runner.SSHCommandRunner(head_ip, - port=22, + runner = command_runner.SSHCommandRunner(node=(head_ip, 22), **ssh_credentials) with rich_utils.safe_status( '[bold cyan]Waiting for workers...') as worker_status: @@ -1222,7 +1235,7 @@ def ssh_credential_from_yaml( def parallel_data_transfer_to_nodes( - runners: List[command_runner.SSHCommandRunner], + runners: List[command_runner.CommandRunner], source: Optional[str], target: str, cmd: Optional[str], @@ -1232,32 +1245,36 @@ def parallel_data_transfer_to_nodes( # Advanced options. log_path: str = os.devnull, stream_logs: bool = False, + source_bashrc: bool = False, ): """Runs a command on all nodes and optionally runs rsync from src->dst. Args: - runners: A list of SSHCommandRunner objects that represent multiple nodes. + runners: A list of CommandRunner objects that represent multiple nodes. source: Optional[str]; Source for rsync on local node target: str; Destination on remote node for rsync cmd: str; Command to be executed on all nodes action_message: str; Message to be printed while the command runs log_path: str; Path to the log file stream_logs: bool; Whether to stream logs to stdout + source_bashrc: bool; Source bashrc before running the command. """ fore = colorama.Fore style = colorama.Style origin_source = source - def _sync_node(runner: 'command_runner.SSHCommandRunner') -> None: + def _sync_node(runner: 'command_runner.CommandRunner') -> None: if cmd is not None: rc, stdout, stderr = runner.run(cmd, log_path=log_path, stream_logs=stream_logs, - require_outputs=True) + require_outputs=True, + source_bashrc=source_bashrc) err_msg = ('Failed to run command before rsync ' f'{origin_source} -> {target}. ' - 'Ensure that the network is stable, then retry.') + 'Ensure that the network is stable, then retry. ' + f'{cmd}') if log_path != os.devnull: err_msg += f' See logs in {log_path}' subprocess_utils.handle_returncode(rc, @@ -1322,7 +1339,7 @@ def _query_head_ip_with_retries(cluster_yaml: str, """Returns the IP of the head node by querying the cloud. Raises: - exceptions.FetchIPError: if we failed to get the head IP. + exceptions.FetchClusterInfoError: if we failed to get the head IP. """ backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5) for i in range(max_attempts): @@ -1351,8 +1368,8 @@ def _query_head_ip_with_retries(cluster_yaml: str, break except subprocess.CalledProcessError as e: if i == max_attempts - 1: - raise exceptions.FetchIPError( - reason=exceptions.FetchIPError.Reason.HEAD) from e + raise exceptions.FetchClusterInfoError( + reason=exceptions.FetchClusterInfoError.Reason.HEAD) from e # Retry if the cluster is not up yet. logger.debug('Retrying to get head ip.') time.sleep(backoff.current_backoff()) @@ -1377,7 +1394,7 @@ def get_node_ips(cluster_yaml: str, IPs. Raises: - exceptions.FetchIPError: if we failed to get the IPs. e.reason is + exceptions.FetchClusterInfoError: if we failed to get the IPs. e.reason is HEAD or WORKER. """ ray_config = common_utils.read_yaml(cluster_yaml) @@ -1398,11 +1415,12 @@ def get_node_ips(cluster_yaml: str, 'Failed to get cluster info for ' f'{ray_config["cluster_name"]} from the new provisioner ' f'with {common_utils.format_exception(e)}.') - raise exceptions.FetchIPError( - exceptions.FetchIPError.Reason.HEAD) from e + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.HEAD) from e if len(metadata.instances) < expected_num_nodes: # Simulate the exception when Ray head node is not up. - raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD) + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.HEAD) return metadata.get_feasible_ips(get_internal_ips) if get_internal_ips: @@ -1432,8 +1450,8 @@ def get_node_ips(cluster_yaml: str, break except subprocess.CalledProcessError as e: if retry_cnt == worker_ip_max_attempts - 1: - raise exceptions.FetchIPError( - exceptions.FetchIPError.Reason.WORKER) from e + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.WORKER) from e # Retry if the ssh is not ready for the workers yet. backoff_time = backoff.current_backoff() logger.debug('Retrying to get worker ip ' @@ -1458,8 +1476,8 @@ def get_node_ips(cluster_yaml: str, f'detected IP(s): {worker_ips[-n:]}.') worker_ips = worker_ips[-n:] else: - raise exceptions.FetchIPError( - exceptions.FetchIPError.Reason.WORKER) + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.WORKER) else: worker_ips = [] return head_ip_list + worker_ips @@ -1523,7 +1541,7 @@ def check_owner_identity(cluster_name: str) -> None: for i, (owner, current) in enumerate(zip(owner_identity, current_user_identity)): - # Clean up the owner identiy for the backslash and newlines, caused + # Clean up the owner identity for the backslash and newlines, caused # by the cloud CLI output, e.g. gcloud. owner = owner.replace('\n', '').replace('\\', '') if owner == current: @@ -1746,14 +1764,11 @@ def _update_cluster_status_no_lock( def run_ray_status_to_check_ray_cluster_healthy() -> bool: try: - # TODO(zhwu): This function cannot distinguish transient network - # error in ray's get IPs vs. ray runtime failing. - # NOTE: fetching the IPs is very slow as it calls into # `ray get head-ip/worker-ips`. Using cached IPs is safe because # in the worst case we time out in the `ray status` SSH command # below. - external_ips = handle.cached_external_ips + runners = handle.get_command_runners(force_cached=True) # This happens when user interrupt the `sky launch` process before # the first time resources handle is written back to local database. # This is helpful when user interrupt after the provision is done @@ -1761,27 +1776,13 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: # helps keep the cluster status to INIT after `sky status -r`, so # user will be notified that any auto stop/down might not be # triggered. - if external_ips is None or len(external_ips) == 0: + if not runners: logger.debug(f'Refreshing status ({cluster_name!r}): No cached ' f'IPs found. Handle: {handle}') - raise exceptions.FetchIPError( - reason=exceptions.FetchIPError.Reason.HEAD) - - # Potentially refresh the external SSH ports, in case the existing - # cluster before #2491 was launched without external SSH ports - # cached. - external_ssh_ports = handle.external_ssh_ports() - head_ssh_port = external_ssh_ports[0] - - # Check if ray cluster status is healthy. - ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml, - handle.docker_user, - handle.ssh_user) - - runner = command_runner.SSHCommandRunner(external_ips[0], - **ssh_credentials, - port=head_ssh_port) - rc, output, stderr = runner.run( + raise exceptions.FetchClusterInfoError( + reason=exceptions.FetchClusterInfoError.Reason.HEAD) + head_runner = runners[0] + rc, output, stderr = head_runner.run( instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND, stream_logs=False, require_outputs=True, @@ -1801,7 +1802,7 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: f'Refreshing status ({cluster_name!r}): ray status not showing ' f'all nodes ({ready_head + ready_workers}/' f'{total_nodes}); output: {output}; stderr: {stderr}') - except exceptions.FetchIPError: + except exceptions.FetchClusterInfoError: logger.debug( f'Refreshing status ({cluster_name!r}) failed to get IPs.') except RuntimeError as e: @@ -2342,9 +2343,9 @@ def is_controller_accessible( handle.docker_user, handle.ssh_user) - runner = command_runner.SSHCommandRunner(handle.head_ip, - **ssh_credentials, - port=handle.head_ssh_port) + runner = command_runner.SSHCommandRunner(node=(handle.head_ip, + handle.head_ssh_port), + **ssh_credentials) if not runner.check_connection(): error_msg = controller.value.connection_error_hint else: @@ -2704,3 +2705,115 @@ def check_stale_runtime_on_remote(returncode: int, stderr: str, f'not interrupted): {colorama.Style.BRIGHT}sky start -f -y ' f'{cluster_name}{colorama.Style.RESET_ALL}' f'\n--- Details ---\n{stderr.strip()}\n') + + +def get_endpoints(cluster: str, + port: Optional[Union[int, str]] = None, + skip_status_check: bool = False) -> Dict[int, str]: + """Gets the endpoint for a given cluster and port number (endpoint). + + Args: + cluster: The name of the cluster. + port: The port number to get the endpoint for. If None, endpoints + for all ports are returned. + skip_status_check: Whether to skip the status check for the cluster. + This is useful when the cluster is known to be in a INIT state + and the caller wants to query the endpoints. Used by serve + controller to query endpoints during cluster launch when multiple + services may be getting launched in parallel (and as a result, + the controller may be in INIT status due to a concurrent launch). + + Returns: A dictionary of port numbers to endpoints. If endpoint is None, + the dictionary will contain all ports:endpoints exposed on the cluster. + If the endpoint is not exposed yet (e.g., during cluster launch or + waiting for cloud provider to expose the endpoint), an empty dictionary + is returned. + + Raises: + ValueError: if the port is invalid or the cloud provider does not + support querying endpoints. + exceptions.ClusterNotUpError: if the cluster is not in UP status. + """ + # Cast endpoint to int if it is not None + if port is not None: + try: + port = int(port) + except ValueError: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Invalid endpoint {port!r}.') from None + cluster_records = get_clusters(include_controller=True, + refresh=False, + cluster_names=[cluster]) + cluster_record = cluster_records[0] + if (not skip_status_check and + cluster_record['status'] != status_lib.ClusterStatus.UP): + with ux_utils.print_exception_no_traceback(): + raise exceptions.ClusterNotUpError( + f'Cluster {cluster_record["name"]!r} ' + 'is not in UP status.', cluster_record['status']) + handle = cluster_record['handle'] + if not isinstance(handle, backends.CloudVmRayResourceHandle): + with ux_utils.print_exception_no_traceback(): + raise ValueError('Querying IP address is not supported ' + f'for cluster {cluster!r} with backend ' + f'{get_backend_from_handle(handle).NAME}.') + + launched_resources = handle.launched_resources + cloud = launched_resources.cloud + try: + cloud.check_features_are_supported( + launched_resources, {clouds.CloudImplementationFeatures.OPEN_PORTS}) + except exceptions.NotSupportedError: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Querying endpoints is not supported ' + f'for cluster {cluster!r} on {cloud}.') from None + + config = common_utils.read_yaml(handle.cluster_yaml) + port_details = provision_lib.query_ports(repr(cloud), + handle.cluster_name_on_cloud, + handle.launched_resources.ports, + head_ip=handle.head_ip, + provider_config=config['provider']) + + # Validation before returning the endpoints + if port is not None: + # If the requested endpoint was not to be exposed + port_set = resources_utils.port_ranges_to_set( + handle.launched_resources.ports) + if port not in port_set: + logger.warning(f'Port {port} is not exposed on ' + f'cluster {cluster!r}.') + return {} + # If the user requested a specific port endpoint, check if it is exposed + if port not in port_details: + error_msg = (f'Port {port} not exposed yet. ' + f'{_ENDPOINTS_RETRY_MESSAGE} ') + if handle.launched_resources.cloud.is_same_cloud( + clouds.Kubernetes()): + # Add Kubernetes specific debugging info + error_msg += (kubernetes_utils.get_endpoint_debug_message()) + logger.warning(error_msg) + return {} + return {port: port_details[port][0].url()} + else: + if not port_details: + # If cluster had no ports to be exposed + if handle.launched_resources.ports is None: + logger.warning(f'Cluster {cluster!r} does not have any ' + 'ports to be exposed.') + return {} + # Else ports have not been exposed even though they exist. + # In this case, ask the user to retry. + else: + error_msg = (f'No endpoints exposed yet. ' + f'{_ENDPOINTS_RETRY_MESSAGE} ') + if handle.launched_resources.cloud.is_same_cloud( + clouds.Kubernetes()): + # Add Kubernetes specific debugging info + error_msg += \ + kubernetes_utils.get_endpoint_debug_message() + logger.warning(error_msg) + return {} + return { + port_num: urls[0].url() for port_num, urls in port_details.items() + } diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index f916d931b5f..5f1930123b0 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2,6 +2,7 @@ import base64 import copy import enum +import functools import getpass import inspect import json @@ -1478,10 +1479,12 @@ def _retry_zones( if zones and len(zones) == 1: launched_resources = launched_resources.copy(zone=zones[0].name) - prev_cluster_ips, prev_ssh_ports = None, None + prev_cluster_ips, prev_ssh_ports, prev_cluster_info = (None, None, + None) if prev_handle is not None: prev_cluster_ips = prev_handle.stable_internal_external_ips prev_ssh_ports = prev_handle.stable_ssh_ports + prev_cluster_info = prev_handle.cached_cluster_info # Record early, so if anything goes wrong, 'sky status' will show # the cluster name and users can appropriately 'sky down'. It also # means a second 'sky launch -c ' will attempt to reuse. @@ -1500,7 +1503,9 @@ def _retry_zones( # optimize the case where the cluster is restarted, i.e., no # need to query IPs and ports from the cloud provider. stable_internal_external_ips=prev_cluster_ips, - stable_ssh_ports=prev_ssh_ports) + stable_ssh_ports=prev_ssh_ports, + cluster_info=prev_cluster_info, + ) usage_lib.messages.usage.update_final_cluster_status( status_lib.ClusterStatus.INIT) @@ -1581,14 +1586,14 @@ def _retry_zones( # manually or by the cloud provider. # Optimize the case where the cluster's head IPs can be parsed # from the output of 'ray up'. - kwargs = {} if handle.launched_nodes == 1: - kwargs = { - 'internal_ips': [head_internal_ip], - 'external_ips': [head_external_ip] - } - handle.update_cluster_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS, - **kwargs) + handle.update_cluster_ips( + max_attempts=_FETCH_IP_MAX_ATTEMPTS, + internal_ips=[head_internal_ip], + external_ips=[head_external_ip]) + else: + handle.update_cluster_ips( + max_attempts=_FETCH_IP_MAX_ATTEMPTS) handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS) if cluster_exists: # Guard against the case where there's an existing cluster @@ -1991,9 +1996,20 @@ def provision_with_retries( cloud_user = None else: cloud_user = to_provision.cloud.get_current_user_identity() + + requested_features = self._requested_features.copy() + # Skip stop feature for Kubernetes jobs controller. + if (isinstance(to_provision.cloud, clouds.Kubernetes) and + controller_utils.Controllers.from_name(cluster_name) + == controller_utils.Controllers.JOBS_CONTROLLER): + assert (clouds.CloudImplementationFeatures.STOP + in requested_features), requested_features + requested_features.remove( + clouds.CloudImplementationFeatures.STOP) + # Skip if to_provision.cloud does not support requested features to_provision.cloud.check_features_are_supported( - to_provision, self._requested_features) + to_provision, requested_features) config_dict = self._retry_zones( to_provision, @@ -2114,7 +2130,7 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle): """ # Bump if any fields get added/removed/changed, and add backward # compaitibility logic in __setstate__. - _VERSION = 7 + _VERSION = 8 def __init__( self, @@ -2127,6 +2143,7 @@ def __init__( stable_internal_external_ips: Optional[List[Tuple[str, str]]] = None, stable_ssh_ports: Optional[List[int]] = None, + cluster_info: Optional[provision_common.ClusterInfo] = None, # The following 2 fields are deprecated. SkyPilot new provisioner # API handles the TPU node creation/deletion. # Backward compatibility for TPU nodes created before #2943. @@ -2143,10 +2160,10 @@ def __init__( # internal or external ips, depending on the use_internal_ips flag. self.stable_internal_external_ips = stable_internal_external_ips self.stable_ssh_ports = stable_ssh_ports + self.cached_cluster_info = cluster_info self.launched_nodes = launched_nodes self.launched_resources = launched_resources self.docker_user: Optional[str] = None - self.ssh_user: Optional[str] = None # Deprecated. SkyPilot new provisioner API handles the TPU node # creation/deletion. # Backward compatibility for TPU nodes created before #2943. @@ -2210,13 +2227,8 @@ def update_ssh_ports(self, max_attempts: int = 1) -> None: Use this method to use any cloud-specific port fetching logic. """ del max_attempts # Unused. - if isinstance(self.launched_resources.cloud, clouds.RunPod): - cluster_info = provision_lib.get_cluster_info( - str(self.launched_resources.cloud).lower(), - region=self.launched_resources.region, - cluster_name_on_cloud=self.cluster_name_on_cloud, - provider_config=None) - self.stable_ssh_ports = cluster_info.get_ssh_ports() + if self.cached_cluster_info is not None: + self.stable_ssh_ports = self.cached_cluster_info.get_ssh_ports() return head_ssh_port = 22 @@ -2224,11 +2236,49 @@ def update_ssh_ports(self, max_attempts: int = 1) -> None: [head_ssh_port] + [22] * (self.num_ips_per_node * self.launched_nodes - 1)) + def _update_cluster_info(self): + # When a cluster is on a cloud that does not support the new + # provisioner, we should skip updating cluster_info. + if (self.launched_resources.cloud.PROVISIONER_VERSION >= + clouds.ProvisionerVersion.SKYPILOT): + provider_name = str(self.launched_resources.cloud).lower() + config = {} + if os.path.exists(self.cluster_yaml): + # It is possible that the cluster yaml is not available when + # the handle is unpickled for service replicas from the + # controller with older version. + config = common_utils.read_yaml(self.cluster_yaml) + try: + cluster_info = provision_lib.get_cluster_info( + provider_name, + region=self.launched_resources.region, + cluster_name_on_cloud=self.cluster_name_on_cloud, + provider_config=config.get('provider', None)) + except Exception as e: # pylint: disable=broad-except + # This could happen when the VM is not fully launched, and a + # user is trying to terminate it with `sky down`. + logger.debug('Failed to get cluster info for ' + f'{self.cluster_name} from the new provisioner ' + f'with {common_utils.format_exception(e)}.') + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.HEAD) from e + if cluster_info.num_instances != self.launched_nodes: + logger.debug( + f'Available nodes in the cluster {self.cluster_name} ' + 'do not match the number of nodes requested (' + f'{cluster_info.num_instances} != ' + f'{self.launched_nodes}).') + raise exceptions.FetchClusterInfoError( + exceptions.FetchClusterInfoError.Reason.HEAD) + self.cached_cluster_info = cluster_info + def update_cluster_ips( self, max_attempts: int = 1, internal_ips: Optional[List[Optional[str]]] = None, - external_ips: Optional[List[Optional[str]]] = None) -> None: + external_ips: Optional[List[Optional[str]]] = None, + cluster_info: Optional[provision_common.ClusterInfo] = None + ) -> None: """Updates the cluster IPs cached in the handle. We cache the cluster IPs in the handle to avoid having to retrieve @@ -2254,61 +2304,74 @@ def update_cluster_ips( external IPs from the cloud provider. Raises: - exceptions.FetchIPError: if we failed to get the IPs. e.reason is - HEAD or WORKER. + exceptions.FetchClusterInfoError: if we failed to get the cluster + infos. e.reason is HEAD or WORKER. """ - - def is_provided_ips_valid(ips: Optional[List[Optional[str]]]) -> bool: - return (ips is not None and - len(ips) == self.num_ips_per_node * self.launched_nodes and - all(ip is not None for ip in ips)) - - use_internal_ips = self._use_internal_ips() - - # cluster_feasible_ips is the list of IPs of the nodes in the cluster - # which can be used to connect to the cluster. It is a list of external - # IPs if the cluster is assigned public IPs, otherwise it is a list of - # internal IPs. - cluster_feasible_ips: List[str] - if is_provided_ips_valid(external_ips): - logger.debug(f'Using provided external IPs: {external_ips}') - cluster_feasible_ips = typing.cast(List[str], external_ips) + if cluster_info is not None: + self.cached_cluster_info = cluster_info + use_internal_ips = self._use_internal_ips() + cluster_feasible_ips = self.cached_cluster_info.get_feasible_ips( + use_internal_ips) + cluster_internal_ips = self.cached_cluster_info.get_feasible_ips( + force_internal_ips=True) else: - cluster_feasible_ips = backend_utils.get_node_ips( - self.cluster_yaml, - self.launched_nodes, - head_ip_max_attempts=max_attempts, - worker_ip_max_attempts=max_attempts, - get_internal_ips=use_internal_ips) - - if self.cached_external_ips == cluster_feasible_ips: - logger.debug('Skipping the fetching of internal IPs as the cached ' - 'external IPs matches the newly fetched ones.') - # Optimization: If the cached external IPs are the same as the - # retrieved feasible IPs, then we can skip retrieving internal - # IPs since the cached IPs are up-to-date. - return - logger.debug( - 'Cached external IPs do not match with the newly fetched ones: ' - f'cached ({self.cached_external_ips}), new ({cluster_feasible_ips})' - ) + # For clouds that do not support the SkyPilot Provisioner API. + # TODO(zhwu): once all the clouds are migrated to SkyPilot + # Provisioner API, we should remove this else block + def is_provided_ips_valid( + ips: Optional[List[Optional[str]]]) -> bool: + return (ips is not None and len(ips) + == self.num_ips_per_node * self.launched_nodes and + all(ip is not None for ip in ips)) + + use_internal_ips = self._use_internal_ips() + + # cluster_feasible_ips is the list of IPs of the nodes in the + # cluster which can be used to connect to the cluster. It is a list + # of external IPs if the cluster is assigned public IPs, otherwise + # it is a list of internal IPs. + if is_provided_ips_valid(external_ips): + logger.debug(f'Using provided external IPs: {external_ips}') + cluster_feasible_ips = typing.cast(List[str], external_ips) + else: + cluster_feasible_ips = backend_utils.get_node_ips( + self.cluster_yaml, + self.launched_nodes, + head_ip_max_attempts=max_attempts, + worker_ip_max_attempts=max_attempts, + get_internal_ips=use_internal_ips) + + if self.cached_external_ips == cluster_feasible_ips: + logger.debug( + 'Skipping the fetching of internal IPs as the cached ' + 'external IPs matches the newly fetched ones.') + # Optimization: If the cached external IPs are the same as the + # retrieved feasible IPs, then we can skip retrieving internal + # IPs since the cached IPs are up-to-date. + return - if use_internal_ips: - # Optimization: if we know use_internal_ips is True (currently - # only exposed for AWS and GCP), then our provisioner is guaranteed - # to not assign public IPs, thus the first list of IPs returned - # above are already private IPs. So skip the second query. - cluster_internal_ips = list(cluster_feasible_ips) - elif is_provided_ips_valid(internal_ips): - logger.debug(f'Using provided internal IPs: {internal_ips}') - cluster_internal_ips = typing.cast(List[str], internal_ips) - else: - cluster_internal_ips = backend_utils.get_node_ips( - self.cluster_yaml, - self.launched_nodes, - head_ip_max_attempts=max_attempts, - worker_ip_max_attempts=max_attempts, - get_internal_ips=True) + logger.debug( + 'Cached external IPs do not match with the newly fetched ones: ' + f'cached ({self.cached_external_ips}), new ' + f'({cluster_feasible_ips})') + + if use_internal_ips: + # Optimization: if we know use_internal_ips is True (currently + # only exposed for AWS and GCP), then our provisioner is + # guaranteed to not assign public IPs, thus the first list of + # IPs returned above are already private IPs. So skip the second + # query. + cluster_internal_ips = list(cluster_feasible_ips) + elif is_provided_ips_valid(internal_ips): + logger.debug(f'Using provided internal IPs: {internal_ips}') + cluster_internal_ips = typing.cast(List[str], internal_ips) + else: + cluster_internal_ips = backend_utils.get_node_ips( + self.cluster_yaml, + self.launched_nodes, + head_ip_max_attempts=max_attempts, + worker_ip_max_attempts=max_attempts, + get_internal_ips=True) assert len(cluster_feasible_ips) == len(cluster_internal_ips), ( f'Cluster {self.cluster_name!r}:' @@ -2327,6 +2390,39 @@ def is_provided_ips_valid(ips: Optional[List[Optional[str]]]) -> bool: internal_external_ips[1:], key=lambda x: x[1]) self.stable_internal_external_ips = stable_internal_external_ips + @functools.lru_cache() + @timeline.event + def get_command_runners(self, + force_cached: bool = False, + avoid_ssh_control: bool = False + ) -> List[command_runner.CommandRunner]: + """Returns a list of command runners for the cluster.""" + ssh_credentials = backend_utils.ssh_credential_from_yaml( + self.cluster_yaml, self.docker_user, self.ssh_user) + if avoid_ssh_control: + ssh_credentials.pop('ssh_control_name', None) + if (clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR >= + self.launched_resources.cloud.PROVISIONER_VERSION): + ip_list = (self.cached_external_ips + if force_cached else self.external_ips()) + if ip_list is None: + return [] + # Potentially refresh the external SSH ports, in case the existing + # cluster before #2491 was launched without external SSH ports + # cached. + port_list = self.external_ssh_ports() + runners = command_runner.SSHCommandRunner.make_runner_list( + zip(ip_list, port_list), **ssh_credentials) + return runners + if self.cached_cluster_info is None: + assert not force_cached, 'cached_cluster_info is None.' + self._update_cluster_info() + assert self.cached_cluster_info is not None, self + runners = provision_lib.get_command_runners( + self.cached_cluster_info.provider_name, self.cached_cluster_info, + **ssh_credentials) + return runners + @property def cached_internal_ips(self) -> Optional[List[str]]: if self.stable_internal_external_ips is not None: @@ -2392,6 +2488,16 @@ def setup_docker_user(self, cluster_config_file: str): def cluster_yaml(self): return os.path.expanduser(self._cluster_yaml) + @property + def ssh_user(self): + if self.cached_cluster_info is not None: + # Overload ssh_user with the user stored in cluster_info, which is + # useful for kubernetes case, where the ssh_user can depend on the + # container image used. For those clusters launched with ray + # autoscaler, we directly use the ssh_user in yaml config. + return self.cached_cluster_info.ssh_user + return None + @property def head_ip(self): external_ips = self.cached_external_ips @@ -2437,8 +2543,8 @@ def __setstate__(self, state): if version < 6: state['cluster_name_on_cloud'] = state['cluster_name'] - if version < 7: - self.ssh_user = None + if version < 8: + self.cached_cluster_info = None self.__dict__.update(state) @@ -2448,7 +2554,7 @@ def __setstate__(self, state): if version < 3 and head_ip is not None: try: self.update_cluster_ips() - except exceptions.FetchIPError: + except exceptions.FetchClusterInfoError: # This occurs when an old cluster from was autostopped, # so the head IP in the database is not updated. pass @@ -2457,6 +2563,14 @@ def __setstate__(self, state): self._update_cluster_region() + if version < 8: + try: + self._update_cluster_info() + except exceptions.FetchClusterInfoError: + # This occurs when an old cluster from was autostopped, + # so the head IP in the database is not updated. + pass + class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']): """Backend: runs on cloud virtual machines, managed by Ray. @@ -2748,22 +2862,17 @@ def _provision( provision_record=provision_record, custom_resource=resources_vars.get('custom_resources'), log_dir=self.log_dir) - # We must query the IPs from the cloud provider, when the - # provisioning is done, to make sure the cluster IPs are - # up-to-date. + # We use the IPs from the cluster_info to update_cluster_ips, + # when the provisioning is done, to make sure the cluster IPs + # are up-to-date. # The staled IPs may be caused by the node being restarted # manually or by the cloud provider. # Optimize the case where the cluster's IPs can be retrieved # from cluster_info. - internal_ips, external_ips = zip(*cluster_info.ip_tuples()) - if not cluster_info.has_external_ips(): - external_ips = internal_ips + handle.docker_user = cluster_info.docker_user handle.update_cluster_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS, - internal_ips=list(internal_ips), - external_ips=list(external_ips)) + cluster_info=cluster_info) handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS) - handle.docker_user = cluster_info.docker_user - handle.ssh_user = cluster_info.ssh_user # Update launched resources. handle.launched_resources = handle.launched_resources.copy( @@ -2795,11 +2904,7 @@ def _provision( handle.launched_resources.cloud.get_zone_shell_cmd()) # zone is None for Azure if get_zone_cmd is not None: - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, - handle.ssh_user) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=ssh_port_list, **ssh_credentials) + runners = handle.get_command_runners() def _get_zone(runner): retry_count = 0 @@ -2840,8 +2945,11 @@ def _get_zone(runner): logger.debug('Checking if skylet is running on the head node.') with rich_utils.safe_status( '[bold cyan]Preparing SkyPilot runtime'): + # We need to source bashrc for skylet to make sure the autostop + # event can access the path to the cloud CLIs. self.run_on_head(handle, - instance_setup.MAYBE_SKYLET_RESTART_CMD) + instance_setup.MAYBE_SKYLET_RESTART_CMD, + source_bashrc=True) self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, @@ -2940,7 +3048,6 @@ def _sync_workdir(self, handle: CloudVmRayResourceHandle, fore = colorama.Fore style = colorama.Style ip_list = handle.external_ips() - port_list = handle.external_ssh_ports() assert ip_list is not None, 'external_ips is not cached in handle' full_workdir = os.path.abspath(os.path.expanduser(workdir)) @@ -2965,14 +3072,10 @@ def _sync_workdir(self, handle: CloudVmRayResourceHandle, log_path = os.path.join(self.log_dir, 'workdir_sync.log') - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - # TODO(zhwu): refactor this with backend_utils.parallel_cmd_with_rsync - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=port_list, **ssh_credentials) + runners = handle.get_command_runners() - def _sync_workdir_node(runner: command_runner.SSHCommandRunner) -> None: + def _sync_workdir_node(runner: command_runner.CommandRunner) -> None: runner.rsync( source=workdir, target=SKY_REMOTE_WORKDIR, @@ -3019,29 +3122,18 @@ def _setup(self, handle: CloudVmRayResourceHandle, task: task_lib.Task, return setup = task.setup # Sync the setup script up and run it. - ip_list = handle.external_ips() internal_ips = handle.internal_ips() - port_list = handle.external_ssh_ports() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - # Disable connection sharing for setup script to avoid old - # connections being reused, which may cause stale ssh agent - # forwarding. - ssh_credentials.pop('ssh_control_name', None) - remote_setup_file_name = f'/tmp/sky_setup_{self.run_timestamp}' # Need this `-i` option to make sure `source ~/.bashrc` work setup_cmd = f'/bin/bash -i {remote_setup_file_name} 2>&1' + runners = handle.get_command_runners(avoid_ssh_control=True) def _setup_node(node_id: int) -> None: setup_envs = task.envs.copy() setup_envs.update(self._skypilot_predefined_env_vars(handle)) setup_envs['SKYPILOT_SETUP_NODE_IPS'] = '\n'.join(internal_ips) setup_envs['SKYPILOT_SETUP_NODE_RANK'] = str(node_id) - runner = command_runner.SSHCommandRunner(ip_list[node_id], - port=port_list[node_id], - **ssh_credentials) + runner = runners[node_id] setup_script = log_lib.make_task_bash_script(setup, env_vars=setup_envs) with tempfile.NamedTemporaryFile('w', prefix='sky_setup_') as f: @@ -3055,11 +3147,12 @@ def _setup_node(node_id: int) -> None: if detach_setup: return setup_log_path = os.path.join(self.log_dir, - f'setup-{runner.ip}.log') + f'setup-{runner.node_id}.log') returncode = runner.run( setup_cmd, log_path=setup_log_path, process_stream=False, + source_bashrc=True, ) def error_message() -> str: @@ -3089,7 +3182,7 @@ def error_message() -> str: command=setup_cmd, error_msg=error_message) - num_nodes = len(ip_list) + num_nodes = len(runners) plural = 's' if num_nodes > 1 else '' if not detach_setup: logger.info(f'{fore.CYAN}Running setup on {num_nodes} node{plural}.' @@ -3144,6 +3237,32 @@ def _exec_code_on_head( code = job_lib.JobLibCodeGen.queue_job(job_id, job_submit_cmd) job_submit_cmd = ' && '.join([mkdir_code, create_script_code, code]) + if len(job_submit_cmd) > 120 * 1024: + # The maximum size of a command line arguments is 128 KB, i.e. the + # command executed with /bin/sh should be less than 128KB. + # https://github.com/torvalds/linux/blob/master/include/uapi/linux/binfmts.h + # If a user have very long run or setup commands, the generated + # command may exceed the limit, as we encode the script in base64 + # and directly include it in the job submission command. If the + # command is too long, we instead write it to a file, rsync and + # execute it. + # We use 120KB as a threshold to be safe for other arguments that + # might be added during ssh. + runners = handle.get_command_runners() + head_runner = runners[0] + with tempfile.NamedTemporaryFile('w', prefix='sky_app_') as fp: + fp.write(codegen) + fp.flush() + script_path = os.path.join(SKY_REMOTE_APP_DIR, + f'sky_job_{job_id}') + # We choose to sync code + exec, because the alternative of 'ray + # submit' may not work as it may use system python (python2) to + # execute the script. Happens for AWS. + head_runner.rsync(source=fp.name, + target=script_path, + up=True, + stream_logs=False) + job_submit_cmd = f'{mkdir_code} && {code}' if managed_job_dag is not None: # Add the managed job to job queue database. @@ -3498,15 +3617,7 @@ def sync_down_logs( logger.info(f'{fore.CYAN}Job {job_id} logs: {log_dir}' f'{style.RESET_ALL}') - ip_list = handle.external_ips() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_port_list = handle.external_ssh_ports() - assert ssh_port_list is not None, 'external_ssh_ports is not cached ' \ - 'in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=ssh_port_list, **ssh_credentials) + runners = handle.get_command_runners() def _rsync_down(args) -> None: """Rsync down logs from remote nodes. @@ -3700,7 +3811,7 @@ def teardown_no_lock(self, # even when the command was executed successfully. self.run_on_head(handle, f'{constants.SKY_RAY_CMD} stop --force') - except exceptions.FetchIPError: + except exceptions.FetchClusterInfoError: # This error is expected if the previous cluster IP is # failed to be found, # i.e., the cluster is already stopped/terminated. @@ -3977,6 +4088,14 @@ def post_teardown_cleanup(self, pass except exceptions.PortDoesNotExistError: logger.debug('Ports do not exist. Skipping cleanup.') + except Exception as e: # pylint: disable=broad-except + if purge: + logger.warning( + f'Failed to cleanup ports. Skipping since purge is ' + f'set. Details: ' + f'{common_utils.format_exception(e, use_bracket=True)}') + else: + raise # The cluster file must exist because the cluster_yaml will only # be removed after the cluster entry in the database is removed. @@ -4015,6 +4134,17 @@ def set_autostop(self, # The core.autostop() function should have already checked that the # cloud and resources support requested autostop. if idle_minutes_to_autostop is not None: + # Skip auto-stop for Kubernetes clusters. + if (isinstance(handle.launched_resources.cloud, clouds.Kubernetes) + and not down and idle_minutes_to_autostop >= 0): + # We should hit this code path only for the jobs controller on + # Kubernetes clusters. + assert (controller_utils.Controllers.from_name( + handle.cluster_name) == controller_utils.Controllers. + JOBS_CONTROLLER), handle.cluster_name + logger.info('Auto-stop is not supported for Kubernetes ' + 'clusters. Skipping.') + return # Check if we're stopping spot assert (handle.launched_resources is not None and @@ -4077,6 +4207,7 @@ def run_on_head( require_outputs: bool = False, separate_stderr: bool = False, process_stream: bool = True, + source_bashrc: bool = False, **kwargs, ) -> Union[int, Tuple[int, str, str]]: """Runs 'cmd' on the cluster's head node. @@ -4102,6 +4233,9 @@ def run_on_head( process_stream: Whether to post-process the stdout/stderr of the command, such as replacing or skipping lines on the fly. If enabled, lines are printed only when '\r' or '\n' is found. + source_bashrc: Whether to source bashrc when running on the command + on the VM. If it is a user-related commands, it would always be + good to source bashrc to make sure the env vars are set. Returns: returncode @@ -4109,24 +4243,17 @@ def run_on_head( A tuple of (returncode, stdout, stderr). Raises: - exceptions.FetchIPError: If the head node IP cannot be fetched. + exceptions.FetchClusterInfoError: If the cluster info cannot be + fetched. """ # This will try to fetch the head node IP if it is not cached. - external_ips = handle.external_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS) - head_ip = external_ips[0] - external_ssh_ports = handle.external_ssh_ports( - max_attempts=_FETCH_IP_MAX_ATTEMPTS) - head_ssh_port = external_ssh_ports[0] - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - runner = command_runner.SSHCommandRunner(head_ip, - port=head_ssh_port, - **ssh_credentials) + runners = handle.get_command_runners() + head_runner = runners[0] if under_remote_workdir: cmd = f'cd {SKY_REMOTE_WORKDIR} && {cmd}' - return runner.run( + return head_runner.run( cmd, port_forward=port_forward, log_path=log_path, @@ -4135,6 +4262,7 @@ def run_on_head( ssh_mode=ssh_mode, require_outputs=require_outputs, separate_stderr=separate_stderr, + source_bashrc=source_bashrc, **kwargs, ) @@ -4278,13 +4406,7 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, style = colorama.Style logger.info(f'{fore.CYAN}Processing file mounts.{style.RESET_ALL}') start = time.time() - ip_list = handle.external_ips() - port_list = handle.external_ssh_ports() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=port_list, **ssh_credentials) + runners = handle.get_command_runners() log_path = os.path.join(self.log_dir, 'file_mounts.log') # Check the files and warn @@ -4380,12 +4502,21 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, action_message='Syncing', log_path=log_path, stream_logs=False, + # Need to source bashrc, as the cloud specific CLI or SDK may + # require PATH in bashrc. + source_bashrc=True, ) # (2) Run the commands to create symlinks on all the nodes. symlink_command = ' && '.join(symlink_commands) if symlink_command: - - def _symlink_node(runner: command_runner.SSHCommandRunner): + # ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD sets sudo to empty string for + # root. We need this as we do not source bashrc for the command for + # better performance, and our sudo handling is only in bashrc. + symlink_command = ( + f'{command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD} && ' + f'{symlink_command}') + + def _symlink_node(runner: command_runner.CommandRunner): returncode = runner.run(symlink_command, log_path=log_path) subprocess_utils.handle_returncode( returncode, symlink_command, @@ -4425,13 +4556,7 @@ def _execute_storage_mounts( logger.info(f'{fore.CYAN}Processing {len(storage_mounts)} ' f'storage mount{plural}.{style.RESET_ALL}') start = time.time() - ip_list = handle.external_ips() - port_list = handle.external_ssh_ports() - assert ip_list is not None, 'external_ips is not cached in handle' - ssh_credentials = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, handle.docker_user, handle.ssh_user) - runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list, port_list=port_list, **ssh_credentials) + runners = handle.get_command_runners() log_path = os.path.join(self.log_dir, 'storage_mounts.log') for dst, storage_obj in storage_mounts.items(): @@ -4462,6 +4587,9 @@ def _execute_storage_mounts( run_rsync=False, action_message='Mounting', log_path=log_path, + # Need to source bashrc, as the cloud specific CLI or SDK + # may require PATH in bashrc. + source_bashrc=True, ) except exceptions.CommandError as e: if e.returncode == exceptions.MOUNT_PATH_NON_EMPTY_CODE: diff --git a/sky/benchmark/benchmark_utils.py b/sky/benchmark/benchmark_utils.py index 9fd0c529453..e1323bb714a 100644 --- a/sky/benchmark/benchmark_utils.py +++ b/sky/benchmark/benchmark_utils.py @@ -172,7 +172,7 @@ def _create_benchmark_bucket() -> Tuple[str, str]: raise_if_no_cloud_access=True) # Already checked by raise_if_no_cloud_access=True. assert enabled_clouds - bucket_type = data.StoreType.from_cloud(enabled_clouds[0]) + bucket_type = data.StoreType.from_cloud(enabled_clouds[0]).value # Create a benchmark bucket. logger.info(f'Creating a bucket {bucket_name} to save the benchmark logs.') diff --git a/sky/cli.py b/sky/cli.py index 485703e4caf..2e863f2eef7 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -52,7 +52,6 @@ from sky import exceptions from sky import global_user_state from sky import jobs as managed_jobs -from sky import provision as provision_lib from sky import serve as serve_lib from sky import sky_logging from sky import status_lib @@ -1650,71 +1649,28 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool, head_ip = handle.external_ips()[0] if show_endpoints: - launched_resources = handle.launched_resources - cloud = launched_resources.cloud - try: - cloud.check_features_are_supported( - launched_resources, - {clouds.CloudImplementationFeatures.OPEN_PORTS}) - except exceptions.NotSupportedError: - with ux_utils.print_exception_no_traceback(): - raise ValueError('Querying endpoints is not supported ' - f'for {cloud}.') from None - - config = common_utils.read_yaml(handle.cluster_yaml) - port_details = provision_lib.query_ports( - repr(cloud), handle.cluster_name_on_cloud, - handle.launched_resources.ports, config['provider']) - - if endpoint is not None: - # If cluster had no ports to be exposed - ports_set = resources_utils.port_ranges_to_set( - handle.launched_resources.ports) - if endpoint not in ports_set: - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Port {endpoint} is not exposed ' - 'on cluster ' - f'{cluster_record["name"]!r}.') - # If the user requested a specific port endpoint - if endpoint not in port_details: - error_msg = (f'Port {endpoint} not exposed yet. ' - f'{_ENDPOINTS_RETRY_MESSAGE} ') - if handle.launched_resources.cloud.is_same_cloud( - clouds.Kubernetes()): - # Add Kubernetes specific debugging info - error_msg += ( - kubernetes_utils.get_endpoint_debug_message()) - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(error_msg) - click.echo(port_details[endpoint][0].url(ip=head_ip)) - return - - if not port_details: - # If cluster had no ports to be exposed - if handle.launched_resources.ports is None: - with ux_utils.print_exception_no_traceback(): - raise ValueError('Cluster does not have any ports ' - 'to be exposed.') - # Else wait for the ports to be exposed - else: - error_msg = (f'No endpoints exposed yet. ' - f'{_ENDPOINTS_RETRY_MESSAGE} ') - if handle.launched_resources.cloud.is_same_cloud( - clouds.Kubernetes()): - # Add Kubernetes specific debugging info - error_msg += \ - kubernetes_utils.get_endpoint_debug_message() - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(error_msg) - - for port, urls in port_details.items(): - click.echo( - f'{colorama.Fore.BLUE}{colorama.Style.BRIGHT}{port}' - f'{colorama.Style.RESET_ALL}: ' - f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' - f'{urls[0].url(ip=head_ip)}{colorama.Style.RESET_ALL}') + if endpoint: + cluster_endpoint = core.endpoints(cluster_record['name'], + endpoint).get( + endpoint, None) + if not cluster_endpoint: + raise click.Abort( + f'Endpoint {endpoint} not found for cluster ' + f'{cluster_record["name"]!r}.') + click.echo(cluster_endpoint) + else: + cluster_endpoints = core.endpoints(cluster_record['name']) + assert isinstance(cluster_endpoints, dict) + if not cluster_endpoints: + raise click.Abort(f'No endpoint found for cluster ' + f'{cluster_record["name"]!r}.') + for port, port_endpoint in cluster_endpoints.items(): + click.echo( + f'{colorama.Fore.BLUE}{colorama.Style.BRIGHT}{port}' + f'{colorama.Style.RESET_ALL}: ' + f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' + f'{port_endpoint}{colorama.Style.RESET_ALL}') return - click.echo(head_ip) return hints = [] @@ -2590,6 +2546,17 @@ def down( def _hint_or_raise_for_down_jobs_controller(controller_name: str): + """Helper function to check job controller status before tearing it down. + + Raises helpful exceptions and errors if the controller is not in a safe + state to be torn down. + + Raises: + RuntimeError: if failed to get the job queue. + exceptions.NotSupportedError: if the controller is not in a safe state + to be torn down (e.g., because it has jobs running or + it is in init state) + """ controller = controller_utils.Controllers.from_name(controller_name) assert controller is not None, controller_name @@ -2633,6 +2600,17 @@ def _hint_or_raise_for_down_jobs_controller(controller_name: str): def _hint_or_raise_for_down_sky_serve_controller(controller_name: str): + """Helper function to check serve controller status before tearing it down. + + Raises helpful exceptions and errors if the controller is not in a safe + state to be torn down. + + Raises: + RuntimeError: if failed to get the service status. + exceptions.NotSupportedError: if the controller is not in a safe state + to be torn down (e.g., because it has services running or + it is in init state) + """ controller = controller_utils.Controllers.from_name(controller_name) assert controller is not None, controller_name with rich_utils.safe_status('[bold cyan]Checking for live services[/]'): @@ -2756,7 +2734,8 @@ def _down_or_stop_clusters( # managed job or service. We should make this check atomic # with the termination. hint_or_raise(controller_name) - except exceptions.ClusterOwnerIdentityMismatchError as e: + except (exceptions.ClusterOwnerIdentityMismatchError, + RuntimeError) as e: if purge: click.echo(common_utils.format_exception(e)) else: @@ -2998,6 +2977,11 @@ def _output(): name, quantity = None, None + # Kubernetes specific bools + cloud_is_kubernetes = isinstance(cloud_obj, clouds.Kubernetes) + kubernetes_autoscaling = kubernetes_utils.get_autoscaler_type( + ) is not None + if accelerator_str is None: result = service_catalog.list_accelerator_counts( gpus_only=True, @@ -3005,16 +2989,17 @@ def _output(): region_filter=region, ) - if (len(result) == 0 and cloud_obj is not None and - cloud_obj.is_same_cloud(clouds.Kubernetes())): + if len(result) == 0 and cloud_is_kubernetes: yield kubernetes_utils.NO_GPU_ERROR_MESSAGE + if kubernetes_autoscaling: + yield '\n' + yield kubernetes_utils.KUBERNETES_AUTOSCALER_NOTE return # "Common" GPUs # If cloud is kubernetes, we want to show all GPUs here, even if # they are not listed as common in SkyPilot. - if (cloud_obj is not None and - cloud_obj.is_same_cloud(clouds.Kubernetes())): + if cloud_is_kubernetes: for gpu, _ in sorted(result.items()): gpu_table.add_row([gpu, _list_to_str(result.pop(gpu))]) else: @@ -3038,9 +3023,16 @@ def _output(): other_table.add_row([gpu, _list_to_str(qty)]) yield from other_table.get_string() yield '\n\n' + if (cloud_is_kubernetes or + cloud is None) and kubernetes_autoscaling: + yield kubernetes_utils.KUBERNETES_AUTOSCALER_NOTE + yield '\n\n' else: yield ('\n\nHint: use -a/--all to see all accelerators ' '(including non-common ones) and pricing.') + if (cloud_is_kubernetes or + cloud is None) and kubernetes_autoscaling: + yield kubernetes_utils.KUBERNETES_AUTOSCALER_NOTE return else: # Parse accelerator string diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 1fef481d8d0..b2d76e7b7df 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -81,6 +81,8 @@ class AWSIdentityType(enum.Enum): IAM_ROLE = 'iam-role' + CONTAINER_ROLE = 'container-role' + # Name Value Type Location # ---- ----- ---- -------- # profile None None @@ -545,6 +547,12 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: # jobs-controller) created by an SSO account, i.e. the VM will be # assigned the IAM role: skypilot-v1. hints = f'AWS IAM role is set.{single_cloud_hint}' + elif identity_type == AWSIdentityType.CONTAINER_ROLE: + # Similar to the IAM ROLE, an ECS container may not store credentials + # in the~/.aws/credentials file. So we don't check for the existence of + # the file. i.e. the container will be assigned the IAM role of the + # task: skypilot-v1. + hints = f'AWS container-role is set.{single_cloud_hint}' else: # This file is required because it is required by the VMs launched on # other clouds to access private s3 buckets and resources like EC2. @@ -604,6 +612,8 @@ def _is_access_key_of_type(type_str: str) -> bool: return AWSIdentityType.SSO elif _is_access_key_of_type(AWSIdentityType.IAM_ROLE.value): return AWSIdentityType.IAM_ROLE + elif _is_access_key_of_type(AWSIdentityType.CONTAINER_ROLE.value): + return AWSIdentityType.CONTAINER_ROLE elif _is_access_key_of_type(AWSIdentityType.ENV.value): return AWSIdentityType.ENV else: diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 9777a28948b..5740e0ed9b1 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -14,6 +14,7 @@ from sky.provision.kubernetes import utils as kubernetes_utils from sky.utils import common_utils from sky.utils import resources_utils +from sky.utils import schemas if typing.TYPE_CHECKING: # Renaming to avoid shadowing variables. @@ -36,9 +37,8 @@ class Kubernetes(clouds.Cloud): """Kubernetes.""" SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys' - SKY_SSH_KEY_SECRET_FIELD_NAME = \ - f'ssh-publickey-{common_utils.get_user_hash()}' SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod' + SKY_DEFAULT_SERVICE_ACCOUNT_NAME = 'skypilot-service-account' PORT_FORWARD_PROXY_CMD_TEMPLATE = \ 'kubernetes-port-forward-proxy-command.sh.j2' PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh' @@ -52,6 +52,8 @@ class Kubernetes(clouds.Cloud): timeout = skypilot_config.get_nested(['kubernetes', 'provision_timeout'], 10) + _SUPPORTS_SERVICE_ACCOUNT_ON_REMOTE = True + _DEFAULT_NUM_VCPUS = 2 _DEFAULT_MEMORY_CPU_RATIO = 1 _DEFAULT_MEMORY_CPU_RATIO_WITH_GPU = 4 # Allocate more memory for GPU tasks @@ -71,13 +73,6 @@ class Kubernetes(clouds.Cloud): 'tiers are not ' 'supported in ' 'Kubernetes.', - # Kubernetes may be using exec-based auth, which may not work by - # directly copying the kubeconfig file to the controller. - # Support for service accounts for auth will be added in #3377, which - # will allow us to support hosting controllers. - clouds.CloudImplementationFeatures.HOST_CONTROLLERS: 'Kubernetes can ' - 'not host ' - 'controllers.', } IMAGE_CPU = 'skypilot:cpu-ubuntu-2004' @@ -86,11 +81,24 @@ class Kubernetes(clouds.Cloud): PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT STATUS_VERSION = clouds.StatusVersion.SKYPILOT + @property + def ssh_key_secret_field_name(self): + # Use a fresh user hash to avoid conflicts in the secret object naming. + # This can happen when the controller is reusing the same user hash + # through USER_ID_ENV_VAR but has a different SSH key. + fresh_user_hash = common_utils.get_user_hash(force_fresh_hash=True) + return f'ssh-publickey-{fresh_user_hash}' + @classmethod def _unsupported_features_for_resources( cls, resources: 'resources_lib.Resources' ) -> Dict[clouds.CloudImplementationFeatures, str]: unsupported_features = cls._CLOUD_UNSUPPORTED_FEATURES + is_exec_auth, message = kubernetes_utils.is_kubeconfig_exec_auth() + if is_exec_auth: + assert isinstance(message, str), message + unsupported_features[ + clouds.CloudImplementationFeatures.HOST_CONTROLLERS] = message return unsupported_features @classmethod @@ -261,6 +269,23 @@ def make_deploy_resources_variables( port_mode = network_utils.get_port_mode(None) + remote_identity = skypilot_config.get_nested( + ('kubernetes', 'remote_identity'), schemas.REMOTE_IDENTITY_DEFAULT) + if (remote_identity == + schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value): + # SA name doesn't matter since automounting credentials is disabled + k8s_service_account_name = 'default' + k8s_automount_sa_token = 'false' + elif (remote_identity == + schemas.RemoteIdentityOptions.SERVICE_ACCOUNT.value): + # Use the default service account + k8s_service_account_name = self.SKY_DEFAULT_SERVICE_ACCOUNT_NAME + k8s_automount_sa_token = 'true' + else: + # User specified a custom service account + k8s_service_account_name = remote_identity + k8s_automount_sa_token = 'true' + fuse_device_required = bool(resources.requires_fuse) deploy_vars = { @@ -279,6 +304,8 @@ def make_deploy_resources_variables( 'k8s_acc_label_value': k8s_acc_label_value, 'k8s_ssh_jump_name': self.SKY_SSH_JUMP_NAME, 'k8s_ssh_jump_image': ssh_jump_image, + 'k8s_service_account_name': k8s_service_account_name, + 'k8s_automount_sa_token': k8s_automount_sa_token, 'k8s_fuse_device_required': fuse_device_required, # Namespace to run the FUSE device manager in 'k8s_fuse_device_manager_namespace': _SKY_SYSTEM_NAMESPACE, @@ -337,30 +364,32 @@ def _make(instance_list): gpu_task_cpus, gpu_task_memory, acc_count, acc_type).name) # Check if requested instance type will fit in the cluster. - # TODO(romilb): This will fail early for autoscaling clusters. - fits, reason = kubernetes_utils.check_instance_fits( - chosen_instance_type) - if not fits: - logger.debug(f'Instance type {chosen_instance_type} does ' - 'not fit in the Kubernetes cluster. ' - f'Reason: {reason}') - return [], [] + autoscaler_type = kubernetes_utils.get_autoscaler_type() + if autoscaler_type is None: + # If autoscaler is not set, check if the instance type fits in the + # cluster. Else, rely on the autoscaler to provision the right + # instance type without running checks. Worst case, if autoscaling + # fails, the pod will be stuck in pending state until + # provision_timeout, after which failover will be triggered. + fits, reason = kubernetes_utils.check_instance_fits( + chosen_instance_type) + if not fits: + logger.debug(f'Instance type {chosen_instance_type} does ' + 'not fit in the Kubernetes cluster. ' + f'Reason: {reason}') + return [], [] # No fuzzy lists for Kubernetes return _make([chosen_instance_type]), [] @classmethod def check_credentials(cls) -> Tuple[bool, Optional[str]]: - if os.path.exists(os.path.expanduser(CREDENTIAL_PATH)): - # Test using python API - try: - return kubernetes_utils.check_credentials() - except Exception as e: # pylint: disable=broad-except - return (False, 'Credential check failed: ' - f'{common_utils.format_exception(e)}') - else: - return (False, 'Credentials not found - ' - f'check if {CREDENTIAL_PATH} exists.') + # Test using python API + try: + return kubernetes_utils.check_credentials() + except Exception as e: # pylint: disable=broad-except + return (False, 'Credential check failed: ' + f'{common_utils.format_exception(e)}') def get_credential_file_mounts(self) -> Dict[str, str]: if os.path.exists(os.path.expanduser(CREDENTIAL_PATH)): diff --git a/sky/core.py b/sky/core.py index c71a3fa9734..b5ecc483354 100644 --- a/sky/core.py +++ b/sky/core.py @@ -109,6 +109,26 @@ def status(cluster_names: Optional[Union[str, List[str]]] = None, cluster_names=cluster_names) +def endpoints(cluster: str, + port: Optional[Union[int, str]] = None) -> Dict[int, str]: + """Gets the endpoint for a given cluster and port number (endpoint). + + Args: + cluster: The name of the cluster. + port: The port number to get the endpoint for. If None, endpoints + for all ports are returned.. + + Returns: A dictionary of port numbers to endpoints. If endpoint is None, + the dictionary will contain all ports:endpoints exposed on the cluster. + + Raises: + ValueError: if the cluster is not UP or the endpoint is not exposed. + RuntimeError: if the cluster has no ports to be exposed or no endpoints + are exposed yet. + """ + return backend_utils.get_endpoints(cluster=cluster, port=port) + + @usage_lib.entrypoint def cost_report() -> List[Dict[str, Any]]: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. diff --git a/sky/data/mounting_utils.py b/sky/data/mounting_utils.py index 2f4e37a1b66..043799e5ab5 100644 --- a/sky/data/mounting_utils.py +++ b/sky/data/mounting_utils.py @@ -4,6 +4,7 @@ from typing import Optional from sky import exceptions +from sky.utils import command_runner # Values used to construct mounting commands _STAT_CACHE_TTL = '5s' @@ -129,6 +130,8 @@ def get_mounting_script( script = textwrap.dedent(f""" #!/usr/bin/env bash set -e + + {command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD} MOUNT_PATH={mount_path} MOUNT_BINARY={mount_binary} diff --git a/sky/exceptions.py b/sky/exceptions.py index e3b33ea3e5e..4fced20ce4e 100644 --- a/sky/exceptions.py +++ b/sky/exceptions.py @@ -190,8 +190,8 @@ class StorageExternalDeletionError(StorageBucketGetError): pass -class FetchIPError(Exception): - """Raised when fetching the IP fails.""" +class FetchClusterInfoError(Exception): + """Raised when fetching the cluster info fails.""" class Reason(enum.Enum): HEAD = 'HEAD' diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index 9dc73a54a53..8371fb8ad83 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -5,7 +5,7 @@ """ import functools import inspect -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from sky import sky_logging from sky import status_lib @@ -21,6 +21,7 @@ from sky.provision import kubernetes from sky.provision import runpod from sky.provision import vsphere +from sky.utils import command_runner logger = sky_logging.init_logger(__name__) @@ -41,8 +42,12 @@ def _wrapper(*args, **kwargs): module = globals().get(module_name) assert module is not None, f'Unknown provider: {module_name}' - impl = getattr(module, func.__name__) - return impl(*args, **kwargs) + impl = getattr(module, func.__name__, None) + if impl is not None: + return impl(*args, **kwargs) + + # If implementation does not exist, fall back to default implementation + return func(provider_name, *args, **kwargs) return _wrapper @@ -141,13 +146,19 @@ def query_ports( provider_name: str, cluster_name_on_cloud: str, ports: List[str], + head_ip: Optional[str] = None, provider_config: Optional[Dict[str, Any]] = None, ) -> Dict[int, List[common.Endpoint]]: """Query details about ports on a cluster. + If head_ip is provided, it may be used by the cloud implementation to + return the endpoint without querying the cloud provider. If head_ip is not + provided, the cloud provider will be queried to get the endpoint info. + Returns a dict with port as the key and a list of common.Endpoint. """ - raise NotImplementedError + del provider_name, provider_config, cluster_name_on_cloud # unused + return common.query_ports_passthrough(ports, head_ip) @_route_to_cloud_impl @@ -165,3 +176,18 @@ def get_cluster_info( provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """Get the metadata of instances in a cluster.""" raise NotImplementedError + + +@_route_to_cloud_impl +def get_command_runners( + provider_name: str, + cluster_info: common.ClusterInfo, + **crednetials: Dict[str, Any], +) -> List[command_runner.CommandRunner]: + """Get a command runner for the given cluster.""" + ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() + return command_runner.SSHCommandRunner.make_runner_list( + node_list=zip(ip_list, port_list), + **crednetials, + ) diff --git a/sky/provision/aws/__init__.py b/sky/provision/aws/__init__.py index bcbe646f219..e569d3b042e 100644 --- a/sky/provision/aws/__init__.py +++ b/sky/provision/aws/__init__.py @@ -5,7 +5,6 @@ from sky.provision.aws.instance import get_cluster_info from sky.provision.aws.instance import open_ports from sky.provision.aws.instance import query_instances -from sky.provision.aws.instance import query_ports from sky.provision.aws.instance import run_instances from sky.provision.aws.instance import stop_instances from sky.provision.aws.instance import terminate_instances diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index b9fdf80326d..e279b30c74b 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -843,7 +843,6 @@ def get_cluster_info( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" - del provider_config # unused ec2 = _default_ec2_resource(region) filters = [ { @@ -875,14 +874,6 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='aws', + provider_config=provider_config, ) - - -def query_ports( - cluster_name_on_cloud: str, - ports: List[str], - provider_config: Optional[Dict[str, Any]] = None, -) -> Dict[int, List[common.Endpoint]]: - """See sky/provision/__init__.py""" - return common.query_ports_passthrough(cluster_name_on_cloud, ports, - provider_config) diff --git a/sky/provision/azure/__init__.py b/sky/provision/azure/__init__.py index 9c87fc907db..b83dbb462d9 100644 --- a/sky/provision/azure/__init__.py +++ b/sky/provision/azure/__init__.py @@ -2,4 +2,3 @@ from sky.provision.azure.instance import cleanup_ports from sky.provision.azure.instance import open_ports -from sky.provision.azure.instance import query_ports diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index dc7b23dee5c..de5c7cbf0e9 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -4,7 +4,6 @@ from sky import sky_logging from sky.adaptors import azure -from sky.provision import common from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) @@ -94,13 +93,3 @@ def cleanup_ports( # Azure will automatically cleanup network security groups when cleanup # resource group. So we don't need to do anything here. del cluster_name_on_cloud, ports, provider_config # Unused. - - -def query_ports( - cluster_name_on_cloud: str, - ports: List[str], - provider_config: Optional[Dict[str, Any]] = None, -) -> Dict[int, List[common.Endpoint]]: - """See sky/provision/__init__.py""" - return common.query_ports_passthrough(cluster_name_on_cloud, ports, - provider_config) diff --git a/sky/provision/common.py b/sky/provision/common.py index 75178b15623..7c1bcb32652 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -4,7 +4,7 @@ import os from typing import Any, Dict, List, Optional, Tuple -from sky.utils.resources_utils import port_ranges_to_set +from sky.utils import resources_utils # NOTE: we can use pydantic instead of dataclasses or namedtuples, because # pydantic provides more features like validation or parsing from @@ -104,6 +104,10 @@ class ClusterInfo: # The unique identifier of the head instance, i.e., the # `instance_info.instance_id` of the head node. head_instance_id: Optional[InstanceId] + # Provider related information. + provider_name: str + provider_config: Optional[Dict[str, Any]] = None + docker_user: Optional[str] = None # Override the ssh_user from the cluster config. ssh_user: Optional[str] = None @@ -151,6 +155,19 @@ def ip_tuples(self) -> List[Tuple[str, Optional[str]]]: other_ips.append(pair) return head_instance_ip + other_ips + def instance_ids(self) -> List[str]: + """Return the instance ids in the same order of ip_tuples.""" + id_list = [] + if self.head_instance_id is not None: + id_list.append(self.head_instance_id + '-0') + for inst_id, instances in self.instances.items(): + start_idx = 0 + if inst_id == self.head_instance_id: + start_idx = 1 + id_list.extend( + [f'{inst_id}-{i}' for i in range(start_idx, len(instances))]) + return id_list + def has_external_ips(self) -> bool: """True if the cluster has external IP.""" ip_tuples = self.ip_tuples() @@ -186,8 +203,10 @@ def get_feasible_ips(self, force_internal_ips: bool = False) -> List[str]: def get_ssh_ports(self) -> List[int]: """Get the SSH port of all the instances.""" head_instance = self.get_head_instance() - assert head_instance is not None, self - head_instance_port = [head_instance.ssh_port] + + head_instance_port = [] + if head_instance is not None: + head_instance_port = [head_instance.ssh_port] worker_instances = self.get_worker_instances() worker_instance_ports = [ @@ -201,7 +220,7 @@ class Endpoint: pass @abc.abstractmethod - def url(self, ip: str): + def url(self, override_ip: Optional[str] = None) -> str: raise NotImplementedError @@ -211,44 +230,41 @@ class SocketEndpoint(Endpoint): port: Optional[int] host: str = '' - def url(self, ip: str): - if not self.host: - self.host = ip - return f'{self.host}{":" + str(self.port) if self.port else ""}' + def url(self, override_ip: Optional[str] = None) -> str: + host = override_ip if override_ip else self.host + return f'{host}{":" + str(self.port) if self.port else ""}' @dataclasses.dataclass class HTTPEndpoint(SocketEndpoint): - """HTTP endpoint accesible via a url.""" + """HTTP endpoint accessible via a url.""" path: str = '' - def url(self, ip: str): - del ip # Unused. - return f'http://{os.path.join(super().url(self.host), self.path)}' + def url(self, override_ip: Optional[str] = None) -> str: + host = override_ip if override_ip else self.host + return f'http://{os.path.join(super().url(host), self.path)}' @dataclasses.dataclass class HTTPSEndpoint(SocketEndpoint): - """HTTPS endpoint accesible via a url.""" + """HTTPS endpoint accessible via a url.""" path: str = '' - def url(self, ip: str): - del ip # Unused. - return f'https://{os.path.join(super().url(self.host), self.path)}' + def url(self, override_ip: Optional[str] = None) -> str: + host = override_ip if override_ip else self.host + return f'https://{os.path.join(super().url(host), self.path)}' def query_ports_passthrough( - cluster_name_on_cloud: str, ports: List[str], - provider_config: Optional[Dict[str, Any]] = None, + head_ip: Optional[str], ) -> Dict[int, List[Endpoint]]: - """Common function to query ports for AWS, GCP and Azure. + """Common function to get endpoints for AWS, GCP and Azure. - Returns a list of socket endpoint with empty host and the input ports.""" - del cluster_name_on_cloud, provider_config # Unused. - ports = list(port_ranges_to_set(ports)) + Returns a list of socket endpoint using head_ip and ports.""" + assert head_ip is not None, head_ip + ports = list(resources_utils.port_ranges_to_set(ports)) result: Dict[int, List[Endpoint]] = {} for port in ports: - result[port] = [SocketEndpoint(port=port)] - + result[port] = [SocketEndpoint(port=port, host=head_ip)] return result diff --git a/sky/provision/cudo/instance.py b/sky/provision/cudo/instance.py index e4a2db722e4..39d4bc6b3d1 100644 --- a/sky/provision/cudo/instance.py +++ b/sky/provision/cudo/instance.py @@ -162,7 +162,7 @@ def get_cluster_info( region: str, cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: - del region, provider_config + del region nodes = _filter_instances(cluster_name_on_cloud, ['runn', 'pend']) instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None @@ -178,10 +178,10 @@ def get_cluster_info( if node_info['name'].endswith('-head'): head_instance_id = node_id - return common.ClusterInfo( - instances=instances, - head_instance_id=head_instance_id, - ) + return common.ClusterInfo(instances=instances, + head_instance_id=head_instance_id, + provider_name='cudo', + provider_config=provider_config) def query_instances( diff --git a/sky/provision/docker_utils.py b/sky/provision/docker_utils.py index 8de7beab2e7..10ae5dafc07 100644 --- a/sky/provision/docker_utils.py +++ b/sky/provision/docker_utils.py @@ -3,16 +3,13 @@ import dataclasses import shlex import time -import typing from typing import Any, Dict, List from sky import sky_logging from sky.skylet import constants +from sky.utils import command_runner from sky.utils import subprocess_utils -if typing.TYPE_CHECKING: - from sky.utils import command_runner - logger = sky_logging.init_logger(__name__) DOCKER_PERMISSION_DENIED_STR = ('permission denied while trying to connect to ' @@ -117,7 +114,7 @@ class DockerInitializer: """Initializer for docker containers on a remote node.""" def __init__(self, docker_config: Dict[str, Any], - runner: 'command_runner.SSHCommandRunner', log_path: str): + runner: 'command_runner.CommandRunner', log_path: str): self.docker_config = docker_config self.container_name = docker_config['container_name'] self.runner = runner @@ -255,7 +252,8 @@ def initialize(self) -> str: # Disable apt-get from asking user input during installation. # see https://askubuntu.com/questions/909277/avoiding-user-interaction-with-tzdata-when-installing-certbot-in-a-docker-contai # pylint: disable=line-too-long self._run( - 'echo \'[ "$(whoami)" == "root" ] && alias sudo=""\' >> ~/.bashrc;' + f'echo \'{command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD}\' ' + '>> ~/.bashrc;' 'echo "export DEBIAN_FRONTEND=noninteractive" >> ~/.bashrc;', run_env='docker') # Install dependencies. diff --git a/sky/provision/fluidstack/instance.py b/sky/provision/fluidstack/instance.py index 2c0d836fadc..b37519a8458 100644 --- a/sky/provision/fluidstack/instance.py +++ b/sky/provision/fluidstack/instance.py @@ -273,7 +273,7 @@ def get_cluster_info( region: str, cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: - del region, provider_config # unused + del region # unused running_instances = _filter_instances(cluster_name_on_cloud, ['running']) instances: Dict[str, List[common.InstanceInfo]] = {} @@ -296,7 +296,9 @@ def get_cluster_info( return common.ClusterInfo(instances=instances, head_instance_id=head_instance_id, - custom_ray_options={'use_external_ip': True}) + custom_ray_options={'use_external_ip': True}, + provider_name='fluidstack', + provider_config=provider_config) def query_instances( diff --git a/sky/provision/gcp/__init__.py b/sky/provision/gcp/__init__.py index fdadd5345e2..0d24a577690 100644 --- a/sky/provision/gcp/__init__.py +++ b/sky/provision/gcp/__init__.py @@ -5,7 +5,6 @@ from sky.provision.gcp.instance import get_cluster_info from sky.provision.gcp.instance import open_ports from sky.provision.gcp.instance import query_instances -from sky.provision.gcp.instance import query_ports from sky.provision.gcp.instance import run_instances from sky.provision.gcp.instance import stop_instances from sky.provision.gcp.instance import terminate_instances diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index e7f69f8c6eb..a4996fc4d4b 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -428,6 +428,8 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='gcp', + provider_config=provider_config, ) @@ -615,13 +617,3 @@ def cleanup_ports( firewall_rule_name = provider_config['firewall_rule'] instance_utils.GCPComputeInstance.delete_firewall_rule( project_id, firewall_rule_name) - - -def query_ports( - cluster_name_on_cloud: str, - ports: List[str], - provider_config: Optional[Dict[str, Any]] = None, -) -> Dict[int, List[common.Endpoint]]: - """See sky/provision/__init__.py""" - return common.query_ports_passthrough(cluster_name_on_cloud, ports, - provider_config) diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 81e13f54fd0..1e5e6285fef 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -8,6 +8,7 @@ import time from typing import Any, Dict, List, Optional, Tuple +from sky import provision from sky import sky_logging from sky.provision import common from sky.provision import docker_utils @@ -123,24 +124,22 @@ def _parallel_ssh_with_cache(func, max_workers = subprocess_utils.get_parallel_threads() with futures.ThreadPoolExecutor(max_workers=max_workers) as pool: results = [] - for instance_id, metadatas in cluster_info.instances.items(): - for i, metadata in enumerate(metadatas): - cache_id = f'{instance_id}-{i}' - runner = command_runner.SSHCommandRunner( - metadata.get_feasible_ip(), - port=metadata.ssh_port, - **ssh_credentials) - wrapper = metadata_utils.cache_func(cluster_name, cache_id, - stage_name, digest) - if (cluster_info.head_instance_id == instance_id and i == 0): - # Log the head node's output to the provision.log - log_path_abs = str(provision_logging.get_log_path()) - else: - log_dir_abs = metadata_utils.get_instance_log_dir( - cluster_name, cache_id) - log_path_abs = str(log_dir_abs / (stage_name + '.log')) - results.append( - pool.submit(wrapper(func), runner, metadata, log_path_abs)) + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + # instance_ids is guaranteed to be in the same order as runners. + instance_ids = cluster_info.instance_ids() + for i, runner in enumerate(runners): + cache_id = instance_ids[i] + wrapper = metadata_utils.cache_func(cluster_name, cache_id, + stage_name, digest) + if i == 0: + # Log the head node's output to the provision.log + log_path_abs = str(provision_logging.get_log_path()) + else: + log_dir_abs = metadata_utils.get_instance_log_dir( + cluster_name, cache_id) + log_path_abs = str(log_dir_abs / (stage_name + '.log')) + results.append(pool.submit(wrapper(func), runner, log_path_abs)) return [future.result() for future in results] @@ -155,9 +154,7 @@ def initialize_docker(cluster_name: str, docker_config: Dict[str, Any], _hint_worker_log_path(cluster_name, cluster_info, 'initialize_docker') @_auto_retry - def _initialize_docker(runner: command_runner.SSHCommandRunner, - metadata: common.InstanceInfo, log_path: str): - del metadata # Unused. + def _initialize_docker(runner: command_runner.CommandRunner, log_path: str): docker_user = docker_utils.DockerInitializer(docker_config, runner, log_path).initialize() logger.debug(f'Initialized docker user: {docker_user}') @@ -194,14 +191,16 @@ def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str], digest = hasher.hexdigest() @_auto_retry - def _setup_node(runner: command_runner.SSHCommandRunner, - metadata: common.InstanceInfo, log_path: str): - del metadata + def _setup_node(runner: command_runner.CommandRunner, log_path: str): for cmd in setup_commands: - returncode, stdout, stderr = runner.run(cmd, - stream_logs=False, - log_path=log_path, - require_outputs=True) + returncode, stdout, stderr = runner.run( + cmd, + stream_logs=False, + log_path=log_path, + require_outputs=True, + # Installing depencies requires source bashrc to access the PATH + # in bashrc. + source_bashrc=True) retry_cnt = 0 while returncode == 255 and retry_cnt < _MAX_RETRY: # Got network connection issue occur during setup. This could @@ -215,7 +214,8 @@ def _setup_node(runner: command_runner.SSHCommandRunner, returncode, stdout, stderr = runner.run(cmd, stream_logs=False, log_path=log_path, - require_outputs=True) + require_outputs=True, + source_bashrc=True) if not returncode: break @@ -256,11 +256,9 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: """Start Ray on the head node.""" - ip_list = cluster_info.get_feasible_ips() - port_list = cluster_info.get_ssh_ports() - ssh_runner = command_runner.SSHCommandRunner(ip_list[0], - port=port_list[0], - **ssh_credentials) + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + head_runner = runners[0] assert cluster_info.head_instance_id is not None, (cluster_name, cluster_info) @@ -297,10 +295,14 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], _RAY_PRLIMIT + _DUMP_RAY_PORTS + RAY_HEAD_WAIT_INITIALIZED_COMMAND) logger.info(f'Running command on head node: {cmd}') # TODO(zhwu): add the output to log files. - returncode, stdout, stderr = ssh_runner.run(cmd, - stream_logs=False, - log_path=log_path_abs, - require_outputs=True) + returncode, stdout, stderr = head_runner.run( + cmd, + stream_logs=False, + log_path=log_path_abs, + require_outputs=True, + # Source bashrc for starting ray cluster to make sure actors started by + # ray will have the correct PATH. + source_bashrc=True) if returncode: raise RuntimeError('Failed to start ray on the head node ' f'(exit code {returncode}). Error: \n' @@ -318,11 +320,9 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, if cluster_info.num_instances <= 1: return _hint_worker_log_path(cluster_name, cluster_info, 'ray_cluster') - ip_list = cluster_info.get_feasible_ips() - ssh_runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list[1:], - port_list=cluster_info.get_ssh_ports()[1:], - **ssh_credentials) + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + worker_runners = runners[1:] worker_instances = cluster_info.get_worker_instances() cache_ids = [] prev_instance_id = None @@ -374,11 +374,11 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, f'grep "gcs-address={head_ip}:${{RAY_PORT}}" || ' f'{{ {cmd} }}') else: - cmd = 'ray stop; ' + cmd + cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd logger.info(f'Running command on worker nodes: {cmd}') - def _setup_ray_worker(runner_and_id: Tuple[command_runner.SSHCommandRunner, + def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, str]): # for cmd in config_from_yaml['worker_start_ray_commands']: # cmd = cmd.replace('$RAY_HEAD_IP', ip_list[0][0]) @@ -386,13 +386,17 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.SSHCommandRunner, runner, instance_id = runner_and_id log_dir = metadata_utils.get_instance_log_dir(cluster_name, instance_id) log_path_abs = str(log_dir / ('ray_cluster' + '.log')) - return runner.run(cmd, - stream_logs=False, - require_outputs=True, - log_path=log_path_abs) + return runner.run( + cmd, + stream_logs=False, + require_outputs=True, + log_path=log_path_abs, + # Source bashrc for starting ray cluster to make sure actors started + # by ray will have the correct PATH. + source_bashrc=True) results = subprocess_utils.run_in_parallel( - _setup_ray_worker, list(zip(ssh_runners, cache_ids))) + _setup_ray_worker, list(zip(worker_runners, cache_ids))) for returncode, stdout, stderr in results: if returncode: with ux_utils.print_exception_no_traceback(): @@ -410,18 +414,19 @@ def start_skylet_on_head_node(cluster_name: str, ssh_credentials: Dict[str, Any]) -> None: """Start skylet on the head node.""" del cluster_name - ip_list = cluster_info.get_feasible_ips() - port_list = cluster_info.get_ssh_ports() - ssh_runner = command_runner.SSHCommandRunner(ip_list[0], - port=port_list[0], - **ssh_credentials) + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + head_runner = runners[0] assert cluster_info.head_instance_id is not None, cluster_info log_path_abs = str(provision_logging.get_log_path()) logger.info(f'Running command on head node: {MAYBE_SKYLET_RESTART_CMD}') - returncode, stdout, stderr = ssh_runner.run(MAYBE_SKYLET_RESTART_CMD, - stream_logs=False, - require_outputs=True, - log_path=log_path_abs) + # We need to source bashrc for skylet to make sure the autostop event can + # access the path to the cloud CLIs. + returncode, stdout, stderr = head_runner.run(MAYBE_SKYLET_RESTART_CMD, + stream_logs=False, + require_outputs=True, + log_path=log_path_abs, + source_bashrc=True) if returncode: raise RuntimeError('Failed to start skylet on the head node ' f'(exit code {returncode}). Error: ' @@ -431,7 +436,7 @@ def start_skylet_on_head_node(cluster_name: str, @_auto_retry def _internal_file_mounts(file_mounts: Dict, - runner: command_runner.SSHCommandRunner, + runner: command_runner.CommandRunner, log_path: str) -> None: if file_mounts is None or not file_mounts: return @@ -493,9 +498,7 @@ def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str], """Executes file mounts - rsyncing internal local files""" _hint_worker_log_path(cluster_name, cluster_info, 'internal_file_mounts') - def _setup_node(runner: command_runner.SSHCommandRunner, - metadata: common.InstanceInfo, log_path: str): - del metadata + def _setup_node(runner: command_runner.CommandRunner, log_path: str): _internal_file_mounts(common_file_mounts, runner, log_path) _parallel_ssh_with_cache( diff --git a/sky/provision/kubernetes/config.py b/sky/provision/kubernetes/config.py index ef1926ac9ce..d5c30133ef2 100644 --- a/sky/provision/kubernetes/config.py +++ b/sky/provision/kubernetes/config.py @@ -30,12 +30,47 @@ def bootstrap_instances( if config.provider_config.get('fuse_device_required', False): _configure_fuse_mounting(config.provider_config) - if not config.provider_config.get('_operator'): - # These steps are unecessary when using the Operator. + requested_service_account = config.node_config['spec']['serviceAccountName'] + if requested_service_account == 'skypilot-service-account': + # If the user has requested a different service account (via pod_config + # in ~/.sky/config.yaml), we assume they have already set up the + # necessary roles and role bindings. + # If not, set up the roles and bindings for skypilot-service-account + # here. _configure_autoscaler_service_account(namespace, config.provider_config) - _configure_autoscaler_role(namespace, config.provider_config) - _configure_autoscaler_role_binding(namespace, config.provider_config) - + _configure_autoscaler_role(namespace, + config.provider_config, + role_field='autoscaler_role') + _configure_autoscaler_role_binding( + namespace, + config.provider_config, + binding_field='autoscaler_role_binding') + _configure_autoscaler_cluster_role(namespace, config.provider_config) + _configure_autoscaler_cluster_role_binding(namespace, + config.provider_config) + if config.provider_config.get('port_mode', 'loadbalancer') == 'ingress': + logger.info('Port mode is set to ingress, setting up ingress role ' + 'and role binding.') + try: + _configure_autoscaler_role(namespace, + config.provider_config, + role_field='autoscaler_ingress_role') + _configure_autoscaler_role_binding( + namespace, + config.provider_config, + binding_field='autoscaler_ingress_role_binding') + except kubernetes.api_exception() as e: + # If namespace is not found, we will ignore the error + if e.status == 404: + logger.info( + 'Namespace not found - is your nginx ingress installed?' + ' Skipping ingress role and role binding setup.') + else: + raise e + + elif requested_service_account != 'default': + logger.info(f'Using service account {requested_service_account!r}, ' + 'skipping role and role binding setup.') return config @@ -214,9 +249,16 @@ def _configure_autoscaler_service_account( f'{created_msg(account_field, name)}') -def _configure_autoscaler_role(namespace: str, - provider_config: Dict[str, Any]) -> None: - role_field = 'autoscaler_role' +def _configure_autoscaler_role(namespace: str, provider_config: Dict[str, Any], + role_field: str) -> None: + """ Reads the role from the provider config, creates if it does not exist. + + Args: + namespace: The namespace to create the role in. + provider_config: The provider config. + role_field: The field in the provider config that contains the role. + """ + if role_field not in provider_config: logger.info('_configure_autoscaler_role: ' f'{not_provided_msg(role_field)}') @@ -225,8 +267,8 @@ def _configure_autoscaler_role(namespace: str, role = provider_config[role_field] if 'namespace' not in role['metadata']: role['metadata']['namespace'] = namespace - elif role['metadata']['namespace'] != namespace: - raise InvalidNamespaceError(role_field, namespace) + else: + namespace = role['metadata']['namespace'] name = role['metadata']['name'] field_selector = f'metadata.name={name}' @@ -245,8 +287,16 @@ def _configure_autoscaler_role(namespace: str, def _configure_autoscaler_role_binding(namespace: str, - provider_config: Dict[str, Any]) -> None: - binding_field = 'autoscaler_role_binding' + provider_config: Dict[str, Any], + binding_field: str) -> None: + """ Reads the role binding from the config, creates if it does not exist. + + Args: + namespace: The namespace to create the role binding in. + provider_config: The provider config. + binding_field: The field in the provider config that contains the role + """ + if binding_field not in provider_config: logger.info('_configure_autoscaler_role_binding: ' f'{not_provided_msg(binding_field)}') @@ -255,8 +305,10 @@ def _configure_autoscaler_role_binding(namespace: str, binding = provider_config[binding_field] if 'namespace' not in binding['metadata']: binding['metadata']['namespace'] = namespace - elif binding['metadata']['namespace'] != namespace: - raise InvalidNamespaceError(binding_field, namespace) + rb_namespace = namespace + else: + rb_namespace = binding['metadata']['namespace'] + for subject in binding['subjects']: if 'namespace' not in subject: subject['namespace'] = namespace @@ -268,7 +320,7 @@ def _configure_autoscaler_role_binding(namespace: str, name = binding['metadata']['name'] field_selector = f'metadata.name={name}' accounts = (kubernetes.auth_api().list_namespaced_role_binding( - namespace, field_selector=field_selector).items) + rb_namespace, field_selector=field_selector).items) if len(accounts) > 0: assert len(accounts) == 1 logger.info('_configure_autoscaler_role_binding: ' @@ -277,11 +329,80 @@ def _configure_autoscaler_role_binding(namespace: str, logger.info('_configure_autoscaler_role_binding: ' f'{not_found_msg(binding_field, name)}') - kubernetes.auth_api().create_namespaced_role_binding(namespace, binding) + kubernetes.auth_api().create_namespaced_role_binding(rb_namespace, binding) logger.info('_configure_autoscaler_role_binding: ' f'{created_msg(binding_field, name)}') +def _configure_autoscaler_cluster_role(namespace, + provider_config: Dict[str, Any]) -> None: + role_field = 'autoscaler_cluster_role' + if role_field not in provider_config: + logger.info('_configure_autoscaler_cluster_role: ' + f'{not_provided_msg(role_field)}') + return + + role = provider_config[role_field] + if 'namespace' not in role['metadata']: + role['metadata']['namespace'] = namespace + elif role['metadata']['namespace'] != namespace: + raise InvalidNamespaceError(role_field, namespace) + + name = role['metadata']['name'] + field_selector = f'metadata.name={name}' + accounts = (kubernetes.auth_api().list_cluster_role( + field_selector=field_selector).items) + if len(accounts) > 0: + assert len(accounts) == 1 + logger.info('_configure_autoscaler_cluster_role: ' + f'{using_existing_msg(role_field, name)}') + return + + logger.info('_configure_autoscaler_cluster_role: ' + f'{not_found_msg(role_field, name)}') + kubernetes.auth_api().create_cluster_role(role) + logger.info( + f'_configure_autoscaler_cluster_role: {created_msg(role_field, name)}') + + +def _configure_autoscaler_cluster_role_binding( + namespace, provider_config: Dict[str, Any]) -> None: + binding_field = 'autoscaler_cluster_role_binding' + if binding_field not in provider_config: + logger.info('_configure_autoscaler_cluster_role_binding: ' + f'{not_provided_msg(binding_field)}') + return + + binding = provider_config[binding_field] + if 'namespace' not in binding['metadata']: + binding['metadata']['namespace'] = namespace + elif binding['metadata']['namespace'] != namespace: + raise InvalidNamespaceError(binding_field, namespace) + for subject in binding['subjects']: + if 'namespace' not in subject: + subject['namespace'] = namespace + elif subject['namespace'] != namespace: + subject_name = subject['name'] + raise InvalidNamespaceError( + binding_field + f' subject {subject_name}', namespace) + + name = binding['metadata']['name'] + field_selector = f'metadata.name={name}' + accounts = (kubernetes.auth_api().list_cluster_role_binding( + field_selector=field_selector).items) + if len(accounts) > 0: + assert len(accounts) == 1 + logger.info('_configure_autoscaler_cluster_role_binding: ' + f'{using_existing_msg(binding_field, name)}') + return + + logger.info('_configure_autoscaler_cluster_role_binding: ' + f'{not_found_msg(binding_field, name)}') + kubernetes.auth_api().create_cluster_role_binding(binding) + logger.info('_configure_autoscaler_cluster_role_binding: ' + f'{created_msg(binding_field, name)}') + + def _configure_ssh_jump(namespace, config: common.ProvisionConfig): """Creates a SSH jump pod to connect to the cluster. diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 51484f1f579..9068079701f 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -218,7 +218,7 @@ def _wait_for_pods_to_run(namespace, new_nodes): node.metadata.name, namespace) # Continue if pod and all the containers within the - # pod are succesfully created and running. + # pod are successfully created and running. if pod.status.phase == 'Running' and all( container.state.running for container in pod.status.container_statuses): @@ -730,7 +730,9 @@ def get_cluster_info( custom_ray_options={ 'object-store-memory': 500000000, 'num-cpus': cpu_request, - }) + }, + provider_name='kubernetes', + provider_config=provider_config) def query_instances( diff --git a/sky/provision/kubernetes/network.py b/sky/provision/kubernetes/network.py index 4abde138b1a..61870cb9119 100644 --- a/sky/provision/kubernetes/network.py +++ b/sky/provision/kubernetes/network.py @@ -73,7 +73,7 @@ def _open_ports_using_ingress( 'https://github.com/kubernetes/ingress-nginx/blob/main/docs/deploy/index.md.' # pylint: disable=line-too-long ) - # Prepare service names, ports, for template rendering + # Prepare service names, ports, for template rendering service_details = [ (f'{cluster_name_on_cloud}-skypilot-service--{port}', port, _PATH_PREFIX.format(cluster_name_on_cloud=cluster_name_on_cloud, @@ -177,9 +177,11 @@ def _cleanup_ports_for_ingress( def query_ports( cluster_name_on_cloud: str, ports: List[str], + head_ip: Optional[str] = None, provider_config: Optional[Dict[str, Any]] = None, ) -> Dict[int, List[common.Endpoint]]: """See sky/provision/__init__.py""" + del head_ip # unused assert provider_config is not None, 'provider_config is required' port_mode = network_utils.get_port_mode( provider_config.get('port_mode', None)) diff --git a/sky/provision/kubernetes/network_utils.py b/sky/provision/kubernetes/network_utils.py index 19abe01888d..836d75af41f 100644 --- a/sky/provision/kubernetes/network_utils.py +++ b/sky/provision/kubernetes/network_utils.py @@ -199,11 +199,17 @@ def get_ingress_external_ip_and_ports( ingress_service = ingress_services[0] if ingress_service.status.load_balancer.ingress is None: - # Try to use assigned external IP if it exists, - # otherwise return 'localhost' + # We try to get an IP/host for the service in the following order: + # 1. Try to use assigned external IP if it exists + # 2. Use the skypilot.co/external-ip annotation in the service + # 3. Otherwise return 'localhost' + ip = None if ingress_service.spec.external_i_ps is not None: ip = ingress_service.spec.external_i_ps[0] - else: + elif ingress_service.metadata.annotations is not None: + ip = ingress_service.metadata.annotations.get( + 'skypilot.co/external-ip', None) + if ip is None: ip = 'localhost' ports = ingress_service.spec.ports http_port = [port for port in ports if port.name == 'http'][0].node_port diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index c7c19680e07..3b3608947ad 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -18,6 +18,7 @@ from sky.utils import common_utils from sky.utils import env_options from sky.utils import kubernetes_enums +from sky.utils import schemas from sky.utils import ux_utils DEFAULT_NAMESPACE = 'default' @@ -35,6 +36,12 @@ (e.g., skypilot.co/accelerator) are setup correctly. \ To further debug, run: sky check.' +KUBERNETES_AUTOSCALER_NOTE = ( + 'Note: Kubernetes cluster autoscaling is enabled. ' + 'All GPUs that can be provisioned may not be listed ' + 'here. Refer to your autoscaler\'s node pool ' + 'configuration to see the list of supported GPUs.') + # TODO(romilb): Add links to docs for configuration instructions when ready. ENDPOINTS_DEBUG_MESSAGE = ('Additionally, make sure your {endpoint_type} ' 'is configured correctly. ' @@ -178,13 +185,31 @@ def get_accelerator_from_label_value(cls, value: str) -> str: f'Invalid accelerator name in GKE cluster: {value}') +class KarpenterLabelFormatter(SkyPilotLabelFormatter): + """Karpeneter label formatter + Karpenter uses the label `karpenter.k8s.aws/instance-gpu-name` to identify + the GPU type. Details: https://karpenter.sh/docs/reference/instance-types/ + The naming scheme is same as the SkyPilot formatter, so we inherit from it. + """ + LABEL_KEY = 'karpenter.k8s.aws/instance-gpu-name' + + # LABEL_FORMATTER_REGISTRY stores the label formats SkyPilot will try to # discover the accelerator type from. The order of the list is important, as -# it will be used to determine the priority of the label formats. +# it will be used to determine the priority of the label formats when +# auto-detecting the GPU label type. LABEL_FORMATTER_REGISTRY = [ - SkyPilotLabelFormatter, CoreWeaveLabelFormatter, GKELabelFormatter + SkyPilotLabelFormatter, CoreWeaveLabelFormatter, GKELabelFormatter, + KarpenterLabelFormatter ] +# Mapping of autoscaler type to label formatter +AUTOSCALER_TO_LABEL_FORMATTER = { + kubernetes_enums.KubernetesAutoscalerType.GKE: GKELabelFormatter, + kubernetes_enums.KubernetesAutoscalerType.KARPENTER: KarpenterLabelFormatter, # pylint: disable=line-too-long + kubernetes_enums.KubernetesAutoscalerType.GENERIC: SkyPilotLabelFormatter, +} + def detect_gpu_label_formatter( ) -> Tuple[Optional[GPULabelFormatter], Dict[str, List[Tuple[str, str]]]]: @@ -348,10 +373,26 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]: # Check if the cluster has GPU resources # TODO(romilb): This assumes the accelerator is a nvidia GPU. We # need to support TPUs and other accelerators as well. - # TODO(romilb): This will fail early for autoscaling clusters. - # For AS clusters, we may need a way for users to specify GPU node pools - # to use since the cluster may be scaling up from zero nodes and may not - # have any GPU nodes yet. + # TODO(romilb): Currently, we broadly disable all GPU checks if autoscaling + # is configured in config.yaml since the cluster may be scaling up from + # zero nodes and may not have any GPU nodes yet. In the future, we should + # support pollingthe clusters for autoscaling information, such as the + # node pools configured etc. + + autoscaler_type = get_autoscaler_type() + if autoscaler_type is not None: + # If autoscaler is set in config.yaml, override the label key and value + # to the autoscaler's format and bypass the GPU checks. + if check_mode: + # If check mode is enabled and autoscaler is set, we can return + # early since we assume the cluster autoscaler will handle GPU + # node provisioning. + return '', '' + formatter = AUTOSCALER_TO_LABEL_FORMATTER.get(autoscaler_type) + assert formatter is not None, ('Unsupported autoscaler type:' + f' {autoscaler_type}') + return formatter.get_label_key(), formatter.get_label_value(acc_type) + has_gpus, cluster_resources = detect_gpu_resource() if has_gpus: # Check if the cluster has GPU labels setup correctly @@ -509,20 +550,106 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \ except Exception as e: # pylint: disable=broad-except return False, ('An error occurred: ' f'{common_utils.format_exception(e, use_bracket=True)}') - # If we reach here, the credentials are valid and Kubernetes cluster is up + + # If we reach here, the credentials are valid and Kubernetes cluster is up. + # We now do softer checks to check if exec based auth is used and to + # see if the cluster is GPU-enabled. + + _, exec_msg = is_kubeconfig_exec_auth() + # We now check if GPUs are available and labels are set correctly on the # cluster, and if not we return hints that may help debug any issues. # This early check avoids later surprises for user when they try to run # `sky launch --gpus ` and the optimizer does not list Kubernetes as a # provider if their cluster GPUs are not setup correctly. + gpu_msg = '' try: _, _ = get_gpu_label_key_value(acc_type='', check_mode=True) except exceptions.ResourcesUnavailableError as e: # If GPUs are not available, we return cluster as enabled (since it can # be a CPU-only cluster) but we also return the exception message which # serves as a hint for how to enable GPU access. - return True, f'{e}' - return True, None + gpu_msg = str(e) + if exec_msg and gpu_msg: + return True, f'{gpu_msg}\n Additionally, {exec_msg}' + elif gpu_msg: + return True, gpu_msg + elif exec_msg: + return True, exec_msg + else: + return True, None + + +def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]: + """Checks if the kubeconfig file uses exec-based authentication + + Exec-based auth is commonly used for authenticating with cloud hosted + Kubernetes services, such as GKE. Here is an example snippet from a + kubeconfig using exec-based authentication for a GKE cluster: + - name: mycluster + user: + exec: + apiVersion: client.authentication.k8s.io/v1beta1 + command: /Users/romilb/google-cloud-sdk/bin/gke-gcloud-auth-plugin + installHint: Install gke-gcloud-auth-plugin ... + provideClusterInfo: true + + + Using exec-based authentication is problematic when used in conjunction + with kubernetes.remote_identity = LOCAL_CREDENTIAL in ~/.sky/config.yaml. + This is because the exec-based authentication may not have the relevant + dependencies installed on the remote cluster or may have hardcoded paths + that are not available on the remote cluster. + + Returns: + bool: True if exec-based authentication is used and LOCAL_CREDENTIAL + mode is used for remote_identity in ~/.sky/config.yaml. + str: Error message if exec-based authentication is used, None otherwise + """ + k8s = kubernetes.kubernetes + try: + k8s.config.load_kube_config() + except kubernetes.config_exception(): + # Using service account token or other auth methods, continue + return False, None + + # Get active context and user from kubeconfig using k8s api + _, current_context = k8s.config.list_kube_config_contexts() + target_username = current_context['context']['user'] + + # K8s api does not provide a mechanism to get the user details from the + # context. We need to load the kubeconfig file and parse it to get the + # user details. + kubeconfig_path = os.path.expanduser( + os.getenv('KUBECONFIG', + k8s.config.kube_config.KUBE_CONFIG_DEFAULT_LOCATION)) + # Load the kubeconfig file as a dictionary + with open(kubeconfig_path, 'r', encoding='utf-8') as f: + kubeconfig = yaml.safe_load(f) + + user_details = kubeconfig['users'] + + # Find user matching the target username + user_details = next( + user for user in user_details if user['name'] == target_username) + + remote_identity = skypilot_config.get_nested( + ('kubernetes', 'remote_identity'), schemas.REMOTE_IDENTITY_DEFAULT) + if ('exec' in user_details.get('user', {}) and remote_identity + == schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value): + ctx_name = current_context['name'] + exec_msg = ('exec-based authentication is used for ' + f'Kubernetes context {ctx_name!r}.' + ' This may cause issues when running Managed Jobs ' + 'or SkyServe controller on Kubernetes. To fix, configure ' + 'SkyPilot to create a service account for running pods by ' + 'adding the following in ~/.sky/config.yaml:\n' + ' kubernetes:\n' + ' remote_identity: SERVICE_ACCOUNT\n' + ' More: https://skypilot.readthedocs.io/en/latest/' + 'reference/config.html') + return True, exec_msg + return False, None def get_current_kube_config_context_name() -> Optional[str]: @@ -588,23 +715,18 @@ def parse_memory_resource(resource_qty_str: str, class KubernetesInstanceType: """Class to represent the "Instance Type" in a Kubernetes. - Since Kubernetes does not have a notion of instances, we generate virtual instance types that represent the resources requested by a pod ("node"). - This name captures the following resource requests: - CPU - Memory - Accelerators - The name format is "{n}CPU--{k}GB" where n is the number of vCPUs and k is the amount of memory in GB. Accelerators can be specified by appending "--{a}{type}" where a is the number of accelerators and type is the accelerator type. - CPU and memory can be specified as floats. Accelerator count must be int. - Examples: - 4CPU--16GB - 0.5CPU--1.5GB @@ -643,7 +765,6 @@ def _parse_instance_type( cls, name: str) -> Tuple[float, float, Optional[int], Optional[str]]: """Parses and returns resources from the given InstanceType name - Returns: cpus | float: Number of CPUs memory | float: Amount of memory in GB @@ -688,7 +809,6 @@ def from_resources(cls, accelerator_count: Union[float, int] = 0, accelerator_type: str = '') -> 'KubernetesInstanceType': """Returns an instance name object from the given resources. - If accelerator_count is not an int, it will be rounded up since GPU requests in Kubernetes must be int. """ @@ -1310,3 +1430,14 @@ def get_head_pod_name(cluster_name_on_cloud: str): # label, but since we know the naming convention, we can directly return the # head pod name. return f'{cluster_name_on_cloud}-head' + + +def get_autoscaler_type( +) -> Optional[kubernetes_enums.KubernetesAutoscalerType]: + """Returns the autoscaler type by reading from config""" + autoscaler_type = skypilot_config.get_nested(['kubernetes', 'autoscaler'], + None) + if autoscaler_type is not None: + autoscaler_type = kubernetes_enums.KubernetesAutoscalerType( + autoscaler_type) + return autoscaler_type diff --git a/sky/provision/paperspace/instance.py b/sky/provision/paperspace/instance.py index 12c581c8314..ce1a4768c24 100644 --- a/sky/provision/paperspace/instance.py +++ b/sky/provision/paperspace/instance.py @@ -251,7 +251,7 @@ def get_cluster_info( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None, ) -> common.ClusterInfo: - del region, provider_config # unused + del region # unused running_instances = _filter_instances(cluster_name_on_cloud, ['ready']) instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None @@ -271,6 +271,8 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='paperspace', + provider_config=provider_config, ) diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 764d197493a..df9a9fcc58a 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -24,7 +24,6 @@ from sky.provision import logging as provision_logging from sky.provision import metadata_utils from sky.skylet import constants -from sky.utils import command_runner from sky.utils import common_utils from sky.utils import rich_utils from sky.utils import ux_utils @@ -444,8 +443,6 @@ def _post_provision_setup( 'status with: sky status -r; and retry provisioning.') # TODO(suquark): Move wheel build here in future PRs. - ip_list = cluster_info.get_feasible_ips() - port_list = cluster_info.get_ssh_ports() # We don't set docker_user here, as we are configuring the VM itself. ssh_credentials = backend_utils.ssh_credential_from_yaml( cluster_yaml, ssh_user=cluster_info.ssh_user) @@ -505,9 +502,9 @@ def _post_provision_setup( cluster_name.name_on_cloud, config_from_yaml['setup_commands'], cluster_info, ssh_credentials) - head_runner = command_runner.SSHCommandRunner(ip_list[0], - port=port_list[0], - **ssh_credentials) + runners = provision.get_command_runners(cloud_name, cluster_info, + **ssh_credentials) + head_runner = runners[0] status.update( runtime_preparation_str.format(step=3, step_name='runtime')) @@ -544,7 +541,7 @@ def _post_provision_setup( # if provision_record.is_instance_just_booted(inst.instance_id): # worker_ips.append(inst.public_ip) - if len(ip_list) > 1: + if cluster_info.num_instances > 1: instance_setup.start_ray_on_worker_nodes( cluster_name.name_on_cloud, no_restart=not full_ray_setup, diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index 3ae99dae8d5..d7cb20b57a6 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -154,7 +154,7 @@ def get_cluster_info( region: str, cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: - del region, provider_config # unused + del region # unused running_instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None @@ -174,6 +174,8 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='runpod', + provider_config=provider_config, ) diff --git a/sky/provision/vsphere/instance.py b/sky/provision/vsphere/instance.py index 69a544210b3..787d8c97f62 100644 --- a/sky/provision/vsphere/instance.py +++ b/sky/provision/vsphere/instance.py @@ -571,8 +571,6 @@ def get_cluster_info( cluster_name: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" - if provider_config: - del provider_config # unused logger.info('New provision of Vsphere: get_cluster_info().') # Init the vsphere client @@ -610,4 +608,6 @@ def get_cluster_info( return common.ClusterInfo( instances=instances, head_instance_id=head_instance_id, + provider_name='vsphere', + provider_config=provider_config, ) diff --git a/sky/serve/constants.py b/sky/serve/constants.py index 2ac8f9169ba..89ca683ada5 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -81,7 +81,7 @@ # automatically generated from this start port. CONTROLLER_PORT_START = 20001 LOAD_BALANCER_PORT_START = 30001 -LOAD_BALANCER_PORT_RANGE = '30001-30100' +LOAD_BALANCER_PORT_RANGE = '30001-30020' # Initial version of service. INITIAL_VERSION = 1 diff --git a/sky/serve/controller.py b/sky/serve/controller.py index b9d18d3eb58..8d7964f090b 100644 --- a/sky/serve/controller.py +++ b/sky/serve/controller.py @@ -3,6 +3,7 @@ Responsible for autoscaling and replica management. """ import logging +import os import threading import time import traceback @@ -39,7 +40,7 @@ class SkyServeController: """ def __init__(self, service_name: str, service_spec: serve.SkyServiceSpec, - task_yaml: str, port: int) -> None: + task_yaml: str, host: str, port: int) -> None: self._service_name = service_name self._replica_manager: replica_managers.ReplicaManager = ( replica_managers.SkyPilotReplicaManager(service_name=service_name, @@ -47,6 +48,7 @@ def __init__(self, service_name: str, service_spec: serve.SkyServiceSpec, task_yaml_path=task_yaml)) self._autoscaler: autoscalers.Autoscaler = ( autoscalers.Autoscaler.from_spec(service_name, service_spec)) + self._host = host self._port = port self._app = fastapi.FastAPI() @@ -150,15 +152,25 @@ def configure_logger(): threading.Thread(target=self._run_autoscaler).start() logger.info('SkyServe Controller started on ' - f'http://localhost:{self._port}') + f'http://{self._host}:{self._port}') - uvicorn.run(self._app, host='localhost', port=self._port) + uvicorn.run(self._app, host={self._host}, port=self._port) # TODO(tian): Probably we should support service that will stop the VM in # specific time period. def run_controller(service_name: str, service_spec: serve.SkyServiceSpec, task_yaml: str, controller_port: int): - controller = SkyServeController(service_name, service_spec, task_yaml, + # We expose the controller to the public network when running inside a + # kubernetes cluster to allow external load balancers (example, for + # high availability load balancers) to communicate with the controller. + def _get_host(): + if 'KUBERNETES_SERVICE_HOST' in os.environ: + return '0.0.0.0' + else: + return 'localhost' + + host = _get_host() + controller = SkyServeController(service_name, service_spec, task_yaml, host, controller_port) controller.run() diff --git a/sky/serve/core.py b/sky/serve/core.py index 0444cfa715d..f193a85285b 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -186,13 +186,14 @@ def up( # whether the service is already running. If the id is the same # with the current job id, we know the service is up and running # for the first time; otherwise it is a name conflict. + idle_minutes_to_autodown = constants.CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP controller_job_id, controller_handle = sky.launch( task=controller_task, stream_logs=False, cluster_name=controller_name, detach_run=True, - idle_minutes_to_autostop=constants. - CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP, + idle_minutes_to_autostop=idle_minutes_to_autodown, + down=True, retry_until_up=True, _disable_controller_check=True, ) @@ -253,7 +254,10 @@ def up( else: lb_port = serve_utils.load_service_initialization_result( lb_port_payload) - endpoint = f'{controller_handle.head_ip}:{lb_port}' + endpoint = backend_utils.get_endpoints( + controller_handle.cluster_name, lb_port, + skip_status_check=True).get(lb_port) + assert endpoint is not None, 'Did not get endpoint for controller.' sky_logging.print( f'{fore.CYAN}Service name: ' @@ -470,7 +474,7 @@ def down( code, require_outputs=True, stream_logs=False) - except exceptions.FetchIPError as e: + except exceptions.FetchClusterInfoError as e: raise RuntimeError( 'Failed to fetch controller IP. Please refresh controller status ' f'by `sky status -r {serve_utils.SKY_SERVE_CONTROLLER_NAME}` ' diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index 0356f5b59d5..0e17119115c 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -91,7 +91,7 @@ def _sync_with_controller(self): # TODO(tian): Support HTTPS. self._client_pool[replica_url] = ( httpx.AsyncClient( - base_url=f'http://{replica_url}')) + base_url=replica_url)) urls_to_close = set( self._client_pool.keys()) - set(ready_replica_urls) client_to_close = [] @@ -216,3 +216,19 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int): load_balancer = SkyServeLoadBalancer(controller_url=controller_addr, load_balancer_port=load_balancer_port) load_balancer.run() + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--controller-addr', + required=True, + default='127.0.0.1', + help='The address of the controller.') + parser.add_argument('--load-balancer-port', + type=int, + required=True, + default=8890, + help='The port where the load balancer listens to.') + args = parser.parse_args() + run_load_balancer(args.controller_addr, args.load_balancer_port) diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index 70e5ba2c6dd..efb3ba3cf48 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -17,6 +17,7 @@ import sky from sky import backends +from sky import core from sky import exceptions from sky import global_user_state from sky import sky_logging @@ -428,7 +429,20 @@ def url(self) -> Optional[str]: handle = self.handle() if handle is None: return None - return f'{handle.head_ip}:{self.replica_port}' + replica_port_int = int(self.replica_port) + try: + endpoint_dict = core.endpoints(handle.cluster_name, + replica_port_int) + except exceptions.ClusterNotUpError: + return None + endpoint = endpoint_dict.get(replica_port_int, None) + if not endpoint: + return None + assert isinstance(endpoint, str), endpoint + # If replica doesn't start with http or https, add http:// + if not endpoint.startswith('http'): + endpoint = 'http://' + endpoint + return endpoint @property def status(self) -> serve_state.ReplicaStatus: @@ -446,6 +460,7 @@ def to_info_dict(self, with_handle: bool) -> Dict[str, Any]: 'name': self.cluster_name, 'status': self.status, 'version': self.version, + 'endpoint': self.url, 'is_spot': self.is_spot, 'launched_at': (cluster_record['launched_at'] if cluster_record is not None else None), @@ -487,7 +502,13 @@ def probe( try: msg = '' # TODO(tian): Support HTTPS in the future. - readiness_path = (f'http://{self.url}{readiness_path}') + url = self.url + if url is None: + logger.info(f'Error when probing {replica_identity}: ' + 'Cannot get the endpoint.') + return self, False, probe_time + readiness_path = (f'{url}{readiness_path}') + logger.info(f'Probing {replica_identity} with {readiness_path}.') if post_data is not None: msg += 'POST' response = requests.post( diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 0814441eb79..8a4387b40c0 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -24,6 +24,7 @@ from sky import exceptions from sky import global_user_state from sky import status_lib +from sky.backends import backend_utils from sky.serve import constants from sky.serve import serve_state from sky.skylet import constants as skylet_constants @@ -725,12 +726,21 @@ def get_endpoint(service_record: Dict[str, Any]) -> str: handle = global_user_state.get_handle_from_cluster_name( SKY_SERVE_CONTROLLER_NAME) assert isinstance(handle, backends.CloudVmRayResourceHandle) - if handle is None or handle.head_ip is None: + if handle is None: return '-' load_balancer_port = service_record['load_balancer_port'] if load_balancer_port is None: return '-' - return f'{handle.head_ip}:{load_balancer_port}' + try: + endpoint = backend_utils.get_endpoints(handle.cluster_name, + load_balancer_port).get( + load_balancer_port, None) + except exceptions.ClusterNotUpError: + return '-' + if endpoint is None: + return '-' + assert isinstance(endpoint, str), endpoint + return endpoint def format_service_table(service_records: List[Dict[str, Any]], @@ -794,7 +804,7 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], return 'No existing replicas.' replica_columns = [ - 'SERVICE_NAME', 'ID', 'VERSION', 'IP', 'LAUNCHED', 'RESOURCES', + 'SERVICE_NAME', 'ID', 'VERSION', 'ENDPOINT', 'LAUNCHED', 'RESOURCES', 'STATUS', 'REGION' ] if show_all: @@ -808,10 +818,11 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], replica_records = replica_records[:_REPLICA_TRUNC_NUM] for record in replica_records: + endpoint = record.get('endpoint', '-') service_name = record['service_name'] replica_id = record['replica_id'] version = (record['version'] if 'version' in record else '-') - replica_ip = '-' + replica_endpoint = endpoint if endpoint else '-' launched_at = log_utils.readable_time_duration(record['launched_at']) resources_str = '-' replica_status = record['status'] @@ -821,8 +832,6 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], replica_handle: 'backends.CloudVmRayResourceHandle' = record['handle'] if replica_handle is not None: - if replica_handle.head_ip is not None: - replica_ip = replica_handle.head_ip resources_str = resources_utils.get_readable_resources_repr( replica_handle, simplify=not show_all) if replica_handle.launched_resources.region is not None: @@ -834,7 +843,7 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], service_name, replica_id, version, - replica_ip, + replica_endpoint, launched_at, resources_str, status_str, diff --git a/sky/skylet/autostop_lib.py b/sky/skylet/autostop_lib.py index 687c04f5211..130e39fb425 100644 --- a/sky/skylet/autostop_lib.py +++ b/sky/skylet/autostop_lib.py @@ -75,10 +75,16 @@ def set_autostopping_started() -> None: configs.set_config(_AUTOSTOP_INDICATOR, str(psutil.boot_time())) -def get_is_autostopping_payload() -> str: +def get_is_autostopping() -> bool: """Returns whether the cluster is in the process of autostopping.""" result = configs.get_config(_AUTOSTOP_INDICATOR) is_autostopping = (result == str(psutil.boot_time())) + return is_autostopping + + +def get_is_autostopping_payload() -> str: + """Payload for whether the cluster is in the process of autostopping.""" + is_autostopping = get_is_autostopping() return common_utils.encode_payload(is_autostopping) diff --git a/sky/skylet/events.py b/sky/skylet/events.py index 22e86778570..c63b42cc438 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -17,6 +17,7 @@ from sky.jobs import utils as managed_job_utils from sky.serve import serve_utils from sky.skylet import autostop_lib +from sky.skylet import constants from sky.skylet import job_lib from sky.utils import cluster_yaml_utils from sky.utils import common_utils @@ -197,16 +198,19 @@ def _stop_cluster(self, autostop_config): logger.info('Running ray down.') # Stop the workers first to avoid orphan workers. subprocess.run( - ['ray', 'down', '-y', '--workers-only', config_path], + f'{constants.SKY_RAY_CMD} down -y --workers-only ' + f'{config_path}', check=True, + shell=True, # We pass env inherited from os.environ due to calling `ray # `. env=env) logger.info('Running final ray down.') subprocess.run( - ['ray', 'down', '-y', config_path], + f'{constants.SKY_RAY_CMD} down -y {config_path}', check=True, + shell=True, # We pass env inherited from os.environ due to calling `ray # `. env=env) @@ -228,7 +232,7 @@ def _stop_cluster_with_new_provisioner(self, autostop_config, # Stop the ray autoscaler to avoid scaling up, during # stopping/terminating of the cluster. logger.info('Stopping the ray cluster.') - subprocess.run('ray stop', shell=True, check=True) + subprocess.run(f'{constants.SKY_RAY_CMD} stop', shell=True, check=True) operation_fn = provision_lib.stop_instances if autostop_config.down: diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index ceed5a26024..93bbe99b3ce 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -883,7 +883,7 @@ def tail_logs(cls, follow: bool = True) -> str: # pylint: disable=line-too-long code = [ - f'job_id = {job_id} if {job_id} is not None else job_lib.get_latest_job_id()', + f'job_id = {job_id} if {job_id} != None else job_lib.get_latest_job_id()', 'run_timestamp = job_lib.get_run_timestamp(job_id)', f'log_dir = None if run_timestamp is None else os.path.join({constants.SKY_LOGS_DIRECTORY!r}, run_timestamp)', f'log_lib.tail_logs(job_id=job_id, log_dir=log_dir, ' diff --git a/sky/skylet/providers/azure/config.py b/sky/skylet/providers/azure/config.py index 0c1827a1141..a19273761ba 100644 --- a/sky/skylet/providers/azure/config.py +++ b/sky/skylet/providers/azure/config.py @@ -120,6 +120,8 @@ def _configure_resource_group(config): create_or_update = get_azure_sdk_function( client=resource_client.deployments, function_name="create_or_update" ) + # TODO (skypilot): this takes a long time (> 40 seconds) for stopping an + # azure VM, and this can be called twice during ray down. outputs = ( create_or_update( resource_group_name=resource_group, diff --git a/sky/skylet/providers/azure/node_provider.py b/sky/skylet/providers/azure/node_provider.py index 4b315f23589..068930eb390 100644 --- a/sky/skylet/providers/azure/node_provider.py +++ b/sky/skylet/providers/azure/node_provider.py @@ -15,6 +15,7 @@ bootstrap_azure, get_azure_sdk_function, ) +from sky.skylet import autostop_lib from sky.skylet.providers.command_runner import SkyDockerCommandRunner from sky.provision import docker_utils @@ -61,16 +62,23 @@ class AzureNodeProvider(NodeProvider): def __init__(self, provider_config, cluster_name): NodeProvider.__init__(self, provider_config, cluster_name) - # TODO(suquark): This is a temporary patch for resource group. - # By default, Ray autoscaler assumes the resource group is still here even - # after the whole cluster is destroyed. However, now we deletes the resource - # group after tearing down the cluster. To comfort the autoscaler, we need - # to create/update it here, so the resource group always exists. - from sky.skylet.providers.azure.config import _configure_resource_group - - _configure_resource_group( - {"cluster_name": cluster_name, "provider": provider_config} - ) + if not autostop_lib.get_is_autostopping(): + # TODO(suquark): This is a temporary patch for resource group. + # By default, Ray autoscaler assumes the resource group is still + # here even after the whole cluster is destroyed. However, now we + # deletes the resource group after tearing down the cluster. To + # comfort the autoscaler, we need to create/update it here, so the + # resource group always exists. + # + # We should not re-configure the resource group again, when it is + # running on the remote VM and the autostopping is in progress, + # because the VM is running which guarantees the resource group + # exists. + from sky.skylet.providers.azure.config import _configure_resource_group + + _configure_resource_group( + {"cluster_name": cluster_name, "provider": provider_config} + ) subscription_id = provider_config["subscription_id"] self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True) # Sky only supports Azure CLI credential for now. diff --git a/sky/skylet/providers/lambda_cloud/node_provider.py b/sky/skylet/providers/lambda_cloud/node_provider.py index 8a9c5997a0b..bb8d40da62e 100644 --- a/sky/skylet/providers/lambda_cloud/node_provider.py +++ b/sky/skylet/providers/lambda_cloud/node_provider.py @@ -155,9 +155,10 @@ def _get_internal_ip(node: Dict[str, Any]): if node['external_ip'] is None or node['status'] != 'active': node['internal_ip'] = None return - runner = command_runner.SSHCommandRunner(node['external_ip'], - 'ubuntu', - self.ssh_key_path) + runner = command_runner.SSHCommandRunner( + node=(node['external_ip'], 22), + ssh_user='ubuntu', + ssh_private_key=self.ssh_key_path) rc, stdout, stderr = runner.run(_GET_INTERNAL_IP_CMD, require_outputs=True, stream_logs=False) diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 10fc90fa850..5b205e2692a 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -152,7 +152,9 @@ def _try_load_config() -> None: common_utils.validate_schema( _dict, schemas.get_config_schema(), - f'Invalid config YAML ({config_path}): ', + f'Invalid config YAML ({config_path}). See: ' + 'https://skypilot.readthedocs.io/en/latest/reference/config.html. ' # pylint: disable=line-too-long + 'Error: ', skip_none=False) logger.debug('Config syntax check passed.') diff --git a/sky/templates/aws-ray.yml.j2 b/sky/templates/aws-ray.yml.j2 index 6f1df43cfd5..66c01f53617 100644 --- a/sky/templates/aws-ray.yml.j2 +++ b/sky/templates/aws-ray.yml.j2 @@ -60,6 +60,10 @@ available_node_types: ray.head.default: resources: {} node_config: + {% if remote_identity not in ['LOCAL_CREDENTIALS', 'SERVICE_ACCOUNT'] %} + IamInstanceProfile: + Name: {{remote_identity}} + {% endif %} InstanceType: {{instance_type}} ImageId: {{image_id}} # Deep Learning AMI (Ubuntu 18.04); see aws.py. # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances diff --git a/sky/templates/kubernetes-ray.yml.j2 b/sky/templates/kubernetes-ray.yml.j2 index c64436c512e..172c958bb3e 100644 --- a/sky/templates/kubernetes-ray.yml.j2 +++ b/sky/templates/kubernetes-ray.yml.j2 @@ -85,6 +85,81 @@ provider: name: skypilot-service-account-role apiGroup: rbac.authorization.k8s.io + + # Role to access ingress services for fetching IP + autoscaler_ingress_role: + kind: Role + apiVersion: rbac.authorization.k8s.io/v1 + metadata: + namespace: ingress-nginx + name: skypilot-service-account-ingress-role + labels: + parent: skypilot + rules: + - apiGroups: [ "" ] + resources: [ "services" ] + verbs: [ "list", "get", "watch" ] + - apiGroups: [ "rbac.authorization.k8s.io" ] + resources: [ "roles", "rolebindings" ] + verbs: [ "get", "list", "watch" ] + + # RoleBinding to access ingress services for fetching IP + autoscaler_ingress_role_binding: + apiVersion: rbac.authorization.k8s.io/v1 + kind: RoleBinding + metadata: + namespace: ingress-nginx + name: skypilot-service-account-ingress-role-binding + labels: + parent: skypilot + subjects: + - kind: ServiceAccount + name: skypilot-service-account + roleRef: + kind: Role + name: skypilot-service-account-ingress-role + apiGroup: rbac.authorization.k8s.io + + # In addition to a role binding, we also need a cluster role binding to give + # the SkyPilot access to the cluster-wide resources such as nodes to get + # node resources. + autoscaler_cluster_role: + kind: ClusterRole + apiVersion: rbac.authorization.k8s.io/v1 + metadata: + labels: + parent: skypilot + name: skypilot-service-account-cluster-role + rules: + - apiGroups: [ "" ] + resources: [ "nodes" ] # Required for getting node resources. + verbs: [ "get", "list", "watch" ] + - apiGroups: [ "rbac.authorization.k8s.io" ] + resources: [ "clusterroles", "clusterrolebindings" ] # Required for launching more SkyPilot clusters from within the pod. + verbs: [ "get", "list", "watch" ] + - apiGroups: [ "node.k8s.io" ] + resources: [ "runtimeclasses" ] # Required for autodetecting the runtime class of the nodes. + verbs: [ "get", "list", "watch" ] + - apiGroups: [ "networking.k8s.io" ] # Required for exposing services. + resources: [ "ingressclasses" ] + verbs: [ "get", "list", "watch" ] + + # Bind cluster role to the service account + autoscaler_cluster_role_binding: + apiVersion: rbac.authorization.k8s.io/v1 + kind: ClusterRoleBinding + metadata: + labels: + parent: skypilot + name: skypilot-service-account-cluster-role-binding + subjects: + - kind: ServiceAccount + name: skypilot-service-account + roleRef: + kind: ClusterRole + name: skypilot-service-account-cluster-role + apiGroup: rbac.authorization.k8s.io + services: # Service to expose the head node pod's SSH port. - apiVersion: v1 @@ -154,9 +229,9 @@ available_node_types: container.apparmor.security.beta.kubernetes.io/ray-node: unconfined {% endif %} spec: - # Change this if you altered the autoscaler_service_account above - # or want to provide your own. - serviceAccountName: skypilot-service-account + # serviceAccountName: skypilot-service-account + serviceAccountName: {{k8s_service_account_name}} + automountServiceAccountToken: {{k8s_automount_sa_token}} restartPolicy: Never diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index c66b5dfe032..3aa87eda138 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -5,13 +5,14 @@ import pathlib import shlex import time -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Type, Union from sky import sky_logging from sky.skylet import constants from sky.skylet import log_lib from sky.utils import common_utils from sky.utils import subprocess_utils +from sky.utils import timeline logger = sky_logging.init_logger(__name__) @@ -41,6 +42,12 @@ def _ssh_control_path(ssh_control_filename: Optional[str]) -> Optional[str]: return path +# Disable sudo for root user. This is useful when the command is running in a +# docker container, i.e. image_id is a docker image. +ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD = ( + '{ [ "$(whoami)" == "root" ] && function sudo() { "$@"; } || true; }') + + def ssh_options_list( ssh_private_key: Optional[str], ssh_control_name: Optional[str], @@ -134,17 +141,155 @@ class SshMode(enum.Enum): LOGIN = 2 -class SSHCommandRunner: +class CommandRunner: + """Runner for commands to be executed on the cluster.""" + + def __init__(self, node: Tuple[Any, Any], **kwargs): + del kwargs # Unused. + self.node = node + + @property + def node_id(self) -> str: + return '-'.join(str(x) for x in self.node) + + def _get_command_to_run( + self, + cmd: Union[str, List[str]], + process_stream: bool, + separate_stderr: bool, + skip_lines: int, + source_bashrc: bool = False, + ) -> str: + """Returns the command to run.""" + if isinstance(cmd, list): + cmd = ' '.join(cmd) + + # We need this to correctly run the cmd, and get the output. + command = [ + 'bash', + '--login', + '-c', + ] + if source_bashrc: + command += [ + # Need this `-i` option to make sure `source ~/.bashrc` work. + # Sourcing bashrc may take a few seconds causing overheads. + '-i', + shlex.quote( + f'true && source ~/.bashrc && export OMP_NUM_THREADS=1 ' + f'PYTHONWARNINGS=ignore && ({cmd})'), + ] + else: + # Optimization: this reduces the time for connecting to the remote + # cluster by 1 second. + # sourcing ~/.bashrc is not required for internal executions + command += [ + 'true && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore' + f' && ({cmd})' + ] + if not separate_stderr: + command.append('2>&1') + if not process_stream and skip_lines: + command += [ + # A hack to remove the following bash warnings (twice): + # bash: cannot set terminal process group + # bash: no job control in this shell + f'| stdbuf -o0 tail -n +{skip_lines}', + # This is required to make sure the executor of command can get + # correct returncode, since linux pipe is used. + '; exit ${PIPESTATUS[0]}' + ] + + command_str = ' '.join(command) + return command_str + + @timeline.event + def run( + self, + cmd: Union[str, List[str]], + *, + require_outputs: bool = False, + # Advanced options. + log_path: str = os.devnull, + # If False, do not redirect stdout/stderr to optimize performance. + process_stream: bool = True, + stream_logs: bool = True, + ssh_mode: SshMode = SshMode.NON_INTERACTIVE, + separate_stderr: bool = False, + source_bashrc: bool = False, + **kwargs) -> Union[int, Tuple[int, str, str]]: + """Runs the command on the cluster. + + Args: + cmd: The command to run. + require_outputs: Whether to return the stdout/stderr of the command. + log_path: Redirect stdout/stderr to the log_path. + stream_logs: Stream logs to the stdout/stderr. + ssh_mode: The mode to use for ssh. + See SSHMode for more details. + separate_stderr: Whether to separate stderr from stdout. + + Returns: + returncode + or + A tuple of (returncode, stdout, stderr). + """ + raise NotImplementedError + + @timeline.event + def rsync( + self, + source: str, + target: str, + *, + up: bool, + # Advanced options. + log_path: str = os.devnull, + stream_logs: bool = True, + max_retry: int = 1, + ) -> None: + """Uses 'rsync' to sync 'source' to 'target'. + + Args: + source: The source path. + target: The target path. + up: The direction of the sync, True for local to cluster, False + for cluster to local. + log_path: Redirect stdout/stderr to the log_path. + stream_logs: Stream logs to the stdout/stderr. + max_retry: The maximum number of retries for the rsync command. + This value should be non-negative. + + Raises: + exceptions.CommandError: rsync command failed. + """ + raise NotImplementedError + + @classmethod + def make_runner_list( + cls: Type['CommandRunner'], + node_list: Iterable[Any], + **kwargs, + ) -> List['CommandRunner']: + """Helper function for creating runners with the same credentials""" + return [cls(node, **kwargs) for node in node_list] + + def check_connection(self) -> bool: + """Check if the connection to the remote machine is successful.""" + returncode = self.run('true', connect_timeout=5, stream_logs=False) + return returncode == 0 + + +class SSHCommandRunner(CommandRunner): """Runner for SSH commands.""" def __init__( self, - ip: str, + node: Tuple[str, int], ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = '__default__', ssh_proxy_command: Optional[str] = None, - port: int = 22, docker_user: Optional[str] = None, disable_control_master: Optional[bool] = False, ): @@ -156,7 +301,7 @@ def __init__( runner.rsync(source, target, up=True) Args: - ip: The IP address of the remote machine. + node: (ip, port) The IP address and port of the remote machine. ssh_private_key: The path to the private key to use for ssh. ssh_user: The user to use for ssh. ssh_control_name: The files name of the ssh_control to use. This is @@ -174,6 +319,8 @@ def __init__( command will utilize ControlMaster. We currently disable it for k8s instance. """ + super().__init__(node) + ip, port = node self.ssh_private_key = ssh_private_key self.ssh_control_name = ( None if ssh_control_name is None else hashlib.md5( @@ -198,27 +345,6 @@ def __init__( self.port = port self._docker_ssh_proxy_command = None - @staticmethod - def make_runner_list( - ip_list: List[str], - ssh_user: str, - ssh_private_key: str, - ssh_control_name: Optional[str] = None, - ssh_proxy_command: Optional[str] = None, - disable_control_master: Optional[bool] = False, - port_list: Optional[List[int]] = None, - docker_user: Optional[str] = None, - ) -> List['SSHCommandRunner']: - """Helper function for creating runners with the same ssh credentials""" - if not port_list: - port_list = [22] * len(ip_list) - return [ - SSHCommandRunner(ip, ssh_user, ssh_private_key, ssh_control_name, - ssh_proxy_command, port, docker_user, - disable_control_master) - for ip, port in zip(ip_list, port_list) - ] - def _ssh_base_command(self, *, ssh_mode: SshMode, port_forward: Optional[List[int]], connect_timeout: Optional[int]) -> List[str]: @@ -251,6 +377,7 @@ def _ssh_base_command(self, *, ssh_mode: SshMode, f'{self.ssh_user}@{self.ip}' ] + @timeline.event def run( self, cmd: Union[str, List[str]], @@ -265,11 +392,11 @@ def run( ssh_mode: SshMode = SshMode.NON_INTERACTIVE, separate_stderr: bool = False, connect_timeout: Optional[int] = None, + source_bashrc: bool = False, **kwargs) -> Union[int, Tuple[int, str, str]]: """Uses 'ssh' to run 'cmd' on a node with ip. Args: - ip: The IP address of the node. cmd: The command to run. port_forward: A list of ports to forward from the localhost to the remote host. @@ -299,39 +426,20 @@ def run( command = base_ssh_command + cmd proc = subprocess_utils.run(command, shell=False, check=False) return proc.returncode, '', '' - if isinstance(cmd, list): - cmd = ' '.join(cmd) + + command_str = self._get_command_to_run( + cmd, + process_stream, + separate_stderr, + # A hack to remove the following bash warnings (twice): + # bash: cannot set terminal process group + # bash: no job control in this shell + skip_lines=5 if source_bashrc else 0, + source_bashrc=source_bashrc) + command = base_ssh_command + [shlex.quote(command_str)] log_dir = os.path.expanduser(os.path.dirname(log_path)) os.makedirs(log_dir, exist_ok=True) - # We need this to correctly run the cmd, and get the output. - command = [ - 'bash', - '--login', - '-c', - # Need this `-i` option to make sure `source ~/.bashrc` work. - '-i', - ] - - command += [ - shlex.quote(f'true && source ~/.bashrc && export OMP_NUM_THREADS=1 ' - f'PYTHONWARNINGS=ignore && ({cmd})'), - ] - if not separate_stderr: - command.append('2>&1') - if not process_stream and ssh_mode == SshMode.NON_INTERACTIVE: - command += [ - # A hack to remove the following bash warnings (twice): - # bash: cannot set terminal process group - # bash: no job control in this shell - '| stdbuf -o0 tail -n +5', - # This is required to make sure the executor of command can get - # correct returncode, since linux pipe is used. - '; exit ${PIPESTATUS[0]}' - ] - - command_str = ' '.join(command) - command = base_ssh_command + [shlex.quote(command_str)] executable = None if not process_stream: @@ -354,6 +462,7 @@ def run( executable=executable, **kwargs) + @timeline.event def rsync( self, source: str, @@ -456,10 +565,3 @@ def rsync( error_msg, stderr=stdout + stderr, stream_logs=stream_logs) - - def check_connection(self) -> bool: - """Check if the connection to the remote machine is successful.""" - returncode = self.run('true', connect_timeout=5, stream_logs=False) - if returncode: - return False - return True diff --git a/sky/utils/command_runner.pyi b/sky/utils/command_runner.pyi index f1b547927d3..e8f12ef6ebe 100644 --- a/sky/utils/command_runner.pyi +++ b/sky/utils/command_runner.pyi @@ -6,7 +6,7 @@ determine the return type based on the value of require_outputs. """ import enum import typing -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union from typing_extensions import Literal @@ -18,6 +18,7 @@ GIT_EXCLUDE: str RSYNC_DISPLAY_OPTION: str RSYNC_FILTER_OPTION: str RSYNC_EXCLUDE_OPTION: str +ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD: str def ssh_options_list( @@ -39,40 +40,91 @@ class SshMode(enum.Enum): LOGIN: int -class SSHCommandRunner: +class CommandRunner: + node_id: str + + def __init__( + self, + node: Tuple[Any, ...], + **kwargs, + ) -> None: + ... + + @typing.overload + def run(self, + cmd: Union[str, List[str]], + *, + require_outputs: Literal[False] = ..., + log_path: str = ..., + process_stream: bool = ..., + stream_logs: bool = ..., + separate_stderr: bool = ..., + **kwargs) -> int: + ... + + @typing.overload + def run(self, + cmd: Union[str, List[str]], + *, + require_outputs: Literal[True], + log_path: str = ..., + process_stream: bool = ..., + stream_logs: bool = ..., + separate_stderr: bool = ..., + **kwargs) -> Tuple[int, str, str]: + ... + + @typing.overload + def run(self, + cmd: Union[str, List[str]], + *, + require_outputs: bool = ..., + log_path: str = ..., + process_stream: bool = ..., + stream_logs: bool = ..., + separate_stderr: bool = ..., + **kwargs) -> Union[Tuple[int, str, str], int]: + ... + + def rsync(self, + source: str, + target: str, + *, + up: bool, + log_path: str = ..., + stream_logs: bool = ...) -> None: + ... + + @classmethod + def make_runner_list(cls: typing.Type[CommandRunner], + node_list: Iterable[Tuple[Any, ...]], + **kwargs) -> List[CommandRunner]: + ... + + def check_connection(self) -> bool: + ... + + +class SSHCommandRunner(CommandRunner): ip: str + port: int ssh_user: str ssh_private_key: str ssh_control_name: Optional[str] docker_user: str - port: int disable_control_master: Optional[bool] def __init__( self, - ip: str, + node: Tuple[str, int], ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = ..., - port: int = ..., docker_user: Optional[str] = ..., disable_control_master: Optional[bool] = ..., ) -> None: ... - @staticmethod - def make_runner_list( - ip_list: List[str], - ssh_user: str, - ssh_private_key: str, - ssh_control_name: Optional[str] = ..., - ssh_proxy_command: Optional[str] = ..., - port_list: Optional[List[int]] = ..., - docker_user: Optional[str] = ..., - disable_control_master: Optional[bool] = ..., - ) -> List['SSHCommandRunner']: - ... - @typing.overload def run(self, cmd: Union[str, List[str]], @@ -123,6 +175,3 @@ class SSHCommandRunner: log_path: str = ..., stream_logs: bool = ...) -> None: ... - - def check_connection(self) -> bool: - ... diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 2abefc6fea0..0dc78d8427c 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -61,11 +61,18 @@ def get_usage_run_id() -> str: return _usage_run_id -def get_user_hash() -> str: +def get_user_hash(force_fresh_hash: bool = False) -> str: """Returns a unique user-machine specific hash as a user id. We cache the user hash in a file to avoid potential user_name or hostname changes causing a new user hash to be generated. + + Args: + force_fresh_hash: Bypasses the cached hash in USER_HASH_FILE and the + hash in the USER_ID_ENV_VAR and forces a fresh user-machine hash + to be generated. Used by `kubernetes.ssh_key_secret_field_name` to + avoid controllers sharing the same ssh key field name as the + local client. """ def _is_valid_user_hash(user_hash: Optional[str]) -> bool: @@ -77,12 +84,13 @@ def _is_valid_user_hash(user_hash: Optional[str]) -> bool: return False return len(user_hash) == USER_HASH_LENGTH - user_hash = os.getenv(constants.USER_ID_ENV_VAR) - if _is_valid_user_hash(user_hash): - assert user_hash is not None - return user_hash + if not force_fresh_hash: + user_hash = os.getenv(constants.USER_ID_ENV_VAR) + if _is_valid_user_hash(user_hash): + assert user_hash is not None + return user_hash - if os.path.exists(_USER_HASH_FILE): + if not force_fresh_hash and os.path.exists(_USER_HASH_FILE): # Read from cached user hash file. with open(_USER_HASH_FILE, 'r', encoding='utf-8') as f: # Remove invalid characters. @@ -96,8 +104,13 @@ def _is_valid_user_hash(user_hash: Optional[str]) -> bool: # A fallback in case the hash is invalid. user_hash = uuid.uuid4().hex[:USER_HASH_LENGTH] os.makedirs(os.path.dirname(_USER_HASH_FILE), exist_ok=True) - with open(_USER_HASH_FILE, 'w', encoding='utf-8') as f: - f.write(user_hash) + if not force_fresh_hash: + # Do not cache to file if force_fresh_hash is True since the file may + # be intentionally using a different hash, e.g. we want to keep the + # user_hash for usage collection the same on the jobs/serve controller + # as users' local client. + with open(_USER_HASH_FILE, 'w', encoding='utf-8') as f: + f.write(user_hash) return user_hash @@ -439,9 +452,9 @@ def class_fullname(cls, skip_builtins: bool = True): """Get the full name of a class. Example: - >>> e = sky.exceptions.FetchIPError() + >>> e = sky.exceptions.FetchClusterInfoError() >>> class_fullname(e.__class__) - 'sky.exceptions.FetchIPError' + 'sky.exceptions.FetchClusterInfoError' Args: cls: The class to get the full name. diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index b4a312ac1ab..9908fa54286 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -228,6 +228,25 @@ def _get_cloud_dependencies_installation_commands( 'pip list | grep google-cloud-storage > /dev/null 2>&1 || ' 'pip install google-cloud-storage > /dev/null 2>&1') commands.append(f'{gcp.GOOGLE_SDK_INSTALLATION_COMMAND}') + elif isinstance(cloud, clouds.Kubernetes): + commands.append( + f'echo -en "\\r{prefix_str}Kubernetes{empty_str}" && ' + 'pip list | grep kubernetes > /dev/null 2>&1 || ' + 'pip install "kubernetes>=20.0.0" > /dev/null 2>&1 &&' + # Install k8s + skypilot dependencies + 'sudo bash -c "if ' + '! command -v curl &> /dev/null || ' + '! command -v socat &> /dev/null || ' + '! command -v netcat &> /dev/null; ' + 'then apt update && apt install curl socat netcat -y; ' + 'fi" && ' + # Install kubectl + '(command -v kubectl &>/dev/null || ' + '(curl -s -LO "https://dl.k8s.io/release/' + '$(curl -L -s https://dl.k8s.io/release/stable.txt)' + '/bin/linux/amd64/kubectl" && ' + 'sudo install -o root -g root -m 0755 ' + 'kubectl /usr/local/bin/kubectl))') if controller == Controllers.JOBS_CONTROLLER: if isinstance(cloud, clouds.IBM): commands.append( @@ -239,11 +258,6 @@ def _get_cloud_dependencies_installation_commands( commands.append(f'echo -en "\\r{prefix_str}OCI{empty_str}" && ' 'pip list | grep oci > /dev/null 2>&1 || ' 'pip install oci > /dev/null 2>&1') - elif isinstance(cloud, clouds.Kubernetes): - commands.append( - f'echo -en "\\r{prefix_str}Kubernetes{empty_str}" && ' - 'pip list | grep kubernetes > /dev/null 2>&1 || ' - 'pip install "kubernetes>=20.0.0" > /dev/null 2>&1') elif isinstance(cloud, clouds.RunPod): commands.append( f'echo -en "\\r{prefix_str}RunPod{empty_str}" && ' @@ -671,7 +685,28 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', # whenever task.storage_mounts is non-empty. logger.info(f'{colorama.Fore.YELLOW}Uploading sources to cloud storage.' f'{colorama.Style.RESET_ALL} See: sky storage ls') - task.sync_storage_mounts() + try: + task.sync_storage_mounts() + except ValueError as e: + if 'No enabled cloud for storage' in str(e): + data_src = None + if has_local_source_paths_file_mounts: + data_src = 'file_mounts' + if has_local_source_paths_workdir: + if data_src: + data_src += ' and workdir' + else: + data_src = 'workdir' + store_enabled_clouds = ', '.join(storage_lib.STORE_ENABLED_CLOUDS) + with ux_utils.print_exception_no_traceback(): + raise exceptions.NotSupportedError( + f'Unable to use {data_src} - no cloud with object store ' + 'is enabled. Please enable at least one cloud with ' + f'object store support ({store_enabled_clouds}) by running ' + f'`sky check`, or remove {data_src} from your task.' + '\nHint: If you do not have any cloud access, you may still' + ' download data and code over the network using curl or ' + 'other tools in the `setup` section of the task.') from None # Step 5: Add the file download into the file mounts, such as # /original-dst: s3://spot-fm-file-only-bucket-name/file-0 diff --git a/sky/utils/kubernetes/create_cluster.sh b/sky/utils/kubernetes/create_cluster.sh index b5911b1acbe..62fb700edf3 100755 --- a/sky/utils/kubernetes/create_cluster.sh +++ b/sky/utils/kubernetes/create_cluster.sh @@ -195,7 +195,7 @@ if $ENABLE_GPUS; then echo "Enabling GPU support..." # Run patch for missing ldconfig.real # https://github.com/NVIDIA/nvidia-docker/issues/614#issuecomment-423991632 - docker exec -ti skypilot-control-plane ln -s /sbin/ldconfig /sbin/ldconfig.real + docker exec -ti skypilot-control-plane /bin/bash -c '[ ! -f /sbin/ldconfig.real ] && ln -s /sbin/ldconfig /sbin/ldconfig.real || echo "/sbin/ldconfig.real already exists"' echo "Installing NVIDIA GPU operator..." # Install the NVIDIA GPU operator diff --git a/sky/utils/kubernetes/generate_static_kubeconfig.sh b/sky/utils/kubernetes/generate_static_kubeconfig.sh new file mode 100755 index 00000000000..30ea929177a --- /dev/null +++ b/sky/utils/kubernetes/generate_static_kubeconfig.sh @@ -0,0 +1,137 @@ +#!/bin/bash +# This script creates a new k8s Service Account and generates a kubeconfig with +# its credentials. This Service Account has all the necessary permissions for +# SkyPilot. The kubeconfig is written in the current directory. +# +# You must configure your local kubectl to point to the right k8s cluster and +# have admin-level access. +# +# Note: all of the k8s resources are created in namespace "skypilot". If you +# delete any of these objects, SkyPilot will stop working. +# +# You can override the default namespace "skypilot" using the +# SKYPILOT_NAMESPACE environment variable. +# You can override the default service account name "skypilot-sa" using the +# SKYPILOT_SA_NAME environment variable. + +set -eu -o pipefail + +# Allow passing in common name and username in environment. If not provided, +# use default. +SKYPILOT_SA=${SKYPILOT_SA_NAME:-skypilot-sa} +NAMESPACE=${SKYPILOT_NAMESPACE:-default} + +# Set OS specific values. +if [[ "$OSTYPE" == "linux-gnu" ]]; then + BASE64_DECODE_FLAG="-d" +elif [[ "$OSTYPE" == "darwin"* ]]; then + BASE64_DECODE_FLAG="-D" +elif [[ "$OSTYPE" == "linux-musl" ]]; then + BASE64_DECODE_FLAG="-d" +else + echo "Unknown OS ${OSTYPE}" + exit 1 +fi + +echo "Creating the Kubernetes Service Account with minimal RBAC permissions." +kubectl apply -f - < kubeconfig < ~/controller.log 2>&1 & + python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 & sleep 10 echo 'Starting model worker...' python -u -m fastchat.serve.model_worker \ - --model-path lmsys/$MODEL_NAME 2>&1 \ + --host 127.0.0.1 \ + --model-path lmsys/$MODEL_NAME 2>&1 \ | tee model_worker.log & echo 'Waiting for model worker to start...' diff --git a/tests/skyserve/restart/user_bug.yaml b/tests/skyserve/restart/user_bug.yaml index b3cbf9e907d..959e725d23d 100644 --- a/tests/skyserve/restart/user_bug.yaml +++ b/tests/skyserve/restart/user_bug.yaml @@ -8,7 +8,6 @@ service: resources: ports: 8080 cpus: 2+ - use_spot: True workdir: tests/skyserve/restart diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 717eb2388c7..40d6cd14b70 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -284,11 +284,13 @@ def test_minimal(generic_cloud: str): # Ensure the raylet process has the correct file descriptor limit. f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"', f'sky logs {name} 2 --status', # Ensure the job succeeded. + # Install jq for the next test. + f'sky exec {name} \'sudo apt-get update && sudo apt-get install -y jq\'', # Check the cluster info f'sky exec {name} \'echo "$SKYPILOT_CLUSTER_INFO" | jq .cluster_name | grep {name}\'', - f'sky logs {name} 3 --status', # Ensure the job succeeded. - f'sky exec {name} \'echo "$SKYPILOT_CLUSTER_INFO" | jq .cloud | grep -i {generic_cloud}\'', f'sky logs {name} 4 --status', # Ensure the job succeeded. + f'sky exec {name} \'echo "$SKYPILOT_CLUSTER_INFO" | jq .cloud | grep -i {generic_cloud}\'', + f'sky logs {name} 5 --status', # Ensure the job succeeded. ], f'sky down -y {name}', _get_timeout(generic_cloud), @@ -1918,8 +1920,9 @@ def test_azure_start_stop(): f'sky start -y {name} -i 1', f'sky exec {name} examples/azure_start_stop.yaml', f'sky logs {name} 3 --status', # Ensure the job succeeded. - 'sleep 200', + 'sleep 260', f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' ], f'sky down -y {name}', timeout=30 * 60, # 30 mins @@ -2931,9 +2934,12 @@ def test_azure_start_stop_two_nodes(): f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml', f'sky logs {name} 1 --status', # Ensure the job succeeded. f'sky stop -y {name}', - f'sky start -y {name}', + f'sky start -y {name} -i 1', f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml', f'sky logs {name} 2 --status', # Ensure the job succeeded. + 'sleep 200', + f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' ], f'sky down -y {name}', timeout=30 * 60, # 30 mins (it takes around ~23 mins) @@ -3243,8 +3249,16 @@ def test_skyserve_azure_http(): run_one_test(test) +@pytest.mark.kubernetes +@pytest.mark.serve +def test_skyserve_kubernetes_http(): + """Test skyserve on Kubernetes""" + name = _get_service_name() + test = _get_skyserve_http_test(name, 'kubernetes', 30) + run_one_test(test) + + @pytest.mark.serve -@pytest.mark.no_kubernetes def test_skyserve_llm(generic_cloud: str): """Test skyserve with real LLM usecase""" name = _get_service_name() @@ -3272,7 +3286,7 @@ def generate_llm_test_command(prompt: str, expected_output: str) -> str: ], ], _TEARDOWN_SERVICE.format(name=name), - timeout=25 * 60, + timeout=40 * 60, ) run_one_test(test) @@ -3366,7 +3380,6 @@ def test_skyserve_dynamic_ondemand_fallback(): @pytest.mark.serve -@pytest.mark.no_kubernetes def test_skyserve_user_bug_restart(generic_cloud: str): """Tests that we restart the service after user bug.""" # TODO(zhwu): this behavior needs some rethinking. @@ -3400,7 +3413,7 @@ def test_skyserve_user_bug_restart(generic_cloud: str): @pytest.mark.serve -@pytest.mark.no_kubernetes +@pytest.mark.no_kubernetes # Replicas on k8s may be running on the same node and have the same public IP def test_skyserve_load_balancer(generic_cloud: str): """Test skyserve load balancer round-robin policy""" name = _get_service_name() @@ -3466,7 +3479,6 @@ def test_skyserve_auto_restart(): @pytest.mark.serve -@pytest.mark.no_kubernetes def test_skyserve_cancel(generic_cloud: str): """Test skyserve with cancel""" name = _get_service_name() @@ -3492,7 +3504,6 @@ def test_skyserve_cancel(generic_cloud: str): @pytest.mark.serve -@pytest.mark.no_kubernetes def test_skyserve_streaming(generic_cloud: str): """Test skyserve with streaming""" name = _get_service_name() @@ -3512,7 +3523,6 @@ def test_skyserve_streaming(generic_cloud: str): @pytest.mark.serve -@pytest.mark.no_kubernetes def test_skyserve_update(generic_cloud: str): """Test skyserve with update""" name = _get_service_name() @@ -3541,7 +3551,6 @@ def test_skyserve_update(generic_cloud: str): @pytest.mark.serve -@pytest.mark.no_kubernetes def test_skyserve_rolling_update(generic_cloud: str): """Test skyserve with rolling update""" name = _get_service_name() @@ -3578,7 +3587,6 @@ def test_skyserve_rolling_update(generic_cloud: str): @pytest.mark.serve -@pytest.mark.no_kubernetes def test_skyserve_fast_update(generic_cloud: str): """Test skyserve with fast update (Increment version of old replicas)""" name = _get_service_name() @@ -3591,7 +3599,7 @@ def test_skyserve_fast_update(generic_cloud: str): f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/bump_version_after.yaml', # sleep to wait for update to be registered. - 'sleep 120', + 'sleep 30', # 2 on-deamnd (ready) + 1 on-demand (provisioning). ( _check_replica_in_status( @@ -3605,7 +3613,7 @@ def test_skyserve_fast_update(generic_cloud: str): # Test rolling update f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/bump_version_before.yaml', # sleep to wait for update to be registered. - 'sleep 30', + 'sleep 15', # 2 on-deamnd (ready) + 1 on-demand (shutting down). _check_replica_in_status(name, [(2, False, 'READY'), (1, False, 'SHUTTING_DOWN')]), @@ -3620,7 +3628,6 @@ def test_skyserve_fast_update(generic_cloud: str): @pytest.mark.serve -@pytest.mark.no_kubernetes def test_skyserve_update_autoscale(generic_cloud: str): """Test skyserve update with autoscale""" name = _get_service_name() @@ -3657,8 +3664,8 @@ def test_skyserve_update_autoscale(generic_cloud: str): @pytest.mark.serve +@pytest.mark.no_kubernetes # Spot instances are not supported in Kubernetes @pytest.mark.parametrize('mode', ['rolling', 'blue_green']) -@pytest.mark.no_kubernetes def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): """Test skyserve with update that changes autoscaler""" name = _get_service_name() + mode