Skip to content

Commit

Permalink
[k8s] Support for switching k8s contexts (#3913)
Browse files Browse the repository at this point in the history
* wip

* wip

* fix rsync

* fix user identity checks

* newline

* lint and tests

* fix

* typing

* fix

* lint

* Update identity logic

* update node_id

* fix context/namespace passing to helper scripts

* lint

* backward compatibility

* backward compatibility

* lint

* Add k8s logging

* comments

* comments

* comments

* comments

* newline
  • Loading branch information
romilbhardwaj authored Sep 11, 2024
1 parent 0de763c commit bad7dab
Show file tree
Hide file tree
Showing 30 changed files with 525 additions and 334 deletions.
99 changes: 32 additions & 67 deletions sky/adaptors/kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Kubernetes adaptors"""

# pylint: disable=import-outside-toplevel

import functools
import logging
import os
from typing import Any, Callable, Set
from typing import Any, Callable, Optional, Set

from sky.adaptors import common
from sky.sky_logging import set_logging_level
Expand All @@ -18,15 +16,6 @@
urllib3 = common.LazyImport('urllib3',
import_error_message=_IMPORT_ERROR_MESSAGE)

_configured = False
_core_api = None
_auth_api = None
_networking_api = None
_custom_objects_api = None
_node_api = None
_apps_api = None
_api_client = None

# Timeout to use for API calls
API_TIMEOUT = 5

Expand Down Expand Up @@ -66,10 +55,7 @@ def wrapped(*args, **kwargs):
return decorated_api


def _load_config():
global _configured
if _configured:
return
def _load_config(context: Optional[str] = None):
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
try:
# Load in-cluster config if running in a pod
Expand All @@ -82,7 +68,7 @@ def _load_config():
kubernetes.config.load_incluster_config()
except kubernetes.config.config_exception.ConfigException:
try:
kubernetes.config.load_kube_config()
kubernetes.config.load_kube_config(context=context)
except kubernetes.config.config_exception.ConfigException as e:
suffix = ''
if env_options.Options.SHOW_DEBUG_INFO.get():
Expand All @@ -101,76 +87,55 @@ def _load_config():
err_str += '\nTo disable Kubernetes for SkyPilot: run `sky check`.'
with ux_utils.print_exception_no_traceback():
raise ValueError(err_str) from None
_configured = True


@_api_logging_decorator('urllib3', logging.ERROR)
def core_api():
global _core_api
if _core_api is None:
_load_config()
_core_api = kubernetes.client.CoreV1Api()
return _core_api
@functools.lru_cache()
def core_api(context: Optional[str] = None):
_load_config(context)
return kubernetes.client.CoreV1Api()


@_api_logging_decorator('urllib3', logging.ERROR)
def auth_api():
global _auth_api
if _auth_api is None:
_load_config()
_auth_api = kubernetes.client.RbacAuthorizationV1Api()

return _auth_api
@functools.lru_cache()
def auth_api(context: Optional[str] = None):
_load_config(context)
return kubernetes.client.RbacAuthorizationV1Api()


@_api_logging_decorator('urllib3', logging.ERROR)
def networking_api():
global _networking_api
if _networking_api is None:
_load_config()
_networking_api = kubernetes.client.NetworkingV1Api()

return _networking_api
@functools.lru_cache()
def networking_api(context: Optional[str] = None):
_load_config(context)
return kubernetes.client.NetworkingV1Api()


@_api_logging_decorator('urllib3', logging.ERROR)
def custom_objects_api():
global _custom_objects_api
if _custom_objects_api is None:
_load_config()
_custom_objects_api = kubernetes.client.CustomObjectsApi()

return _custom_objects_api
@functools.lru_cache()
def custom_objects_api(context: Optional[str] = None):
_load_config(context)
return kubernetes.client.CustomObjectsApi()


@_api_logging_decorator('urllib3', logging.ERROR)
def node_api():
global _node_api
if _node_api is None:
_load_config()
_node_api = kubernetes.client.NodeV1Api()

return _node_api
@functools.lru_cache()
def node_api(context: Optional[str] = None):
_load_config(context)
return kubernetes.client.NodeV1Api()


@_api_logging_decorator('urllib3', logging.ERROR)
def apps_api():
global _apps_api
if _apps_api is None:
_load_config()
_apps_api = kubernetes.client.AppsV1Api()

return _apps_api
@functools.lru_cache()
def apps_api(context: Optional[str] = None):
_load_config(context)
return kubernetes.client.AppsV1Api()


@_api_logging_decorator('urllib3', logging.ERROR)
def api_client():
global _api_client
if _api_client is None:
_load_config()
_api_client = kubernetes.client.ApiClient()

return _api_client
@functools.lru_cache()
def api_client(context: Optional[str] = None):
_load_config(context)
return kubernetes.client.ApiClient()


def api_exception():
Expand Down
19 changes: 12 additions & 7 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,11 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name
namespace = kubernetes_utils.get_current_kube_config_context_namespace()
namespace = config['provider'].get(
'namespace',
kubernetes_utils.get_current_kube_config_context_namespace())
context = config['provider'].get(
'context', kubernetes_utils.get_current_kube_config_context_name())
k8s = kubernetes.kubernetes
with open(public_key_path, 'r', encoding='utf-8') as f:
public_key = f.read()
Expand All @@ -399,14 +403,14 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
secret = k8s.client.V1Secret(
metadata=k8s.client.V1ObjectMeta(**secret_metadata),
string_data={secret_field_name: public_key})
if kubernetes_utils.check_secret_exists(secret_name, namespace):
if kubernetes_utils.check_secret_exists(secret_name, namespace, context):
logger.debug(f'Key {secret_name} exists in the cluster, patching it...')
kubernetes.core_api().patch_namespaced_secret(secret_name, namespace,
secret)
kubernetes.core_api(context).patch_namespaced_secret(
secret_name, namespace, secret)
else:
logger.debug(
f'Key {secret_name} does not exist in the cluster, creating it...')
kubernetes.core_api().create_namespaced_secret(namespace, secret)
kubernetes.core_api(context).create_namespaced_secret(namespace, secret)

private_key_path, _ = get_or_generate_keys()
if network_mode == nodeport_mode:
Expand All @@ -415,13 +419,14 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace,
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, context,
service_type)
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_jump_name,
nodeport_mode,
private_key_path=private_key_path,
namespace=namespace)
namespace=namespace,
context=context)
elif network_mode == port_forward_mode:
# Using `kubectl port-forward` creates a direct tunnel to the pod and
# does not require a ssh jump pod.
Expand Down
73 changes: 40 additions & 33 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,58 +1558,65 @@ def check_owner_identity(cluster_name: str) -> None:
return

cloud = handle.launched_resources.cloud
current_user_identity = cloud.get_current_user_identity()
user_identities = cloud.get_user_identities()
owner_identity = record['owner']
if current_user_identity is None:
if user_identities is None:
# Skip the check if the cloud does not support user identity.
return
# The user identity can be None, if the cluster is created by an older
# version of SkyPilot. In that case, we set the user identity to the
# current one.
# current active one.
# NOTE: a user who upgrades SkyPilot and switches to a new cloud identity
# immediately without `sky status --refresh` first, will cause a leakage
# of the existing cluster. We deem this an acceptable tradeoff mainly
# because multi-identity is not common (at least at the moment).
if owner_identity is None:
global_user_state.set_owner_identity_for_cluster(
cluster_name, current_user_identity)
cluster_name, user_identities[0])
else:
assert isinstance(owner_identity, list)
# It is OK if the owner identity is shorter, which will happen when
# the cluster is launched before #1808. In that case, we only check
# the same length (zip will stop at the shorter one).
for i, (owner,
current) in enumerate(zip(owner_identity,
current_user_identity)):
# Clean up the owner identity for the backslash and newlines, caused
# by the cloud CLI output, e.g. gcloud.
owner = owner.replace('\n', '').replace('\\', '')
if owner == current:
if i != 0:
logger.warning(
f'The cluster was owned by {owner_identity}, but '
f'a new identity {current_user_identity} is activated. We still '
'allow the operation as the two identities are likely to have '
'the same access to the cluster. Please be aware that this can '
'cause unexpected cluster leakage if the two identities are not '
'actually equivalent (e.g., belong to the same person).'
)
if i != 0 or len(owner_identity) != len(current_user_identity):
# We update the owner of a cluster, when:
# 1. The strictest identty (i.e. the first one) does not
# match, but the latter ones match.
# 2. The length of the two identities are different, which
# will only happen when the cluster is launched before #1808.
# Update the user identity to avoid showing the warning above
# again.
global_user_state.set_owner_identity_for_cluster(
cluster_name, current_user_identity)
return # The user identity matches.
for identity in user_identities:
for i, (owner, current) in enumerate(zip(owner_identity, identity)):
# Clean up the owner identity for the backslash and newlines, caused
# by the cloud CLI output, e.g. gcloud.
owner = owner.replace('\n', '').replace('\\', '')
if owner == current:
if i != 0:
logger.warning(
f'The cluster was owned by {owner_identity}, but '
f'a new identity {identity} is activated. We still '
'allow the operation as the two identities are '
'likely to have the same access to the cluster. '
'Please be aware that this can cause unexpected '
'cluster leakage if the two identities are not '
'actually equivalent (e.g., belong to the same '
'person).')
if i != 0 or len(owner_identity) != len(identity):
# We update the owner of a cluster, when:
# 1. The strictest identty (i.e. the first one) does not
# match, but the latter ones match.
# 2. The length of the two identities are different,
# which will only happen when the cluster is launched
# before #1808. Update the user identity to avoid
# showing the warning above again.
global_user_state.set_owner_identity_for_cluster(
cluster_name, identity)
return # The user identity matches.
# Generate error message if no match found
if len(user_identities) == 1:
err_msg = f'the activated identity is {user_identities[0]!r}.'
else:
err_msg = (f'available identities are {user_identities!r}.')
if cloud.is_same_cloud(clouds.Kubernetes()):
err_msg += (' Check your kubeconfig file and make sure the '
'correct context is available.')
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterOwnerIdentityMismatchError(
f'{cluster_name!r} ({cloud}) is owned by account '
f'{owner_identity!r}, but the activated account '
f'is {current_user_identity!r}.')
f'{owner_identity!r}, but ' + err_msg)


def tag_filter_for_cluster(cluster_name: str) -> Dict[str, str]:
Expand Down
2 changes: 1 addition & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,7 @@ def provision_with_retries(
if dryrun:
cloud_user = None
else:
cloud_user = to_provision.cloud.get_current_user_identity()
cloud_user = to_provision.cloud.get_active_user_identity()

requested_features = self._requested_features.copy()
# Skip stop feature for Kubernetes and RunPod controllers.
Expand Down
2 changes: 1 addition & 1 deletion sky/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def check_one_cloud(
if ok:
enabled_clouds.append(cloud_repr)
if verbose and cloud is not cloudflare:
activated_account = cloud.get_current_user_identity_str()
activated_account = cloud.get_active_user_identity_str()
if activated_account is not None:
echo(f' Activated account: {activated_account}')
if reason is not None:
Expand Down
14 changes: 8 additions & 6 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
# Checks if AWS credentials 1) exist and 2) are valid.
# https://stackoverflow.com/questions/53548737/verify-aws-credentials-with-boto3
try:
identity_str = cls.get_current_user_identity_str()
identity_str = cls.get_active_user_identity_str()
except exceptions.CloudUserIdentityError as e:
return False, str(e)

Expand Down Expand Up @@ -584,7 +584,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
else:
# This file is required because it is required by the VMs launched on
# other clouds to access private s3 buckets and resources like EC2.
# `get_current_user_identity` does not guarantee this file exists.
# `get_active_user_identity` does not guarantee this file exists.
if not static_credential_exists:
return (False, '~/.aws/credentials does not exist. ' +
cls._STATIC_CREDENTIAL_HELP_STR)
Expand Down Expand Up @@ -648,7 +648,7 @@ def _is_access_key_of_type(type_str: str) -> bool:
return AWSIdentityType.SHARED_CREDENTIALS_FILE

@classmethod
def get_current_user_identity(cls) -> Optional[List[str]]:
def get_user_identities(cls) -> Optional[List[List[str]]]:
"""Returns a [UserId, Account] list that uniquely identifies the user.
These fields come from `aws sts get-caller-identity`. We permit the same
Expand Down Expand Up @@ -752,11 +752,13 @@ def get_current_user_identity(cls) -> Optional[List[str]]:
f'Failed to get AWS user.\n'
f' Reason: {common_utils.format_exception(e, use_bracket=True)}.'
) from None
return user_ids
# TODO: Return a list of identities in the profile when we support
# automatic switching for AWS. Currently we only support one identity.
return [user_ids]

@classmethod
def get_current_user_identity_str(cls) -> Optional[str]:
user_identity = cls.get_current_user_identity()
def get_active_user_identity_str(cls) -> Optional[str]:
user_identity = cls.get_active_user_identity()
if user_identity is None:
return None
identity_str = f'{user_identity[0]} [account={user_identity[1]}]'
Expand Down
12 changes: 7 additions & 5 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
# If Azure is properly logged in, this will return the account email
# address + subscription ID.
try:
cls.get_current_user_identity()
cls.get_active_user_identity()
except exceptions.CloudUserIdentityError as e:
return False, (f'Getting user\'s Azure identity failed.{help_str}\n'
f'{cls._INDENT_PREFIX}Details: '
Expand Down Expand Up @@ -516,7 +516,7 @@ def instance_type_exists(self, instance_type):

@classmethod
@functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
def get_current_user_identity(cls) -> Optional[List[str]]:
def get_user_identities(cls) -> Optional[List[List[str]]]:
"""Returns the cloud user identity."""
# This returns the user's email address + [subscription_id].
retry_cnt = 0
Expand Down Expand Up @@ -558,11 +558,13 @@ def get_current_user_identity(cls) -> Optional[List[str]]:
with ux_utils.print_exception_no_traceback():
raise exceptions.CloudUserIdentityError(
'Failed to get Azure project ID.') from e
return [f'{account_email} [subscription_id={project_id}]']
# TODO: Return a list of identities in the profile when we support
# automatic switching for Az. Currently we only support one identity.
return [[f'{account_email} [subscription_id={project_id}]']]

@classmethod
def get_current_user_identity_str(cls) -> Optional[str]:
user_identity = cls.get_current_user_identity()
def get_active_user_identity_str(cls) -> Optional[str]:
user_identity = cls.get_active_user_identity()
if user_identity is None:
return None
return user_identity[0]
Expand Down
Loading

0 comments on commit bad7dab

Please sign in to comment.