Skip to content

Commit

Permalink
fix: Explicitly specify protected service ports not to expose publicly (
Browse files Browse the repository at this point in the history
#2797) (#2802)

Co-authored-by: Joongi Kim <[email protected]>
  • Loading branch information
lablup-octodog and achimnol authored Sep 2, 2024
1 parent d2510a1 commit aa8b69d
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/2797.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Explicitly set the protected service ports depending on the resource group type and the service types
16 changes: 16 additions & 0 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,22 @@ def update_user_bootstrap_script(self, script: str) -> None:
"""
self.kernel_config["bootstrap_script"] = script

@property
@abstractmethod
def repl_ports(self) -> Sequence[int]:
"""
Return the list of intrinsic REPL ports to exclude from public mapping.
"""
raise NotImplementedError

@property
@abstractmethod
def protected_services(self) -> Sequence[str]:
"""
Return the list of protected (intrinsic) service names to exclude from public mapping.
"""
raise NotImplementedError

@abstractmethod
async def apply_network(self, cluster_info: ClusterInfo) -> None:
"""
Expand Down
41 changes: 35 additions & 6 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from aiodocker.exceptions import DockerError
from aiomonitor.task import preserve_termination_log
from async_timeout import timeout
from typing_extensions import override

from ai.backend.common.cgroup import get_cgroup_mount_point
from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef
Expand All @@ -61,6 +62,7 @@
KernelId,
MountPermission,
MountTypes,
ResourceGroupType,
ResourceSlot,
Sentinel,
ServicePort,
Expand Down Expand Up @@ -687,6 +689,21 @@ def _populate_ssh_config():
)
return kernel_obj

@property
@override
def repl_ports(self) -> Sequence[int]:
return (2000, 2001)

@property
@override
def protected_services(self) -> Sequence[str]:
rgtype: ResourceGroupType = self.local_config["agent"]["scaling-group-type"]
match rgtype:
case ResourceGroupType.COMPUTE:
return ()
case ResourceGroupType.STORAGE:
return ("ttyd",)

async def start_container(
self,
kernel_obj: AbstractKernel,
Expand All @@ -703,13 +720,13 @@ async def start_container(
# PHASE 4: Run!
container_bind_host = self.local_config["container"]["bind-host"]
advertised_kernel_host = self.local_config["container"].get("advertised-host")
repl_ports = [2000, 2001]
if len(service_ports) + len(repl_ports) > len(self.port_pool):
if len(service_ports) + len(self.repl_ports) > len(self.port_pool):
raise RuntimeError(
f"Container ports are not sufficiently available. (remaining ports: {self.port_pool})"
)
exposed_ports = repl_ports
host_ports = [self.port_pool.pop() for _ in repl_ports]
exposed_ports = [*self.repl_ports]
host_ports = [self.port_pool.pop() for _ in self.repl_ports]
host_ips = []
for sport in service_ports:
exposed_ports.extend(sport["container_ports"])
if (
Expand All @@ -725,6 +742,18 @@ async def start_container(
else:
hport = self.port_pool.pop()
host_ports.append(hport)
protected_service_ports: set[int] = set()
for sport in service_ports:
if sport["name"] in self.protected_services:
protected_service_ports.update(sport["container_ports"])
for eport in exposed_ports:
if eport in self.repl_ports: # always protected
host_ips.append("127.0.0.1")
elif eport in protected_service_ports: # check if protected by resource group type
host_ips.append("127.0.0.1")
else:
host_ips.append(str(container_bind_host))
assert len(host_ips) == len(host_ports) == len(exposed_ports)

container_log_size = self.local_config["agent"]["container-logs"]["max-length"]
container_log_file_count = 5
Expand Down Expand Up @@ -752,8 +781,8 @@ async def start_container(
"HostConfig": {
"Init": True,
"PortBindings": {
f"{eport}/tcp": [{"HostPort": str(hport), "HostIp": str(container_bind_host)}]
for eport, hport in zip(exposed_ports, host_ports)
f"{eport}/tcp": [{"HostPort": str(hport), "HostIp": hip}]
for eport, hport, hip in zip(exposed_ports, host_ports, host_ips)
},
"PublishAllPorts": False, # we manage port mapping manually!
"CapAdd": [
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/docker/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def create_code_runner(
self.kernel_id,
self.session_id,
event_producer,
kernel_host=self.data["kernel_host"],
kernel_host="127.0.0.1", # repl ports are always bound to 127.0.0.1
repl_in_port=self.data["repl_in_port"],
repl_out_port=self.data["repl_out_port"],
exec_timeout=0,
Expand Down
12 changes: 12 additions & 0 deletions src/ai/backend/agent/dummy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
Tuple,
)

from typing_extensions import override

from ai.backend.common.config import read_from_file
from ai.backend.common.docker import ImageRef
from ai.backend.common.events import EventProducer
Expand Down Expand Up @@ -108,6 +110,16 @@ async def prepare_scratch(self) -> None:
async def get_intrinsic_mounts(self) -> Sequence[Mount]:
return []

@property
@override
def repl_ports(self) -> Sequence[int]:
return (2000, 2001)

@property
@override
def protected_services(self) -> Sequence[str]:
return ()

async def apply_network(self, cluster_info: ClusterInfo) -> None:
return

Expand Down
14 changes: 13 additions & 1 deletion src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import pkg_resources
from kubernetes_asyncio import client as kube_client
from kubernetes_asyncio import config as kube_config
from typing_extensions import override

from ai.backend.common.asyncio import current_loop
from ai.backend.common.docker import ImageRef
Expand Down Expand Up @@ -303,6 +304,17 @@ async def get_intrinsic_mounts(self) -> Sequence[Mount]:

return mounts

@property
@override
def repl_ports(self) -> Sequence[int]:
return (2000, 2001)

@property
@override
def protected_services(self) -> Sequence[str]:
# NOTE: Currently K8s does not support binding container ports to 127.0.0.1 when using NodePort.
return ()

async def apply_network(self, cluster_info: ClusterInfo) -> None:
pass

Expand Down Expand Up @@ -655,7 +667,7 @@ async def start_container(
await kube_config.load_kube_config()
core_api = kube_client.CoreV1Api()
apps_api = kube_client.AppsV1Api()
exposed_ports = [2000, 2001]
exposed_ports = [*self.repl_ports]
for sport in service_ports:
exposed_ports.extend(sport["container_ports"])

Expand Down

0 comments on commit aa8b69d

Please sign in to comment.