From 5a2f1b855e0a8b0a099390882209914f34ea8b98 Mon Sep 17 00:00:00 2001 From: Tian Xia Date: Tue, 14 May 2024 21:44:03 +0800 Subject: [PATCH 01/20] [Serve] Proxy w/ retry (#3395) * init * support streaming * max reetry num * upd comments * remove -L in documentations * streaming smoke test. TODO: debug and make sure it works * Apply suggestions from code review Co-authored-by: Zhanghao Wu * comments and expose exceptions in smoke test * upd smoke test and passed * timeout * yield error * remove -L * apply suggestions from code review * add threading lock * apply suggestions from code review * comments for limit on client * Update sky/serve/load_balancer.py Co-authored-by: Zhanghao Wu * Update sky/serve/load_balancer.py Co-authored-by: Zhanghao Wu * Update sky/serve/load_balancer.py Co-authored-by: Zhanghao Wu * format * retry for no replicas as well * check disconnect if no replcias * format * minor * async probe controller; close clients in the background * async * comments * Update sky/serve/load_balancer.py Co-authored-by: Zhanghao Wu * format * fix --------- Co-authored-by: Zhanghao Wu --- docs/source/serving/sky-serve.rst | 22 +- examples/cog/README.md | 2 +- examples/serve/misc/cancel/README.md | 4 +- examples/serve/stable_diffusion_service.yaml | 2 +- .../stable_diffusion_docker.yaml | 2 +- llm/codellama/README.md | 6 +- llm/dbrx/README.md | 2 +- llm/gemma/README.md | 8 +- llm/mixtral/README.md | 8 +- llm/qwen/README.md | 8 +- llm/sglang/README.md | 4 +- llm/tgi/README.md | 4 +- llm/vllm/README.md | 4 +- sky/serve/README.md | 6 +- sky/serve/constants.py | 12 + sky/serve/core.py | 2 +- sky/serve/load_balancer.py | 207 ++++++++++++++---- sky/serve/load_balancing_policies.py | 50 +++-- sky/templates/sky-serve-controller.yaml.j2 | 2 + tests/skyserve/auto_restart.yaml | 1 + tests/skyserve/llm/get_response.py | 2 + tests/skyserve/streaming/example.txt | 1 + .../streaming/send_streaming_request.py | 24 ++ tests/skyserve/streaming/server.py | 24 ++ tests/skyserve/streaming/streaming.yaml | 13 ++ tests/test_smoke.py | 59 +++-- 26 files changed, 348 insertions(+), 131 deletions(-) create mode 100644 tests/skyserve/streaming/example.txt create mode 100644 tests/skyserve/streaming/send_streaming_request.py create mode 100644 tests/skyserve/streaming/server.py create mode 100644 tests/skyserve/streaming/streaming.yaml diff --git a/docs/source/serving/sky-serve.rst b/docs/source/serving/sky-serve.rst index 1c4ee3f2751..3ccbed140c0 100644 --- a/docs/source/serving/sky-serve.rst +++ b/docs/source/serving/sky-serve.rst @@ -22,7 +22,7 @@ Why SkyServe? How it works: -- Each service gets an endpoint that automatically redirects requests to its replicas. +- Each service gets an endpoint that automatically distributes requests to its replicas. - Replicas of the same service can run in different regions and clouds — reducing cloud costs and increasing availability. - SkyServe handles the load balancing, recovery, and autoscaling of the replicas. @@ -127,7 +127,7 @@ Run :code:`sky serve up service.yaml` to deploy the service with automatic price If you see the :code:`STATUS` column becomes :code:`READY`, then the service is ready to accept traffic! -Simply ``curl -L`` the service endpoint, which automatically load-balances across the two replicas: +Simply ``curl`` the service endpoint, which automatically load-balances across the two replicas: .. tab-set:: @@ -136,7 +136,7 @@ Simply ``curl -L`` the service endpoint, which automatically load-balances acros .. code-block:: console - $ curl -L 3.84.15.251:30001/v1/chat/completions \ + $ curl 3.84.15.251:30001/v1/chat/completions \ -X POST \ -d '{"model": "mistralai/Mixtral-8x7B-Instruct-v0.1", "messages": [{"role": "user", "content": "Who are you?"}]}' \ -H 'Content-Type: application/json' @@ -149,7 +149,7 @@ Simply ``curl -L`` the service endpoint, which automatically load-balances acros .. code-block:: console - $ curl -L 44.211.131.51:30001/generate \ + $ curl 44.211.131.51:30001/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' @@ -240,7 +240,7 @@ Under the hood, :code:`sky serve up`: #. Launches a controller which handles autoscaling, monitoring and load balancing; #. Returns a Service Endpoint which will be used to accept traffic; #. Meanwhile, the controller provisions replica VMs which later run the services; -#. Once any replica is ready, the requests sent to the Service Endpoint will be **HTTP-redirect** to one of the endpoint replicas. +#. Once any replica is ready, the requests sent to the Service Endpoint will be distributed to one of the endpoint replicas. After the controller is provisioned, you'll see the following in :code:`sky serve status` output: @@ -264,7 +264,7 @@ sending requests to :code:`` (e.g., ``44.201.119.3:30001``): .. code-block:: console - $ curl -L + $ curl My First SkyServe Service @@ -274,12 +274,6 @@ sending requests to :code:`` (e.g., ``44.201.119.3:30001``): -.. note:: - - Since we are using HTTP-redirect, we need to use :code:`curl -L - `. The :code:`curl` command by default won't follow the - redirect. - Tutorial: Serve a Chatbot LLM! ------------------------------ @@ -368,7 +362,7 @@ Send a request using the following cURL command: .. code-block:: console - $ curl -L http:///v1/chat/completions \ + $ curl http:///v1/chat/completions \ -X POST \ -d '{"model":"vicuna-7b-v1.3","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Who are you?"}],"temperature":0}' \ -H 'Content-Type: application/json' @@ -468,7 +462,7 @@ SkyServe has a centralized controller VM that manages the deployment of your ser It is composed of the following components: #. **Controller**: The controller will monitor the status of the replicas and re-launch a new replica if one of them fails. It also autoscales the number of replicas if autoscaling config is set (see :ref:`Service YAML spec ` for more information). -#. **Load Balancer**: The load balancer will route the traffic to all ready replicas. It is a lightweight HTTP server that listens on the service endpoint and **HTTP-redirects** the requests to one of the replicas. +#. **Load Balancer**: The load balancer will route the traffic to all ready replicas. It is a lightweight HTTP server that listens on the service endpoint and distribute the requests to one of the replicas. All of the process group shares a single controller VM. The controller VM will be launched in the cloud with the best price/performance ratio. You can also :ref:`customize the controller resources ` based on your needs. diff --git a/examples/cog/README.md b/examples/cog/README.md index 4fa4890420f..b2193e2e18f 100644 --- a/examples/cog/README.md +++ b/examples/cog/README.md @@ -28,7 +28,7 @@ After the service is launched, access the deployment with the following: ```console ENDPOINT=$(sky serve status --endpoint cog) -curl -L http://$ENDPOINT/predictions -X POST \ +curl http://$ENDPOINT/predictions -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"image": "https://blog.skypilot.co/introducing-sky-serve/images/sky-serve-thumbnail.png"}}' \ | jq -r '.output | split(",")[1]' | base64 --decode > output.png diff --git a/examples/serve/misc/cancel/README.md b/examples/serve/misc/cancel/README.md index 65b88c2d540..61c24383909 100644 --- a/examples/serve/misc/cancel/README.md +++ b/examples/serve/misc/cancel/README.md @@ -1,6 +1,6 @@ # SkyServe cancel example -This example demonstrates the redirect support canceling a request. +This example demonstrates the SkyServe load balancer support canceling a request. ## Running the example @@ -33,7 +33,7 @@ Client disconnected, stopping computation. You can also run ```bash -curl -L http:/// +curl http:/// ``` and manually Ctrl + C to cancel the request and see logs. diff --git a/examples/serve/stable_diffusion_service.yaml b/examples/serve/stable_diffusion_service.yaml index 86ef257e7ca..2adaf6e4ca6 100644 --- a/examples/serve/stable_diffusion_service.yaml +++ b/examples/serve/stable_diffusion_service.yaml @@ -18,7 +18,7 @@ file_mounts: /stable_diffusion: examples/stable_diffusion setup: | - sudo curl -L "https://github.com/docker/compose/releases/download/1.29.2/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose + sudo curl "https://github.com/docker/compose/releases/download/1.29.2/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose sudo chmod +x /usr/local/bin/docker-compose cd stable-diffusion-webui-docker sudo rm -r stable-diffusion-webui-docker diff --git a/examples/stable_diffusion/stable_diffusion_docker.yaml b/examples/stable_diffusion/stable_diffusion_docker.yaml index 9c07790ba6b..47499fa2ea4 100644 --- a/examples/stable_diffusion/stable_diffusion_docker.yaml +++ b/examples/stable_diffusion/stable_diffusion_docker.yaml @@ -7,7 +7,7 @@ file_mounts: /stable_diffusion: . setup: | - sudo curl -L "https://github.com/docker/compose/releases/download/1.29.2/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose + sudo curl "https://github.com/docker/compose/releases/download/1.29.2/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose sudo chmod +x /usr/local/bin/docker-compose cd stable-diffusion-webui-docker sudo rm -r stable-diffusion-webui-docker diff --git a/llm/codellama/README.md b/llm/codellama/README.md index 1ed02e301d1..8e5025d22b5 100644 --- a/llm/codellama/README.md +++ b/llm/codellama/README.md @@ -68,7 +68,7 @@ Launching a cluster 'code-llama'. Proceed? [Y/n]: ```bash IP=$(sky status --ip code-llama) -curl -L http://$IP:8000/v1/completions \ +curl http://$IP:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "codellama/CodeLlama-70b-Instruct-hf", @@ -131,7 +131,7 @@ availability of the service while minimizing the cost. ```bash ENDPOINT=$(sky serve status --endpoint code-llama) -curl -L http://$ENDPOINT/v1/completions \ +curl http://$ENDPOINT/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "codellama/CodeLlama-70b-Instruct-hf", @@ -146,7 +146,7 @@ We can also access the Code Llama service with the openAI Chat API. ```bash ENDPOINT=$(sky serve status --endpoint code-llama) -curl -L http://$ENDPOINT/v1/chat/completions \ +curl http://$ENDPOINT/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "codellama/CodeLlama-70b-Instruct-hf", diff --git a/llm/dbrx/README.md b/llm/dbrx/README.md index 4cb6ad47d6e..e0ad216e92c 100644 --- a/llm/dbrx/README.md +++ b/llm/dbrx/README.md @@ -256,7 +256,7 @@ ENDPOINT=$(sky serve status --endpoint dbrx) To curl the endpoint: ```console -curl -L $ENDPOINT/v1/chat/completions \ +curl $ENDPOINT/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "databricks/dbrx-instruct", diff --git a/llm/gemma/README.md b/llm/gemma/README.md index 2801c4fd6f3..ef5027b2807 100644 --- a/llm/gemma/README.md +++ b/llm/gemma/README.md @@ -37,7 +37,7 @@ After the cluster is launched, we can access the model with the following comman ```bash IP=$(sky status --ip gemma) -curl -L http://$IP:8000/v1/completions \ +curl http://$IP:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "google/gemma-7b-it", @@ -50,7 +50,7 @@ Chat API is also supported: ```bash IP=$(sky status --ip gemma) -curl -L http://$IP:8000/v1/chat/completions \ +curl http://$IP:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "google/gemma-7b-it", @@ -78,7 +78,7 @@ After the cluster is launched, we can access the model with the following comman ```bash ENDPOINT=$(sky serve status --endpoint gemma) -curl -L http://$ENDPOINT/v1/completions \ +curl http://$ENDPOINT/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "google/gemma-7b-it", @@ -89,7 +89,7 @@ curl -L http://$ENDPOINT/v1/completions \ Chat API is also supported: ```bash -curl -L http://$ENDPOINT/v1/chat/completions \ +curl http://$ENDPOINT/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "google/gemma-7b-it", diff --git a/llm/mixtral/README.md b/llm/mixtral/README.md index 208b40ca14b..0bddb77c665 100644 --- a/llm/mixtral/README.md +++ b/llm/mixtral/README.md @@ -53,7 +53,7 @@ We can now access the model through the OpenAI API with the IP and port: ```bash IP=$(sky status --ip mixtral) -curl -L http://$IP:8000/v1/completions \ +curl http://$IP:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", @@ -66,7 +66,7 @@ Chat API is also supported: ```bash IP=$(sky status --ip mixtral) -curl -L http://$IP:8000/v1/chat/completions \ +curl http://$IP:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", @@ -119,7 +119,7 @@ After the `sky serve up` command, there will be a single endpoint for the servic ```bash ENDPOINT=$(sky serve status --endpoint mixtral) -curl -L http://$ENDPOINT/v1/completions \ +curl http://$ENDPOINT/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", @@ -132,7 +132,7 @@ Chat API is also supported: ```bash ENDPOINT=$(sky serve status --endpoint mixtral) -curl -L http://$ENDPOINT/v1/chat/completions \ +curl http://$ENDPOINT/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", diff --git a/llm/qwen/README.md b/llm/qwen/README.md index 6ab9bb22ffc..113bbd9e740 100644 --- a/llm/qwen/README.md +++ b/llm/qwen/README.md @@ -34,7 +34,7 @@ sky launch -c qwen serve-110b.yaml ```bash IP=$(sky status --ip qwen) -curl -L http://$IP:8000/v1/completions \ +curl http://$IP:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen1.5-110B-Chat", @@ -45,7 +45,7 @@ curl -L http://$IP:8000/v1/completions \ 3. Send a request for chat completion: ```bash -curl -L http://$IP:8000/v1/chat/completions \ +curl http://$IP:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen1.5-110B-Chat", @@ -92,11 +92,11 @@ As shown, the service is now backed by 2 replicas, one on Azure and one on GCP, type is chosen to be **the cheapest available one** on the clouds. That said, it maximizes the availability of the service while minimizing the cost. -3. To access the model, we use a `curl -L` command (`-L` to follow redirect) to send the request to the endpoint: +3. To access the model, we use a `curl` command to send the request to the endpoint: ```bash ENDPOINT=$(sky serve status --endpoint qwen) -curl -L http://$ENDPOINT/v1/chat/completions \ +curl http://$ENDPOINT/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen1.5-72B-Chat", diff --git a/llm/sglang/README.md b/llm/sglang/README.md index 3ffcc2f484b..fc79529148a 100644 --- a/llm/sglang/README.md +++ b/llm/sglang/README.md @@ -68,7 +68,7 @@ ENDPOINT=$(sky serve status --endpoint sglang-llava) ```bash -curl -L $ENDPOINT/v1/chat/completions \ +curl $ENDPOINT/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "liuhaotian/llava-v1.6-vicuna-7b", @@ -149,7 +149,7 @@ ENDPOINT=$(sky serve status --endpoint sglang-llama2) 4. Once it status is `READY`, you can use the endpoint to interact with the model: ```bash -curl -L $ENDPOINT/v1/chat/completions \ +curl $ENDPOINT/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Llama-2-7b-chat-hf", diff --git a/llm/tgi/README.md b/llm/tgi/README.md index 8fb68222d68..8c8360d0465 100644 --- a/llm/tgi/README.md +++ b/llm/tgi/README.md @@ -17,7 +17,7 @@ A user can access the model with the following command: ```bash ENDPOINT=$(sky status --endpoint 8080 tgi) -curl -L $(sky serve status tgi --endpoint)/generate \ +curl $(sky serve status tgi --endpoint)/generate \ -H 'Content-Type: application/json' \ -d '{ "inputs": "What is Deep Learning?", @@ -51,7 +51,7 @@ After the service is launched, we can access the model with the following comman ```bash ENDPOINT=$(sky serve status --endpoint tgi) -curl -L $ENDPOINT/generate \ +curl $ENDPOINT/generate \ -H 'Content-Type: application/json' \ -d '{ "inputs": "What is Deep Learning?", diff --git a/llm/vllm/README.md b/llm/vllm/README.md index 568b8ff70bd..61932cd8571 100644 --- a/llm/vllm/README.md +++ b/llm/vllm/README.md @@ -154,7 +154,7 @@ ENDPOINT=$(sky serve status --endpoint vllm-llama2) 4. Once it status is `READY`, you can use the endpoint to interact with the model: ```bash -curl -L $ENDPOINT/v1/chat/completions \ +curl $ENDPOINT/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Llama-2-7b-chat-hf", @@ -171,7 +171,7 @@ curl -L $ENDPOINT/v1/chat/completions \ }' ``` -Notice that it is the same with previously curl command, except for thr `-L` argument. You should get a similar response as the following: +Notice that it is the same with previously curl command. You should get a similar response as the following: ```console { diff --git a/sky/serve/README.md b/sky/serve/README.md index 1131849a8d3..838f4dd6d3b 100644 --- a/sky/serve/README.md +++ b/sky/serve/README.md @@ -2,7 +2,7 @@ Serving library for SkyPilot. -The goal of Sky Serve is simple - expose one endpoint, that redirects to serving endpoints running on different resources, regions and clouds. +The goal of Sky Serve is simple - exposing one endpoint, that distributes any incoming traffic to serving endpoints running on different resources, regions, and clouds. Sky Serve transparently handles load balancing, failover and autoscaling of the serving endpoints. @@ -11,8 +11,8 @@ Sky Serve transparently handles load balancing, failover and autoscaling of the ![Architecture](../../docs/source/images/sky-serve-architecture.png) Sky Serve has four key components: -1. Redirector - receiving requests and redirecting them to healthy endpoints. -2. Load balancers - spread requests across healthy endpoints according to different policies. +1. Load Balancers - receiving requests and distributing them to healthy endpoints. +2. Load Balancing Policies - spread requests across healthy endpoints according to different policies. 3. Autoscalers - scale up and down the number of serving endpoints according to different policies. 4. Replica Managers - monitoring replica status and handle recovery of unhealthy endpoints. diff --git a/sky/serve/constants.py b/sky/serve/constants.py index 07f3e837ed4..89ca683ada5 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -21,6 +21,18 @@ # interval. LB_CONTROLLER_SYNC_INTERVAL_SECONDS = 20 +# The maximum retry times for load balancer for each request. After changing to +# proxy implementation, we do retry for failed requests. +# TODO(tian): Expose this option to users in yaml file. +LB_MAX_RETRY = 3 + +# The timeout in seconds for load balancer to wait for a response from replica. +# Large LLMs like Llama2-70b is able to process the request within ~30 seconds. +# We set the timeout to 120s to be safe. For reference, FastChat uses 100s: +# https://github.com/lm-sys/FastChat/blob/f2e6ca964af7ad0585cadcf16ab98e57297e2133/fastchat/constants.py#L39 # pylint: disable=line-too-long +# TODO(tian): Expose this option to users in yaml file. +LB_STREAM_TIMEOUT = 120 + # Interval in seconds to probe replica endpoint. ENDPOINT_PROBE_INTERVAL_SECONDS = 10 diff --git a/sky/serve/core.py b/sky/serve/core.py index 79aa53f7b58..f193a85285b 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -285,7 +285,7 @@ def up( f'{backend_utils.BOLD}watch -n10 sky serve status {service_name}' f'{backend_utils.RESET_BOLD}' '\nTo send a test request:\t\t' - f'{backend_utils.BOLD}curl -L {endpoint}' + f'{backend_utils.BOLD}curl {endpoint}' f'{backend_utils.RESET_BOLD}' '\n' f'\n{fore.GREEN}SkyServe is spinning up your service now.' diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index 7864e242148..24d0958489d 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -1,29 +1,30 @@ -"""LoadBalancer: redirect any incoming request to an endpoint replica.""" +"""LoadBalancer: Distribute any incoming request to all ready replicas.""" +import asyncio import logging import threading -import time +from typing import Dict, Union +import aiohttp import fastapi -import requests +import httpx +from starlette import background import uvicorn from sky import sky_logging from sky.serve import constants from sky.serve import load_balancing_policies as lb_policies from sky.serve import serve_utils +from sky.utils import common_utils logger = sky_logging.init_logger(__name__) class SkyServeLoadBalancer: - """SkyServeLoadBalancer: redirect incoming traffic. + """SkyServeLoadBalancer: distribute incoming traffic with proxy. - This class accept any traffic to the controller and redirect it + This class accept any traffic to the controller and proxies it to the appropriate endpoint replica according to the load balancing policy. - - NOTE: HTTP redirect is used. Thus, when using `curl`, be sure to use - `curl -L`. """ def __init__(self, controller_url: str, load_balancer_port: int) -> None: @@ -34,14 +35,27 @@ def __init__(self, controller_url: str, load_balancer_port: int) -> None: load_balancer_port: The port where the load balancer listens to. """ self._app = fastapi.FastAPI() - self._controller_url = controller_url - self._load_balancer_port = load_balancer_port + self._controller_url: str = controller_url + self._load_balancer_port: int = load_balancer_port self._load_balancing_policy: lb_policies.LoadBalancingPolicy = ( lb_policies.RoundRobinPolicy()) self._request_aggregator: serve_utils.RequestsAggregator = ( serve_utils.RequestTimestamp()) - - def _sync_with_controller(self): + # TODO(tian): httpx.Client has a resource limit of 100 max connections + # for each client. We should wait for feedback on the best max + # connections. + # Reference: https://www.python-httpx.org/advanced/resource-limits/ + # + # If more than 100 requests are sent to the same replica, the + # httpx.Client will queue the requests and send them when a + # connection is available. + # Reference: https://github.com/encode/httpcore/blob/a8f80980daaca98d556baea1783c5568775daadc/httpcore/_async/connection_pool.py#L69-L71 # pylint: disable=line-too-long + self._client_pool: Dict[str, httpx.AsyncClient] = dict() + # We need this lock to avoid getting from the client pool while + # updating it from _sync_with_controller. + self._client_pool_lock: threading.Lock = threading.Lock() + + async def _sync_with_controller(self): """Sync with controller periodically. Every `constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS` seconds, the @@ -51,58 +65,157 @@ def _sync_with_controller(self): autoscaling decisions. """ # Sleep for a while to wait the controller bootstrap. - time.sleep(5) + await asyncio.sleep(5) while True: - with requests.Session() as session: + close_client_tasks = [] + async with aiohttp.ClientSession() as session: try: # Send request information - response = session.post( - self._controller_url + '/controller/load_balancer_sync', - json={ - 'request_aggregator': - self._request_aggregator.to_dict() - }, - timeout=5) - # Clean up after reporting request information to avoid OOM. - self._request_aggregator.clear() - response.raise_for_status() - ready_replica_urls = response.json().get( - 'ready_replica_urls') - except requests.RequestException as e: - print(f'An error occurred: {e}') + async with session.post( + self._controller_url + + '/controller/load_balancer_sync', + json={ + 'request_aggregator': + self._request_aggregator.to_dict() + }, + timeout=5, + ) as response: + # Clean up after reporting request info to avoid OOM. + self._request_aggregator.clear() + response.raise_for_status() + response_json = await response.json() + ready_replica_urls = response_json.get( + 'ready_replica_urls', []) + except aiohttp.ClientError as e: + logger.error('An error occurred when syncing with ' + f'the controller: {e}') else: logger.info(f'Available Replica URLs: {ready_replica_urls}') - self._load_balancing_policy.set_ready_replicas( - ready_replica_urls) - time.sleep(constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS) - - async def _redirect_handler(self, request: fastapi.Request): + with self._client_pool_lock: + self._load_balancing_policy.set_ready_replicas( + ready_replica_urls) + for replica_url in ready_replica_urls: + if replica_url not in self._client_pool: + self._client_pool[replica_url] = ( + httpx.AsyncClient(base_url=replica_url)) + urls_to_close = set( + self._client_pool.keys()) - set(ready_replica_urls) + client_to_close = [] + for replica_url in urls_to_close: + client_to_close.append( + self._client_pool.pop(replica_url)) + for client in client_to_close: + close_client_tasks.append(client.aclose()) + + await asyncio.sleep(constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS) + # Await those tasks after the interval to avoid blocking. + await asyncio.gather(*close_client_tasks) + + async def _proxy_request_to( + self, url: str, request: fastapi.Request + ) -> Union[fastapi.responses.Response, Exception]: + """Proxy the request to the specified URL. + + Returns: + The response from the endpoint replica. Return the exception + encountered if anything goes wrong. + """ + logger.info(f'Proxy request to {url}') + try: + # We defer the get of the client here on purpose, for case when the + # replica is ready in `_proxy_with_retries` but refreshed before + # entering this function. In that case we will return an error here + # and retry to find next ready replica. We also need to wait for the + # update of the client pool to finish before getting the client. + with self._client_pool_lock: + client = self._client_pool.get(url, None) + if client is None: + return RuntimeError(f'Client for {url} not found.') + worker_url = httpx.URL(path=request.url.path, + query=request.url.query.encode('utf-8')) + proxy_request = client.build_request( + request.method, + worker_url, + headers=request.headers.raw, + content=await request.body(), + timeout=constants.LB_STREAM_TIMEOUT) + proxy_response = await client.send(proxy_request, stream=True) + return fastapi.responses.StreamingResponse( + content=proxy_response.aiter_raw(), + status_code=proxy_response.status_code, + headers=proxy_response.headers, + background=background.BackgroundTask(proxy_response.aclose)) + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.error(f'Error when proxy request to {url}: ' + f'{common_utils.format_exception(e)}') + return e + + async def _proxy_with_retries( + self, request: fastapi.Request) -> fastapi.responses.Response: + """Try to proxy the request to the endpoint replica with retries.""" self._request_aggregator.add(request) - ready_replica_url = self._load_balancing_policy.select_replica(request) - - if ready_replica_url is None: - raise fastapi.HTTPException(status_code=503, - detail='No ready replicas. ' - 'Use "sky serve status [SERVICE_NAME]" ' - 'to check the replica status.') - - path = f'{ready_replica_url}{request.url.path}' - logger.info(f'Redirecting request to {path}') - return fastapi.responses.RedirectResponse(url=path) + # TODO(tian): Finetune backoff parameters. + backoff = common_utils.Backoff(initial_backoff=1) + # SkyServe supports serving on Spot Instances. To avoid preemptions + # during request handling, we add a retry here. + retry_cnt = 0 + while True: + retry_cnt += 1 + with self._client_pool_lock: + ready_replica_url = self._load_balancing_policy.select_replica( + request) + if ready_replica_url is None: + response_or_exception = fastapi.HTTPException( + # 503 means that the server is currently + # unable to handle the incoming requests. + status_code=503, + detail='No ready replicas. ' + 'Use "sky serve status [SERVICE_NAME]" ' + 'to check the replica status.') + else: + response_or_exception = await self._proxy_request_to( + ready_replica_url, request) + if not isinstance(response_or_exception, Exception): + return response_or_exception + # When the user aborts the request during streaming, the request + # will be disconnected. We do not need to retry for this case. + if await request.is_disconnected(): + # 499 means a client terminates the connection + # before the server is able to respond. + return fastapi.responses.Response(status_code=499) + # TODO(tian): Fail fast for errors like 404 not found. + if retry_cnt == constants.LB_MAX_RETRY: + if isinstance(response_or_exception, fastapi.HTTPException): + raise response_or_exception + exception = common_utils.remove_color( + common_utils.format_exception(response_or_exception, + use_bracket=True)) + raise fastapi.HTTPException( + # 500 means internal server error. + status_code=500, + detail=f'Max retries {constants.LB_MAX_RETRY} exceeded. ' + f'Last error encountered: {exception}. Please use ' + '"sky serve logs [SERVICE_NAME] --load-balancer" ' + 'for more information.') + current_backoff = backoff.current_backoff() + logger.error(f'Retry in {current_backoff} seconds.') + await asyncio.sleep(current_backoff) def run(self): self._app.add_api_route('/{path:path}', - self._redirect_handler, + self._proxy_with_retries, methods=['GET', 'POST', 'PUT', 'DELETE']) @self._app.on_event('startup') - def configure_logger(): + async def startup(): + # Configure logger uvicorn_access_logger = logging.getLogger('uvicorn.access') for handler in uvicorn_access_logger.handlers: handler.setFormatter(sky_logging.FORMATTER) - threading.Thread(target=self._sync_with_controller, daemon=True).start() + # Register controller synchronization task + asyncio.create_task(self._sync_with_controller()) logger.info('SkyServe Load Balancer started on ' f'http://0.0.0.0:{self._load_balancer_port}') diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py index c8c9aa07765..34c1fa4249b 100644 --- a/sky/serve/load_balancing_policies.py +++ b/sky/serve/load_balancing_policies.py @@ -11,6 +11,14 @@ logger = sky_logging.init_logger(__name__) +def _request_repr(request: 'fastapi.Request') -> str: + return ('') + + class LoadBalancingPolicy: """Abstract class for load balancing policies.""" @@ -20,39 +28,43 @@ def __init__(self) -> None: def set_ready_replicas(self, ready_replicas: List[str]) -> None: raise NotImplementedError + def select_replica(self, request: 'fastapi.Request') -> Optional[str]: + replica = self._select_replica(request) + if replica is not None: + logger.info(f'Selected replica {replica} ' + f'for request {_request_repr(request)}') + else: + logger.warning('No replica selected for request ' + f'{_request_repr(request)}') + return replica + # TODO(tian): We should have an abstract class for Request to # compatible with all frameworks. - def select_replica(self, request: 'fastapi.Request') -> Optional[str]: + def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: raise NotImplementedError class RoundRobinPolicy(LoadBalancingPolicy): """Round-robin load balancing policy.""" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self) -> None: + super().__init__() self.index = 0 def set_ready_replicas(self, ready_replicas: List[str]) -> None: - if set(ready_replicas) != set(self.ready_replicas): - # If the autoscaler keeps scaling up and down the replicas, - # we need this shuffle to not let the first replica have the - # most of the load. - random.shuffle(ready_replicas) - self.ready_replicas = ready_replicas - self.index = 0 + if set(self.ready_replicas) == set(ready_replicas): + return + # If the autoscaler keeps scaling up and down the replicas, + # we need this shuffle to not let the first replica have the + # most of the load. + random.shuffle(ready_replicas) + self.ready_replicas = ready_replicas + self.index = 0 - def select_replica(self, request: 'fastapi.Request') -> Optional[str]: + def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: + del request # Unused. if not self.ready_replicas: return None ready_replica_url = self.ready_replicas[self.index] self.index = (self.index + 1) % len(self.ready_replicas) - request_repr = ('') - logger.info(f'Selected replica {ready_replica_url} ' - f'for request {request_repr}') return ready_replica_url diff --git a/sky/templates/sky-serve-controller.yaml.j2 b/sky/templates/sky-serve-controller.yaml.j2 index 351a89ae7f6..8f79b653a2b 100644 --- a/sky/templates/sky-serve-controller.yaml.j2 +++ b/sky/templates/sky-serve-controller.yaml.j2 @@ -11,8 +11,10 @@ setup: | {%- endfor %} # Install serve dependencies. + # TODO(tian): Gather those into serve constants. pip list | grep uvicorn > /dev/null 2>&1 || pip install uvicorn > /dev/null 2>&1 pip list | grep fastapi > /dev/null 2>&1 || pip install fastapi > /dev/null 2>&1 + pip list | grep httpx > /dev/null 2>&1 || pip install httpx > /dev/null 2>&1 file_mounts: {{remote_task_yaml_path}}: {{local_task_yaml_path}} diff --git a/tests/skyserve/auto_restart.yaml b/tests/skyserve/auto_restart.yaml index f7dc2a13f07..2a3a31051b9 100644 --- a/tests/skyserve/auto_restart.yaml +++ b/tests/skyserve/auto_restart.yaml @@ -7,6 +7,7 @@ service: resources: ports: 8080 + cloud: gcp cpus: 2+ workdir: examples/serve/http_server diff --git a/tests/skyserve/llm/get_response.py b/tests/skyserve/llm/get_response.py index f0fa530effc..9dd6ea53804 100644 --- a/tests/skyserve/llm/get_response.py +++ b/tests/skyserve/llm/get_response.py @@ -27,4 +27,6 @@ 'temperature': 0, }) + if resp.status_code != 200: + raise RuntimeError(f'Failed to get response: {resp.text}') print(resp.json()['choices'][0]['message']['content']) diff --git a/tests/skyserve/streaming/example.txt b/tests/skyserve/streaming/example.txt new file mode 100644 index 00000000000..0e9cd7421d3 --- /dev/null +++ b/tests/skyserve/streaming/example.txt @@ -0,0 +1 @@ +Hello! How can I help you today? \ No newline at end of file diff --git a/tests/skyserve/streaming/send_streaming_request.py b/tests/skyserve/streaming/send_streaming_request.py new file mode 100644 index 00000000000..7c56d929761 --- /dev/null +++ b/tests/skyserve/streaming/send_streaming_request.py @@ -0,0 +1,24 @@ +import argparse + +import requests + +with open('tests/skyserve/streaming/example.txt', 'r') as f: + WORD_TO_STREAM = f.read() + +parser = argparse.ArgumentParser() +parser.add_argument('--endpoint', type=str, required=True) +args = parser.parse_args() +url = f'http://{args.endpoint}/' + +expected = WORD_TO_STREAM.split() +index = 0 +with requests.get(url, stream=True) as response: + response.raise_for_status() + for chunk in response.iter_content(chunk_size=8192): + if chunk: + current = chunk.decode().strip() + assert current == expected[index], (current, expected[index]) + index += 1 +assert index == len(expected) + +print('Streaming test passed') diff --git a/tests/skyserve/streaming/server.py b/tests/skyserve/streaming/server.py new file mode 100644 index 00000000000..d9528af2205 --- /dev/null +++ b/tests/skyserve/streaming/server.py @@ -0,0 +1,24 @@ +import asyncio + +import fastapi +import uvicorn + +with open('example.txt', 'r') as f: + WORD_TO_STREAM = f.read() + +app = fastapi.FastAPI() + + +@app.get('/') +async def stream(): + + async def generate_words(): + for word in WORD_TO_STREAM.split(): + yield word + "\n" + await asyncio.sleep(0.2) + + return fastapi.responses.StreamingResponse(generate_words(), + media_type="text/plain") + + +uvicorn.run(app, host='0.0.0.0', port=8080) diff --git a/tests/skyserve/streaming/streaming.yaml b/tests/skyserve/streaming/streaming.yaml new file mode 100644 index 00000000000..a352d120dde --- /dev/null +++ b/tests/skyserve/streaming/streaming.yaml @@ -0,0 +1,13 @@ +service: + readiness_probe: / + replicas: 1 + +resources: + cpus: 2+ + ports: 8080 + +workdir: tests/skyserve/streaming + +setup: pip install fastapi uvicorn + +run: python server.py diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 2d542001eb7..284ac5aa471 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -3183,7 +3183,7 @@ def _get_skyserve_http_test(name: str, cloud: str, f'sky serve up -n {name} -y tests/skyserve/http/{cloud}.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl -L http://$endpoint | grep "Hi, SkyPilot here"', + 'curl http://$endpoint | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=timeout_minutes * 60, @@ -3305,11 +3305,11 @@ def test_skyserve_spot_recovery(): f'sky serve up -n {name} -y tests/skyserve/spot/recovery.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl -L http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', _terminate_gcp_replica(name, zone, 1), _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl -L http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=20 * 60, @@ -3404,7 +3404,7 @@ def test_skyserve_user_bug_restart(generic_cloud: str): f'echo "$s" | grep -B 100 "NO_REPLICA" | grep "0/0"', f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/auto_restart.yaml', f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'until curl -L http://$endpoint | grep "Hi, SkyPilot here!"; do sleep 2; done; sleep 2; ' + 'until curl http://$endpoint | grep "Hi, SkyPilot here!"; do sleep 2; done; sleep 2; ' + _check_replica_in_status(name, [(1, False, 'READY'), (1, False, 'FAILED')]), ], @@ -3452,7 +3452,7 @@ def test_skyserve_auto_restart(): f'sky serve up -n {name} -y tests/skyserve/auto_restart.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl -L http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', # sleep for 20 seconds (initial delay) to make sure it will # be restarted f'sleep 20', @@ -3472,7 +3472,7 @@ def test_skyserve_auto_restart(): ' sleep 10;' f'done); sleep {serve.LB_CONTROLLER_SYNC_INTERVAL_SECONDS};', f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl -L http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=20 * 60, @@ -3505,6 +3505,25 @@ def test_skyserve_cancel(generic_cloud: str): run_one_test(test) +@pytest.mark.serve +def test_skyserve_streaming(generic_cloud: str): + """Test skyserve with streaming""" + name = _get_service_name() + test = Test( + f'test-skyserve-streaming', + [ + f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/streaming/streaming.yaml', + _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' + 'python3 tests/skyserve/streaming/send_streaming_request.py ' + '--endpoint $endpoint | grep "Streaming test passed"', + ], + _TEARDOWN_SERVICE.format(name=name), + timeout=20 * 60, + ) + run_one_test(test) + + @pytest.mark.serve def test_skyserve_update(generic_cloud: str): """Test skyserve with update""" @@ -3514,14 +3533,14 @@ def test_skyserve_update(generic_cloud: str): [ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/old.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl -L http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/new.yaml', # sleep before update is registered. 'sleep 20', f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'until curl -L http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done;' + 'until curl http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done;' # Make sure the traffic is not mixed - 'curl -L http://$endpoint | grep "Hi, new SkyPilot here"', + 'curl http://$endpoint | grep "Hi, new SkyPilot here"', # The latest 2 version should be READY and the older versions should be shutting down (_check_replica_in_status(name, [(2, False, 'READY'), (2, False, 'SHUTTING_DOWN')]) + @@ -3545,14 +3564,14 @@ def test_skyserve_rolling_update(generic_cloud: str): [ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/old.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl -L http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/new.yaml', # Make sure the traffic is mixed across two versions, the replicas # with even id will sleep 60 seconds before being ready, so we # should be able to get observe the period that the traffic is mixed # across two versions. f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'until curl -L http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done; sleep 2; ' + 'until curl http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done; sleep 2; ' # The latest version should have one READY and the one of the older versions should be shutting down f'{single_new_replica} {_check_service_version(name, "1,2")} ' # Check the output from the old version, immediately after the @@ -3561,7 +3580,7 @@ def test_skyserve_rolling_update(generic_cloud: str): # TODO(zhwu): we should have a more generalized way for checking the # mixed version of replicas to avoid depending on the specific # round robin load balancing policy. - 'curl -L http://$endpoint | grep "Hi, SkyPilot here"', + 'curl http://$endpoint | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=20 * 60, @@ -3579,7 +3598,7 @@ def test_skyserve_fast_update(generic_cloud: str): [ f'sky serve up -n {name} -y --cloud {generic_cloud} tests/skyserve/update/bump_version_before.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl -L http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/bump_version_after.yaml', # sleep to wait for update to be registered. 'sleep 30', @@ -3592,7 +3611,7 @@ def test_skyserve_fast_update(generic_cloud: str): _check_service_version(name, "2")), _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=3) + _check_service_version(name, "2"), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl -L http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', # Test rolling update f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/bump_version_before.yaml', # sleep to wait for update to be registered. @@ -3602,7 +3621,7 @@ def test_skyserve_fast_update(generic_cloud: str): (1, False, 'SHUTTING_DOWN')]), _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) + _check_service_version(name, "3"), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl -L http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=30 * 60, @@ -3621,7 +3640,7 @@ def test_skyserve_update_autoscale(generic_cloud: str): _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) + _check_service_version(name, "1"), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl -L http://$endpoint | grep "Hi, SkyPilot here"', + 'curl http://$endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/num_min_one.yaml', # sleep before update is registered. 'sleep 20', @@ -3629,7 +3648,7 @@ def test_skyserve_update_autoscale(generic_cloud: str): _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1) + _check_service_version(name, "2"), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl -L http://$endpoint | grep "Hi, SkyPilot here!"', + 'curl http://$endpoint | grep "Hi, SkyPilot here!"', # Rolling Update f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/num_min_two.yaml', # sleep before update is registered. @@ -3638,7 +3657,7 @@ def test_skyserve_update_autoscale(generic_cloud: str): _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) + _check_service_version(name, "3"), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl -L http://$endpoint | grep "Hi, SkyPilot here!"', + 'curl http://$endpoint | grep "Hi, SkyPilot here!"', ], _TEARDOWN_SERVICE.format(name=name), timeout=30 * 60, @@ -3680,7 +3699,7 @@ def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) + _check_service_version(name, "1"), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 's=$(curl -L http://$endpoint); echo "$s"; echo "$s" | grep "Hi, SkyPilot here"', + 's=$(curl http://$endpoint); echo "$s"; echo "$s" | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode {mode} -y tests/skyserve/update/new_autoscaler_after.yaml', # Wait for update to be registered f'sleep 120', @@ -3691,7 +3710,7 @@ def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): *update_check, _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=5), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl -L http://$endpoint | grep "Hi, SkyPilot here"', + 'curl http://$endpoint | grep "Hi, SkyPilot here"', _check_replica_in_status(name, [(4, True, 'READY'), (1, False, 'READY')]), ], From f24425c62299170d377268e91fee8a8fb5d43b12 Mon Sep 17 00:00:00 2001 From: David Tran Date: Thu, 16 May 2024 02:29:31 -0400 Subject: [PATCH 02/20] [UX] Support checking user specified cloud(s) with `sky check` from CLI (#3229) * add to cli * update docstr and move echo * rename import and vars to clouds * cache previously enabled clouds and account for disabled clouds * fix * filter output to user specified clouds * separate if for readability * remove unnecessary comment * show all enabled clouds, not just one being currently checked * wip * typing checks for cloudflare support * comments --------- Co-authored-by: David Tran Co-authored-by: Romil Bhardwaj --- sky/check.py | 79 ++++++++++++++++++++++++++++++++++++---------------- sky/cli.py | 16 +++++++---- 2 files changed, 65 insertions(+), 30 deletions(-) diff --git a/sky/check.py b/sky/check.py index 6818d80f3bf..d90fdffefb7 100644 --- a/sky/check.py +++ b/sky/check.py @@ -1,26 +1,32 @@ """Credential checks: check cloud credentials and enable clouds.""" import traceback -from typing import Dict, Iterable, List, Optional, Tuple +from types import ModuleType +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import click import colorama import rich -from sky import clouds +from sky import clouds as sky_clouds from sky import exceptions from sky import global_user_state from sky.adaptors import cloudflare from sky.utils import ux_utils -# TODO(zhwu): add check for a single cloud to improve performance -def check(quiet: bool = False, verbose: bool = False) -> None: +def check( + quiet: bool = False, + verbose: bool = False, + clouds: Optional[Tuple[str]] = None, +) -> None: echo = (lambda *_args, **_kwargs: None) if quiet else click.echo echo('Checking credentials to enable clouds for SkyPilot.') - enabled_clouds = [] + disabled_clouds = [] - def check_one_cloud(cloud_tuple: Tuple[str, clouds.Cloud]) -> None: + def check_one_cloud( + cloud_tuple: Tuple[str, Union[sky_clouds.Cloud, + ModuleType]]) -> None: cloud_repr, cloud = cloud_tuple echo(f' Checking {cloud_repr}...', nl=False) try: @@ -43,25 +49,47 @@ def check_one_cloud(cloud_tuple: Tuple[str, clouds.Cloud]) -> None: if reason is not None: echo(f' Hint: {reason}') else: + disabled_clouds.append(cloud_repr) echo(f' Reason: {reason}') - clouds_to_check = [ - (repr(cloud), cloud) for cloud in clouds.CLOUD_REGISTRY.values() - ] - clouds_to_check.append(('Cloudflare, for R2 object store', cloudflare)) + if clouds is not None: + clouds_to_check: List[Tuple[str, Any]] = [] + for cloud in clouds: + if cloud.lower() == 'cloudflare': + clouds_to_check.append( + ('Cloudflare, for R2 object store', cloudflare)) + else: + cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud) + assert cloud_obj is not None, f'Cloud {cloud!r} not found' + clouds_to_check.append((repr(cloud_obj), cloud_obj)) + else: + clouds_to_check = [(repr(cloud_obj), cloud_obj) + for cloud_obj in sky_clouds.CLOUD_REGISTRY.values()] + clouds_to_check.append(('Cloudflare, for R2 object store', cloudflare)) for cloud_tuple in sorted(clouds_to_check): check_one_cloud(cloud_tuple) - # Cloudflare is not a real cloud in clouds.CLOUD_REGISTRY, and should not be - # inserted into the DB (otherwise `sky launch` and other code would error - # out when it's trying to look it up in the registry). - enabled_clouds = [ + # Cloudflare is not a real cloud in sky_clouds.CLOUD_REGISTRY, and should + # not be inserted into the DB (otherwise `sky launch` and other code would + # error out when it's trying to look it up in the registry). + enabled_clouds_set = { cloud for cloud in enabled_clouds if not cloud.startswith('Cloudflare') - ] - global_user_state.set_enabled_clouds(enabled_clouds) - - if len(enabled_clouds) == 0: + } + disabled_clouds_set = { + cloud for cloud in disabled_clouds if not cloud.startswith('Cloudflare') + } + previously_enabled_clouds_set = { + repr(cloud) for cloud in global_user_state.get_cached_enabled_clouds() + } + + # Determine the set of enabled clouds: previously enabled clouds + newly + # enabled clouds - newly disabled clouds. + all_enabled_clouds = ((previously_enabled_clouds_set | enabled_clouds_set) - + disabled_clouds_set) + global_user_state.set_enabled_clouds(list(all_enabled_clouds)) + + if len(all_enabled_clouds) == 0: echo( click.style( 'No cloud is enabled. SkyPilot will not be able to run any ' @@ -70,11 +98,13 @@ def check_one_cloud(cloud_tuple: Tuple[str, clouds.Cloud]) -> None: bold=True)) raise SystemExit() else: + clouds_arg = (' ' + + ' '.join(disabled_clouds) if clouds is not None else '') echo( click.style( '\nTo enable a cloud, follow the hints above and rerun: ', - dim=True) + click.style('sky check', bold=True) + '\n' + - click.style( + dim=True) + click.style(f'sky check{clouds_arg}', bold=True) + + '\n' + click.style( 'If any problems remain, refer to detailed docs at: ' 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html', # pylint: disable=line-too-long dim=True)) @@ -82,13 +112,13 @@ def check_one_cloud(cloud_tuple: Tuple[str, clouds.Cloud]) -> None: # Pretty print for UX. if not quiet: enabled_clouds_str = '\n :heavy_check_mark: '.join( - [''] + sorted(enabled_clouds)) + [''] + sorted(all_enabled_clouds)) rich.print('\n[green]:tada: Enabled clouds :tada:' f'{enabled_clouds_str}[/green]') def get_cached_enabled_clouds_or_refresh( - raise_if_no_cloud_access: bool = False) -> List[clouds.Cloud]: + raise_if_no_cloud_access: bool = False) -> List[sky_clouds.Cloud]: """Returns cached enabled clouds and if no cloud is enabled, refresh. This function will perform a refresh if no public cloud is enabled. @@ -120,7 +150,8 @@ def get_cached_enabled_clouds_or_refresh( def get_cloud_credential_file_mounts( - excluded_clouds: Optional[Iterable[clouds.Cloud]]) -> Dict[str, str]: + excluded_clouds: Optional[Iterable[sky_clouds.Cloud]] +) -> Dict[str, str]: """Returns the files necessary to access all enabled clouds. Returns a dictionary that will be added to a task's file mounts @@ -130,7 +161,7 @@ def get_cloud_credential_file_mounts( file_mounts = {} for cloud in enabled_clouds: if (excluded_clouds is not None and - clouds.cloud_in_iterable(cloud, excluded_clouds)): + sky_clouds.cloud_in_iterable(cloud, excluded_clouds)): continue cloud_file_mounts = cloud.get_credential_file_mounts() file_mounts.update(cloud_file_mounts) diff --git a/sky/cli.py b/sky/cli.py index 5b180d25dc8..5abd12f4caa 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -47,7 +47,7 @@ import sky from sky import backends from sky import check as sky_check -from sky import clouds +from sky import clouds as sky_clouds from sky import core from sky import exceptions from sky import global_user_state @@ -479,7 +479,7 @@ def _parse_override_params( if cloud.lower() == 'none': override_params['cloud'] = None else: - override_params['cloud'] = clouds.CLOUD_REGISTRY.from_str(cloud) + override_params['cloud'] = sky_clouds.CLOUD_REGISTRY.from_str(cloud) if region is not None: if region.lower() == 'none': override_params['region'] = None @@ -2863,23 +2863,27 @@ def _down_or_stop(name: str): @cli.command() +@click.argument('clouds', required=False, type=str, nargs=-1) @click.option('--verbose', '-v', is_flag=True, default=False, help='Show the activated account for each cloud.') @usage_lib.entrypoint -def check(verbose: bool): +def check(clouds: Tuple[str], verbose: bool): """Check which clouds are available to use. This checks access credentials for all clouds supported by SkyPilot. If a cloud is detected to be inaccessible, the reason and correction steps will be shown. + If CLOUDS are specified, checks credentials for only those clouds. + The enabled clouds are cached and form the "search space" to be considered for each task. """ - sky_check.check(verbose=verbose) + clouds_arg = clouds if len(clouds) > 0 else None + sky_check.check(verbose=verbose, clouds=clouds_arg) @cli.command() @@ -2958,7 +2962,7 @@ def show_gpus( '--all-regions and --region flags cannot be used simultaneously.') # This will validate 'cloud' and raise if not found. - cloud_obj = clouds.CLOUD_REGISTRY.from_str(cloud) + cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud) service_catalog.validate_region_zone(region, None, clouds=cloud) show_all = all if show_all and accelerator_str is not None: @@ -2978,7 +2982,7 @@ def _output(): name, quantity = None, None # Kubernetes specific bools - cloud_is_kubernetes = isinstance(cloud_obj, clouds.Kubernetes) + cloud_is_kubernetes = isinstance(cloud_obj, sky_clouds.Kubernetes) kubernetes_autoscaling = kubernetes_utils.get_autoscaler_type( ) is not None From 1285e0314568ee11a53ae662b3ce7e35eb62f4f7 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Fri, 17 May 2024 02:18:17 +0900 Subject: [PATCH 03/20] [Serve] Return service name and endpoint from `sky.serve.up` (#3546) * Return service name and endpoint when calling sky serve up * Fix a dumb editor error * re-adding tuple return type * Adding docstr * Change format of return docstr * Fix pylint issue --- sky/serve/core.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sky/serve/core.py b/sky/serve/core.py index f193a85285b..09b6c9b5151 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -1,7 +1,7 @@ """SkyServe core APIs.""" import re import tempfile -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import colorama @@ -94,7 +94,7 @@ def _validate_service_task(task: 'sky.Task') -> None: def up( task: 'sky.Task', service_name: Optional[str] = None, -) -> None: +) -> Tuple[str, str]: """Spin up a service. Please refer to the sky.cli.serve_up for the document. @@ -102,6 +102,11 @@ def up( Args: task: sky.Task to serve up. service_name: Name of the service. + + Returns: + service_name: str; The name of the service. Same if passed in as an + argument. + endpoint: str; The service endpoint. """ if service_name is None: service_name = serve_utils.generate_service_name() @@ -292,6 +297,7 @@ def up( f'{style.RESET_ALL}' f'\n{fore.GREEN}The replicas should be ready within a ' f'short time.{style.RESET_ALL}') + return service_name, endpoint @usage_lib.entrypoint From 4a66806c4f3d907254f79b659499b66bf6868cca Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Thu, 16 May 2024 10:26:33 -0700 Subject: [PATCH 04/20] [docs] Docs for sky check specific cloud (#3558) docs --- docs/source/getting-started/installation.rst | 4 ++++ sky/cli.py | 12 +++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst index 490d2d76311..42336bcd5cc 100644 --- a/docs/source/getting-started/installation.rst +++ b/docs/source/getting-started/installation.rst @@ -164,6 +164,10 @@ section :ref:`below `. If your clouds show ``enabled`` --- |:tada:| |:tada:| **Congratulations!** |:tada:| |:tada:| You can now head over to :ref:`Quickstart ` to get started with SkyPilot. +.. tip:: + + To check credentials only for specific clouds, pass the clouds as arguments: :code:`sky check aws gcp` + .. _cloud-account-setup: Cloud account setup diff --git a/sky/cli.py b/sky/cli.py index 5abd12f4caa..365468f0bba 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -2862,7 +2862,7 @@ def _down_or_stop(name: str): progress.refresh() -@cli.command() +@cli.command(cls=_DocumentedCodeCommand) @click.argument('clouds', required=False, type=str, nargs=-1) @click.option('--verbose', '-v', @@ -2881,6 +2881,16 @@ def check(clouds: Tuple[str], verbose: bool): The enabled clouds are cached and form the "search space" to be considered for each task. + + Examples: + + .. code-block:: bash + + # Check credentials for all supported clouds. + sky check + \b + # Check only specific clouds - AWS and GCP. + sky check aws gcp """ clouds_arg = clouds if len(clouds) > 0 else None sky_check.check(verbose=verbose, clouds=clouds_arg) From eae8fc5740c0563b458fcb2cec6df5a7f0d0f9d1 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 16 May 2024 10:29:09 -0700 Subject: [PATCH 05/20] [UX] Error out for null env var (#3557) * [UX] Error out for null env var * format * Fix examples for env, including HF_TOKEN and WANDB_API_KEY * fix * Add test * format * fix * type * fix * remove print * add doc * fix comment * minor fix --- .../running-jobs/environment-variables.rst | 14 ++++++++++++++ examples/serve/llama2/llama2.yaml | 2 +- examples/spot_pipeline/bert_qa_train_eval.yaml | 4 ++-- llm/axolotl/axolotl-spot.yaml | 4 ++-- llm/axolotl/axolotl.yaml | 2 +- llm/dbrx/README.md | 2 +- llm/dbrx/dbrx.yaml | 2 +- llm/falcon/falcon.yaml | 4 ++-- llm/gemma/serve.yaml | 2 +- llm/llama-2/README.md | 2 +- llm/llama-2/chatbot-hf.yaml | 2 +- llm/llama-2/chatbot-meta.yaml | 2 +- llm/llama-3/README.md | 2 +- llm/llama-3/llama3.yaml | 2 +- llm/sglang/llama2.yaml | 2 +- llm/vicuna-llama-2/README.md | 2 +- llm/vicuna-llama-2/train.yaml | 6 +++--- llm/vicuna/train.yaml | 14 +++++++------- llm/vllm/serve-openai-api.yaml | 2 +- llm/vllm/service.yaml | 2 +- sky/task.py | 18 ++++++++++++++++-- sky/utils/schemas.py | 2 +- tests/test_yaml_parser.py | 12 ++++++++++++ 23 files changed, 73 insertions(+), 33 deletions(-) diff --git a/docs/source/running-jobs/environment-variables.rst b/docs/source/running-jobs/environment-variables.rst index 16502f70818..2f3427c1bf5 100644 --- a/docs/source/running-jobs/environment-variables.rst +++ b/docs/source/running-jobs/environment-variables.rst @@ -12,6 +12,20 @@ You can specify environment variables to be made available to a task in two ways - The ``envs`` field (dict) in a :ref:`task YAML ` - The ``--env`` flag in the ``sky launch/exec`` :ref:`CLI ` (takes precedence over the above) +.. tip:: + + If an environment variable is required to be specified with `--env` during + ``sky launch/exec``, you can set it to ``null`` in task YAML to raise an + error when it is forgotten to be specified. For example, the ``WANDB_API_KEY`` + and ``HF_TOKEN`` in the following task YAML: + + .. code-block:: yaml + + envs: + WANDB_API_KEY: + HF_TOKEN: null + MYVAR: val + The ``file_mounts``, ``setup``, and ``run`` sections of a task YAML can access the variables via the ``${MYVAR}`` syntax. Using in ``file_mounts`` diff --git a/examples/serve/llama2/llama2.yaml b/examples/serve/llama2/llama2.yaml index 5eaaea449d0..42c82ea0cc9 100644 --- a/examples/serve/llama2/llama2.yaml +++ b/examples/serve/llama2/llama2.yaml @@ -25,7 +25,7 @@ resources: envs: MODEL_SIZE: 7 - HF_TOKEN: # TODO: Replace with huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. setup: | conda activate chatbot diff --git a/examples/spot_pipeline/bert_qa_train_eval.yaml b/examples/spot_pipeline/bert_qa_train_eval.yaml index 32fb526ca91..62bd34c3b76 100644 --- a/examples/spot_pipeline/bert_qa_train_eval.yaml +++ b/examples/spot_pipeline/bert_qa_train_eval.yaml @@ -42,7 +42,7 @@ run: | echo Model saved to /checkpoint/bert_qa/$SKYPILOT_TASK_ID envs: - WANDB_API_KEY: # NOTE: Fill in your wandb key + WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass. --- @@ -84,4 +84,4 @@ run: | --save_steps 1000 envs: - WANDB_API_KEY: # NOTE: Fill in your wandb key + WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass. diff --git a/llm/axolotl/axolotl-spot.yaml b/llm/axolotl/axolotl-spot.yaml index b6c81b742c9..942f4ccc4ba 100644 --- a/llm/axolotl/axolotl-spot.yaml +++ b/llm/axolotl/axolotl-spot.yaml @@ -38,8 +38,8 @@ run: | accelerate launch -m axolotl.cli.train /sky_workdir/qlora-checkpoint.yaml envs: - HF_TOKEN: # TODO: Replace with huggingface token - BUCKET: + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. + BUCKET: # TODO: Fill with your unique bucket name, or use --env to pass. diff --git a/llm/axolotl/axolotl.yaml b/llm/axolotl/axolotl.yaml index d9cfd91aa6d..9cec1d1f331 100644 --- a/llm/axolotl/axolotl.yaml +++ b/llm/axolotl/axolotl.yaml @@ -26,7 +26,7 @@ run: | accelerate launch -m axolotl.cli.train /sky_workdir/qlora.yaml envs: - HF_TOKEN: # TODO: Replace with huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. diff --git a/llm/dbrx/README.md b/llm/dbrx/README.md index e0ad216e92c..3011af9d4e6 100644 --- a/llm/dbrx/README.md +++ b/llm/dbrx/README.md @@ -22,7 +22,7 @@ In this recipe, you will serve `databricks/dbrx-instruct` on your own infra -- ```yaml envs: MODEL_NAME: databricks/dbrx-instruct - HF_TOKEN: # Change to your own huggingface token, or use --env to pass. + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. service: replicas: 2 diff --git a/llm/dbrx/dbrx.yaml b/llm/dbrx/dbrx.yaml index ffa777ab86d..0c9abd06d30 100644 --- a/llm/dbrx/dbrx.yaml +++ b/llm/dbrx/dbrx.yaml @@ -31,7 +31,7 @@ envs: MODEL_NAME: databricks/dbrx-instruct - HF_TOKEN: # Change to your own huggingface token, or use --env to pass. + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. service: replicas: 2 diff --git a/llm/falcon/falcon.yaml b/llm/falcon/falcon.yaml index 256d936d61b..b752db5256b 100644 --- a/llm/falcon/falcon.yaml +++ b/llm/falcon/falcon.yaml @@ -7,7 +7,7 @@ workdir: . envs: MODEL_NAME: tiiuae/falcon-7b # [ybelkada/falcon-7b-sharded-bf16, tiiuae/falcon-7b, tiiuae/falcon-40b] - WANDB_API_KEY: $WANDB_KEY # Change to your own wandb key + WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass. OUTPUT_BUCKET_NAME: # Set a unique name for the bucket which will store model weights file_mounts: @@ -39,4 +39,4 @@ run: | --bnb_4bit_compute_dtype bfloat16 \ --max_steps 500 \ --dataset_name timdettmers/openassistant-guanaco \ - --output_dir /results \ No newline at end of file + --output_dir /results diff --git a/llm/gemma/serve.yaml b/llm/gemma/serve.yaml index 73f5b9c2b5d..4c5a2c984c5 100644 --- a/llm/gemma/serve.yaml +++ b/llm/gemma/serve.yaml @@ -17,7 +17,7 @@ service: envs: MODEL_NAME: google/gemma-7b-it - HF_TOKEN: # TODO: Replace with huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. resources: accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} diff --git a/llm/llama-2/README.md b/llm/llama-2/README.md index 7b20ea4aed7..d8f8151572e 100644 --- a/llm/llama-2/README.md +++ b/llm/llama-2/README.md @@ -33,7 +33,7 @@ Fill the access token in the [chatbot-hf.yaml](https://github.com/skypilot-org/s ```yaml envs: MODEL_SIZE: 7 - HF_TOKEN: + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. ``` diff --git a/llm/llama-2/chatbot-hf.yaml b/llm/llama-2/chatbot-hf.yaml index 992c01346e6..ee9d0281296 100644 --- a/llm/llama-2/chatbot-hf.yaml +++ b/llm/llama-2/chatbot-hf.yaml @@ -6,7 +6,7 @@ resources: envs: MODEL_SIZE: 7 - HF_TOKEN: # TODO: Replace with huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. setup: | conda activate chatbot diff --git a/llm/llama-2/chatbot-meta.yaml b/llm/llama-2/chatbot-meta.yaml index a0481fe760f..733a2a867d2 100644 --- a/llm/llama-2/chatbot-meta.yaml +++ b/llm/llama-2/chatbot-meta.yaml @@ -6,7 +6,7 @@ resources: envs: MODEL_SIZE: 7 - HF_TOKEN: # TODO: Replace with huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. setup: | set -ex diff --git a/llm/llama-3/README.md b/llm/llama-3/README.md index 7b3b6cb56e5..decff6054bf 100644 --- a/llm/llama-3/README.md +++ b/llm/llama-3/README.md @@ -44,7 +44,7 @@ envs: MODEL_NAME: meta-llama/Meta-Llama-3-70B-Instruct # MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct - HF_TOKEN: # Change to your own huggingface token, or use --env to pass. + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. service: replicas: 2 diff --git a/llm/llama-3/llama3.yaml b/llm/llama-3/llama3.yaml index 0974d4db51b..1e9b236efd4 100644 --- a/llm/llama-3/llama3.yaml +++ b/llm/llama-3/llama3.yaml @@ -59,7 +59,7 @@ envs: MODEL_NAME: meta-llama/Meta-Llama-3-70B-Instruct # MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct - HF_TOKEN: # Change to your own huggingface token, or use --env to pass. + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. service: replicas: 2 diff --git a/llm/sglang/llama2.yaml b/llm/sglang/llama2.yaml index 08427ab2001..8b58c4365d6 100644 --- a/llm/sglang/llama2.yaml +++ b/llm/sglang/llama2.yaml @@ -6,7 +6,7 @@ service: envs: MODEL_NAME: meta-llama/Llama-2-7b-chat-hf - HF_TOKEN: # Change to your own huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. resources: accelerators: {L4:1, A10G:1, A10:1, A100:1, A100-80GB:1} diff --git a/llm/vicuna-llama-2/README.md b/llm/vicuna-llama-2/README.md index 0fc5da6c4ba..899792c299d 100644 --- a/llm/vicuna-llama-2/README.md +++ b/llm/vicuna-llama-2/README.md @@ -31,7 +31,7 @@ cd skypilot/llm/vicuna-llama-2 Paste the access token into [train.yaml](https://github.com/skypilot-org/skypilot/tree/master/llm/vicuna-llama-2/train.yaml): ```yaml envs: - HF_TOKEN: # Change to your own huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. ``` ## Train your own Vicuna on Llama-2 diff --git a/llm/vicuna-llama-2/train.yaml b/llm/vicuna-llama-2/train.yaml index e23d5797e76..8d35c2dff85 100644 --- a/llm/vicuna-llama-2/train.yaml +++ b/llm/vicuna-llama-2/train.yaml @@ -1,7 +1,7 @@ envs: - HF_TOKEN: # Change to your own huggingface token - ARTIFACT_BUCKET_NAME: YOUR_OWN_BUCKET_NAME # Change to your own bucket name - WANDB_API_KEY: "" # Change to your own wandb api key + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. + ARTIFACT_BUCKET_NAME: # TODO: Fill with your unique bucket name, or use --env to pass. + WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass. MODEL_SIZE: 7 USE_XFORMERS: 1 diff --git a/llm/vicuna/train.yaml b/llm/vicuna/train.yaml index c577561e858..a2121aaf8fd 100644 --- a/llm/vicuna/train.yaml +++ b/llm/vicuna/train.yaml @@ -1,3 +1,10 @@ +envs: + MODEL_SIZE: 7 + SEQ_LEN: 2048 + GC_SCALE: 4 + USE_FLASH_ATTN: 0 + WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass. + resources: accelerators: A100-80GB:8 disk_size: 1000 @@ -109,10 +116,3 @@ run: | gsutil -m rsync -r -x 'checkpoint-*' $LOCAL_CKPT_PATH/ $CKPT_PATH/ exit $returncode - -envs: - MODEL_SIZE: 7 - SEQ_LEN: 2048 - GC_SCALE: 4 - USE_FLASH_ATTN: 0 - WANDB_API_KEY: "" diff --git a/llm/vllm/serve-openai-api.yaml b/llm/vllm/serve-openai-api.yaml index 9ddf7b280ba..a68f476edc7 100644 --- a/llm/vllm/serve-openai-api.yaml +++ b/llm/vllm/serve-openai-api.yaml @@ -1,6 +1,6 @@ envs: MODEL_NAME: meta-llama/Llama-2-7b-chat-hf - HF_TOKEN: # Change to your own huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. resources: accelerators: {L4:1, A10G:1, A10:1, A100:1, A100-80GB:1} diff --git a/llm/vllm/service.yaml b/llm/vllm/service.yaml index 335f8a50650..1e5d92a60e5 100644 --- a/llm/vllm/service.yaml +++ b/llm/vllm/service.yaml @@ -9,7 +9,7 @@ service: # Fields below are the same with `serve-openai-api.yaml`. envs: MODEL_NAME: meta-llama/Llama-2-7b-chat-hf - HF_TOKEN: # Change to your own huggingface token + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. resources: accelerators: {L4:1, A10G:1, A10:1, A100:1, A100-80GB:1} diff --git a/sky/task.py b/sky/task.py index b6a71581a15..3dd254838f0 100644 --- a/sky/task.py +++ b/sky/task.py @@ -353,8 +353,13 @@ def from_yaml_config( # as int causing validate_schema() to fail. envs = config.get('envs') if envs is not None and isinstance(envs, dict): - config['envs'] = {str(k): str(v) for k, v in envs.items()} - + new_envs: Dict[str, Optional[str]] = {} + for k, v in envs.items(): + if v is not None: + new_envs[str(k)] = str(v) + else: + new_envs[str(k)] = None + config['envs'] = new_envs common_utils.validate_schema(config, schemas.get_task_schema(), 'Invalid task YAML: ') if env_overrides is not None: @@ -368,6 +373,15 @@ def from_yaml_config( new_envs.update(env_overrides) config['envs'] = new_envs + for k, v in config.get('envs', {}).items(): + if v is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Environment variable {k!r} is None. Please set a ' + 'value for it in task YAML or with --env flag. ' + f'To set it to be empty, use an empty string ({k}: "" ' + f'in task YAML or --env {k}="" in CLI).') + # Fill in any Task.envs into file_mounts (src/dst paths, storage # name/source). if config.get('file_mounts') is not None: diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index c50e15185a3..878fe67178e 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -402,7 +402,7 @@ def get_task_schema(): 'patternProperties': { # Checks env keys are valid env var names. '^[a-zA-Z_][a-zA-Z0-9_]*$': { - 'type': 'string' + 'type': ['string', 'null'] } }, 'additionalProperties': False, diff --git a/tests/test_yaml_parser.py b/tests/test_yaml_parser.py index 0338084925e..1453cfe1620 100644 --- a/tests/test_yaml_parser.py +++ b/tests/test_yaml_parser.py @@ -134,3 +134,15 @@ def test_invalid_envs_type(tmp_path): with pytest.raises(ValueError) as e: Task.from_yaml(config_path) assert 'is not of type \'dict\'' in e.value.args[0] + + +def test_invalid_empty_envs(tmp_path): + config_path = _create_config_file( + textwrap.dedent(f"""\ + envs: + env_key1: abc + env_key2: + """), tmp_path) + with pytest.raises(ValueError) as e: + Task.from_yaml(config_path) + assert 'Environment variable \'env_key2\' is None.' in e.value.args[0] From 4bf71d5007e4799f95af02eb4be9d3fccb038a07 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 16 May 2024 12:52:50 -0700 Subject: [PATCH 06/20] [CLI] Add alias for CLIs for convenience and consistency (#3011) * Add alias for CLIs for convenience and consistency * remove auto_exec for now * remove detach-setup * format * format * Fix * Add test for exec -c * address comments --- sky/cli.py | 52 ++++++++++++++++++++++++++++++++++++--------- tests/test_smoke.py | 9 ++++++++ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sky/cli.py b/sky/cli.py index 365468f0bba..9ad1bcfcb9c 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -687,7 +687,7 @@ def _pop_and_ignore_fields_in_override_params( def _make_task_or_dag_from_entrypoint_with_overrides( - entrypoint: List[str], + entrypoint: Tuple[str, ...], *, entrypoint_name: str = 'Task', name: Optional[str] = None, @@ -1028,7 +1028,7 @@ def cli(): 'the same data on the boot disk as an existing cluster.')) @usage_lib.entrypoint def launch( - entrypoint: List[str], + entrypoint: Tuple[str, ...], cluster: Optional[str], dryrun: bool, detach_setup: bool, @@ -1130,11 +1130,19 @@ def launch( @cli.command(cls=_DocumentedCodeCommand) @click.argument('cluster', - required=True, + required=False, type=str, **_get_shell_complete_args(_complete_cluster_name)) +@click.option( + '--cluster', + '-c', + 'cluster_option', + hidden=True, + type=str, + help='This is the same as the positional argument, just for consistency.', + **_get_shell_complete_args(_complete_cluster_name)) @click.argument('entrypoint', - required=True, + required=False, type=str, nargs=-1, **_get_shell_complete_args(_complete_file_name)) @@ -1149,8 +1157,9 @@ def launch( @usage_lib.entrypoint # pylint: disable=redefined-builtin def exec( - cluster: str, - entrypoint: List[str], + cluster: Optional[str], + cluster_option: Optional[str], + entrypoint: Tuple[str, ...], detach_run: bool, name: Optional[str], cloud: Optional[str], @@ -1228,6 +1237,17 @@ def exec( sky exec mycluster --env WANDB_API_KEY python train_gpu.py """ + if cluster_option is None and cluster is None: + raise click.UsageError('Missing argument \'[CLUSTER]\' and ' + '\'[ENTRYPOINT]...\'') + if cluster_option is not None: + if cluster is not None: + entrypoint = (cluster,) + entrypoint + cluster = cluster_option + if not entrypoint: + raise click.UsageError('Missing argument \'[ENTRYPOINT]...\'') + assert cluster is not None, (cluster, cluster_option, entrypoint) + env = _merge_env_vars(env_file, env) controller_utils.check_cluster_name_not_controller( cluster, operation_str='Executing task on it') @@ -3284,6 +3304,12 @@ def jobs(): **_get_shell_complete_args(_complete_file_name)) # TODO(zhwu): Add --dryrun option to test the launch command. @_add_click_options(_TASK_OPTIONS_WITH_NAME + _EXTRA_RESOURCES_OPTIONS) +@click.option('--cluster', + '-c', + default=None, + type=str, + hidden=True, + help=('Alias for --name, the name of the spot job.')) @click.option('--job-recovery', default=None, type=str, @@ -3316,8 +3342,9 @@ def jobs(): @timeline.event @usage_lib.entrypoint def jobs_launch( - entrypoint: List[str], + entrypoint: Tuple[str, ...], name: Optional[str], + cluster: Optional[str], workdir: Optional[str], cloud: Optional[str], region: Optional[str], @@ -3353,6 +3380,11 @@ def jobs_launch( sky jobs launch 'echo hello!' """ + if cluster is not None: + if name is not None and name != cluster: + raise click.UsageError('Cannot specify both --name and --cluster. ' + 'Use one of the flags as they are alias.') + name = cluster env = _merge_env_vars(env_file, env) task_or_dag = _make_task_or_dag_from_entrypoint_with_overrides( entrypoint, @@ -3697,7 +3729,7 @@ def serve(): def _generate_task_with_service( service_name: str, - service_yaml_args: List[str], + service_yaml_args: Tuple[str, ...], workdir: Optional[str], cloud: Optional[str], region: Optional[str], @@ -3802,7 +3834,7 @@ def _generate_task_with_service( @timeline.event @usage_lib.entrypoint def serve_up( - service_yaml: List[str], + service_yaml: Tuple[str, ...], service_name: Optional[str], workdir: Optional[str], cloud: Optional[str], @@ -3920,7 +3952,7 @@ def serve_up( @usage_lib.entrypoint def serve_update( service_name: str, - service_yaml: List[str], + service_yaml: Tuple[str, ...], workdir: Optional[str], cloud: Optional[str], region: Optional[str], diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 284ac5aa471..73ad2c0f46a 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -291,6 +291,15 @@ def test_minimal(generic_cloud: str): f'sky logs {name} 4 --status', # Ensure the job succeeded. f'sky exec {name} \'echo "$SKYPILOT_CLUSTER_INFO" | jq .cloud | grep -i {generic_cloud}\'', f'sky logs {name} 5 --status', # Ensure the job succeeded. + # Test '-c' for exec + f'sky exec -c {name} echo', + f'sky logs {name} 6 --status', + f'sky exec echo -c {name}', + f'sky logs {name} 7 --status', + f'sky exec -c {name} echo hi test', + f'sky logs {name} 8 | grep "hi test"', + f'sky exec {name} && exit 1 || true', + f'sky exec -c {name} && exit 1 || true', ], f'sky down -y {name}', _get_timeout(generic_cloud), From e134a35805048b373303dc55f5608e37aa285aa6 Mon Sep 17 00:00:00 2001 From: Tian Xia Date: Fri, 17 May 2024 11:11:39 +0800 Subject: [PATCH 07/20] [Serve] Change back from autodown to autostop (#3535) * fix * skip autostop for k8s * comments * fix skip autostop * fix --- sky/backends/cloud_vm_ray_backend.py | 9 ++++----- sky/serve/core.py | 5 ++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index a0f746a7098..6d2447fe89b 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2011,10 +2011,10 @@ def provision_with_retries( cloud_user = to_provision.cloud.get_current_user_identity() requested_features = self._requested_features.copy() - # Skip stop feature for Kubernetes jobs controller. + # Skip stop feature for Kubernetes controllers. if (isinstance(to_provision.cloud, clouds.Kubernetes) and controller_utils.Controllers.from_name(cluster_name) - == controller_utils.Controllers.JOBS_CONTROLLER): + is not None): assert (clouds.CloudImplementationFeatures.STOP in requested_features), requested_features requested_features.remove( @@ -4152,11 +4152,10 @@ def set_autostop(self, # Skip auto-stop for Kubernetes clusters. if (isinstance(handle.launched_resources.cloud, clouds.Kubernetes) and not down and idle_minutes_to_autostop >= 0): - # We should hit this code path only for the jobs controller on + # We should hit this code path only for the controllers on # Kubernetes clusters. assert (controller_utils.Controllers.from_name( - handle.cluster_name) == controller_utils.Controllers. - JOBS_CONTROLLER), handle.cluster_name + handle.cluster_name) is not None), handle.cluster_name logger.info('Auto-stop is not supported for Kubernetes ' 'clusters. Skipping.') return diff --git a/sky/serve/core.py b/sky/serve/core.py index 09b6c9b5151..4f15413cf7f 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -191,14 +191,13 @@ def up( # whether the service is already running. If the id is the same # with the current job id, we know the service is up and running # for the first time; otherwise it is a name conflict. - idle_minutes_to_autodown = constants.CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP + idle_minutes_to_autostop = constants.CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP controller_job_id, controller_handle = sky.launch( task=controller_task, stream_logs=False, cluster_name=controller_name, detach_run=True, - idle_minutes_to_autostop=idle_minutes_to_autodown, - down=True, + idle_minutes_to_autostop=idle_minutes_to_autostop, retry_until_up=True, _disable_controller_check=True, ) From d09b6dcd0998a91000e2e95ff1a1ce5b4010d086 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 17 May 2024 00:10:53 -0700 Subject: [PATCH 08/20] [Usage] Collect failure message (#3560) * Collect failure message * fix comments * Fix unit tests * readability --- sky/cli.py | 49 +++++++++++++++++++----------------- tests/test_jobs_and_serve.py | 10 +++++--- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/sky/cli.py b/sky/cli.py index 9ad1bcfcb9c..c26b434b025 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -565,9 +565,10 @@ def _launch_with_confirm( raise_if_no_cloud_access=True) except exceptions.NoCloudAccessError as e: # Catch the exception where the public cloud is not enabled, and - # only print the error message without the error type. - click.secho(e, fg='yellow') - sys.exit(1) + # make it yellow for better visibility. + with ux_utils.print_exception_no_traceback(): + raise RuntimeError(f'{colorama.Fore.YELLOW}{e}' + f'{colorama.Style.RESET_ALL}') from e dag = sky.optimize(dag) task = dag.tasks[0] @@ -2094,16 +2095,16 @@ def cancel(cluster: str, all: bool, jobs: List[int], yes: bool): # pylint: disa try: core.cancel(cluster, all=all, job_ids=job_ids_to_cancel) - except exceptions.NotSupportedError: + except exceptions.NotSupportedError as e: controller = controller_utils.Controllers.from_name(cluster) assert controller is not None, cluster - click.echo(controller.value.decline_cancel_hint) - sys.exit(1) + with ux_utils.print_exception_no_traceback(): + raise click.UsageError(controller.value.decline_cancel_hint) from e except ValueError as e: raise click.UsageError(str(e)) - except exceptions.ClusterNotUpError as e: - click.echo(str(e)) - sys.exit(1) + except exceptions.ClusterNotUpError: + with ux_utils.print_exception_no_traceback(): + raise @cli.command(cls=_DocumentedCodeCommand) @@ -3638,9 +3639,9 @@ def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool, follow=follow) else: managed_jobs.tail_logs(name=name, job_id=job_id, follow=follow) - except exceptions.ClusterNotUpError as e: - click.echo(e) - sys.exit(1) + except exceptions.ClusterNotUpError: + with ux_utils.print_exception_no_traceback(): + raise @jobs.command('dashboard', cls=_DocumentedCodeCommand) @@ -4298,9 +4299,9 @@ def serve_logs( target=target_component, replica_id=replica_id, follow=follow) - except exceptions.ClusterNotUpError as e: - click.echo(e) - sys.exit(1) + except exceptions.ClusterNotUpError: + with ux_utils.print_exception_no_traceback(): + raise # ============================== @@ -4987,10 +4988,11 @@ def local_up(gpus: bool): f'exists.{style.RESET_ALL}\n' 'If you want to delete it instead, run: sky local down') else: - click.echo('Failed to create local cluster. ' - f'Full log: {log_path}' - f'\nError: {style.BRIGHT}{stderr}{style.RESET_ALL}') - sys.exit(1) + with ux_utils.print_exception_no_traceback(): + raise RuntimeError( + 'Failed to create local cluster. ' + f'Full log: {log_path}' + f'\nError: {style.BRIGHT}{stderr}{style.RESET_ALL}') # Run sky check with rich_utils.safe_status('[bold cyan]Running sky check...'): sky_check.check(quiet=True) @@ -5087,10 +5089,11 @@ def local_down(): elif returncode == 100: click.echo('\nLocal cluster does not exist.') else: - click.echo('Failed to create local cluster. ' - f'Stdout: {stdout}' - f'\nError: {style.BRIGHT}{stderr}{style.RESET_ALL}') - sys.exit(1) + with ux_utils.print_exception_no_traceback(): + raise RuntimeError( + 'Failed to create local cluster. ' + f'Stdout: {stdout}' + f'\nError: {style.BRIGHT}{stderr}{style.RESET_ALL}') if cluster_removed: # Run sky check with rich_utils.safe_status('[bold cyan]Running sky check...'): diff --git a/tests/test_jobs_and_serve.py b/tests/test_jobs_and_serve.py index 61d8a9f0a98..a599fb7ba88 100644 --- a/tests/test_jobs_and_serve.py +++ b/tests/test_jobs_and_serve.py @@ -254,7 +254,7 @@ def test_cancel_on_jobs_controller(self, _mock_cluster_state, _mock_jobs_controller): cli_runner = cli_testing.CliRunner() result = cli_runner.invoke(cli.cancel, [jobs.JOB_CONTROLLER_NAME, '-a']) - assert result.exit_code == 1 + assert result.exit_code == click.UsageError.exit_code assert 'Cancelling the jobs controller\'s jobs is not allowed.' in str( result.output) @@ -272,7 +272,8 @@ def test_logs(self, _mock_db_conn): result = cli_runner.invoke(cli.jobs_logs, ['1']) assert result.exit_code == 1 assert controller_utils.Controllers.JOBS_CONTROLLER.value.default_hint_if_non_existent in str( - result.output), (result.exception, result.output, result.exc_info) + result.exception), (result.exception, result.output, + result.exc_info) @pytest.mark.timeout(60) def test_queue(self, _mock_db_conn): @@ -397,7 +398,7 @@ def test_cancel_on_serve_controller(self, _mock_cluster_state, cli_runner = cli_testing.CliRunner() result = cli_runner.invoke(cli.cancel, [serve.SKY_SERVE_CONTROLLER_NAME, '-a']) - assert result.exit_code == 1 + assert result.exit_code == click.UsageError.exit_code assert 'Cancelling the sky serve controller\'s jobs is not allowed.' in str( result.output) @@ -416,7 +417,8 @@ def test_logs(self, _mock_db_conn): cli_runner = cli_testing.CliRunner() result = cli_runner.invoke(cli.serve_logs, ['test', '--controller']) assert controller_utils.Controllers.SKY_SERVE_CONTROLLER.value.default_hint_if_non_existent in str( - result.output), (result.exception, result.output, result.exc_info) + result.exception), (result.exception, result.output, + result.exc_info) @pytest.mark.timeout(60) def test_status(self, _mock_db_conn): From 6968be534782a768e8440ba967f0a30ceb46c042 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Fri, 17 May 2024 01:10:17 -0700 Subject: [PATCH 09/20] [core] Add `allowed_clouds` to config to check only specific clouds (#3556) * candidate_clouds * Working allowed_clouds * Working allowed_clouds * comments * lint * change skipped clouds to disabled clouds * lint --- docs/source/reference/config.rst | 13 ++++++ sky/adaptors/cloudflare.py | 1 + sky/check.py | 76 +++++++++++++++++++++++++------- sky/utils/schemas.py | 12 +++++ 4 files changed, 85 insertions(+), 17 deletions(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 53c983edfad..dce0ce1f643 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -27,6 +27,19 @@ Available fields and semantics: cpus: 4+ # number of vCPUs, max concurrent spot jobs = 2 * cpus disk_size: 100 + # Allow list for clouds to be used in `sky check` + # + # This field is used to restrict the clouds that SkyPilot will check and use + # when running `sky check`. Any cloud already enabled but not specified here + # will be disabled on the next `sky check` run. + # If this field is not set, SkyPilot will check and use all supported clouds. + # + # Default: null (use all supported clouds). + allowed_clouds: + - aws + - gcp + - kubernetes + # Advanced AWS configurations (optional). # Apply to all new instances but not existing ones. aws: diff --git a/sky/adaptors/cloudflare.py b/sky/adaptors/cloudflare.py index 2a49dc6fff0..864248614f3 100644 --- a/sky/adaptors/cloudflare.py +++ b/sky/adaptors/cloudflare.py @@ -23,6 +23,7 @@ R2_PROFILE_NAME = 'r2' _INDENT_PREFIX = ' ' NAME = 'Cloudflare' +SKY_CHECK_NAME = 'Cloudflare (for R2 object store)' @contextlib.contextmanager diff --git a/sky/check.py b/sky/check.py index d90fdffefb7..f4ecd5a8b18 100644 --- a/sky/check.py +++ b/sky/check.py @@ -1,7 +1,7 @@ """Credential checks: check cloud credentials and enable clouds.""" import traceback from types import ModuleType -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import click import colorama @@ -10,6 +10,7 @@ from sky import clouds as sky_clouds from sky import exceptions from sky import global_user_state +from sky import skypilot_config from sky.adaptors import cloudflare from sky.utils import ux_utils @@ -52,20 +53,42 @@ def check_one_cloud( disabled_clouds.append(cloud_repr) echo(f' Reason: {reason}') + def get_cloud_tuple( + cloud_name: str) -> Tuple[str, Union[sky_clouds.Cloud, ModuleType]]: + # Validates cloud_name and returns a tuple of the cloud's name and + # the cloud object. Includes special handling for Cloudflare. + if cloud_name.lower().startswith('cloudflare'): + return cloudflare.SKY_CHECK_NAME, cloudflare + else: + cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud_name) + assert cloud_obj is not None, f'Cloud {cloud_name!r} not found' + return repr(cloud_obj), cloud_obj + + def get_all_clouds(): + return tuple([repr(c) for c in sky_clouds.CLOUD_REGISTRY.values()] + + [cloudflare.SKY_CHECK_NAME]) + if clouds is not None: - clouds_to_check: List[Tuple[str, Any]] = [] - for cloud in clouds: - if cloud.lower() == 'cloudflare': - clouds_to_check.append( - ('Cloudflare, for R2 object store', cloudflare)) - else: - cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud) - assert cloud_obj is not None, f'Cloud {cloud!r} not found' - clouds_to_check.append((repr(cloud_obj), cloud_obj)) + cloud_list = clouds else: - clouds_to_check = [(repr(cloud_obj), cloud_obj) - for cloud_obj in sky_clouds.CLOUD_REGISTRY.values()] - clouds_to_check.append(('Cloudflare, for R2 object store', cloudflare)) + cloud_list = get_all_clouds() + clouds_to_check = [get_cloud_tuple(c) for c in cloud_list] + + # Use allowed_clouds from config if it exists, otherwise check all clouds. + # Also validate names with get_cloud_tuple. + config_allowed_cloud_names = [ + get_cloud_tuple(c)[0] for c in skypilot_config.get_nested( + ['allowed_clouds'], get_all_clouds()) + ] + # Use disallowed_cloud_names for logging the clouds that will be disabled + # because they are not included in allowed_clouds in config.yaml. + disallowed_cloud_names = [ + c for c in get_all_clouds() if c not in config_allowed_cloud_names + ] + # Check only the clouds which are allowed in the config. + clouds_to_check = [ + c for c in clouds_to_check if c[0] in config_allowed_cloud_names + ] for cloud_tuple in sorted(clouds_to_check): check_one_cloud(cloud_tuple) @@ -79,16 +102,30 @@ def check_one_cloud( disabled_clouds_set = { cloud for cloud in disabled_clouds if not cloud.startswith('Cloudflare') } + config_allowed_clouds_set = { + cloud for cloud in config_allowed_cloud_names + if not cloud.startswith('Cloudflare') + } previously_enabled_clouds_set = { repr(cloud) for cloud in global_user_state.get_cached_enabled_clouds() } - # Determine the set of enabled clouds: previously enabled clouds + newly - # enabled clouds - newly disabled clouds. - all_enabled_clouds = ((previously_enabled_clouds_set | enabled_clouds_set) - - disabled_clouds_set) + # Determine the set of enabled clouds: (previously enabled clouds + newly + # enabled clouds - newly disabled clouds) intersected with + # config_allowed_clouds, if specified in config.yaml. + # This means that if a cloud is already enabled and is not included in + # allowed_clouds in config.yaml, it will be disabled. + all_enabled_clouds = (config_allowed_clouds_set & ( + (previously_enabled_clouds_set | enabled_clouds_set) - + disabled_clouds_set)) global_user_state.set_enabled_clouds(list(all_enabled_clouds)) + disallowed_clouds_hint = None + if disallowed_cloud_names: + disallowed_clouds_hint = ( + '\nNote: The following clouds were disabled because they were not ' + 'included in allowed_clouds in ~/.sky/config.yaml: ' + f'{", ".join([c for c in disallowed_cloud_names])}') if len(all_enabled_clouds) == 0: echo( click.style( @@ -96,6 +133,8 @@ def check_one_cloud( 'task. Run `sky check` for more info.', fg='red', bold=True)) + if disallowed_clouds_hint: + echo(click.style(disallowed_clouds_hint, dim=True)) raise SystemExit() else: clouds_arg = (' ' + @@ -109,6 +148,9 @@ def check_one_cloud( 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html', # pylint: disable=line-too-long dim=True)) + if disallowed_clouds_hint: + echo(click.style(disallowed_clouds_hint, dim=True)) + # Pretty print for UX. if not quiet: enabled_clouds_str = '\n :heavy_check_mark: '.join( diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 878fe67178e..42e0da96211 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -587,6 +587,7 @@ def get_default_remote_identity(cloud: str) -> str: def get_config_schema(): # pylint: disable=import-outside-toplevel + from sky.clouds import service_catalog from sky.utils import kubernetes_enums resources_schema = { @@ -722,6 +723,16 @@ def get_config_schema(): }, } + allowed_clouds = { + # A list of cloud names that are allowed to be used + 'type': 'array', + 'items': { + 'type': 'string', + 'case_insensitive_enum': + (list(service_catalog.ALL_CLOUDS) + ['cloudflare']) + } + } + for cloud, config in cloud_configs.items(): if cloud == 'aws': config['properties'].update(_REMOTE_IDENTITY_SCHEMA_AWS) @@ -738,6 +749,7 @@ def get_config_schema(): 'jobs': controller_resources_schema, 'spot': controller_resources_schema, 'serve': controller_resources_schema, + 'allowed_clouds': allowed_clouds, **cloud_configs, }, # Avoid spot and jobs being present at the same time. From ed053c1d45017197977600f5f7f14c07f04ba0b5 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Fri, 17 May 2024 22:10:01 -0700 Subject: [PATCH 10/20] [k8s] Check only kubernetes after `sky local up` (#3563) * check only kubernetes for `sky local up` * lint --- sky/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/cli.py b/sky/cli.py index c26b434b025..894ccc7257c 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -4995,7 +4995,7 @@ def local_up(gpus: bool): f'\nError: {style.BRIGHT}{stderr}{style.RESET_ALL}') # Run sky check with rich_utils.safe_status('[bold cyan]Running sky check...'): - sky_check.check(quiet=True) + sky_check.check(clouds=('kubernetes',), quiet=True) if cluster_created: # Prepare completion message which shows CPU and GPU count # Get number of CPUs From 19e8ed16a8cdb34f9f1963da19db1ce317798eb1 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 18 May 2024 23:26:19 -0700 Subject: [PATCH 11/20] [Docs] Clarify managed spot jobs against sky launch (#3561) * Clarify managed spot jobs against sky launch * change to table * change to table * Update docs/source/examples/managed-jobs.rst Co-authored-by: Zongheng Yang * Update docs/source/examples/managed-jobs.rst Co-authored-by: Zongheng Yang * Update docs/source/examples/managed-jobs.rst Co-authored-by: Zongheng Yang * Update docs/source/examples/managed-jobs.rst Co-authored-by: Zongheng Yang * fix --------- Co-authored-by: Zongheng Yang --- docs/source/examples/managed-jobs.rst | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/docs/source/examples/managed-jobs.rst b/docs/source/examples/managed-jobs.rst index ba449c1f087..a47b4345b9f 100644 --- a/docs/source/examples/managed-jobs.rst +++ b/docs/source/examples/managed-jobs.rst @@ -7,7 +7,7 @@ Managed Jobs This feature is great for scaling out: running a single job for long durations, or running many jobs (pipelines). -SkyPilot supports **managed jobs**, which can automatically recover from any spot preemptions or hardware failures. +SkyPilot supports **managed jobs** (:code:`sky jobs`), which can automatically recover from any spot preemptions or hardware failures. It can be used in three modes: #. :ref:`Managed Spot Jobs `: Jobs run on auto-recovering spot instances. This can **save significant costs** (e.g., up to 70\% for GPU VMs) by making preemptible spot instances useful for long-running jobs. @@ -20,9 +20,29 @@ It can be used in three modes: Managed Spot Jobs ----------------- -SkyPilot automatically finds available spot resources across regions and clouds to maximize availability. +In this mode, :code:`sky jobs launch --use-spot` is used to launch a managed spot job. SkyPilot automatically finds available spot resources across regions and clouds to maximize availability. Any spot preemptions are automatically handled by SkyPilot without user intervention. + +Quick comparison between *unmanaged spot clusters* vs. *managed spot jobs*: + +.. list-table:: + :widths: 30 18 12 35 + :header-rows: 1 + + * - Command + - Managed? + - SSH-able? + - Best for + * - :code:`sky launch --use-spot` + - Unmanaged spot cluster + - Yes + - Interactive dev on spot instances (especially for hardware with low preemption rates) + * - :code:`sky jobs launch --use-spot` + - Managed spot job (auto-recovery) + - No + - Scaling out long-running jobs (e.g., data processing, training, batch inference) + Here is an example of a BERT training job failing over different regions across AWS and GCP. .. image:: https://i.imgur.com/Vteg3fK.gif From 6a624b2fa3533496bbed3bf013ae608e834b807f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 20 May 2024 12:34:30 -0700 Subject: [PATCH 12/20] [Minor] Refactor `is_same_cloud` (#3568) * minor refactor for is_same_cloud * format * remove from azure.py --- sky/clouds/aws.py | 3 --- sky/clouds/azure.py | 3 --- sky/clouds/cloud.py | 4 ++-- sky/clouds/cudo.py | 4 ---- sky/clouds/fluidstack.py | 4 ---- sky/clouds/gcp.py | 3 --- sky/clouds/ibm.py | 3 --- sky/clouds/kubernetes.py | 3 --- sky/clouds/lambda_cloud.py | 4 ---- sky/clouds/oci.py | 4 ---- sky/clouds/paperspace.py | 4 ---- sky/clouds/runpod.py | 4 ---- sky/clouds/scp.py | 4 ---- sky/clouds/vsphere.py | 4 ---- 14 files changed, 2 insertions(+), 49 deletions(-) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index b2d76e7b7df..b2b55e14d5e 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -340,9 +340,6 @@ def get_egress_cost(self, num_gigabytes: float): cost += 0.0 return cost - def is_same_cloud(self, other: clouds.Cloud): - return isinstance(other, AWS) - @classmethod def get_default_instance_type( cls, diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index edf7eb1a060..4df1cd4a4bf 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -134,9 +134,6 @@ def get_egress_cost(self, num_gigabytes: float): cost += 0.0 return cost - def is_same_cloud(self, other): - return isinstance(other, Azure) - @classmethod def get_image_size(cls, image_id: str, region: Optional[str]) -> float: if region is None: diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 08045e28ab9..c5ff78e1c79 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -247,8 +247,8 @@ def get_egress_cost(self, num_gigabytes: float): """ raise NotImplementedError - def is_same_cloud(self, other: 'Cloud'): - raise NotImplementedError + def is_same_cloud(self, other: 'Cloud') -> bool: + return isinstance(other, self.__class__) def make_deploy_resources_variables( self, diff --git a/sky/clouds/cudo.py b/sky/clouds/cudo.py index ad7a22e6e03..7b4c13699d1 100644 --- a/sky/clouds/cudo.py +++ b/sky/clouds/cudo.py @@ -155,10 +155,6 @@ def get_egress_cost(self, num_gigabytes: float) -> float: # `return 0.0` is a good placeholder.) return 0.0 - def is_same_cloud(self, other: clouds.Cloud) -> bool: - # Returns true if the two clouds are the same cloud type. - return isinstance(other, Cudo) - @classmethod def get_default_instance_type( cls, diff --git a/sky/clouds/fluidstack.py b/sky/clouds/fluidstack.py index 4d6b7f1a2ec..d7921a3f51a 100644 --- a/sky/clouds/fluidstack.py +++ b/sky/clouds/fluidstack.py @@ -140,10 +140,6 @@ def get_egress_cost(self, num_gigabytes: float) -> float: def __repr__(self): return 'Fluidstack' - def is_same_cloud(self, other: clouds.Cloud) -> bool: - # Returns true if the two clouds are the same cloud type. - return isinstance(other, Fluidstack) - @classmethod def get_default_instance_type( cls, diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 7babf34ac52..93260533f27 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -315,9 +315,6 @@ def get_egress_cost(self, num_gigabytes: float): else: return 0.08 * num_gigabytes - def is_same_cloud(self, other): - return isinstance(other, GCP) - @classmethod def _is_machine_image(cls, image_id: str) -> bool: find_machine = re.match(r'projects/.*/.*/machineImages/.*', image_id) diff --git a/sky/clouds/ibm.py b/sky/clouds/ibm.py index 880ad212e25..86e325a437b 100644 --- a/sky/clouds/ibm.py +++ b/sky/clouds/ibm.py @@ -165,9 +165,6 @@ def get_egress_cost(self, num_gigabytes: float): num_gigabytes -= (num_gigabytes - price_threshold['threshold']) return cost - def is_same_cloud(self, other): - return isinstance(other, IBM) - def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index fcf8c2f87ac..c0b25232f84 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -140,9 +140,6 @@ def get_egress_cost(self, num_gigabytes: float) -> float: def __repr__(self): return self._REPR - def is_same_cloud(self, other: clouds.Cloud) -> bool: - return isinstance(other, Kubernetes) - @classmethod def get_port(cls, svc_name) -> int: ns = kubernetes_utils.get_current_kube_config_context_namespace() diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index 37750355a88..979b4833354 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -121,10 +121,6 @@ def get_egress_cost(self, num_gigabytes: float) -> float: def __repr__(self): return 'Lambda' - def is_same_cloud(self, other: clouds.Cloud) -> bool: - # Returns true if the two clouds are the same cloud type. - return isinstance(other, Lambda) - @classmethod def get_default_instance_type( cls, diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index 03351fc4cf6..5fb0111bf01 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -160,10 +160,6 @@ def get_egress_cost(self, num_gigabytes: float) -> float: # return 0.0 return (num_gigabytes - 10 * 1024) * 0.0085 - def is_same_cloud(self, other: clouds.Cloud) -> bool: - # Returns true if the two clouds are the same cloud type. - return isinstance(other, OCI) - @classmethod def get_default_instance_type( cls, diff --git a/sky/clouds/paperspace.py b/sky/clouds/paperspace.py index f76772ab8b7..f67a9a27176 100644 --- a/sky/clouds/paperspace.py +++ b/sky/clouds/paperspace.py @@ -147,10 +147,6 @@ def get_egress_cost(self, num_gigabytes: float) -> float: def __repr__(self): return self._REPR - def is_same_cloud(self, other: clouds.Cloud) -> bool: - # Returns true if the two clouds are the same cloud type. - return isinstance(other, Paperspace) - @classmethod def get_default_instance_type( cls, diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 0f9e5c68169..c7a24e274dd 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -140,10 +140,6 @@ def accelerators_to_hourly_cost(self, def get_egress_cost(self, num_gigabytes: float) -> float: return 0.0 - def is_same_cloud(self, other: clouds.Cloud) -> bool: - # Returns true if the two clouds are the same cloud type. - return isinstance(other, RunPod) - @classmethod def get_default_instance_type( cls, diff --git a/sky/clouds/scp.py b/sky/clouds/scp.py index 1d6cb6cf20f..6a3daf2712a 100644 --- a/sky/clouds/scp.py +++ b/sky/clouds/scp.py @@ -144,10 +144,6 @@ def accelerators_to_hourly_cost(self, def get_egress_cost(self, num_gigabytes: float) -> float: return 0.0 - def is_same_cloud(self, other: clouds.Cloud) -> bool: - # Returns true if the two clouds are the same cloud type. - return isinstance(other, SCP) - @classmethod def get_default_instance_type( cls, diff --git a/sky/clouds/vsphere.py b/sky/clouds/vsphere.py index 02a794d7d58..872b8df9d70 100644 --- a/sky/clouds/vsphere.py +++ b/sky/clouds/vsphere.py @@ -136,10 +136,6 @@ def get_egress_cost(self, num_gigabytes: float) -> float: def __repr__(self): return 'vSphere' - def is_same_cloud(self, other: clouds.Cloud) -> bool: - # Returns true if the two clouds are the same cloud type. - return isinstance(other, Vsphere) - @classmethod def get_default_instance_type( cls, From d259ddcc26b0ad79d524aa5b357c5e351f43c3f8 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 20 May 2024 14:02:16 -0700 Subject: [PATCH 13/20] [Core] Allow nonexistent cloud in candidate resources and speed up optimization (#3567) * first version * refactor * fix * revert is_same_cloud * Update sky/optimizer.py Co-authored-by: Tian Xia * address comments * check once for a dag * fix doc str * format --------- Co-authored-by: Tian Xia --- sky/check.py | 2 +- sky/optimizer.py | 195 ++++++++++++++++++++++++++++++----------------- 2 files changed, 125 insertions(+), 72 deletions(-) diff --git a/sky/check.py b/sky/check.py index f4ecd5a8b18..e8a61317d63 100644 --- a/sky/check.py +++ b/sky/check.py @@ -18,7 +18,7 @@ def check( quiet: bool = False, verbose: bool = False, - clouds: Optional[Tuple[str]] = None, + clouds: Optional[Iterable[str]] = None, ) -> None: echo = (lambda *_args, **_kwargs: None) if quiet else click.echo echo('Checking credentials to enable clouds for SkyPilot.') diff --git a/sky/optimizer.py b/sky/optimizer.py index 1cb9bc0890c..c2f1f7b223e 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -4,7 +4,7 @@ import enum import json import typing -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import colorama import numpy as np @@ -120,6 +120,8 @@ def optimize(dag: 'dag_lib.Dag', for a task. exceptions.NoCloudAccessError: if no public clouds are enabled. """ + _check_specified_clouds(dag) + # This function is effectful: mutates every node in 'dag' by setting # node.best_resources if it is None. Optimizer._add_dummy_source_sink_nodes(dag) @@ -269,7 +271,6 @@ def _estimate_nodes_cost_or_time( _fill_in_launchable_resources( task=node, blocked_resources=blocked_resources, - try_fix_with_sky_check=True, quiet=quiet)) node_to_candidate_map[node] = cloud_candidates else: @@ -992,18 +993,18 @@ def ordinal_number(n): for task_id in range(len(dag.tasks)): task = dag.tasks[task_id] if isinstance(task.resources, list): + # For ordered resources, we try the resources in the order + # specified by the user. local_task = local_dag.tasks[task_id] for resources in task.resources: # Check if there exists launchable resources local_task.set_resources(resources) - launchable_resources_map, _ , _ = \ + launchable_resources_map, _, _ = ( _fill_in_launchable_resources( - task = local_task, - blocked_resources = blocked_resources, - try_fix_with_sky_check = True, - quiet = False - ) - if len(launchable_resources_map[resources]) != 0: + task=local_task, + blocked_resources=blocked_resources, + quiet=False)) + if launchable_resources_map.get(resources, []): break local_graph = local_dag.get_graph() @@ -1043,24 +1044,25 @@ def ordinal_number(n): if has_resources_ordered: best_plan = {} # We have to manually set the best_resources for the tasks in the - # original dag, to pass the optimization results - # to the caller, as we deep copied the dag - # when the dag has nodes with ordered resources. + # original dag, to pass the optimization results to the caller, as + # we deep copied the dag when the dag has nodes with ordered + # resources. for task, resources in local_best_plan.items(): task_idx = local_dag.tasks.index(task) dag.tasks[task_idx].best_resources = resources best_plan[dag.tasks[task_idx]] = resources + + topo_order = list(nx.topological_sort(graph)) + # Get the cost of each specified resources for display purpose. + node_to_cost_map, _ = Optimizer._estimate_nodes_cost_or_time( + topo_order=topo_order, + minimize_cost=minimize_cost, + blocked_resources=blocked_resources, + quiet=True) else: best_plan = local_best_plan - - topo_order = list(nx.topological_sort(graph)) if has_resources_ordered \ - else local_topo_order - node_to_cost_map, _ = (Optimizer._estimate_nodes_cost_or_time( - topo_order=topo_order, - minimize_cost=minimize_cost, - blocked_resources=blocked_resources, - quiet=True)) if has_resources_ordered else ( - local_node_to_cost_map, local_node_to_candidate_map) + topo_order = local_topo_order + node_to_cost_map = local_node_to_cost_map if not quiet: Optimizer.print_optimized_plan(graph, topo_order, best_plan, @@ -1146,10 +1148,63 @@ def _filter_out_blocked_launchable_resources( return available_resources +def _check_specified_clouds(dag: 'dag_lib.Dag') -> None: + """Check if specified clouds are enabled in cache and refresh if needed. + + Our enabled cloud list is cached in a local database, and if a user + specified a cloud that is not enabled, we should refresh the cache for that + cloud in case the cloud access has been enabled since the last cache update. + + Args: + dag: The DAG specified by a user. + """ + enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh( + raise_if_no_cloud_access=True) + + global_disabled_clouds: Set[str] = set() + for task in dag.tasks: + # Recheck the enabled clouds if the task's requested resources are on a + # cloud that is not enabled in the cached enabled_clouds. + all_clouds_specified: Set[str] = set() + clouds_need_recheck: Set[str] = set() + for resources in task.resources: + cloud_str = str(resources.cloud) + if (resources.cloud is not None and not clouds.cloud_in_iterable( + resources.cloud, enabled_clouds)): + # Explicitly check again to update the enabled cloud list. + clouds_need_recheck.add(cloud_str) + all_clouds_specified.add(cloud_str) + + # Explicitly check again to update the enabled cloud list. + sky_check.check(quiet=True, + clouds=list(clouds_need_recheck - + global_disabled_clouds)) + enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh( + raise_if_no_cloud_access=True) + disabled_clouds = (clouds_need_recheck - + {str(c) for c in enabled_clouds}) + global_disabled_clouds.update(disabled_clouds) + if disabled_clouds: + is_or_are = 'is' if len(disabled_clouds) == 1 else 'are' + task_name = f' {task.name!r}' if task.name is not None else '' + msg = (f'Task{task_name} requires {", ".join(disabled_clouds)} ' + f'which {is_or_are} not enabled. To enable access, change ' + f'the task cloud requirement or run: {colorama.Style.BRIGHT}' + f'sky check {" ".join(disabled_clouds)}' + f'{colorama.Style.RESET_ALL}') + if all_clouds_specified == disabled_clouds: + # If all resources are specified with a disabled cloud, we + # should raise an error as no resource can satisfy the + # requirement. Otherwise, we should just skip the resource. + with ux_utils.print_exception_no_traceback(): + raise exceptions.ResourcesUnavailableError(msg) + logger.warning( + f'{colorama.Fore.YELLOW}{msg}{colorama.Style.RESET_ALL}') + + def _fill_in_launchable_resources( task: task_lib.Task, blocked_resources: Optional[Iterable[resources_lib.Resources]], - try_fix_with_sky_check: bool = True, quiet: bool = False ) -> Tuple[Dict[resources_lib.Resources, List[resources_lib.Resources]], _PerCloudCandidates, List[str]]: @@ -1161,10 +1216,15 @@ def _fill_in_launchable_resources( Resources, Dict mapping Cloud to a list of feasible Resources (for printing), Sorted list of fuzzy candidates (alternative GPU names). + Raises: + ResourcesUnavailableError: if all resources required by the task are on + a cloud that is not enabled. """ enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh( raise_if_no_cloud_access=True) - launchable = collections.defaultdict(list) + + launchable: Dict[resources_lib.Resources, List[resources_lib.Resources]] = ( + collections.defaultdict(list)) all_fuzzy_candidates = set() cloud_candidates: _PerCloudCandidates = collections.defaultdict( List[resources_lib.Resources]) @@ -1173,58 +1233,51 @@ def _fill_in_launchable_resources( for resources in task.resources: if (resources.cloud is not None and not clouds.cloud_in_iterable(resources.cloud, enabled_clouds)): - if try_fix_with_sky_check: - # Explicitly check again to update the enabled cloud list. - sky_check.check(quiet=True) - return _fill_in_launchable_resources(task, blocked_resources, - False) - with ux_utils.print_exception_no_traceback(): - raise exceptions.ResourcesUnavailableError( - f'Task requires {resources.cloud} which is ' - f'not enabled: {task}.\nTo enable access, run ' - f'{colorama.Style.BRIGHT}' - f'sky check {colorama.Style.RESET_ALL}, or change the ' - 'cloud requirement') - else: - clouds_list = ([resources.cloud] - if resources.cloud is not None else enabled_clouds) - for cloud in clouds_list: - (feasible_resources, fuzzy_candidate_list) = ( - cloud.get_feasible_launchable_resources( - resources, num_nodes=task.num_nodes)) - if len(feasible_resources) > 0: - # Assume feasible_resources is sorted by prices. - cheapest = feasible_resources[0] - # Generate region/zone-specified resources. - launchable[resources].extend( - _make_launchables_for_valid_region_zones(cheapest)) - cloud_candidates[cloud] = feasible_resources - else: - all_fuzzy_candidates.update(fuzzy_candidate_list) - if len(launchable[resources]) == 0: - clouds_str = str(clouds_list) if len(clouds_list) > 1 else str( - clouds_list[0]) - num_node_str = '' - if task.num_nodes > 1: - num_node_str = f'{task.num_nodes}x ' - if not quiet: - logger.info( - f'No resource satisfying {num_node_str}' - f'{resources.repr_with_region_zone} on {clouds_str}.') - if len(all_fuzzy_candidates) > 0: + # Skip the resources that are on a cloud that is not enabled. The + # hint has been printed in _check_specified_clouds. + launchable[resources] = [] + continue + clouds_list = ([resources.cloud] + if resources.cloud is not None else enabled_clouds) + for cloud in clouds_list: + (feasible_resources, + fuzzy_candidate_list) = cloud.get_feasible_launchable_resources( + resources, num_nodes=task.num_nodes) + if len(feasible_resources) > 0: + # Assume feasible_resources is sorted by prices. Guaranteed by + # the implementation of get_feasible_launchable_resources and + # the underlying service_catalog filtering + cheapest = feasible_resources[0] + # Generate region/zone-specified resources. + launchable[resources].extend( + _make_launchables_for_valid_region_zones(cheapest)) + cloud_candidates[cloud] = feasible_resources + else: + all_fuzzy_candidates.update(fuzzy_candidate_list) + if len(launchable[resources]) == 0: + clouds_str = str(clouds_list) if len(clouds_list) > 1 else str( + clouds_list[0]) + num_node_str = '' + if task.num_nodes > 1: + num_node_str = f'{task.num_nodes}x ' + if not quiet: + logger.info( + f'No resource satisfying {num_node_str}' + f'{resources.repr_with_region_zone} on {clouds_str}.') + if all_fuzzy_candidates: logger.info('Did you mean: ' f'{colorama.Fore.CYAN}' f'{sorted(all_fuzzy_candidates)}' f'{colorama.Style.RESET_ALL}') - else: - if resources.cpus is not None: - logger.info('Try specifying a different CPU count, ' - 'or add "+" to the end of the CPU count ' - 'to allow for larger instances.') - if resources.memory is not None: - logger.info('Try specifying a different memory size, ' - 'or add "+" to the end of the memory size ' - 'to allow for larger instances.') + else: + if resources.cpus is not None: + logger.info('Try specifying a different CPU count, ' + 'or add "+" to the end of the CPU count ' + 'to allow for larger instances.') + if resources.memory is not None: + logger.info('Try specifying a different memory size, ' + 'or add "+" to the end of the memory size ' + 'to allow for larger instances.') launchable[resources] = _filter_out_blocked_launchable_resources( launchable[resources], blocked_resources) From 7be7f6a7c2c91ebebdfdede1e4e76133bb7e28a4 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 20 May 2024 15:04:08 -0700 Subject: [PATCH 14/20] [minor] Use lower case for sky check hint in optimizer (#3571) use lower case for sky check hint --- sky/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/optimizer.py b/sky/optimizer.py index c2f1f7b223e..9c11511a38b 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -1190,7 +1190,7 @@ def _check_specified_clouds(dag: 'dag_lib.Dag') -> None: msg = (f'Task{task_name} requires {", ".join(disabled_clouds)} ' f'which {is_or_are} not enabled. To enable access, change ' f'the task cloud requirement or run: {colorama.Style.BRIGHT}' - f'sky check {" ".join(disabled_clouds)}' + f'sky check {" ".join(c.lower() for c in disabled_clouds)}' f'{colorama.Style.RESET_ALL}') if all_clouds_specified == disabled_clouds: # If all resources are specified with a disabled cloud, we From cf840dc7fdbaf59bf6ba31b302fd12f2a1dd62d3 Mon Sep 17 00:00:00 2001 From: Tian Xia Date: Tue, 21 May 2024 12:31:25 +0800 Subject: [PATCH 15/20] [Serve] Support headers in Readiness Probe (#3552) * inti * probe_str remobve heades and delete env vars * remove header values in replica manager logging --- llm/vllm/service-with-auth.yaml | 42 +++++++++++++++++++++++++++++++++ sky/serve/replica_managers.py | 19 ++++++++++++--- sky/serve/service_spec.py | 18 ++++++++++++-- sky/utils/schemas.py | 8 ++++++- 4 files changed, 81 insertions(+), 6 deletions(-) create mode 100644 llm/vllm/service-with-auth.yaml diff --git a/llm/vllm/service-with-auth.yaml b/llm/vllm/service-with-auth.yaml new file mode 100644 index 00000000000..0a40df29293 --- /dev/null +++ b/llm/vllm/service-with-auth.yaml @@ -0,0 +1,42 @@ +# service.yaml +# The newly-added `service` section to the `serve-openai-api.yaml` file. +service: + # Specifying the path to the endpoint to check the readiness of the service. + readiness_probe: + path: /v1/models + # Set authorization headers here if needed. + headers: + Authorization: Bearer $AUTH_TOKEN + # How many replicas to manage. + replicas: 1 + +# Fields below are the same with `serve-openai-api.yaml`. +envs: + MODEL_NAME: meta-llama/Llama-2-7b-chat-hf + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. + AUTH_TOKEN: # TODO: Fill with your own auth token (a random string), or use --env to pass. + +resources: + accelerators: {L4:1, A10G:1, A10:1, A100:1, A100-80GB:1} + ports: 8000 + +setup: | + conda activate vllm + if [ $? -ne 0 ]; then + conda create -n vllm python=3.10 -y + conda activate vllm + fi + + pip install transformers==4.38.0 + pip install vllm==0.3.2 + + python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" + + +run: | + conda activate vllm + echo 'Starting vllm openai api server...' + python -m vllm.entrypoints.openai.api_server \ + --model $MODEL_NAME --tokenizer hf-internal-testing/llama-tokenizer \ + --host 0.0.0.0 --port 8000 --api-key $AUTH_TOKEN + diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index efb3ba3cf48..b4732d36153 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -488,6 +488,7 @@ def probe( self, readiness_path: str, post_data: Optional[Dict[str, Any]], + headers: Optional[Dict[str, str]], ) -> Tuple['ReplicaInfo', bool, float]: """Probe the readiness of the replica. @@ -513,12 +514,14 @@ def probe( msg += 'POST' response = requests.post( readiness_path, + headers=headers, json=post_data, timeout=serve_constants.READINESS_PROBE_TIMEOUT_SECONDS) else: msg += 'GET' response = requests.get( readiness_path, + headers=headers, timeout=serve_constants.READINESS_PROBE_TIMEOUT_SECONDS) msg += (f' request to {replica_identity} returned status ' f'code {response.status_code}') @@ -565,9 +568,13 @@ def __init__(self, service_name: str, self._service_name: str = service_name self._uptime: Optional[float] = None self._update_mode = serve_utils.DEFAULT_UPDATE_MODE + header_keys = None + if spec.readiness_headers is not None: + header_keys = list(spec.readiness_headers.keys()) logger.info(f'Readiness probe path: {spec.readiness_path}\n' f'Initial delay seconds: {spec.initial_delay_seconds}\n' - f'Post data: {spec.post_data}') + f'Post data: {spec.post_data}\n' + f'Readiness header keys: {header_keys}') # Newest version among the currently provisioned and launched replicas self.latest_version: int = serve_constants.INITIAL_VERSION @@ -1033,8 +1040,11 @@ def _probe_all_replicas(self) -> None: probe_futures.append( pool.apply_async( info.probe, - (self._get_readiness_path( - info.version), self._get_post_data(info.version)), + ( + self._get_readiness_path(info.version), + self._get_post_data(info.version), + self._get_readiness_headers(info.version), + ), ),) logger.info(f'Replicas to probe: {", ".join(replica_to_probe)}') @@ -1215,5 +1225,8 @@ def _get_readiness_path(self, version: int) -> str: def _get_post_data(self, version: int) -> Optional[Dict[str, Any]]: return self._get_version_spec(version).post_data + def _get_readiness_headers(self, version: int) -> Optional[Dict[str, str]]: + return self._get_version_spec(version).readiness_headers + def _get_initial_delay_seconds(self, version: int) -> int: return self._get_version_spec(version).initial_delay_seconds diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index eba38aa5a79..80217acfff8 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -23,6 +23,7 @@ def __init__( max_replicas: Optional[int] = None, target_qps_per_replica: Optional[float] = None, post_data: Optional[Dict[str, Any]] = None, + readiness_headers: Optional[Dict[str, str]] = None, dynamic_ondemand_fallback: Optional[bool] = None, base_ondemand_fallback_replicas: Optional[int] = None, upscale_delay_seconds: Optional[int] = None, @@ -81,6 +82,7 @@ def __init__( self._max_replicas: Optional[int] = max_replicas self._target_qps_per_replica: Optional[float] = target_qps_per_replica self._post_data: Optional[Dict[str, Any]] = post_data + self._readiness_headers: Optional[Dict[str, str]] = readiness_headers self._dynamic_ondemand_fallback: Optional[ bool] = dynamic_ondemand_fallback self._base_ondemand_fallback_replicas: Optional[ @@ -111,11 +113,13 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec': service_config['readiness_path'] = readiness_section initial_delay_seconds = None post_data = None + readiness_headers = None else: service_config['readiness_path'] = readiness_section['path'] initial_delay_seconds = readiness_section.get( 'initial_delay_seconds', None) post_data = readiness_section.get('post_data', None) + readiness_headers = readiness_section.get('headers', None) if initial_delay_seconds is None: initial_delay_seconds = constants.DEFAULT_INITIAL_DELAY_SECONDS service_config['initial_delay_seconds'] = initial_delay_seconds @@ -129,6 +133,7 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec': '`readiness_probe` section of your service YAML.' ) from e service_config['post_data'] = post_data + service_config['readiness_headers'] = readiness_headers policy_section = config.get('replica_policy', None) simplified_policy_section = config.get('replicas', None) @@ -204,6 +209,7 @@ def add_if_not_none(section, key, value, no_empty: bool = False): add_if_not_none('readiness_probe', 'initial_delay_seconds', self.initial_delay_seconds) add_if_not_none('readiness_probe', 'post_data', self.post_data) + add_if_not_none('readiness_probe', 'headers', self._readiness_headers) add_if_not_none('replica_policy', 'min_replicas', self.min_replicas) add_if_not_none('replica_policy', 'max_replicas', self.max_replicas) add_if_not_none('replica_policy', 'target_qps_per_replica', @@ -220,8 +226,12 @@ def add_if_not_none(section, key, value, no_empty: bool = False): def probe_str(self): if self.post_data is None: - return f'GET {self.readiness_path}' - return f'POST {self.readiness_path} {json.dumps(self.post_data)}' + method = f'GET {self.readiness_path}' + else: + method = f'POST {self.readiness_path} {json.dumps(self.post_data)}' + headers = ('' if self.readiness_headers is None else + ' with custom headers') + return f'{method}{headers}' def spot_policy_str(self): policy_strs = [] @@ -287,6 +297,10 @@ def target_qps_per_replica(self) -> Optional[float]: def post_data(self) -> Optional[Dict[str, Any]]: return self._post_data + @property + def readiness_headers(self) -> Optional[Dict[str, str]]: + return self._readiness_headers + @property def base_ondemand_fallback_replicas(self) -> Optional[int]: return self._base_ondemand_fallback_replicas diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 42e0da96211..8bbe1d54e60 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -303,7 +303,13 @@ def get_service_schema(): }, { 'type': 'object', }] - } + }, + 'headers': { + 'type': 'object', + 'additionalProperties': { + 'type': 'string' + } + }, } }] }, From ccfc3b1e55cc88d8527e4270706e1dd6d4d595d3 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 22 May 2024 02:47:10 -0700 Subject: [PATCH 16/20] [Core] Fix quote in command runner (#3572) * Quote the command correctly when source_bashrc is not set * Remove unnecessary source bashrc * format * Fix setup script for conda * Add comment * format * address comments --- sky/backends/cloud_vm_ray_backend.py | 5 ++--- sky/provision/instance_setup.py | 4 ++-- sky/skylet/constants.py | 11 +++++++---- sky/utils/command_runner.py | 10 ++++++---- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 6d2447fe89b..6ff9731033e 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3175,7 +3175,8 @@ def _setup_node(node_id: int) -> None: f'{create_script_code} && {setup_cmd}', log_path=setup_log_path, process_stream=False, - source_bashrc=True, + # We do not source bashrc for setup, since bashrc is sourced + # in the script already. ) def error_message() -> str: @@ -3724,7 +3725,6 @@ def tail_managed_job_logs(self, process_stream=False, ssh_mode=command_runner.SshMode.INTERACTIVE, stdin=subprocess.DEVNULL, - source_bashrc=True, ) def tail_serve_logs(self, handle: CloudVmRayResourceHandle, @@ -3762,7 +3762,6 @@ def tail_serve_logs(self, handle: CloudVmRayResourceHandle, process_stream=False, ssh_mode=command_runner.SshMode.INTERACTIVE, stdin=subprocess.DEVNULL, - source_bashrc=True, ) def teardown_no_lock(self, diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 1e5e6285fef..2e07f026616 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -198,8 +198,8 @@ def _setup_node(runner: command_runner.CommandRunner, log_path: str): stream_logs=False, log_path=log_path, require_outputs=True, - # Installing depencies requires source bashrc to access the PATH - # in bashrc. + # Installing dependencies requires source bashrc to access + # conda. source_bashrc=True) retry_cnt = 0 while returncode == 255 and retry_cnt < _MAX_RETRY: diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 578629ea3e2..0f2d7540007 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -91,15 +91,18 @@ # AWS's Deep Learning AMI's default conda environment. CONDA_INSTALLATION_COMMANDS = ( 'which conda > /dev/null 2>&1 || ' - '(wget -nc https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -O Miniconda3-Linux-x86_64.sh && ' # pylint: disable=line-too-long + '{ wget -nc https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -O Miniconda3-Linux-x86_64.sh && ' # pylint: disable=line-too-long 'bash Miniconda3-Linux-x86_64.sh -b && ' 'eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && ' - 'conda config --set auto_activate_base true); ' - 'grep "# >>> conda initialize >>>" ~/.bashrc || conda init;' + 'conda config --set auto_activate_base true && ' + # Use $(echo ~) instead of ~ to avoid the error "no such file or directory". + # Also, not using $HOME to avoid the error HOME variable not set. + f'echo "$(echo ~)/miniconda3/bin/python" > {SKY_PYTHON_PATH_FILE}; }}; ' + 'grep "# >>> conda initialize >>>" ~/.bashrc || ' + '{ conda init && source ~/.bashrc; };' '(type -a python | grep -q python3) || ' 'echo \'alias python=python3\' >> ~/.bashrc;' '(type -a pip | grep -q pip3) || echo \'alias pip=pip3\' >> ~/.bashrc;' - 'source ~/.bashrc;' # Writes Python path to file if it does not exist or the file is empty. f'[ -s {SKY_PYTHON_PATH_FILE} ] || which python3 > {SKY_PYTHON_PATH_FILE};') diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index 3aa87eda138..3ed69d0e2a1 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -184,8 +184,8 @@ def _get_command_to_run( # cluster by 1 second. # sourcing ~/.bashrc is not required for internal executions command += [ - 'true && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore' - f' && ({cmd})' + shlex.quote('true && export OMP_NUM_THREADS=1 ' + f'PYTHONWARNINGS=ignore && ({cmd})') ] if not separate_stderr: command.append('2>&1') @@ -431,10 +431,12 @@ def run( cmd, process_stream, separate_stderr, - # A hack to remove the following bash warnings (twice): + # A hack to remove the following SSH warning+bash warnings (twice): + # Warning: Permanently added 'xx.xx.xx.xx' to the list of known... # bash: cannot set terminal process group # bash: no job control in this shell - skip_lines=5 if source_bashrc else 0, + # When not source_bashrc, the bash warning will only show once. + skip_lines=5 if source_bashrc else 3, source_bashrc=source_bashrc) command = base_ssh_command + [shlex.quote(command_str)] From 89e8d484ac82718842f483d644ac392d8345162a Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 22 May 2024 10:49:04 -0700 Subject: [PATCH 17/20] [k8s] Add `skypilot-user` label to pods (#3576) add skypilot-user label --- sky/templates/kubernetes-ray.yml.j2 | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sky/templates/kubernetes-ray.yml.j2 b/sky/templates/kubernetes-ray.yml.j2 index b05c8b589f6..7078a6ca787 100644 --- a/sky/templates/kubernetes-ray.yml.j2 +++ b/sky/templates/kubernetes-ray.yml.j2 @@ -205,6 +205,7 @@ provider: labels: parent: skypilot skypilot-cluster: {{cluster_name_on_cloud}} + skypilot-user: {{ user }} name: {{cluster_name_on_cloud}}-head-ssh spec: selector: @@ -220,6 +221,7 @@ provider: labels: parent: skypilot skypilot-cluster: {{cluster_name_on_cloud}} + skypilot-user: {{ user }} # NOTE: If you're running multiple Ray clusters with services # on one Kubernetes cluster, they must have unique service # names. @@ -256,6 +258,7 @@ available_node_types: skypilot-cluster: {{cluster_name_on_cloud}} # Identifies the SSH jump pod used by this pod. Used in life cycle management of the ssh jump pod. skypilot-ssh-jump: {{k8s_ssh_jump_name}} + skypilot-user: {{ user }} # Custom tags for the pods {%- for label_key, label_value in labels.items() %} {{ label_key }}: {{ label_value }} From 211386fb4c17be1669b9e924ead123f5586a3711 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 22 May 2024 10:49:20 -0700 Subject: [PATCH 18/20] [k8s] Check only kubernetes for `sky local down` (#3578) check only kubernetes for local down --- sky/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/cli.py b/sky/cli.py index 894ccc7257c..9a45a35ae55 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -4995,7 +4995,7 @@ def local_up(gpus: bool): f'\nError: {style.BRIGHT}{stderr}{style.RESET_ALL}') # Run sky check with rich_utils.safe_status('[bold cyan]Running sky check...'): - sky_check.check(clouds=('kubernetes',), quiet=True) + sky_check.check(clouds=['kubernetes'], quiet=True) if cluster_created: # Prepare completion message which shows CPU and GPU count # Get number of CPUs @@ -5097,7 +5097,7 @@ def local_down(): if cluster_removed: # Run sky check with rich_utils.safe_status('[bold cyan]Running sky check...'): - sky_check.check(quiet=True) + sky_check.check(clouds=['kubernetes'], quiet=True) click.echo( f'{colorama.Fore.GREEN}Local cluster removed.{style.RESET_ALL}') From 97cba000dbf29de6170bc937623c0fc697f67964 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 23 May 2024 10:27:34 -0700 Subject: [PATCH 19/20] [Core] Install SkyPilot runtime in separate env (#3575) * Quote the command correctly when source_bashrc is not set * Remove unnecessary source bashrc * format * Fix setup script for conda * Add comment * format * Separate env for skypilot * add test smoke * add system site-packages * add test for default to non-base conda env * Fix controllers and ray node providers * move activate to maybe_skylet * Make axolotl example work for kubernetes * fix axolotl * add test for 3.12 * format * Fix docker PATH * format * add axolotl image in test * address comments * revert grpcio version as it is only installed in our runtime env * refactor command for env set up * switch to curl as CentOS may not have wget installed but have curl * add l4 in command * fix dependency for test * fix python path for ray executable * Fix azure launch * add comments * fix test * fix smoke * fix name * fix * fix usage * fix usage for accelerators * fix event * fix name * fix * address comments --- llm/axolotl/axolotl-docker.yaml | 29 +++++++++++++ llm/axolotl/axolotl-spot.yaml | 22 ++-------- llm/axolotl/axolotl.yaml | 19 ++------- llm/axolotl/mistral/qlora-checkpoint.yaml | 3 +- llm/axolotl/mistral/qlora.yaml | 3 +- sky/backends/backend_utils.py | 9 +++- sky/jobs/core.py | 1 - sky/provision/docker_utils.py | 16 +++++++- sky/provision/instance_setup.py | 5 ++- sky/provision/kubernetes/instance.py | 13 +++--- sky/skylet/attempt_skylet.py | 3 ++ sky/skylet/constants.py | 41 ++++++++++++++----- sky/skylet/events.py | 15 ++++++- sky/templates/azure-ray.yml.j2 | 4 +- sky/templates/ibm-ray.yml.j2 | 4 +- sky/templates/jobs-controller.yaml.j2 | 3 ++ sky/templates/lambda-ray.yml.j2 | 4 +- sky/templates/oci-ray.yml.j2 | 4 +- sky/templates/scp-ray.yml.j2 | 4 +- sky/templates/sky-serve-controller.yaml.j2 | 5 +++ sky/usage/usage_lib.py | 6 +++ sky/utils/controller_utils.py | 6 ++- tests/conftest.py | 2 +- tests/kubernetes/README.md | 8 +++- tests/skyserve/cancel/cancel.yaml | 2 + tests/test_smoke.py | 38 ++++++++++++----- .../different_default_conda_env.yaml | 11 +++++ 27 files changed, 195 insertions(+), 85 deletions(-) create mode 100644 llm/axolotl/axolotl-docker.yaml create mode 100644 tests/test_yamls/different_default_conda_env.yaml diff --git a/llm/axolotl/axolotl-docker.yaml b/llm/axolotl/axolotl-docker.yaml new file mode 100644 index 00000000000..b883ebdde46 --- /dev/null +++ b/llm/axolotl/axolotl-docker.yaml @@ -0,0 +1,29 @@ +# Usage: +# HF_TOKEN=abc sky launch -c axolotl axolotl.yaml --env HF_TOKEN -y -i30 --down + +name: axolotl + +resources: + accelerators: L4:1 + cloud: gcp # optional + +workdir: mistral + +setup: | + docker pull winglian/axolotl:main-py3.10-cu118-2.0.1 + +run: | + docker run --gpus all \ + -v ~/sky_workdir:/sky_workdir \ + -v /root/.cache:/root/.cache \ + winglian/axolotl:main-py3.10-cu118-2.0.1 \ + huggingface-cli login --token ${HF_TOKEN} + + docker run --gpus all \ + -v ~/sky_workdir:/sky_workdir \ + -v /root/.cache:/root/.cache \ + winglian/axolotl:main-py3.10-cu118-2.0.1 \ + accelerate launch -m axolotl.cli.train /sky_workdir/qlora.yaml + +envs: + HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. diff --git a/llm/axolotl/axolotl-spot.yaml b/llm/axolotl/axolotl-spot.yaml index 942f4ccc4ba..8970737483d 100644 --- a/llm/axolotl/axolotl-spot.yaml +++ b/llm/axolotl/axolotl-spot.yaml @@ -12,6 +12,7 @@ resources: accelerators: A100:1 cloud: gcp # optional use_spot: True + image_id: docker:winglian/axolotl:main-py3.10-cu118-2.0.1 workdir: mistral @@ -20,29 +21,12 @@ file_mounts: name: ${BUCKET} mode: MOUNT -setup: | - docker pull winglian/axolotl:main-py3.10-cu118-2.0.1 - run: | - docker run --gpus all \ - -v ~/sky_workdir:/sky_workdir \ - -v /root/.cache:/root/.cache \ - winglian/axolotl:main-py3.10-cu118-2.0.1 \ - huggingface-cli login --token ${HF_TOKEN} + huggingface-cli login --token ${HF_TOKEN} - docker run --gpus all \ - -v ~/sky_workdir:/sky_workdir \ - -v /root/.cache:/root/.cache \ - -v /sky-notebook:/sky-notebook \ - winglian/axolotl:main-py3.10-cu118-2.0.1 \ - accelerate launch -m axolotl.cli.train /sky_workdir/qlora-checkpoint.yaml + accelerate launch -m axolotl.cli.train qlora-checkpoint.yaml envs: HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. BUCKET: # TODO: Fill with your unique bucket name, or use --env to pass. - - - - - diff --git a/llm/axolotl/axolotl.yaml b/llm/axolotl/axolotl.yaml index 9cec1d1f331..f46588e9aae 100644 --- a/llm/axolotl/axolotl.yaml +++ b/llm/axolotl/axolotl.yaml @@ -5,25 +5,14 @@ name: axolotl resources: accelerators: L4:1 - cloud: gcp # optional + image_id: docker:winglian/axolotl:main-py3.10-cu118-2.0.1 workdir: mistral -setup: | - docker pull winglian/axolotl:main-py3.10-cu118-2.0.1 - run: | - docker run --gpus all \ - -v ~/sky_workdir:/sky_workdir \ - -v /root/.cache:/root/.cache \ - winglian/axolotl:main-py3.10-cu118-2.0.1 \ - huggingface-cli login --token ${HF_TOKEN} - - docker run --gpus all \ - -v ~/sky_workdir:/sky_workdir \ - -v /root/.cache:/root/.cache \ - winglian/axolotl:main-py3.10-cu118-2.0.1 \ - accelerate launch -m axolotl.cli.train /sky_workdir/qlora.yaml + huggingface-cli login --token ${HF_TOKEN} + + accelerate launch -m axolotl.cli.train qlora.yaml envs: HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass. diff --git a/llm/axolotl/mistral/qlora-checkpoint.yaml b/llm/axolotl/mistral/qlora-checkpoint.yaml index 278a5d72b9a..1f1cc67446c 100644 --- a/llm/axolotl/mistral/qlora-checkpoint.yaml +++ b/llm/axolotl/mistral/qlora-checkpoint.yaml @@ -71,6 +71,7 @@ warmup_steps: 10 eval_steps: 0.05 eval_table_size: eval_table_max_new_tokens: 128 +eval_sample_packing: false save_steps: 2 ## increase based on your dataset save_strategy: steps debug: @@ -81,4 +82,4 @@ fsdp_config: special_tokens: bos_token: "" eos_token: "" - unk_token: "" \ No newline at end of file + unk_token: "" diff --git a/llm/axolotl/mistral/qlora.yaml b/llm/axolotl/mistral/qlora.yaml index 42c3742b52d..39b2c55b1ce 100644 --- a/llm/axolotl/mistral/qlora.yaml +++ b/llm/axolotl/mistral/qlora.yaml @@ -69,6 +69,7 @@ warmup_steps: 10 eval_steps: 0.05 eval_table_size: eval_table_max_new_tokens: 128 +eval_sample_packing: false save_steps: debug: deepspeed: @@ -78,4 +79,4 @@ fsdp_config: special_tokens: bos_token: "" eos_token: "" - unk_token: "" \ No newline at end of file + unk_token: "" diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index cf43cfdf2ed..b1598c7c039 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -925,7 +925,14 @@ def write_cluster_config( 'dump_port_command': dump_port_command, # Sky-internal constants. 'sky_ray_cmd': constants.SKY_RAY_CMD, - 'sky_pip_cmd': constants.SKY_PIP_CMD, + # pip install needs to have python env activated to make sure + # installed packages are within the env path. + 'sky_pip_cmd': f'{constants.SKY_PIP_CMD}', + # Activate the SkyPilot runtime environment when starting ray + # cluster, so that ray autoscaler can access cloud SDK and CLIs + # on remote + 'sky_activate_python_env': + constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV, 'ray_version': constants.SKY_REMOTE_RAY_VERSION, # Command for waiting ray cluster to be ready on head. 'ray_head_wait_initialized_command': diff --git a/sky/jobs/core.py b/sky/jobs/core.py index ff9953489d5..7f9e0d757ea 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -98,7 +98,6 @@ def launch( 'dag_name': dag.name, 'retry_until_up': retry_until_up, 'remote_user_config_path': remote_user_config_path, - 'sky_python_cmd': skylet_constants.SKY_PYTHON_CMD, 'modified_catalogs': service_catalog_common.get_modified_catalog_file_mounts(), **controller_utils.shared_controller_vars_to_fill( diff --git a/sky/provision/docker_utils.py b/sky/provision/docker_utils.py index 10ae5dafc07..b9ed689fdaf 100644 --- a/sky/provision/docker_utils.py +++ b/sky/provision/docker_utils.py @@ -15,6 +15,17 @@ DOCKER_PERMISSION_DENIED_STR = ('permission denied while trying to connect to ' 'the Docker daemon socket') +# Configure environment variables. A docker image can have environment variables +# set in the Dockerfile with `ENV``. We need to export these variables to the +# shell environment, so that our ssh session can access them. +SETUP_ENV_VARS_CMD = ( + 'prefix_cmd() ' + '{ if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; } && ' + 'printenv | while IFS=\'=\' read -r key value; do echo "export $key=\\\"$value\\\""; done > ' # pylint: disable=line-too-long + '~/container_env_var.sh && ' + '$(prefix_cmd) mv ~/container_env_var.sh /etc/profile.d/container_env_var.sh' +) + @dataclasses.dataclass class DockerLoginConfig: @@ -244,6 +255,8 @@ def initialize(self) -> str: self._run(start_command) # SkyPilot: Setup Commands. + # TODO(zhwu): the following setups should be aligned with the kubernetes + # pod setup, like provision.kubernetes.instance::_set_env_vars_in_pods # TODO(tian): These setup commands assumed that the container is # debian-based. We should make it more general. # Most of docker images are using root as default user, so we set an @@ -296,7 +309,8 @@ def initialize(self) -> str: 'mkdir -p ~/.ssh;' 'cat /tmp/host_ssh_authorized_keys >> ~/.ssh/authorized_keys;' 'sudo service ssh start;' - 'sudo sed -i "s/mesg n/tty -s \&\& mesg n/" ~/.profile;', + 'sudo sed -i "s/mesg n/tty -s \&\& mesg n/" ~/.profile;' + f'{SETUP_ENV_VARS_CMD}', run_env='docker') # SkyPilot: End of Setup Commands. diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 2e07f026616..c81ecd78db4 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -61,7 +61,10 @@ 'done;') # Restart skylet when the version does not match to keep the skylet up-to-date. -MAYBE_SKYLET_RESTART_CMD = (f'{constants.SKY_PYTHON_CMD} -m ' +# We need to activate the python environment to make sure autostop in skylet +# can find the cloud SDK/CLI in PATH. +MAYBE_SKYLET_RESTART_CMD = (f'{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV}; ' + f'{constants.SKY_PYTHON_CMD} -m ' 'sky.skylet.attempt_skylet;') diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 9068079701f..4f88293525f 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -10,6 +10,7 @@ from sky import status_lib from sky.adaptors import kubernetes from sky.provision import common +from sky.provision import docker_utils from sky.provision.kubernetes import config as config_lib from sky.provision.kubernetes import utils as kubernetes_utils from sky.utils import common_utils @@ -241,7 +242,7 @@ def _wait_for_pods_to_run(namespace, new_nodes): 'the node. Error details: ' f'{container_status.state.waiting.message}.') # Reaching this point means that one of the pods had an issue, - # so break out of the loop + # so break out of the loop, and wait until next second. break if all_pods_running: @@ -301,13 +302,7 @@ def _set_env_vars_in_pods(namespace: str, new_pods: List): set_k8s_env_var_cmd = [ '/bin/sh', '-c', - ( - 'prefix_cmd() ' - '{ if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; } && ' - 'printenv | while IFS=\'=\' read -r key value; do echo "export $key=\\\"$value\\\""; done > ' # pylint: disable=line-too-long - '~/k8s_env_var.sh && ' - 'mv ~/k8s_env_var.sh /etc/profile.d/k8s_env_var.sh || ' - '$(prefix_cmd) mv ~/k8s_env_var.sh /etc/profile.d/k8s_env_var.sh') + docker_utils.SETUP_ENV_VARS_CMD, ] for new_pod in new_pods: @@ -540,6 +535,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str, _wait_for_pods_to_schedule(namespace, wait_pods, provision_timeout) # Wait until the pods and their containers are up and running, and # fail early if there is an error + logger.debug(f'run_instances: waiting for pods to be running (pulling ' + f'images): {list(wait_pods_dict.keys())}') _wait_for_pods_to_run(namespace, wait_pods) logger.debug(f'run_instances: all pods are scheduled and running: ' f'{list(wait_pods_dict.keys())}') diff --git a/sky/skylet/attempt_skylet.py b/sky/skylet/attempt_skylet.py index 609cfa09141..54df4986080 100644 --- a/sky/skylet/attempt_skylet.py +++ b/sky/skylet/attempt_skylet.py @@ -21,6 +21,9 @@ def restart_skylet(): shell=True, check=False) subprocess.run( + # We have made sure that `attempt_skylet.py` is executed with the + # skypilot runtime env activated, so that skylet can access the cloud + # CLI tools. f'nohup {constants.SKY_PYTHON_CMD} -m sky.skylet.skylet' ' >> ~/.sky/skylet.log 2>&1 &', shell=True, diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 0f2d7540007..0c68fd7f6e6 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -37,8 +37,18 @@ SKY_PYTHON_CMD = f'$({SKY_GET_PYTHON_PATH_CMD})' SKY_PIP_CMD = f'{SKY_PYTHON_CMD} -m pip' # Ray executable, e.g., /opt/conda/bin/ray -SKY_RAY_CMD = (f'$([ -s {SKY_RAY_PATH_FILE} ] && ' +# We need to add SKY_PYTHON_CMD before ray executable because: +# The ray executable is a python script with a header like: +# #!/opt/conda/bin/python3 +# When we create the skypilot-runtime venv, the previously installed ray +# executable will be reused (due to --system-site-packages), and that will cause +# running ray CLI commands to use the wrong python executable. +SKY_RAY_CMD = (f'{SKY_PYTHON_CMD} $([ -s {SKY_RAY_PATH_FILE} ] && ' f'cat {SKY_RAY_PATH_FILE} 2> /dev/null || which ray)') +# Separate env for SkyPilot runtime dependencies. +SKY_REMOTE_PYTHON_ENV_NAME = 'skypilot-runtime' +SKY_REMOTE_PYTHON_ENV = f'~/{SKY_REMOTE_PYTHON_ENV_NAME}' +ACTIVATE_SKY_REMOTE_PYTHON_ENV = f'source {SKY_REMOTE_PYTHON_ENV}/bin/activate' # The name for the environment variable that stores the unique ID of the # current task. This will stay the same across multiple recoveries of the @@ -91,20 +101,27 @@ # AWS's Deep Learning AMI's default conda environment. CONDA_INSTALLATION_COMMANDS = ( 'which conda > /dev/null 2>&1 || ' - '{ wget -nc https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -O Miniconda3-Linux-x86_64.sh && ' # pylint: disable=line-too-long + '{ curl https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -o Miniconda3-Linux-x86_64.sh && ' # pylint: disable=line-too-long 'bash Miniconda3-Linux-x86_64.sh -b && ' 'eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && ' 'conda config --set auto_activate_base true && ' - # Use $(echo ~) instead of ~ to avoid the error "no such file or directory". - # Also, not using $HOME to avoid the error HOME variable not set. - f'echo "$(echo ~)/miniconda3/bin/python" > {SKY_PYTHON_PATH_FILE}; }}; ' + f'conda activate base; }}; ' 'grep "# >>> conda initialize >>>" ~/.bashrc || ' '{ conda init && source ~/.bashrc; };' - '(type -a python | grep -q python3) || ' - 'echo \'alias python=python3\' >> ~/.bashrc;' - '(type -a pip | grep -q pip3) || echo \'alias pip=pip3\' >> ~/.bashrc;' - # Writes Python path to file if it does not exist or the file is empty. - f'[ -s {SKY_PYTHON_PATH_FILE} ] || which python3 > {SKY_PYTHON_PATH_FILE};') + # If Python version is larger then equal to 3.12, create a new conda env + # with Python 3.10. + # We don't use a separate conda env for SkyPilot dependencies because it is + # costly to create a new conda env, and venv should be a lightweight and + # faster alternative when the python version satisfies the requirement. + '[[ $(python3 --version | cut -d " " -f 2 | cut -d "." -f 2) -ge 12 ]] && ' + f'echo "Creating conda env with Python 3.10" && ' + f'conda create -y -n {SKY_REMOTE_PYTHON_ENV_NAME} python=3.10 && ' + f'conda activate {SKY_REMOTE_PYTHON_ENV_NAME};' + # Create a separate conda environment for SkyPilot dependencies. + f'[ -d {SKY_REMOTE_PYTHON_ENV} ] || ' + f'{{ {SKY_PYTHON_CMD} -m venv {SKY_REMOTE_PYTHON_ENV} --system-site-packages && ' + f'echo "$(echo {SKY_REMOTE_PYTHON_ENV})/bin/python" > {SKY_PYTHON_PATH_FILE}; }};' +) _sky_version = str(version.parse(sky.__version__)) RAY_STATUS = f'RAY_ADDRESS=127.0.0.1:{SKY_REMOTE_RAY_PORT} {SKY_RAY_CMD} status' @@ -142,7 +159,9 @@ # mentioned above are resolved. 'export PATH=$PATH:$HOME/.local/bin; ' # Writes ray path to file if it does not exist or the file is empty. - f'[ -s {SKY_RAY_PATH_FILE} ] || which ray > {SKY_RAY_PATH_FILE}; ' + f'[ -s {SKY_RAY_PATH_FILE} ] || ' + f'{{ {ACTIVATE_SKY_REMOTE_PYTHON_ENV} && ' + f'which ray > {SKY_RAY_PATH_FILE} || exit 1; }}; ' # END ray package check and installation f'{{ {SKY_PIP_CMD} list | grep "skypilot " && ' '[ "$(cat ~/.sky/wheels/current_sky_wheel_hash)" == "{sky_wheel_hash}" ]; } || ' # pylint: disable=line-too-long diff --git a/sky/skylet/events.py b/sky/skylet/events.py index c63b42cc438..b6e99707dab 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -3,7 +3,6 @@ import os import re import subprocess -import sys import time import traceback @@ -193,7 +192,10 @@ def _stop_cluster(self, autostop_config): # Passing env inherited from os.environ is technically not # needed, because we call `python