diff --git a/sky/authentication.py b/sky/authentication.py index 7eeb0e0ec9c..fd8dbc29cad 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -44,6 +44,7 @@ from sky.adaptors import kubernetes from sky.adaptors import runpod from sky.clouds.utils import lambda_utils +from sky.provision.do import utils as do_utils from sky.provision.fluidstack import fluidstack_utils from sky.provision.kubernetes import utils as kubernetes_utils from sky.utils import common_utils @@ -471,3 +472,16 @@ def setup_fluidstack_authentication(config: Dict[str, Any]) -> Dict[str, Any]: client.get_or_add_ssh_key(public_key) config['auth']['ssh_public_key'] = PUBLIC_SSH_KEY_PATH return configure_ssh_info(config) + + +def setup_do_authentication(config: Dict[str, Any]) -> Dict[str, Any]: + + get_or_generate_keys() + + public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH) + public_key = None + with open(public_key_path, 'r', encoding='utf-8') as f: + public_key = f.read() + do_utils.get_or_create_ssh_keys(public_key) + config['auth']['ssh_public_key'] = PUBLIC_SSH_KEY_PATH + return configure_ssh_info(config) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index e558f18bc11..514a49f83ab 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1026,7 +1026,6 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): clouds.SCP, clouds.Vsphere, clouds.Cudo, - clouds.DigitalOcean, clouds.Paperspace, clouds.Azure, )): @@ -1043,6 +1042,8 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): config = auth.setup_runpod_authentication(config) elif isinstance(cloud, clouds.Fluidstack): config = auth.setup_fluidstack_authentication(config) + elif isinstance(cloud, clouds.DO): + config = auth.setup_do_authentication(config) else: assert False, cloud common_utils.dump_yaml(cluster_config_file, config) diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index 843f6dff67f..83c471873ae 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -15,6 +15,7 @@ from sky.clouds.aws import AWS from sky.clouds.azure import Azure from sky.clouds.cudo import Cudo +from sky.clouds.do import DO from sky.clouds.fluidstack import Fluidstack from sky.clouds.gcp import GCP from sky.clouds.ibm import IBM @@ -32,7 +33,7 @@ 'Azure', 'Cloud', 'Cudo', - 'DigitalOcean' + 'DO', 'GCP', 'Lambda', 'Paperspace', diff --git a/sky/clouds/cloud_registry.py b/sky/clouds/cloud_registry.py index 5c4b10b9fd4..be3cd8d4791 100644 --- a/sky/clouds/cloud_registry.py +++ b/sky/clouds/cloud_registry.py @@ -23,6 +23,7 @@ def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']: def register(self, cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']: name = cloud_cls.__name__.lower() + # import pdb; pdb.set_trace() assert name not in self, f'{name} already registered' self[name] = cloud_cls() return cloud_cls diff --git a/sky/provision/do/config.py b/sky/provision/do/config.py deleted file mode 100644 index ab8c0467081..00000000000 --- a/sky/provision/do/config.py +++ /dev/null @@ -1,40 +0,0 @@ -"""DigitalOcean configuration bootstrapping.""" - -from sky import sky_logging -from sky.provision import common -from sky.provision.do import constants -from sky.provision.do import utils - -logger = sky_logging.init_logger(__name__) - - -def bootstrap_instances( - region: str, cluster_name: str, - config: common.ProvisionConfig) -> common.ProvisionConfig: - """Bootstraps instances for the given cluster.""" - if not config.node_config['DiskSize'] in constants.DISK_SIZES: - if config.node_config['DiskSize'] > constants.DISK_SIZES[-1]: - raise ValueError( - f'DigitalOcean largest disk size is {constants.DISK_SIZES[-1]},' - f'requested {config.node_config["DiskSize"]}') - - size = 0 - for possible_size in constants.DISK_SIZES: - if size < config.node_config['DiskSize'] < possible_size: - logger.warning( - 'DigitalOcean only supports disk sizes' - f'{constants.DISK_SIZES}, ' - f'upsizing from {config.node_config["DiskSize"]} ' - f'to {possible_size}') - config.node_config['DiskSize'] = possible_size - break - - client = utils.DigitalOceanCloudClient() - network_id = client.setup_network(cluster_name, region)['id'] - config.node_config['NetworkId'] = network_id - - # Add pubkey to machines via startup - public_key = config.authentication_config['ssh_public_key'] - client.set_sky_key_script(public_key) - - return config diff --git a/sky/provision/do/instance.py b/sky/provision/do/instance.py index 8660dbfe17c..04cabf8fb3e 100644 --- a/sky/provision/do/instance.py +++ b/sky/provision/do/instance.py @@ -1,5 +1,6 @@ """DigitalOcean instance provisioning.""" +import pydo import time from typing import Any, Dict, List, Optional @@ -21,14 +22,13 @@ def _filter_instances(cluster_name_on_cloud: str, status_filters: Optional[List[str]]) -> Dict[str, Any]: - client = utils.DigitalOceanCloudClient() - instances = client.list_instances() + response = utils.client.droplets.list() possible_names = [ f'{cluster_name_on_cloud}-head', f'{cluster_name_on_cloud}-worker', ] - filtered_instances = {} + filtered_instances: Dict[str, Any] = {} for instance in instances: instance_id = instance['id'] if status_filters is not None and instance[ diff --git a/sky/provision/do/utils.py b/sky/provision/do/utils.py index a3a8178372b..688ac3ff1ac 100644 --- a/sky/provision/do/utils.py +++ b/sky/provision/do/utils.py @@ -2,250 +2,70 @@ import json import os +import pydo import requests import time import yaml from typing import Any, Dict, List, Optional, Union - +from azure.core.pipeline import policies from sky import sky_logging from sky.utils import common_utils logger = sky_logging.init_logger(__name__) -API_ENDPOINT = 'https://api.digitalocean.com/v2' -INITIAL_BACKOFF_SECONDS = 10 -MAX_BACKOFF_FACTOR = 10 -MAX_ATTEMPTS = 6 +MAX_BACKOFF_FACTOR: int = 10 +MAX_ATTEMPTS: int = 6 -POSSIBLE_CREDENTIALS = [ +# locations set by doctl CLI +POSSIBLE_CREDENTIALS: List[str] = [ '~/Library/Application Support/doctl/config.yaml', # Mac OS os.path.join(os.getenv('XDG_CONFIG_HOME', '~/.config'), 'doctl/config.yaml'), # Linux - os.path.join(os.getenv('APPDATA', ''), '\doctl\config.yaml'), # Windows + os.path.join(os.getenv('APPDATA', ''), r'\doctl\config.yaml'), # Windows ] +SSH_KEY_NAME = f'skypilot-ssh-{common_utils.get_user_hash()}' +client = None + class DigitalOceanCloudError(Exception): pass -def raise_digitalocean_api_error(response: requests.Response) -> None: - """Raise DigitalOceanCloudError if appropriate.""" - status_code = response.status_code - if status_code == 200: - return - if status_code == 429: - raise DigitalOceanCloudError('Your API requests are being rate limited.') - try: - resp_json = response.json() - code = resp_json.get('code') - message = resp_json.get('message') - except json.decoder.JSONDecodeError as e: - raise DigitalOceanCloudError( - 'Response cannot be parsed into JSON. Status ' - f'code: {status_code}; reason: {response.reason}; ' - f'content: {response.text}') from e - raise DigitalOceanCloudError(f'{code}: {message}') - - -def _try_request_with_backoff( - method: str, - url: str, - headers: Dict[str, str], - data: Optional[Union[str, Dict[str, Any]]] = None) -> Dict[str, Any]: - backoff = common_utils.Backoff(initial_backoff=INITIAL_BACKOFF_SECONDS, - max_backoff_factor=MAX_BACKOFF_FACTOR) - for i in range(MAX_ATTEMPTS): - if method == 'get': - response = requests.get(url, headers=headers, params=data) - elif method == 'post': - response = requests.post(url, headers=headers, data=data) - elif method == 'put': - response = requests.put(url, headers=headers, data=data) - elif method == 'patch': - response = requests.patch(url, headers=headers, data=data) - elif method == 'delete': - response = requests.delete(url, headers=headers) - else: - raise ValueError(f'Unsupported requests method: {method}') - # If rate limited, wait and try again - if response.status_code == 429 and i != MAX_ATTEMPTS - 1: - time.sleep(backoff.current_backoff()) - continue - if response.status_code == 200: - return response.json() - raise_digitalocean_api_error(response) - return {} - - -class DigitalOceanCloudClient: - """Wrapper functions for DigitalOcean and Machine Core API.""" - - def __init__(self) -> None: - self._credentials = None - for path in POSSIBLE_CREDENTIALS: - credentials = os.path.expanduser(path) - if os.path.exists(credentials): - self._credentials = credentials - if self._credentials is None: - raise ValueError(f'no valid DigitalOcean config found from {POSSIBLE_CREDENTIALS}') - logger.debug(f"using digital config {self._credentials}") - with open(self._credentials, 'r', encoding='utf-8') as f: - self._credentials = yaml.load(f) - import pdb; pdb.set_trace() - self.api_key = self._credentials['apiKey'] - self.headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json', - } - - def list_endpoint(self, endpoint: str, - **search_kwargs) -> List[Dict[str, Any]]: - items = [] - response = _try_request_with_backoff('get', - f'{API_ENDPOINT}/{endpoint}', - headers=self.headers, - data=search_kwargs) - items.extend(response['items']) - while response['hasMore']: - response = _try_request_with_backoff( - 'get', - f'{API_ENDPOINT}/{endpoint}', - headers=self.headers, - data={ - 'after': f'{response["nextPage"]}', - **search_kwargs - }) - items.extend(response['items']) - return items - - def list_startup_scripts( - self, name: Optional[str] = None) -> List[Dict[str, Any]]: - return self.list_endpoint( - endpoint='startup-scripts', - name=name, - ) - - def get_sky_key_script(self) -> str: - return self.list_startup_scripts(ADD_KEY_SCRIPT)[0]['id'] - - def set_sky_key_script(self, public_key: str) -> None: - script = ( - 'if ! command -v docker &> /dev/null; then \n' - 'apt-get update \n' - 'apt-get install -y ca-certificates curl \n' - 'install -m 0755 -d /etc/apt/keyrings \n' - 'curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc \n' # pylint: disable=line-too-long - 'chmod a+r /etc/apt/keyrings/docker.asc \n' - 'echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \\\n' # pylint: disable=line-too-long - '$(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \\\n' - 'tee /etc/apt/sources.list.d/docker.list > /dev/null \n' - 'apt-get update \n' - 'apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin \n' # pylint: disable=line-too-long - 'fi \n' - 'usermod -aG docker digitalocean \n' - f'echo "{public_key}" >> /home/digitalocean/.ssh/authorized_keys \n') - try: - script_id = self.get_sky_key_script() - _try_request_with_backoff( - 'put', - f'{API_ENDPOINT}/startup-scripts/{script_id}', - headers=self.headers, - data=json.dumps({ - 'name': ADD_KEY_SCRIPT + f'-{common_utils.get_user_hash()}', - 'script': script, - 'isRunOnce': True, - 'isEnabled': True - })) - except IndexError: - _try_request_with_backoff('post', - f'{API_ENDPOINT}/startup-scripts', - headers=self.headers, - data=json.dumps({ - 'name': ADD_KEY_SCRIPT, - 'script': script, - 'isRunOnce': True, - })) - - def get_network(self, network_name: str) -> Dict[str, Any]: - return self.list_endpoint( - endpoint='private-networks', - name=network_name, - )[0] - - def setup_network(self, cluster_name: str, region: str) -> Dict[str, Any]: - """Attempts to find an existing network with a name matching to - the cluster name otherwise create a new network. - """ - try: - network = self.get_network(network_name=cluster_name,) - except IndexError: - network = _try_request_with_backoff( - 'post', - f'{API_ENDPOINT}/private-networks', - headers=self.headers, - data=json.dumps({ - 'name': cluster_name, - 'region': region, - })) - return network - - def delete_network(self, network_id: str) -> Dict[str, Any]: - return _try_request_with_backoff( - 'delete', - f'{API_ENDPOINT}/private-networks/{network_id}', - headers=self.headers, - ) - - def list_instances(self) -> List[Dict[str, Any]]: - return self.list_endpoint(endpoint='machines') - - def launch(self, name: str, instance_type: str, network_id: str, - region: str, disk_size: int) -> Dict[str, Any]: - response = _try_request_with_backoff( - 'post', - f'{API_ENDPOINT}/machines', - headers=self.headers, - data=json.dumps({ - 'name': name, - 'machineType': instance_type, - 'networkId': network_id, - 'region': region, - 'diskSize': disk_size, - 'templateId': instance_type, - 'publicIpType': 'dynamic', - 'startupScriptId': self.get_sky_key_script(), - 'enableNvlink': instance_type, - 'startOnCreate': True, - })) - return response - - def start(self, instance_id: str) -> Dict[str, Any]: - return _try_request_with_backoff( - 'patch', - f'{API_ENDPOINT}/machines/{instance_id}/start', - headers={'Authorization': f'Bearer {self.api_key}'}, - ) - - def stop(self, instance_id: str) -> Dict[str, Any]: - return _try_request_with_backoff( - 'patch', - f'{API_ENDPOINT}/machines/{instance_id}/stop', - headers={'Authorization': f'Bearer {self.api_key}'}, - ) - - def remove(self, instance_id: str) -> Dict[str, Any]: - return _try_request_with_backoff( - 'delete', - f'{API_ENDPOINT}/machines/{instance_id}', - headers=self.headers, - ) - - def rename(self, instance_id: str, name: str) -> Dict[str, Any]: - return _try_request_with_backoff( - 'put', - f'{API_ENDPOINT}/machines/{instance_id}', - headers=self.headers, - data=json.dumps({ - 'name': name, - })) +def init_client() -> None: + global client + + credentials_path = None + for path in POSSIBLE_CREDENTIALS: + path = os.path.expanduser(path) + if os.path.exists(path): + credentials_path = path + break + + if credentials_path is None: + raise ValueError(f'no valid DigitalOcean config found from {POSSIBLE_CREDENTIALS}') + logger.debug(f"using DigitalOcean config located at {credentials_path}") + with open(credentials_path, 'r', encoding='utf-8') as f: + credentials: Dict[str, Any] = yaml.safe_load(f) + api_token = credentials['auth-contexts']['skypilot'] + client = pydo.Client( + token=api_token, + retry_policy=policies.RetryPolicy( + retry_total=MAX_ATTEMPTS, + retry_backoff_factor=MAX_BACKOFF_FACTOR + )) + + +def get_or_create_ssh_keys(public_key: str) -> None: + ssh_keys = client.ssh_keys.list() + for key in ssh_keys['ssh_keys']: + if key['name'] == SSH_KEY_NAME: + return + + pydo.ssh_keys.create({ + 'public_key' : public_key, + 'name' : SSH_KEY_NAME, + }) + +init_client() + \ No newline at end of file diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 604060c68ae..339e2ff0f97 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -236,6 +236,7 @@ def parse_readme(readme: str) -> str: 'runpod': ['runpod>=1.5.1'], 'fluidstack': [], # No dependencies needed for fluidstack 'cudo': ['cudo-compute>=0.1.10'], + 'do': ['pydo>=0.4.0'], 'paperspace': [], # No dependencies needed for paperspace 'vsphere': [ 'pyvmomi==8.0.1.0.2',