Skip to content

Commit

Permalink
Support services with head node setup
Browse files Browse the repository at this point in the history
* Add proxy chain support to `SSHTunnel`
* Add optional head proxy fields to `Replica`
* Extend gateway API to support head proxy fields

Closes: #2010
  • Loading branch information
un-def committed Feb 13, 2025
1 parent 5cc97ae commit 006155f
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 62 deletions.
2 changes: 0 additions & 2 deletions docs/docs/concepts/fleets.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,6 @@ add a front node key (`~/.ssh/head_node_key`) to an SSH agent or configure a key
where `Host` must match `ssh_config.proxy_jump.hostname` or `ssh_config.hosts[n].proxy_jump.hostname` if you configure head nodes
on a per-worker basis.

> Currently, [services](services.md) do not work on instances with a head node setup.

!!! info "Reference"
For all SSH fleet configuration options, refer to the [reference](../reference/dstack.yml/fleet.md).

Expand Down
86 changes: 47 additions & 39 deletions src/dstack/_internal/core/services/ssh/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,44 +69,38 @@ def __init__(
options: Dict[str, str] = SSH_DEFAULT_OPTIONS,
ssh_config_path: Union[PathLike, Literal["none"]] = "none",
port: Optional[int] = None,
ssh_proxy: Optional[SSHConnectionParams] = None,
ssh_proxy_identity: Optional[FilePathOrContent] = None,
ssh_proxies: Iterable[tuple[SSHConnectionParams, Optional[FilePathOrContent]]] = (),
):
"""
:param forwarded_sockets: Connections to the specified local sockets will be
forwarded to their corresponding remote sockets
:param reverse_forwarded_sockets: Connections to the specified remote sockets
will be forwarded to their corresponding local sockets
:param ssh_proxies: pairs of SSH connections params and optional identities,
in order from outer to inner. If an identity is `None`, the `identity` param
is used instead.
"""
self.destination = destination
self.forwarded_sockets = list(forwarded_sockets)
self.reverse_forwarded_sockets = list(reverse_forwarded_sockets)
self.options = options
self.port = port
self.ssh_config_path = normalize_path(ssh_config_path)
self.ssh_proxy = ssh_proxy
temp_dir = tempfile.TemporaryDirectory()
self.temp_dir = temp_dir
if control_sock_path is None:
control_sock_path = os.path.join(temp_dir.name, "control.sock")
self.control_sock_path = normalize_path(control_sock_path)
if isinstance(identity, FilePath):
identity_path = identity.path
else:
identity_path = os.path.join(temp_dir.name, "identity")
with open(
identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
) as f:
f.write(identity.content)
self.identity_path = normalize_path(self._get_identity_path(identity, "identity"))
if ssh_proxy_identity is not None:
self.ssh_proxy_identity_path = normalize_path(
self._get_identity_path(ssh_proxy_identity, "proxy_identity")
)
elif ssh_proxy is not None:
self.ssh_proxy_identity_path = self.identity_path
else:
self.ssh_proxy_identity_path = None
self.ssh_proxies: list[tuple[SSHConnectionParams, PathLike]] = []
for proxy_index, (proxy_params, proxy_identity) in enumerate(ssh_proxies):
if proxy_identity is None:
proxy_identity_path = self.identity_path
else:
proxy_identity_path = self._get_identity_path(
proxy_identity, f"proxy_identity_{proxy_index}"
)
self.ssh_proxies.append((proxy_params, proxy_identity_path))
self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log"))
self.ssh_client_info = get_ssh_client_info()
self.ssh_exec_path = str(self.ssh_client_info.path)
Expand Down Expand Up @@ -151,8 +145,8 @@ def open_command(self) -> List[str]:
command += ["-p", str(self.port)]
for k, v in self.options.items():
command += ["-o", f"{k}={v}"]
if proxy_command := self.proxy_command():
command += ["-o", "ProxyCommand=" + shlex.join(proxy_command)]
if proxy_command := self._get_proxy_command():
command += ["-o", proxy_command]
for socket_pair in self.forwarded_sockets:
command += ["-L", f"{socket_pair.local.render()}:{socket_pair.remote.render()}"]
for socket_pair in self.reverse_forwarded_sockets:
Expand All @@ -169,24 +163,6 @@ def check_command(self) -> List[str]:
def exec_command(self) -> List[str]:
return [self.ssh_exec_path, "-S", self.control_sock_path, self.destination]

def proxy_command(self) -> Optional[List[str]]:
if self.ssh_proxy is None:
return None
return [
self.ssh_exec_path,
"-i",
self.ssh_proxy_identity_path,
"-W",
"%h:%p",
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
"-p",
str(self.ssh_proxy.port),
f"{self.ssh_proxy.username}@{self.ssh_proxy.hostname}",
]

def open(self) -> None:
# We cannot use `stderr=subprocess.PIPE` here since the forked process (daemon) does not
# close standard streams if ProxyJump is used, therefore we will wait EOF from the pipe
Expand Down Expand Up @@ -260,6 +236,38 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def _get_proxy_command(self) -> Optional[str]:
proxy_command: Optional[str] = None
for params, identity_path in self.ssh_proxies:
proxy_command = self._build_proxy_command(params, identity_path, proxy_command)
return proxy_command

def _build_proxy_command(
self,
params: SSHConnectionParams,
identity_path: PathLike,
prev_proxy_command: Optional[str],
) -> Optional[str]:
command = [
self.ssh_exec_path,
"-i",
identity_path,
"-W",
"%h:%p",
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
]
if prev_proxy_command is not None:
command += ["-o", prev_proxy_command.replace("%", "%%")]
command += [
"-p",
str(params.port),
f"{params.username}@{params.hostname}",
]
return "ProxyCommand=" + shlex.join(command)

def _read_log_file(self) -> bytes:
with open(self.log_path, "rb") as f:
return f.read()
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ async def register_replica(
ssh_destination=body.ssh_host,
ssh_port=body.ssh_port,
ssh_proxy=body.ssh_proxy,
ssh_head_proxy=body.ssh_head_proxy,
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class RegisterReplicaRequest(BaseModel):
ssh_host: str
ssh_port: int
ssh_proxy: Optional[SSHConnectionParams]
ssh_head_proxy: Optional[SSHConnectionParams]
ssh_head_proxy_private_key: Optional[str]


class RegisterEntrypointRequest(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ async def register_replica(
ssh_destination: str,
ssh_port: int,
ssh_proxy: Optional[SSHConnectionParams],
ssh_head_proxy: Optional[SSHConnectionParams],
ssh_head_proxy_private_key: Optional[str],
repo: GatewayProxyRepo,
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
Expand All @@ -133,6 +135,8 @@ async def register_replica(
ssh_destination=ssh_destination,
ssh_port=ssh_port,
ssh_proxy=ssh_proxy,
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
)

async with lock:
Expand Down
3 changes: 3 additions & 0 deletions src/dstack/_internal/proxy/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class Replica(ImmutableModel):
ssh_destination: str
ssh_port: int
ssh_proxy: Optional[SSHConnectionParams]
# Optional outer proxy, a head node/bastion
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None


class Service(ImmutableModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dstack._internal.proxy.lib.errors import UnexpectedProxyError
from dstack._internal.proxy.lib.models import Project, Replica, Service
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.utils.common import get_or_error
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import FileContent

Expand Down Expand Up @@ -45,10 +46,16 @@ def __init__(self, project: Project, service: Service, replica: Replica) -> None
os.chmod(self._temp_dir.name, 0o755)
options["StreamLocalBindMask"] = "0111"
self._app_socket_path = (Path(self._temp_dir.name) / "replica.sock").absolute()
ssh_proxies = []
if replica.ssh_head_proxy is not None:
ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key)
ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
if replica.ssh_proxy is not None:
ssh_proxies.append((replica.ssh_proxy, None))
self._tunnel = SSHTunnel(
destination=replica.ssh_destination,
port=replica.ssh_port,
ssh_proxy=replica.ssh_proxy,
ssh_proxies=ssh_proxies,
identity=FileContent(project.ssh_private_key),
forwarded_sockets=[
SocketPair(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import NetworkMode, RegistryAuth, is_core_model_instance
from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
from dstack._internal.core.models.instances import InstanceStatus, RemoteConnectionInfo
from dstack._internal.core.models.instances import (
InstanceStatus,
RemoteConnectionInfo,
SSHConnectionParams,
)
from dstack._internal.core.models.repos import RemoteRepoCreds
from dstack._internal.core.models.runs import (
ClusterInfo,
Expand Down Expand Up @@ -308,8 +312,24 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
and job_model.job_num == 0 # gateway connects only to the first node
and run.run_spec.configuration.type == "service"
):
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
instance = common_utils.get_or_error(job_model.instance)
if instance.remote_connection_info is not None:
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
if rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
try:
await services.register_replica(session, run_model.gateway_id, run, job_model)
await services.register_replica(
session,
run_model.gateway_id,
run,
job_model,
ssh_head_proxy,
ssh_head_proxy_private_key,
)
except GatewayError as e:
logger.warning(
"%s: failed to register service replica: %s, age=%s",
Expand Down
10 changes: 9 additions & 1 deletion src/dstack/_internal/server/services/gateways/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,18 @@ async def unregister_service(self, project: str, run_name: str):
resp.raise_for_status()
self.is_server_ready = True

async def register_replica(self, run: Run, job_submission: JobSubmission):
async def register_replica(
self,
run: Run,
job_submission: JobSubmission,
ssh_head_proxy: Optional[SSHConnectionParams],
ssh_head_proxy_private_key: Optional[str],
):
payload = {
"job_id": job_submission.id.hex,
"app_port": run.run_spec.configuration.port.container_port,
"ssh_head_proxy": ssh_head_proxy.dict() if ssh_head_proxy is not None else None,
"ssh_head_proxy_private_key": ssh_head_proxy_private_key,
}
jpd = job_submission.job_provisioning_data
if not jpd.dockerized:
Expand Down
20 changes: 17 additions & 3 deletions src/dstack/_internal/server/services/proxy/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.configurations import ServiceConfiguration
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
from dstack._internal.core.models.runs import (
JobProvisioningData,
JobStatus,
Expand All @@ -30,6 +30,7 @@
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
from dstack._internal.utils.common import get_or_error


class ServerProxyRepo(BaseProxyRepo):
Expand All @@ -53,9 +54,12 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
JobModel.status == JobStatus.RUNNING,
JobModel.job_num == 0,
)
.options(joinedload(JobModel.run))
.options(
joinedload(JobModel.run),
joinedload(JobModel.instance),
)
)
jobs = res.scalars().all()
jobs = res.unique().scalars().all()
if not len(jobs):
return None
run = jobs[0].run
Expand Down Expand Up @@ -83,12 +87,22 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
username=jpd.username,
port=jpd.ssh_port,
)
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
instance = get_or_error(job.instance)
if instance.remote_connection_info is not None:
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
if rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
replica = Replica(
id=job.id.hex,
app_port=run_spec.configuration.port.container_port,
ssh_destination=ssh_destination,
ssh_port=ssh_port,
ssh_proxy=ssh_proxy,
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
)
replicas.append(replica)
return Service(
Expand Down
7 changes: 5 additions & 2 deletions src/dstack/_internal/server/services/runner/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def wrapper(
else:
proxy_identity = None

ssh_proxies = []
if job_provisioning_data.ssh_proxy is not None:
ssh_proxies.append((job_provisioning_data.ssh_proxy, proxy_identity))

for attempt in range(retries):
last = attempt == retries - 1
# remote_host:local mapping
Expand All @@ -91,8 +95,7 @@ def wrapper(
port=job_provisioning_data.ssh_port,
forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map),
identity=identity,
ssh_proxy=job_provisioning_data.ssh_proxy,
ssh_proxy_identity=proxy_identity,
ssh_proxies=ssh_proxies,
):
return func(runner_ports_map, *args, **kwargs)
except SSHError:
Expand Down
10 changes: 9 additions & 1 deletion src/dstack/_internal/server/services/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration
from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import Run, RunSpec, ServiceModelSpec, ServiceSpec
from dstack._internal.server import settings
from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel
Expand Down Expand Up @@ -155,7 +156,12 @@ def get_service_spec(


async def register_replica(
session: AsyncSession, gateway_id: Optional[uuid.UUID], run: Run, job_model: JobModel
session: AsyncSession,
gateway_id: Optional[uuid.UUID],
run: Run,
job_model: JobModel,
ssh_head_proxy: Optional[SSHConnectionParams],
ssh_head_proxy_private_key: Optional[str],
):
if gateway_id is None: # in-server proxy
return
Expand All @@ -167,6 +173,8 @@ async def register_replica(
await client.register_replica(
run=run,
job_submission=job_submission,
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
)
logger.info("%s: replica is registered for service %s", fmt(job_model), run.id.hex)
except (httpx.RequestError, SSHError) as e:
Expand Down
Loading

0 comments on commit 006155f

Please sign in to comment.