Skip to content

Commit

Permalink
fix: Explicitly specify protected service ports not to expose publicly (
Browse files Browse the repository at this point in the history
#2797)

Backported-from: main (24.09)
Backported-to: 24.03
Backport-of: 2797
  • Loading branch information
achimnol committed Sep 2, 2024
1 parent 1acd768 commit cd09e89
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/2797.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Explicitly set the protected service ports depending on the resource group type and the service types
16 changes: 16 additions & 0 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,22 @@ def update_user_bootstrap_script(self, script: str) -> None:
"""
self.kernel_config["bootstrap_script"] = script

@property
@abstractmethod
def repl_ports(self) -> Sequence[int]:
"""
Return the list of intrinsic REPL ports to exclude from public mapping.
"""
raise NotImplementedError

@property
@abstractmethod
def protected_services(self) -> Sequence[str]:
"""
Return the list of protected (intrinsic) service names to exclude from public mapping.
"""
raise NotImplementedError

@abstractmethod
async def apply_network(self, cluster_info: ClusterInfo) -> None:
"""
Expand Down
41 changes: 35 additions & 6 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Tuple,
Union,
cast,
override,
)
from uuid import UUID

Expand Down Expand Up @@ -63,6 +64,7 @@
KernelId,
MountPermission,
MountTypes,
ResourceGroupType,
ResourceSlot,
Sentinel,
ServicePort,
Expand Down Expand Up @@ -713,6 +715,21 @@ def _populate_ssh_config():
)
return kernel_obj

@property
@override
def repl_ports(self) -> Sequence[int]:
return (2000, 2001)

@property
@override
def protected_services(self) -> Sequence[str]:
rgtype: ResourceGroupType = self.local_config["agent"]["scaling-group-type"]
match rgtype:
case ResourceGroupType.COMPUTE:
return ()
case ResourceGroupType.STORAGE:
return ("ttyd",)

async def start_container(
self,
kernel_obj: AbstractKernel,
Expand All @@ -729,13 +746,13 @@ async def start_container(
# PHASE 4: Run!
container_bind_host = self.local_config["container"]["bind-host"]
advertised_kernel_host = self.local_config["container"].get("advertised-host")
repl_ports = [2000, 2001]
if len(service_ports) + len(repl_ports) > len(self.port_pool):
if len(service_ports) + len(self.repl_ports) > len(self.port_pool):
raise RuntimeError(
f"Container ports are not sufficiently available. (remaining ports: {self.port_pool})"
)
exposed_ports = repl_ports
host_ports = [self.port_pool.pop() for _ in repl_ports]
exposed_ports = [*self.repl_ports]
host_ports = [self.port_pool.pop() for _ in self.repl_ports]
host_ips = []
for sport in service_ports:
exposed_ports.extend(sport["container_ports"])
if (
Expand All @@ -751,6 +768,18 @@ async def start_container(
else:
hport = self.port_pool.pop()
host_ports.append(hport)
protected_service_ports: set[int] = set()
for sport in service_ports:
if sport["name"] in self.protected_services:
protected_service_ports.update(sport["container_ports"])
for eport in exposed_ports:
if eport in self.repl_ports: # always protected
host_ips.append("127.0.0.1")
elif eport in protected_service_ports: # check if protected by resource group type
host_ips.append("127.0.0.1")
else:
host_ips.append(str(container_bind_host))
assert len(host_ips) == len(host_ports) == len(exposed_ports)

container_log_size = self.local_config["agent"]["container-logs"]["max-length"]
container_log_file_count = 5
Expand Down Expand Up @@ -778,8 +807,8 @@ async def start_container(
"HostConfig": {
"Init": True,
"PortBindings": {
f"{eport}/tcp": [{"HostPort": str(hport), "HostIp": str(container_bind_host)}]
for eport, hport in zip(exposed_ports, host_ports)
f"{eport}/tcp": [{"HostPort": str(hport), "HostIp": hip}]
for eport, hport, hip in zip(exposed_ports, host_ports, host_ips)
},
"PublishAllPorts": False, # we manage port mapping manually!
"CapAdd": [
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/docker/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def create_code_runner(
self.kernel_id,
self.session_id,
event_producer,
kernel_host=self.data["kernel_host"],
kernel_host="127.0.0.1", # repl ports are always bound to 127.0.0.1
repl_in_port=self.data["repl_in_port"],
repl_out_port=self.data["repl_out_port"],
exec_timeout=0,
Expand Down
11 changes: 11 additions & 0 deletions src/ai/backend/agent/dummy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Optional,
Sequence,
Tuple,
override,
)

from ai.backend.common.config import read_from_file
Expand Down Expand Up @@ -108,6 +109,16 @@ async def prepare_scratch(self) -> None:
async def get_intrinsic_mounts(self) -> Sequence[Mount]:
return []

@property
@override
def repl_ports(self) -> Sequence[int]:
return (2000, 2001)

@property
@override
def protected_services(self) -> Sequence[str]:
return ()

async def apply_network(self, cluster_info: ClusterInfo) -> None:
return

Expand Down
14 changes: 13 additions & 1 deletion src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Sequence,
Tuple,
Union,
override,
)

import aiotools
Expand Down Expand Up @@ -303,6 +304,17 @@ async def get_intrinsic_mounts(self) -> Sequence[Mount]:

return mounts

@property
@override
def repl_ports(self) -> Sequence[int]:
return (2000, 2001)

@property
@override
def protected_services(self) -> Sequence[str]:
# NOTE: Currently K8s does not support binding container ports to 127.0.0.1 when using NodePort.
return ()

async def apply_network(self, cluster_info: ClusterInfo) -> None:
pass

Expand Down Expand Up @@ -655,7 +667,7 @@ async def start_container(
await kube_config.load_kube_config()
core_api = kube_client.CoreV1Api()
apps_api = kube_client.AppsV1Api()
exposed_ports = [2000, 2001]
exposed_ports = [*self.repl_ports]
for sport in service_ports:
exposed_ports.extend(sport["container_ports"])

Expand Down

0 comments on commit cd09e89

Please sign in to comment.