Skip to content

Commit

Permalink
working prototype of direct-to-pod port-forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
romilbhardwaj committed May 21, 2024
1 parent 19e8ed1 commit 98dc040
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 58 deletions.
36 changes: 21 additions & 15 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,26 +442,32 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME
if network_mode == nodeport_mode:
service_type = kubernetes_enums.KubernetesServiceType.NODEPORT
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace,
service_type)
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_jump_name, nodeport_mode,
private_key_path=PRIVATE_SSH_KEY_PATH,
namespace=namespace)
elif network_mode == port_forward_mode:
# Using `kubectl port-forward` creates a direct tunnel to the pod and
# does not require a ssh jump pod.
kubernetes_utils.check_port_forward_mode_dependencies()
# Using `kubectl port-forward` creates a direct tunnel to jump pod and
# does not require opening any ports on Kubernetes nodes. As a result,
# the service can be a simple ClusterIP service which we access with
# `kubectl port-forward`.
service_type = kubernetes_enums.KubernetesServiceType.CLUSTERIP
# TODO(romilb): Handling multi-node clusters here might be tricky.
# We would need to either port-forward to the head node and use
# it as a jump host, or port-forward to each node individually.
# In the either case, we would need to have a proxycommand on a per
# node basis. This is currently not supported in upstream code.
# Similarly in the downstream provisioning code, we would need to
# set the worked pod name deterministically instead of using uuid.
ssh_target = config['cluster_name'] + '-head'
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_target, port_forward_mode)
else:
# This should never happen because we check for this in from_str above.
raise ValueError(f'Unsupported networking mode: {network_mode_str}')
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, service_type)

ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
PRIVATE_SSH_KEY_PATH, ssh_jump_name, network_mode, namespace,
clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_PATH,
clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_TEMPLATE)

config['auth']['ssh_proxy_command'] = ssh_proxy_cmd

return config
Expand Down
7 changes: 7 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from sky.provision import instance_setup
from sky.provision import metadata_utils
from sky.provision import provisioner
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.skylet import autostop_lib
from sky.skylet import constants
from sky.skylet import job_lib
Expand Down Expand Up @@ -4121,6 +4122,12 @@ def post_teardown_cleanup(self,
auth_config,
handle.docker_user)

# If cloud was kubernetes, remove the ProxyCommand script used for
# port-forwarding.
if isinstance(handle.launched_resources.cloud, clouds.Kubernetes):
kubernetes_utils.remove_proxy_command_script(handle.cluster_name_on_cloud + '-head')


global_user_state.remove_cluster(handle.cluster_name,
terminate=terminate)

Expand Down
4 changes: 1 addition & 3 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class Kubernetes(clouds.Cloud):

SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys'
SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod'
PORT_FORWARD_PROXY_CMD_TEMPLATE = \
'kubernetes-port-forward-proxy-command.sh.j2'
PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh'
# Timeout for resource provisioning. This timeout determines how long to
# wait for pod to be in pending status before giving up.
# Larger timeout may be required for autoscaling clusters, since autoscaler
Expand Down Expand Up @@ -304,6 +301,7 @@ def make_deploy_resources_variables(
'k8s_namespace':
kubernetes_utils.get_current_kube_config_context_namespace(),
'k8s_port_mode': port_mode.value,
'k8s_networking_mode': network_utils.get_networking_mode().value,
'k8s_ssh_key_secret_name': self.SKY_SSH_KEY_SECRET_NAME,
'k8s_acc_label_key': k8s_acc_label_key,
'k8s_acc_label_value': k8s_acc_label_value,
Expand Down
6 changes: 5 additions & 1 deletion sky/provision/kubernetes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from sky.adaptors import kubernetes
from sky.provision import common
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.provision.kubernetes import network_utils
from sky.utils import kubernetes_enums

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +27,9 @@ def bootstrap_instances(

_configure_services(namespace, config.provider_config)

config = _configure_ssh_jump(namespace, config)
networking_mode = network_utils.get_networking_mode(config.provider_config.get('networking_mode'))
if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
config = _configure_ssh_jump(namespace, config)

requested_service_account = config.node_config['spec']['serviceAccountName']
if (requested_service_account ==
Expand Down
17 changes: 11 additions & 6 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sky.provision import common
from sky.provision.kubernetes import config as config_lib
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.provision.kubernetes import network_utils
from sky.utils import common_utils
from sky.utils import kubernetes_enums
from sky.utils import ux_utils
Expand Down Expand Up @@ -520,14 +521,18 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
if head_pod_name is None:
head_pod_name = pod.metadata.name

# Adding the jump pod to the new_nodes list as well so it can be
# checked if it's scheduled and running along with other pods.
ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump']
jump_pod = kubernetes.core_api().read_namespaced_pod(
ssh_jump_pod_name, namespace)

wait_pods_dict = _filter_pods(namespace, tags, ['Pending'])
wait_pods = list(wait_pods_dict.values())
wait_pods.append(jump_pod)

networking_mode = network_utils.get_networking_mode(config.provider_config.get('networking_mode'))
if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
# Adding the jump pod to the new_nodes list as well so it can be
# checked if it's scheduled and running along with other pods.
ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump']
jump_pod = kubernetes.core_api().read_namespaced_pod(
ssh_jump_pod_name, namespace)
wait_pods.append(jump_pod)
provision_timeout = provider_config['timeout']

wait_str = ('indefinitely'
Expand Down
18 changes: 18 additions & 0 deletions sky/provision/kubernetes/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ def get_port_mode(
return port_mode


def get_networking_mode(
mode_str: Optional[str] = None) -> kubernetes_enums.KubernetesNetworkingMode:
"""Get the networking mode from the provider config."""
mode_str = mode_str or skypilot_config.get_nested(
('kubernetes', 'networking_mode'),
kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value)
try:
networking_mode = kubernetes_enums.KubernetesNetworkingMode.from_str(mode_str)
except ValueError as e:
# Add message saying "Please check: ~/.sky/config.yaml" to the error
# message.
with ux_utils.print_exception_no_traceback():
raise ValueError(str(e) + ' Please check: ~/.sky/config.yaml.') \
from None
return networking_mode



def fill_loadbalancer_template(namespace: str, service_name: str,
ports: List[int], selector_key: str,
selector_value: str) -> Dict:
Expand Down
103 changes: 71 additions & 32 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@

KIND_CONTEXT_NAME = 'kind-skypilot' # Context name used by sky local up

# Port-forward proxy command constants
PORT_FORWARD_PROXY_CMD_TEMPLATE = 'kubernetes-port-forward-proxy-command.sh.j2'
PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/generated/kubernetes/proxy_{}.sh'

logger = sky_logging.init_logger(__name__)


Expand Down Expand Up @@ -852,10 +856,14 @@ def construct_ssh_jump_command(private_key_path: str,


def get_ssh_proxy_command(
private_key_path: str, ssh_jump_name: str,
network_mode: kubernetes_enums.KubernetesNetworkingMode, namespace: str,
port_fwd_proxy_cmd_path: str, port_fwd_proxy_cmd_template: str) -> str:
"""Generates the SSH proxy command to connect through the SSH jump pod.
k8s_ssh_target: str,
network_mode: kubernetes_enums.KubernetesNetworkingMode,
private_key_path: Optional[str] = None,
namespace: Optional[str] = None) -> str:
"""Generates the SSH proxy command to connect to the pod.
Uses a jump pod if the network mode is NODEPORT, and direct port-forwarding
if the network mode is PORTFORWARD.
By default, establishing an SSH connection creates a communication
channel to a remote node by setting up a TCP connection. When a
Expand All @@ -871,57 +879,88 @@ def get_ssh_proxy_command(
With the NodePort networking mode, a NodePort service is launched. This
service opens an external port on the node which redirects to the desired
port within the pod. When establishing an SSH session in this mode, the
port to a SSH jump pod. When establishing an SSH session in this mode, the
ProxyCommand makes use of this external port to create a communication
channel directly to port 22, which is the default port ssh server listens
on, of the jump pod.
With Port-forward mode, instead of directly exposing an external port,
'kubectl port-forward' sets up a tunnel between a local port
(127.0.0.1:23100) and port 22 of the jump pod. Then we establish a TCP
(127.0.0.1:23100) and port 22 of the provisioned pod. Then we establish TCP
connection to the local end of this tunnel, 127.0.0.1:23100, using 'socat'.
This is setup in the inner ProxyCommand of the nested ProxyCommand, and the
rest is the same as NodePort approach, which the outer ProxyCommand
establishes a communication channel between 127.0.0.1:23100 and port 22 on
the jump pod. Consequently, any stdin provided on the local machine is
forwarded through this tunnel to the application (SSH server) listening in
the pod. Similarly, any output from the application in the pod is tunneled
back and displayed in the terminal on the local machine.
All of this is done in a ProxyCommand script. Any stdin provided on the
local machine is forwarded through this tunnel to the application
(SSH server) listening in the pod. Similarly, any output from the
application in the pod is tunneled back and displayed in the terminal on
the local machine.
Args:
private_key_path: str; Path to the private key to use for SSH.
This key must be authorized to access the SSH jump pod.
ssh_jump_name: str; Name of the SSH jump service to use
k8s_ssh_target: str; The Kubernetes object that will be used as the
target for SSH. If network_mode is NODEPORT, this is the name of the
service. If network_mode is PORTFORWARD, this is the pod name.
network_mode: KubernetesNetworkingMode; networking mode for ssh
session. It is either 'NODEPORT' or 'PORTFORWARD'
namespace: Kubernetes namespace to use
port_fwd_proxy_cmd_path: str; path to the script used as Proxycommand
with 'kubectl port-forward'
port_fwd_proxy_cmd_template: str; template used to create
'kubectl port-forward' Proxycommand
private_key_path: str; Path to the private key to use for SSH.
This key must be authorized to access the SSH jump pod.
Required for NODEPORT networking mode.
namespace: Kubernetes namespace to use.
Required for NODEPORT networking mode.
"""
# Fetch IP to connect to for the jump svc
ssh_jump_ip = get_external_ip(network_mode)
if network_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
ssh_jump_port = get_port(ssh_jump_name, namespace)
assert namespace is not None, 'Namespace must be provided for NodePort'
assert private_key_path is not None, 'Private key path must be provided'
ssh_jump_port = get_port(k8s_ssh_target, namespace)
ssh_jump_proxy_command = construct_ssh_jump_command(
private_key_path, ssh_jump_ip, ssh_jump_port=ssh_jump_port)
# Setting kubectl port-forward/socat to establish ssh session using
# ClusterIP service to disallow any ports opened
else:
vars_to_fill = {
'ssh_jump_name': ssh_jump_name,
}
common_utils.fill_template(port_fwd_proxy_cmd_template,
vars_to_fill,
output_path=port_fwd_proxy_cmd_path)
ssh_jump_proxy_command = construct_ssh_jump_command(
private_key_path,
ssh_jump_ip,
proxy_cmd_path=port_fwd_proxy_cmd_path)
ssh_jump_proxy_command = create_proxy_command_script(k8s_ssh_target)
return ssh_jump_proxy_command


def create_proxy_command_script(k8s_ssh_target: str) -> str:
"""Creates a ProxyCommand script that uses kubectl port-forward to setup
a tunnel between a local port and the SSH server in the pod.
Args:
k8s_ssh_target: str; The pod name to use as the target for SSH.
Returns:
str: Path to the ProxyCommand script.
"""
vars_to_fill = {
'pod_name': k8s_ssh_target,
}
port_fwd_proxy_cmd_path = os.path.expanduser(
PORT_FORWARD_PROXY_CMD_PATH.format(k8s_ssh_target))
os.makedirs(os.path.dirname(port_fwd_proxy_cmd_path), exist_ok=True,
mode=0o700)
common_utils.fill_template(PORT_FORWARD_PROXY_CMD_TEMPLATE,
vars_to_fill,
output_path=port_fwd_proxy_cmd_path)
# Set the permissions to 700 to ensure only the owner can read, write,
# and execute the file.
os.chmod(port_fwd_proxy_cmd_path, 0o700)
return port_fwd_proxy_cmd_path


def remove_proxy_command_script(k8s_ssh_target: str):
"""Removes the ProxyCommand script used for port-forwarding.
Args:
k8s_ssh_target: str; The pod name to use as the target for SSH.
"""
port_fwd_proxy_cmd_path = os.path.expanduser(
PORT_FORWARD_PROXY_CMD_PATH.format(k8s_ssh_target))
print('Removing proxy command script at path:', port_fwd_proxy_cmd_path)
if os.path.exists(port_fwd_proxy_cmd_path):
print('File exists, removing.')
os.remove(port_fwd_proxy_cmd_path)


def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str,
service_type: kubernetes_enums.KubernetesServiceType):
"""Sets up Kubernetes service resource to access for SSH jump pod.
Expand Down
2 changes: 1 addition & 1 deletion sky/templates/kubernetes-port-forward-proxy-command.sh.j2
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fi
# This is preferred because of socket re-use issues in kubectl port-forward,
# see - https://github.com/kubernetes/kubernetes/issues/74551#issuecomment-769185879
KUBECTL_OUTPUT=$(mktemp)
kubectl port-forward svc/{{ ssh_jump_name }} :22 > "${KUBECTL_OUTPUT}" 2>&1 &
kubectl port-forward pod/{{ pod_name }} :22 > "${KUBECTL_OUTPUT}" 2>&1 &

# Capture the PID for the backgrounded kubectl command
K8S_PORT_FWD_PID=$!
Expand Down
3 changes: 3 additions & 0 deletions sky/templates/kubernetes-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ provider:
# This should be one of KubernetesPortMode
port_mode: {{k8s_port_mode}}

# The networking mode used to ssh to pods. One of KubernetesNetworkingMode.
networking_mode: {{k8s_networking_mode}}

# We use internal IPs since we set up a port-forward between the kubernetes
# cluster and the local machine, or directly use NodePort to reach the
# head node.
Expand Down

0 comments on commit 98dc040

Please sign in to comment.