Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/skypilot-org/skypilot int…
Browse files Browse the repository at this point in the history
…o azure-query-status
  • Loading branch information
Michaelvll committed Jul 1, 2024
2 parents 6c92072 + d3c1f8c commit 08f6789
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 71 deletions.
37 changes: 36 additions & 1 deletion sky/adaptors/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

# pylint: disable=import-outside-toplevel

import logging
import os

from sky.adaptors import common
from sky.sky_logging import set_logging_level
from sky.utils import env_options
from sky.utils import ux_utils

Expand All @@ -28,6 +30,33 @@
API_TIMEOUT = 5


def _decorate_methods(obj, decorator):
for attr_name in dir(obj):
attr = getattr(obj, attr_name)
if callable(attr) and not attr_name.startswith('__'):
setattr(obj, attr_name, decorator(attr))
return obj


def _api_logging_decorator(logger: str, level: int):
"""Decorator to set logging level for API calls.
This is used to suppress the verbose logging from urllib3 when calls to the
Kubernetes API timeout.
"""

def decorated_api(api):

def wrapped(*args, **kwargs):
obj = api(*args, **kwargs)
_decorate_methods(obj, set_logging_level(logger, level))
return obj

return wrapped

return decorated_api


def _load_config():
global _configured
if _configured:
Expand Down Expand Up @@ -65,15 +94,16 @@ def _load_config():
_configured = True


@_api_logging_decorator('urllib3', logging.ERROR)
def core_api():
global _core_api
if _core_api is None:
_load_config()
_core_api = kubernetes.client.CoreV1Api()

return _core_api


@_api_logging_decorator('urllib3', logging.ERROR)
def auth_api():
global _auth_api
if _auth_api is None:
Expand All @@ -83,6 +113,7 @@ def auth_api():
return _auth_api


@_api_logging_decorator('urllib3', logging.ERROR)
def networking_api():
global _networking_api
if _networking_api is None:
Expand All @@ -92,6 +123,7 @@ def networking_api():
return _networking_api


@_api_logging_decorator('urllib3', logging.ERROR)
def custom_objects_api():
global _custom_objects_api
if _custom_objects_api is None:
Expand All @@ -101,6 +133,7 @@ def custom_objects_api():
return _custom_objects_api


@_api_logging_decorator('urllib3', logging.ERROR)
def node_api():
global _node_api
if _node_api is None:
Expand All @@ -110,6 +143,7 @@ def node_api():
return _node_api


@_api_logging_decorator('urllib3', logging.ERROR)
def apps_api():
global _apps_api
if _apps_api is None:
Expand All @@ -119,6 +153,7 @@ def apps_api():
return _apps_api


@_api_logging_decorator('urllib3', logging.ERROR)
def api_client():
global _api_client
if _api_client is None:
Expand Down
41 changes: 25 additions & 16 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,29 +439,38 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
f'Key {secret_name} does not exist in the cluster, creating it...')
kubernetes.core_api().create_namespaced_secret(namespace, secret)

ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME
private_key_path, _ = get_or_generate_keys()
if network_mode == nodeport_mode:
ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME
service_type = kubernetes_enums.KubernetesServiceType.NODEPORT
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace,
service_type)
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_jump_name,
nodeport_mode,
private_key_path=private_key_path,
namespace=namespace)
elif network_mode == port_forward_mode:
# Using `kubectl port-forward` creates a direct tunnel to the pod and
# does not require a ssh jump pod.
kubernetes_utils.check_port_forward_mode_dependencies()
# Using `kubectl port-forward` creates a direct tunnel to jump pod and
# does not require opening any ports on Kubernetes nodes. As a result,
# the service can be a simple ClusterIP service which we access with
# `kubectl port-forward`.
service_type = kubernetes_enums.KubernetesServiceType.CLUSTERIP
# TODO(romilb): This can be further optimized. Instead of using the
# head node as a jump pod for worker nodes, we can also directly
# set the ssh_target to the worker node. However, that requires
# changes in the downstream code to return a mapping of node IPs to
# pod names (to be used as ssh_target) and updating the upstream
# SSHConfigHelper to use a different ProxyCommand for each pod.
# This optimization can reduce SSH time from ~0.35s to ~0.25s, tested
# on GKE.
ssh_target = config['cluster_name'] + '-head'
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_target, port_forward_mode, private_key_path=private_key_path)
else:
# This should never happen because we check for this in from_str above.
raise ValueError(f'Unsupported networking mode: {network_mode_str}')
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, service_type)

ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
PRIVATE_SSH_KEY_PATH, ssh_jump_name, network_mode, namespace,
clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_PATH,
clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_TEMPLATE)

config['auth']['ssh_proxy_command'] = ssh_proxy_cmd

return config
Expand Down
6 changes: 6 additions & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,12 @@ def ssh_credential_from_yaml(
ssh_private_key = auth_section.get('ssh_private_key')
ssh_control_name = config.get('cluster_name', '__default__')
ssh_proxy_command = auth_section.get('ssh_proxy_command')

# Update the ssh_user placeholder in proxy command, if required
if (ssh_proxy_command is not None and
constants.SKY_SSH_USER_PLACEHOLDER in ssh_proxy_command):
ssh_proxy_command = ssh_proxy_command.replace(
constants.SKY_SSH_USER_PLACEHOLDER, ssh_user)
credentials = {
'ssh_user': ssh_user,
'ssh_private_key': ssh_private_key,
Expand Down
5 changes: 4 additions & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3065,7 +3065,10 @@ def _update_after_cluster_provisioned(
)
usage_lib.messages.usage.update_final_cluster_status(
status_lib.ClusterStatus.UP)
auth_config = common_utils.read_yaml(handle.cluster_yaml)['auth']
auth_config = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml,
ssh_user=handle.ssh_user,
docker_user=handle.docker_user)
backend_utils.SSHConfigHelper.add_cluster(handle.cluster_name,
ip_list, auth_config,
ssh_port_list,
Expand Down
4 changes: 1 addition & 3 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class Kubernetes(clouds.Cloud):

SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys'
SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod'
PORT_FORWARD_PROXY_CMD_TEMPLATE = \
'kubernetes-port-forward-proxy-command.sh.j2'
PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh'
# Timeout for resource provisioning. This timeout determines how long to
# wait for pod to be in pending status before giving up.
# Larger timeout may be required for autoscaling clusters, since autoscaler
Expand Down Expand Up @@ -323,6 +320,7 @@ def make_deploy_resources_variables(
'k8s_namespace':
kubernetes_utils.get_current_kube_config_context_namespace(),
'k8s_port_mode': port_mode.value,
'k8s_networking_mode': network_utils.get_networking_mode().value,
'k8s_ssh_key_secret_name': self.SKY_SSH_KEY_SECRET_NAME,
'k8s_acc_label_key': k8s_acc_label_key,
'k8s_acc_label_value': k8s_acc_label_value,
Expand Down
7 changes: 6 additions & 1 deletion sky/provision/kubernetes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from sky.adaptors import kubernetes
from sky.provision import common
from sky.provision.kubernetes import network_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.utils import kubernetes_enums

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +27,10 @@ def bootstrap_instances(

_configure_services(namespace, config.provider_config)

config = _configure_ssh_jump(namespace, config)
networking_mode = network_utils.get_networking_mode(
config.provider_config.get('networking_mode'))
if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
config = _configure_ssh_jump(namespace, config)

requested_service_account = config.node_config['spec']['serviceAccountName']
if (requested_service_account ==
Expand Down
17 changes: 11 additions & 6 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sky.provision import common
from sky.provision import docker_utils
from sky.provision.kubernetes import config as config_lib
from sky.provision.kubernetes import network_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.utils import command_runner
from sky.utils import common_utils
Expand Down Expand Up @@ -495,14 +496,18 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
if head_pod_name is None:
head_pod_name = pod.metadata.name

# Adding the jump pod to the new_nodes list as well so it can be
# checked if it's scheduled and running along with other pods.
ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump']
jump_pod = kubernetes.core_api().read_namespaced_pod(
ssh_jump_pod_name, namespace)
wait_pods_dict = _filter_pods(namespace, tags, ['Pending'])
wait_pods = list(wait_pods_dict.values())
wait_pods.append(jump_pod)

networking_mode = network_utils.get_networking_mode(
config.provider_config.get('networking_mode'))
if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
# Adding the jump pod to the new_nodes list as well so it can be
# checked if it's scheduled and running along with other pods.
ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump']
jump_pod = kubernetes.core_api().read_namespaced_pod(
ssh_jump_pod_name, namespace)
wait_pods.append(jump_pod)
provision_timeout = provider_config['timeout']

wait_str = ('indefinitely'
Expand Down
17 changes: 17 additions & 0 deletions sky/provision/kubernetes/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def get_port_mode(
return port_mode


def get_networking_mode(
mode_str: Optional[str] = None
) -> kubernetes_enums.KubernetesNetworkingMode:
"""Get the networking mode from the provider config."""
mode_str = mode_str or skypilot_config.get_nested(
('kubernetes', 'networking_mode'),
kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value)
try:
networking_mode = kubernetes_enums.KubernetesNetworkingMode.from_str(
mode_str)
except ValueError as e:
with ux_utils.print_exception_no_traceback():
raise ValueError(str(e) +
' Please check: ~/.sky/config.yaml.') from None
return networking_mode


def fill_loadbalancer_template(namespace: str, service_name: str,
ports: List[int], selector_key: str,
selector_value: str) -> Dict:
Expand Down
Loading

0 comments on commit 08f6789

Please sign in to comment.