Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Docker] Support docker login on RunPod. #4287

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/getting-started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ RunPod

.. code-block:: shell

pip install "runpod>=1.5.1"
pip install "runpod>=1.6.1"
runpod config


Expand Down
11 changes: 11 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,17 @@ def _retry_zones(
config_dict['provision_record'] = provision_record
config_dict['resources_vars'] = resources_vars
config_dict['handle'] = handle
if provision_record.ephemeral_resources:
# Some ephemeral resources are created during the launch
# process. Add them to the provider config so that they
# can be cleaned up later.
original_config_content = common_utils.read_yaml(
cluster_config_file)
original_config_content['provider'][
'ephemeral_resources'] = (
provision_record.ephemeral_resources)
common_utils.dump_yaml(cluster_config_file,
original_config_content)
return config_dict
except provision_common.StopFailoverError:
with ux_utils.print_exception_no_traceback():
Expand Down
1 change: 1 addition & 0 deletions sky/provision/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ProvisionRecord:
resumed_instance_ids: List[InstanceId]
# The IDs of all just created instances.
created_instance_ids: List[InstanceId]
ephemeral_resources: List[Any] = dataclasses.field(default_factory=list)

def is_instance_just_booted(self, instance_id: InstanceId) -> bool:
"""Whether or not the instance is just booted.
Expand Down
11 changes: 8 additions & 3 deletions sky/provision/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ class DockerLoginConfig:
password: str
server: str

def format_image(self, image: str) -> str:
"""Format the image name with the server prefix."""
server_prefix = f'{self.server}/'
if not image.startswith(server_prefix):
return f'{server_prefix}{image}'
return image

@classmethod
def from_env_vars(cls, d: Dict[str, str]) -> 'DockerLoginConfig':
return cls(
Expand Down Expand Up @@ -220,9 +227,7 @@ def initialize(self) -> str:
wait_for_docker_daemon=True)
# We automatically add the server prefix to the image name if
# the user did not add it.
server_prefix = f'{docker_login_config.server}/'
if not specific_image.startswith(server_prefix):
specific_image = f'{server_prefix}{specific_image}'
specific_image = docker_login_config.format_image(specific_image)

if self.docker_config.get('pull_before_run', True):
assert specific_image, ('Image must be included in config if ' +
Expand Down
23 changes: 19 additions & 4 deletions sky/provision/runpod/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ def run_instances(region: str, cluster_name_on_cloud: str,
created_instance_ids=[])

created_instance_ids = []
ephemeral_resources = []
for _ in range(to_start_count):
node_type = 'head' if head_instance_id is None else 'worker'
try:
instance_id = utils.launch(
name=f'{cluster_name_on_cloud}-{node_type}',
instance_id, ers = utils.launch(
cluster_name=cluster_name_on_cloud,
node_type=node_type,
instance_type=config.node_config['InstanceType'],
region=region,
disk_size=config.node_config['DiskSize'],
Expand All @@ -92,7 +94,12 @@ def run_instances(region: str, cluster_name_on_cloud: str,
public_key=config.node_config['PublicKey'],
preemptible=config.node_config['Preemptible'],
bid_per_gpu=config.node_config['BidPerGPU'],
docker_login_config=config.provider_config.get(
'docker_login_config'),
)
for er in ers:
if er is not None:
ephemeral_resources.append(er)
except Exception as e: # pylint: disable=broad-except
logger.warning(f'run_instances error: {e}')
raise
Expand Down Expand Up @@ -121,7 +128,8 @@ def run_instances(region: str, cluster_name_on_cloud: str,
zone=None,
head_instance_id=head_instance_id,
resumed_instance_ids=[],
created_instance_ids=created_instance_ids)
created_instance_ids=created_instance_ids,
ephemeral_resources=ephemeral_resources)


def wait_instances(region: str, cluster_name_on_cloud: str,
Expand All @@ -143,7 +151,8 @@ def terminate_instances(
worker_only: bool = False,
) -> None:
"""See sky/provision/__init__.py"""
del provider_config # unused
assert provider_config is not None, (cluster_name_on_cloud, provider_config)
ephemeral_resources = provider_config.get('ephemeral_resources', [])
instances = _filter_instances(cluster_name_on_cloud, None)
for inst_id, inst in instances.items():
logger.debug(f'Terminating instance {inst_id}: {inst}')
Expand All @@ -157,6 +166,12 @@ def terminate_instances(
f'Failed to terminate instance {inst_id}: '
f'{common_utils.format_exception(e, use_bracket=False)}'
) from e
if ephemeral_resources:
# See sky/provision/runpod/utils.py::launch for details
assert len(ephemeral_resources) == 2, ephemeral_resources
template_name, registry_auth_id = ephemeral_resources
utils.delete_pod_template(template_name)
utils.delete_register_auth(registry_auth_id)


def get_cluster_info(
Expand Down
100 changes: 87 additions & 13 deletions sky/provision/runpod/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import base64
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from sky import sky_logging
from sky.adaptors import runpod
from sky.provision import docker_utils
import sky.provision.runpod.api.commands as runpod_commands
from sky.skylet import constants
from sky.utils import common_utils
Expand Down Expand Up @@ -100,14 +101,82 @@ def list_instances() -> Dict[str, Dict[str, Any]]:
return instance_dict


def launch(name: str, instance_type: str, region: str, disk_size: int,
image_name: str, ports: Optional[List[int]], public_key: str,
preemptible: Optional[bool], bid_per_gpu: float) -> str:
def delete_pod_template(template_name: str) -> None:
"""Deletes a pod template."""
try:
runpod.runpod.api.graphql.run_graphql_query(
f'mutation {{deleteTemplate(templateName: "{template_name}")}}')
except runpod.runpod.error.QueryError as e:
logger.warning(f'Failed to delete template {template_name}: {e}'
'Please delete it manually.')


def delete_register_auth(registry_auth_id: str) -> None:
"""Deletes a registry auth."""
try:
runpod.runpod.delete_container_registry_auth(registry_auth_id)
except runpod.runpod.error.QueryError as e:
logger.warning(f'Failed to delete registry auth {registry_auth_id}: {e}'
'Please delete it manually.')


def _create_template_for_docker_login(
cluster_name: str,
image_name: str,
docker_login_config: Optional[Dict[str, str]],
) -> Tuple[str, Optional[str], Optional[str], Optional[str]]:
"""Creates a template for the given image with the docker login config.

Returns:
formatted_image_name: The formatted image name.
# following fields are None for no docker login config.
template_id: The template ID.
template_name: The template name.
registry_auth_id: The registry auth ID.
"""
if docker_login_config is None:
return image_name, None, None, None
login_config = docker_utils.DockerLoginConfig(**docker_login_config)
container_registry_auth_name = f'{cluster_name}-registry-auth'
container_template_name = f'{cluster_name}-docker-login-template'
# The `name` argument is only for display purpose and the registry server
# will be splitted from the docker image name (Tested with AWS ECR).
# Here we only need the username and password to create the registry auth.
# TODO(tian): RunPod python API does not provide a way to get the registry
# auth and template ID by the name, and the only way to get the ID is when
# we create it. So we use a separate auth and template per cluster. This
# also assumes that every cluster has only one node, so no extra worker
# nodes will reuse the same auth and template name.
create_auth_resp = runpod.runpod.create_container_registry_auth(
name=container_registry_auth_name,
username=login_config.username,
password=login_config.password,
)
registry_auth_id = create_auth_resp['id']
create_template_resp = runpod.runpod.create_template(
name=container_template_name,
image_name=None,
registry_auth_id=registry_auth_id,
)
return (login_config.format_image(image_name), create_template_resp['id'],
container_template_name, registry_auth_id)


def launch(
cluster_name: str, node_type: str, instance_type: str, region: str,
disk_size: int, image_name: str, ports: Optional[List[int]],
public_key: str, preemptible: Optional[bool], bid_per_gpu: float,
docker_login_config: Optional[Dict[str, str]]) -> Tuple[str, List[Any]]:
"""Launches an instance with the given parameters.

Converts the instance_type to the RunPod GPU name, finds the specs for the
GPU, and launches the instance.

Returns:
instance_id: The instance ID.
ephemeral_resources: A list of ephemeral resources.
"""
name = f'{cluster_name}-{node_type}'
gpu_type = GPU_NAME_MAP[instance_type.split('_')[1]]
gpu_quantity = int(instance_type.split('_')[0].replace('x', ''))
cloud_type = instance_type.split('_')[2]
Expand Down Expand Up @@ -139,31 +208,36 @@ def launch(name: str, instance_type: str, region: str, disk_size: int,
# Use base64 to deal with the tricky quoting issues caused by runpod API.
encoded = base64.b64encode(setup_cmd.encode('utf-8')).decode('utf-8')

docker_args = (f'bash -c \'echo {encoded} | base64 --decode > init.sh; '
f'bash init.sh\'')

# Port 8081 is occupied for nginx in the base image.
custom_ports_str = ''
if ports is not None:
custom_ports_str = ''.join([f'{p}/tcp,' for p in ports])
ports_str = (f'22/tcp,'
f'{custom_ports_str}'
f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,'
f'{constants.SKY_REMOTE_RAY_PORT}/http')

docker_args = (f'bash -c \'echo {encoded} | base64 --decode > init.sh; '
f'bash init.sh\'')
ports = (f'22/tcp,'
f'{custom_ports_str}'
f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,'
f'{constants.SKY_REMOTE_RAY_PORT}/http')
image_name_formatted, template_id, template_name, registry_auth_id = (
_create_template_for_docker_login(cluster_name, image_name,
docker_login_config))

params = {
'name': name,
'image_name': image_name,
'image_name': image_name_formatted,
'gpu_type_id': gpu_type,
'cloud_type': cloud_type,
'container_disk_in_gb': disk_size,
'min_vcpu_count': 4 * gpu_quantity,
'min_memory_in_gb': gpu_specs['memoryInGb'] * gpu_quantity,
'gpu_count': gpu_quantity,
'country_code': region,
'ports': ports,
'ports': ports_str,
'support_public_ip': True,
'docker_args': docker_args,
'template_id': template_id,
}

if preemptible is None or not preemptible:
Expand All @@ -174,7 +248,7 @@ def launch(name: str, instance_type: str, region: str, disk_size: int,
**params,
)

return new_instance['id']
return new_instance['id'], [template_name, registry_auth_id]


def remove(instance_id: str) -> None:
Expand Down
4 changes: 3 additions & 1 deletion sky/setup_files/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@
'oci': ['oci'] + local_ray,
'kubernetes': ['kubernetes>=20.0.0'],
'remote': remote,
'runpod': ['runpod>=1.5.1'],
# For the container registry auth api. Reference:
# https://github.com/runpod/runpod-python/releases/tag/1.6.1
'runpod': ['runpod>=1.6.1'],
'fluidstack': [], # No dependencies needed for fluidstack
'cudo': ['cudo-compute>=0.1.10'],
'paperspace': [], # No dependencies needed for paperspace
Expand Down
12 changes: 5 additions & 7 deletions sky/skylet/providers/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def docker_start_cmds(
docker_cmd,
):
"""Generating docker start command without --rm.

The code is borrowed from `ray.autoscaler._private.docker`.

Changes we made:
Expand Down Expand Up @@ -159,19 +159,17 @@ def run_init(self, *, as_head: bool, file_mounts: Dict[str, str],
return True

# SkyPilot: Docker login if user specified a private docker registry.
if "docker_login_config" in self.docker_config:
if 'docker_login_config' in self.docker_config:
# TODO(tian): Maybe support a command to get the login password?
docker_login_config: docker_utils.DockerLoginConfig = self.docker_config[
"docker_login_config"]
docker_login_config: docker_utils.DockerLoginConfig = (
self.docker_config['docker_login_config'])
self._run_with_retry(
f'{self.docker_cmd} login --username '
f'{docker_login_config.username} --password '
f'{docker_login_config.password} {docker_login_config.server}')
# We automatically add the server prefix to the image name if
# the user did not add it.
server_prefix = f'{docker_login_config.server}/'
if not specific_image.startswith(server_prefix):
specific_image = f'{server_prefix}{specific_image}'
specific_image = docker_login_config.format_image(specific_image)

if self.docker_config.get('pull_before_run', True):
assert specific_image, ('Image must be included in config if '
Expand Down
13 changes: 13 additions & 0 deletions sky/templates/runpod-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ provider:
module: sky.provision.runpod
region: "{{region}}"
disable_launch_config_check: true
# For RunPod, we directly set the image id for the docker as runtime environment
# support, thus we need to avoid the DockerInitializer detects the docker field
# and performs the initialization. Therefore we put the docker login config in
# the provider config here.
{%- if docker_login_config is not none %}
docker_login_config:
username: |-
{{docker_login_config.username}}
password: |-
{{docker_login_config.password}}
server: |-
{{docker_login_config.server}}
{%- endif %}

auth:
ssh_user: root
Expand Down
Loading