From 98dc040e4b5c9d1a7b4bd61a6599a52f0cc63432 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 21 May 2024 16:44:16 -0700 Subject: [PATCH] working prototype of direct-to-pod port-forwarding --- sky/authentication.py | 36 +++--- sky/backends/cloud_vm_ray_backend.py | 7 ++ sky/clouds/kubernetes.py | 4 +- sky/provision/kubernetes/config.py | 6 +- sky/provision/kubernetes/instance.py | 17 ++- sky/provision/kubernetes/network_utils.py | 18 +++ sky/provision/kubernetes/utils.py | 103 ++++++++++++------ ...ubernetes-port-forward-proxy-command.sh.j2 | 2 +- sky/templates/kubernetes-ray.yml.j2 | 3 + 9 files changed, 138 insertions(+), 58 deletions(-) diff --git a/sky/authentication.py b/sky/authentication.py index 966dad670c5..c01c6560616 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -442,26 +442,32 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME if network_mode == nodeport_mode: service_type = kubernetes_enums.KubernetesServiceType.NODEPORT + # Setup service for SSH jump pod. We create the SSH jump service here + # because we need to know the service IP address and port to set the + # ssh_proxy_command in the autoscaler config. + kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, + service_type) + ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command( + ssh_jump_name, nodeport_mode, + private_key_path=PRIVATE_SSH_KEY_PATH, + namespace=namespace) elif network_mode == port_forward_mode: + # Using `kubectl port-forward` creates a direct tunnel to the pod and + # does not require a ssh jump pod. kubernetes_utils.check_port_forward_mode_dependencies() - # Using `kubectl port-forward` creates a direct tunnel to jump pod and - # does not require opening any ports on Kubernetes nodes. As a result, - # the service can be a simple ClusterIP service which we access with - # `kubectl port-forward`. - service_type = kubernetes_enums.KubernetesServiceType.CLUSTERIP + # TODO(romilb): Handling multi-node clusters here might be tricky. + # We would need to either port-forward to the head node and use + # it as a jump host, or port-forward to each node individually. + # In the either case, we would need to have a proxycommand on a per + # node basis. This is currently not supported in upstream code. + # Similarly in the downstream provisioning code, we would need to + # set the worked pod name deterministically instead of using uuid. + ssh_target = config['cluster_name'] + '-head' + ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command( + ssh_target, port_forward_mode) else: # This should never happen because we check for this in from_str above. raise ValueError(f'Unsupported networking mode: {network_mode_str}') - # Setup service for SSH jump pod. We create the SSH jump service here - # because we need to know the service IP address and port to set the - # ssh_proxy_command in the autoscaler config. - kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, service_type) - - ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command( - PRIVATE_SSH_KEY_PATH, ssh_jump_name, network_mode, namespace, - clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_PATH, - clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_TEMPLATE) - config['auth']['ssh_proxy_command'] = ssh_proxy_cmd return config diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 6d2447fe89b..c1a38f5b78b 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -47,6 +47,7 @@ from sky.provision import instance_setup from sky.provision import metadata_utils from sky.provision import provisioner +from sky.provision.kubernetes import utils as kubernetes_utils from sky.skylet import autostop_lib from sky.skylet import constants from sky.skylet import job_lib @@ -4121,6 +4122,12 @@ def post_teardown_cleanup(self, auth_config, handle.docker_user) + # If cloud was kubernetes, remove the ProxyCommand script used for + # port-forwarding. + if isinstance(handle.launched_resources.cloud, clouds.Kubernetes): + kubernetes_utils.remove_proxy_command_script(handle.cluster_name_on_cloud + '-head') + + global_user_state.remove_cluster(handle.cluster_name, terminate=terminate) diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index fcf8c2f87ac..838726ff99e 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -38,9 +38,6 @@ class Kubernetes(clouds.Cloud): SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys' SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod' - PORT_FORWARD_PROXY_CMD_TEMPLATE = \ - 'kubernetes-port-forward-proxy-command.sh.j2' - PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh' # Timeout for resource provisioning. This timeout determines how long to # wait for pod to be in pending status before giving up. # Larger timeout may be required for autoscaling clusters, since autoscaler @@ -304,6 +301,7 @@ def make_deploy_resources_variables( 'k8s_namespace': kubernetes_utils.get_current_kube_config_context_namespace(), 'k8s_port_mode': port_mode.value, + 'k8s_networking_mode': network_utils.get_networking_mode().value, 'k8s_ssh_key_secret_name': self.SKY_SSH_KEY_SECRET_NAME, 'k8s_acc_label_key': k8s_acc_label_key, 'k8s_acc_label_value': k8s_acc_label_value, diff --git a/sky/provision/kubernetes/config.py b/sky/provision/kubernetes/config.py index 65c494fcebf..77ad8c5825b 100644 --- a/sky/provision/kubernetes/config.py +++ b/sky/provision/kubernetes/config.py @@ -10,6 +10,8 @@ from sky.adaptors import kubernetes from sky.provision import common from sky.provision.kubernetes import utils as kubernetes_utils +from sky.provision.kubernetes import network_utils +from sky.utils import kubernetes_enums logger = logging.getLogger(__name__) @@ -25,7 +27,9 @@ def bootstrap_instances( _configure_services(namespace, config.provider_config) - config = _configure_ssh_jump(namespace, config) + networking_mode = network_utils.get_networking_mode(config.provider_config.get('networking_mode')) + if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: + config = _configure_ssh_jump(namespace, config) requested_service_account = config.node_config['spec']['serviceAccountName'] if (requested_service_account == diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 9068079701f..451659a10c3 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -12,6 +12,7 @@ from sky.provision import common from sky.provision.kubernetes import config as config_lib from sky.provision.kubernetes import utils as kubernetes_utils +from sky.provision.kubernetes import network_utils from sky.utils import common_utils from sky.utils import kubernetes_enums from sky.utils import ux_utils @@ -520,14 +521,18 @@ def _create_pods(region: str, cluster_name_on_cloud: str, if head_pod_name is None: head_pod_name = pod.metadata.name - # Adding the jump pod to the new_nodes list as well so it can be - # checked if it's scheduled and running along with other pods. - ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump'] - jump_pod = kubernetes.core_api().read_namespaced_pod( - ssh_jump_pod_name, namespace) + wait_pods_dict = _filter_pods(namespace, tags, ['Pending']) wait_pods = list(wait_pods_dict.values()) - wait_pods.append(jump_pod) + + networking_mode = network_utils.get_networking_mode(config.provider_config.get('networking_mode')) + if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: + # Adding the jump pod to the new_nodes list as well so it can be + # checked if it's scheduled and running along with other pods. + ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump'] + jump_pod = kubernetes.core_api().read_namespaced_pod( + ssh_jump_pod_name, namespace) + wait_pods.append(jump_pod) provision_timeout = provider_config['timeout'] wait_str = ('indefinitely' diff --git a/sky/provision/kubernetes/network_utils.py b/sky/provision/kubernetes/network_utils.py index 836d75af41f..7deb3dde843 100644 --- a/sky/provision/kubernetes/network_utils.py +++ b/sky/provision/kubernetes/network_utils.py @@ -43,6 +43,24 @@ def get_port_mode( return port_mode +def get_networking_mode( + mode_str: Optional[str] = None) -> kubernetes_enums.KubernetesNetworkingMode: + """Get the networking mode from the provider config.""" + mode_str = mode_str or skypilot_config.get_nested( + ('kubernetes', 'networking_mode'), + kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value) + try: + networking_mode = kubernetes_enums.KubernetesNetworkingMode.from_str(mode_str) + except ValueError as e: + # Add message saying "Please check: ~/.sky/config.yaml" to the error + # message. + with ux_utils.print_exception_no_traceback(): + raise ValueError(str(e) + ' Please check: ~/.sky/config.yaml.') \ + from None + return networking_mode + + + def fill_loadbalancer_template(namespace: str, service_name: str, ports: List[int], selector_key: str, selector_value: str) -> Dict: diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index d5140d8846b..d90b880e5d5 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -53,6 +53,10 @@ KIND_CONTEXT_NAME = 'kind-skypilot' # Context name used by sky local up +# Port-forward proxy command constants +PORT_FORWARD_PROXY_CMD_TEMPLATE = 'kubernetes-port-forward-proxy-command.sh.j2' +PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/generated/kubernetes/proxy_{}.sh' + logger = sky_logging.init_logger(__name__) @@ -852,10 +856,14 @@ def construct_ssh_jump_command(private_key_path: str, def get_ssh_proxy_command( - private_key_path: str, ssh_jump_name: str, - network_mode: kubernetes_enums.KubernetesNetworkingMode, namespace: str, - port_fwd_proxy_cmd_path: str, port_fwd_proxy_cmd_template: str) -> str: - """Generates the SSH proxy command to connect through the SSH jump pod. + k8s_ssh_target: str, + network_mode: kubernetes_enums.KubernetesNetworkingMode, + private_key_path: Optional[str] = None, + namespace: Optional[str] = None) -> str: + """Generates the SSH proxy command to connect to the pod. + + Uses a jump pod if the network mode is NODEPORT, and direct port-forwarding + if the network mode is PORTFORWARD. By default, establishing an SSH connection creates a communication channel to a remote node by setting up a TCP connection. When a @@ -871,57 +879,88 @@ def get_ssh_proxy_command( With the NodePort networking mode, a NodePort service is launched. This service opens an external port on the node which redirects to the desired - port within the pod. When establishing an SSH session in this mode, the + port to a SSH jump pod. When establishing an SSH session in this mode, the ProxyCommand makes use of this external port to create a communication channel directly to port 22, which is the default port ssh server listens on, of the jump pod. With Port-forward mode, instead of directly exposing an external port, 'kubectl port-forward' sets up a tunnel between a local port - (127.0.0.1:23100) and port 22 of the jump pod. Then we establish a TCP + (127.0.0.1:23100) and port 22 of the provisioned pod. Then we establish TCP connection to the local end of this tunnel, 127.0.0.1:23100, using 'socat'. - This is setup in the inner ProxyCommand of the nested ProxyCommand, and the - rest is the same as NodePort approach, which the outer ProxyCommand - establishes a communication channel between 127.0.0.1:23100 and port 22 on - the jump pod. Consequently, any stdin provided on the local machine is - forwarded through this tunnel to the application (SSH server) listening in - the pod. Similarly, any output from the application in the pod is tunneled - back and displayed in the terminal on the local machine. + All of this is done in a ProxyCommand script. Any stdin provided on the + local machine is forwarded through this tunnel to the application + (SSH server) listening in the pod. Similarly, any output from the + application in the pod is tunneled back and displayed in the terminal on + the local machine. Args: - private_key_path: str; Path to the private key to use for SSH. - This key must be authorized to access the SSH jump pod. - ssh_jump_name: str; Name of the SSH jump service to use + k8s_ssh_target: str; The Kubernetes object that will be used as the + target for SSH. If network_mode is NODEPORT, this is the name of the + service. If network_mode is PORTFORWARD, this is the pod name. network_mode: KubernetesNetworkingMode; networking mode for ssh session. It is either 'NODEPORT' or 'PORTFORWARD' - namespace: Kubernetes namespace to use - port_fwd_proxy_cmd_path: str; path to the script used as Proxycommand - with 'kubectl port-forward' - port_fwd_proxy_cmd_template: str; template used to create - 'kubectl port-forward' Proxycommand + private_key_path: str; Path to the private key to use for SSH. + This key must be authorized to access the SSH jump pod. + Required for NODEPORT networking mode. + namespace: Kubernetes namespace to use. + Required for NODEPORT networking mode. """ # Fetch IP to connect to for the jump svc ssh_jump_ip = get_external_ip(network_mode) if network_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: - ssh_jump_port = get_port(ssh_jump_name, namespace) + assert namespace is not None, 'Namespace must be provided for NodePort' + assert private_key_path is not None, 'Private key path must be provided' + ssh_jump_port = get_port(k8s_ssh_target, namespace) ssh_jump_proxy_command = construct_ssh_jump_command( private_key_path, ssh_jump_ip, ssh_jump_port=ssh_jump_port) # Setting kubectl port-forward/socat to establish ssh session using # ClusterIP service to disallow any ports opened else: - vars_to_fill = { - 'ssh_jump_name': ssh_jump_name, - } - common_utils.fill_template(port_fwd_proxy_cmd_template, - vars_to_fill, - output_path=port_fwd_proxy_cmd_path) - ssh_jump_proxy_command = construct_ssh_jump_command( - private_key_path, - ssh_jump_ip, - proxy_cmd_path=port_fwd_proxy_cmd_path) + ssh_jump_proxy_command = create_proxy_command_script(k8s_ssh_target) return ssh_jump_proxy_command +def create_proxy_command_script(k8s_ssh_target: str) -> str: + """Creates a ProxyCommand script that uses kubectl port-forward to setup + a tunnel between a local port and the SSH server in the pod. + + Args: + k8s_ssh_target: str; The pod name to use as the target for SSH. + + Returns: + str: Path to the ProxyCommand script. + """ + vars_to_fill = { + 'pod_name': k8s_ssh_target, + } + port_fwd_proxy_cmd_path = os.path.expanduser( + PORT_FORWARD_PROXY_CMD_PATH.format(k8s_ssh_target)) + os.makedirs(os.path.dirname(port_fwd_proxy_cmd_path), exist_ok=True, + mode=0o700) + common_utils.fill_template(PORT_FORWARD_PROXY_CMD_TEMPLATE, + vars_to_fill, + output_path=port_fwd_proxy_cmd_path) + # Set the permissions to 700 to ensure only the owner can read, write, + # and execute the file. + os.chmod(port_fwd_proxy_cmd_path, 0o700) + return port_fwd_proxy_cmd_path + + +def remove_proxy_command_script(k8s_ssh_target: str): + """Removes the ProxyCommand script used for port-forwarding. + + Args: + k8s_ssh_target: str; The pod name to use as the target for SSH. + """ + port_fwd_proxy_cmd_path = os.path.expanduser( + PORT_FORWARD_PROXY_CMD_PATH.format(k8s_ssh_target)) + print('Removing proxy command script at path:', port_fwd_proxy_cmd_path) + if os.path.exists(port_fwd_proxy_cmd_path): + print('File exists, removing.') + os.remove(port_fwd_proxy_cmd_path) + + def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str, service_type: kubernetes_enums.KubernetesServiceType): """Sets up Kubernetes service resource to access for SSH jump pod. diff --git a/sky/templates/kubernetes-port-forward-proxy-command.sh.j2 b/sky/templates/kubernetes-port-forward-proxy-command.sh.j2 index 39159eb15b9..70bb8f5a543 100644 --- a/sky/templates/kubernetes-port-forward-proxy-command.sh.j2 +++ b/sky/templates/kubernetes-port-forward-proxy-command.sh.j2 @@ -18,7 +18,7 @@ fi # This is preferred because of socket re-use issues in kubectl port-forward, # see - https://github.com/kubernetes/kubernetes/issues/74551#issuecomment-769185879 KUBECTL_OUTPUT=$(mktemp) -kubectl port-forward svc/{{ ssh_jump_name }} :22 > "${KUBECTL_OUTPUT}" 2>&1 & +kubectl port-forward pod/{{ pod_name }} :22 > "${KUBECTL_OUTPUT}" 2>&1 & # Capture the PID for the backgrounded kubectl command K8S_PORT_FWD_PID=$! diff --git a/sky/templates/kubernetes-ray.yml.j2 b/sky/templates/kubernetes-ray.yml.j2 index b05c8b589f6..4f0d80f0478 100644 --- a/sky/templates/kubernetes-ray.yml.j2 +++ b/sky/templates/kubernetes-ray.yml.j2 @@ -24,6 +24,9 @@ provider: # This should be one of KubernetesPortMode port_mode: {{k8s_port_mode}} + # The networking mode used to ssh to pods. One of KubernetesNetworkingMode. + networking_mode: {{k8s_networking_mode}} + # We use internal IPs since we set up a port-forward between the kubernetes # cluster and the local machine, or directly use NodePort to reach the # head node.