diff --git a/sky/authentication.py b/sky/authentication.py index c01c6560616..e7f65d09c87 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -448,7 +448,8 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, service_type) ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command( - ssh_jump_name, nodeport_mode, + ssh_jump_name, + nodeport_mode, private_key_path=PRIVATE_SSH_KEY_PATH, namespace=namespace) elif network_mode == port_forward_mode: diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index c1a38f5b78b..b39e155fa1f 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -4125,8 +4125,8 @@ def post_teardown_cleanup(self, # If cloud was kubernetes, remove the ProxyCommand script used for # port-forwarding. if isinstance(handle.launched_resources.cloud, clouds.Kubernetes): - kubernetes_utils.remove_proxy_command_script(handle.cluster_name_on_cloud + '-head') - + kubernetes_utils.remove_proxy_command_script( + handle.cluster_name_on_cloud + '-head') global_user_state.remove_cluster(handle.cluster_name, terminate=terminate) diff --git a/sky/provision/kubernetes/config.py b/sky/provision/kubernetes/config.py index 77ad8c5825b..d7658cff75a 100644 --- a/sky/provision/kubernetes/config.py +++ b/sky/provision/kubernetes/config.py @@ -9,8 +9,8 @@ from sky.adaptors import kubernetes from sky.provision import common -from sky.provision.kubernetes import utils as kubernetes_utils from sky.provision.kubernetes import network_utils +from sky.provision.kubernetes import utils as kubernetes_utils from sky.utils import kubernetes_enums logger = logging.getLogger(__name__) @@ -27,7 +27,8 @@ def bootstrap_instances( _configure_services(namespace, config.provider_config) - networking_mode = network_utils.get_networking_mode(config.provider_config.get('networking_mode')) + networking_mode = network_utils.get_networking_mode( + config.provider_config.get('networking_mode')) if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: config = _configure_ssh_jump(namespace, config) diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 451659a10c3..8481d1eab3b 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -11,8 +11,8 @@ from sky.adaptors import kubernetes from sky.provision import common from sky.provision.kubernetes import config as config_lib -from sky.provision.kubernetes import utils as kubernetes_utils from sky.provision.kubernetes import network_utils +from sky.provision.kubernetes import utils as kubernetes_utils from sky.utils import common_utils from sky.utils import kubernetes_enums from sky.utils import ux_utils @@ -521,11 +521,11 @@ def _create_pods(region: str, cluster_name_on_cloud: str, if head_pod_name is None: head_pod_name = pod.metadata.name - wait_pods_dict = _filter_pods(namespace, tags, ['Pending']) wait_pods = list(wait_pods_dict.values()) - networking_mode = network_utils.get_networking_mode(config.provider_config.get('networking_mode')) + networking_mode = network_utils.get_networking_mode( + config.provider_config.get('networking_mode')) if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: # Adding the jump pod to the new_nodes list as well so it can be # checked if it's scheduled and running along with other pods. diff --git a/sky/provision/kubernetes/network_utils.py b/sky/provision/kubernetes/network_utils.py index 7deb3dde843..cb5df354ca6 100644 --- a/sky/provision/kubernetes/network_utils.py +++ b/sky/provision/kubernetes/network_utils.py @@ -44,13 +44,15 @@ def get_port_mode( def get_networking_mode( - mode_str: Optional[str] = None) -> kubernetes_enums.KubernetesNetworkingMode: + mode_str: Optional[str] = None +) -> kubernetes_enums.KubernetesNetworkingMode: """Get the networking mode from the provider config.""" mode_str = mode_str or skypilot_config.get_nested( ('kubernetes', 'networking_mode'), kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value) try: - networking_mode = kubernetes_enums.KubernetesNetworkingMode.from_str(mode_str) + networking_mode = kubernetes_enums.KubernetesNetworkingMode.from_str( + mode_str) except ValueError as e: # Add message saying "Please check: ~/.sky/config.yaml" to the error # message. @@ -60,7 +62,6 @@ def get_networking_mode( return networking_mode - def fill_loadbalancer_template(namespace: str, service_name: str, ports: List[int], selector_key: str, selector_value: str) -> Dict: diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index d90b880e5d5..8f0bdf6fe0e 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -936,7 +936,8 @@ def create_proxy_command_script(k8s_ssh_target: str) -> str: } port_fwd_proxy_cmd_path = os.path.expanduser( PORT_FORWARD_PROXY_CMD_PATH.format(k8s_ssh_target)) - os.makedirs(os.path.dirname(port_fwd_proxy_cmd_path), exist_ok=True, + os.makedirs(os.path.dirname(port_fwd_proxy_cmd_path), + exist_ok=True, mode=0o700) common_utils.fill_template(PORT_FORWARD_PROXY_CMD_TEMPLATE, vars_to_fill,