From 006155fbb84236d5287ce5d9f5df392b0b165b1d Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Thu, 13 Feb 2025 08:02:49 +0000 Subject: [PATCH] Support services with head node setup * Add proxy chain support to `SSHTunnel` * Add optional head proxy fields to `Replica` * Extend gateway API to support head proxy fields Closes: https://github.com/dstackai/dstack/issues/2010 --- docs/docs/concepts/fleets.md | 2 - .../_internal/core/services/ssh/tunnel.py | 86 ++++++++++--------- .../proxy/gateway/routers/registry.py | 2 + .../proxy/gateway/schemas/registry.py | 2 + .../proxy/gateway/services/registry.py | 4 + src/dstack/_internal/proxy/lib/models.py | 3 + .../proxy/lib/services/service_connection.py | 9 +- .../background/tasks/process_running_jobs.py | 24 +++++- .../server/services/gateways/client.py | 10 ++- .../_internal/server/services/proxy/repo.py | 20 ++++- .../_internal/server/services/runner/ssh.py | 7 +- .../server/services/services/__init__.py | 10 ++- .../core/services/ssh/test_tunnel.py | 60 ++++++++++++- .../proxy/gateway/routers/test_registry.py | 37 ++++++-- 14 files changed, 214 insertions(+), 62 deletions(-) diff --git a/docs/docs/concepts/fleets.md b/docs/docs/concepts/fleets.md index 15c9bf78a..d0873c209 100644 --- a/docs/docs/concepts/fleets.md +++ b/docs/docs/concepts/fleets.md @@ -339,8 +339,6 @@ add a front node key (`~/.ssh/head_node_key`) to an SSH agent or configure a key where `Host` must match `ssh_config.proxy_jump.hostname` or `ssh_config.hosts[n].proxy_jump.hostname` if you configure head nodes on a per-worker basis. -> Currently, [services](services.md) do not work on instances with a head node setup. - !!! info "Reference" For all SSH fleet configuration options, refer to the [reference](../reference/dstack.yml/fleet.md). diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index 03c9894b4..e00307b11 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -69,14 +69,16 @@ def __init__( options: Dict[str, str] = SSH_DEFAULT_OPTIONS, ssh_config_path: Union[PathLike, Literal["none"]] = "none", port: Optional[int] = None, - ssh_proxy: Optional[SSHConnectionParams] = None, - ssh_proxy_identity: Optional[FilePathOrContent] = None, + ssh_proxies: Iterable[tuple[SSHConnectionParams, Optional[FilePathOrContent]]] = (), ): """ :param forwarded_sockets: Connections to the specified local sockets will be forwarded to their corresponding remote sockets :param reverse_forwarded_sockets: Connections to the specified remote sockets will be forwarded to their corresponding local sockets + :param ssh_proxies: pairs of SSH connections params and optional identities, + in order from outer to inner. If an identity is `None`, the `identity` param + is used instead. """ self.destination = destination self.forwarded_sockets = list(forwarded_sockets) @@ -84,29 +86,21 @@ def __init__( self.options = options self.port = port self.ssh_config_path = normalize_path(ssh_config_path) - self.ssh_proxy = ssh_proxy temp_dir = tempfile.TemporaryDirectory() self.temp_dir = temp_dir if control_sock_path is None: control_sock_path = os.path.join(temp_dir.name, "control.sock") self.control_sock_path = normalize_path(control_sock_path) - if isinstance(identity, FilePath): - identity_path = identity.path - else: - identity_path = os.path.join(temp_dir.name, "identity") - with open( - identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w" - ) as f: - f.write(identity.content) self.identity_path = normalize_path(self._get_identity_path(identity, "identity")) - if ssh_proxy_identity is not None: - self.ssh_proxy_identity_path = normalize_path( - self._get_identity_path(ssh_proxy_identity, "proxy_identity") - ) - elif ssh_proxy is not None: - self.ssh_proxy_identity_path = self.identity_path - else: - self.ssh_proxy_identity_path = None + self.ssh_proxies: list[tuple[SSHConnectionParams, PathLike]] = [] + for proxy_index, (proxy_params, proxy_identity) in enumerate(ssh_proxies): + if proxy_identity is None: + proxy_identity_path = self.identity_path + else: + proxy_identity_path = self._get_identity_path( + proxy_identity, f"proxy_identity_{proxy_index}" + ) + self.ssh_proxies.append((proxy_params, proxy_identity_path)) self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log")) self.ssh_client_info = get_ssh_client_info() self.ssh_exec_path = str(self.ssh_client_info.path) @@ -151,8 +145,8 @@ def open_command(self) -> List[str]: command += ["-p", str(self.port)] for k, v in self.options.items(): command += ["-o", f"{k}={v}"] - if proxy_command := self.proxy_command(): - command += ["-o", "ProxyCommand=" + shlex.join(proxy_command)] + if proxy_command := self._get_proxy_command(): + command += ["-o", proxy_command] for socket_pair in self.forwarded_sockets: command += ["-L", f"{socket_pair.local.render()}:{socket_pair.remote.render()}"] for socket_pair in self.reverse_forwarded_sockets: @@ -169,24 +163,6 @@ def check_command(self) -> List[str]: def exec_command(self) -> List[str]: return [self.ssh_exec_path, "-S", self.control_sock_path, self.destination] - def proxy_command(self) -> Optional[List[str]]: - if self.ssh_proxy is None: - return None - return [ - self.ssh_exec_path, - "-i", - self.ssh_proxy_identity_path, - "-W", - "%h:%p", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", - "-p", - str(self.ssh_proxy.port), - f"{self.ssh_proxy.username}@{self.ssh_proxy.hostname}", - ] - def open(self) -> None: # We cannot use `stderr=subprocess.PIPE` here since the forked process (daemon) does not # close standard streams if ProxyJump is used, therefore we will wait EOF from the pipe @@ -260,6 +236,38 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() + def _get_proxy_command(self) -> Optional[str]: + proxy_command: Optional[str] = None + for params, identity_path in self.ssh_proxies: + proxy_command = self._build_proxy_command(params, identity_path, proxy_command) + return proxy_command + + def _build_proxy_command( + self, + params: SSHConnectionParams, + identity_path: PathLike, + prev_proxy_command: Optional[str], + ) -> Optional[str]: + command = [ + self.ssh_exec_path, + "-i", + identity_path, + "-W", + "%h:%p", + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + ] + if prev_proxy_command is not None: + command += ["-o", prev_proxy_command.replace("%", "%%")] + command += [ + "-p", + str(params.port), + f"{params.username}@{params.hostname}", + ] + return "ProxyCommand=" + shlex.join(command) + def _read_log_file(self) -> bytes: with open(self.log_path, "rb") as f: return f.read() diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index c8e2c55b7..6a4032921 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -76,6 +76,8 @@ async def register_replica( ssh_destination=body.ssh_host, ssh_port=body.ssh_port, ssh_proxy=body.ssh_proxy, + ssh_head_proxy=body.ssh_head_proxy, + ssh_head_proxy_private_key=body.ssh_head_proxy_private_key, repo=repo, nginx=nginx, service_conn_pool=service_conn_pool, diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 7150039c6..6a841c7e9 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -50,6 +50,8 @@ class RegisterReplicaRequest(BaseModel): ssh_host: str ssh_port: int ssh_proxy: Optional[SSHConnectionParams] + ssh_head_proxy: Optional[SSHConnectionParams] + ssh_head_proxy_private_key: Optional[str] class RegisterEntrypointRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index fa935a147..191824c4f 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -123,6 +123,8 @@ async def register_replica( ssh_destination: str, ssh_port: int, ssh_proxy: Optional[SSHConnectionParams], + ssh_head_proxy: Optional[SSHConnectionParams], + ssh_head_proxy_private_key: Optional[str], repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, @@ -133,6 +135,8 @@ async def register_replica( ssh_destination=ssh_destination, ssh_port=ssh_port, ssh_proxy=ssh_proxy, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, ) async with lock: diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index 8079a0868..efc093f9c 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -23,6 +23,9 @@ class Replica(ImmutableModel): ssh_destination: str ssh_port: int ssh_proxy: Optional[SSHConnectionParams] + # Optional outer proxy, a head node/bastion + ssh_head_proxy: Optional[SSHConnectionParams] = None + ssh_head_proxy_private_key: Optional[str] = None class Service(ImmutableModel): diff --git a/src/dstack/_internal/proxy/lib/services/service_connection.py b/src/dstack/_internal/proxy/lib/services/service_connection.py index a6d14ad13..6c1cc49af 100644 --- a/src/dstack/_internal/proxy/lib/services/service_connection.py +++ b/src/dstack/_internal/proxy/lib/services/service_connection.py @@ -18,6 +18,7 @@ from dstack._internal.proxy.lib.errors import UnexpectedProxyError from dstack._internal.proxy.lib.models import Project, Replica, Service from dstack._internal.proxy.lib.repo import BaseProxyRepo +from dstack._internal.utils.common import get_or_error from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import FileContent @@ -45,10 +46,16 @@ def __init__(self, project: Project, service: Service, replica: Replica) -> None os.chmod(self._temp_dir.name, 0o755) options["StreamLocalBindMask"] = "0111" self._app_socket_path = (Path(self._temp_dir.name) / "replica.sock").absolute() + ssh_proxies = [] + if replica.ssh_head_proxy is not None: + ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key) + ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key))) + if replica.ssh_proxy is not None: + ssh_proxies.append((replica.ssh_proxy, None)) self._tunnel = SSHTunnel( destination=replica.ssh_destination, port=replica.ssh_port, - ssh_proxy=replica.ssh_proxy, + ssh_proxies=ssh_proxies, identity=FileContent(project.ssh_private_key), forwarded_sockets=[ SocketPair( diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 9a303f9cf..510c26509 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -11,7 +11,11 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode, RegistryAuth, is_core_model_instance from dstack._internal.core.models.configurations import DevEnvironmentConfiguration -from dstack._internal.core.models.instances import InstanceStatus, RemoteConnectionInfo +from dstack._internal.core.models.instances import ( + InstanceStatus, + RemoteConnectionInfo, + SSHConnectionParams, +) from dstack._internal.core.models.repos import RemoteRepoCreds from dstack._internal.core.models.runs import ( ClusterInfo, @@ -308,8 +312,24 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): and job_model.job_num == 0 # gateway connects only to the first node and run.run_spec.configuration.type == "service" ): + ssh_head_proxy: Optional[SSHConnectionParams] = None + ssh_head_proxy_private_key: Optional[str] = None + instance = common_utils.get_or_error(job_model.instance) + if instance.remote_connection_info is not None: + rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info) + if rci.ssh_proxy is not None: + ssh_head_proxy = rci.ssh_proxy + ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys) + ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private try: - await services.register_replica(session, run_model.gateway_id, run, job_model) + await services.register_replica( + session, + run_model.gateway_id, + run, + job_model, + ssh_head_proxy, + ssh_head_proxy_private_key, + ) except GatewayError as e: logger.warning( "%s: failed to register service replica: %s, age=%s", diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index 109c7bf6d..fd8bfe0dd 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -74,10 +74,18 @@ async def unregister_service(self, project: str, run_name: str): resp.raise_for_status() self.is_server_ready = True - async def register_replica(self, run: Run, job_submission: JobSubmission): + async def register_replica( + self, + run: Run, + job_submission: JobSubmission, + ssh_head_proxy: Optional[SSHConnectionParams], + ssh_head_proxy_private_key: Optional[str], + ): payload = { "job_id": job_submission.id.hex, "app_port": run.run_spec.configuration.port.container_port, + "ssh_head_proxy": ssh_head_proxy.dict() if ssh_head_proxy is not None else None, + "ssh_head_proxy_private_key": ssh_head_proxy_private_key, } jpd = job_submission.job_provisioning_data if not jpd.dockerized: diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 0f63cb76f..eca431935 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -9,7 +9,7 @@ from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.models.common import is_core_model_instance from dstack._internal.core.models.configurations import ServiceConfiguration -from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams from dstack._internal.core.models.runs import ( JobProvisioningData, JobStatus, @@ -30,6 +30,7 @@ from dstack._internal.proxy.lib.repo import BaseProxyRepo from dstack._internal.server.models import JobModel, ProjectModel, RunModel from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE +from dstack._internal.utils.common import get_or_error class ServerProxyRepo(BaseProxyRepo): @@ -53,9 +54,12 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic JobModel.status == JobStatus.RUNNING, JobModel.job_num == 0, ) - .options(joinedload(JobModel.run)) + .options( + joinedload(JobModel.run), + joinedload(JobModel.instance), + ) ) - jobs = res.scalars().all() + jobs = res.unique().scalars().all() if not len(jobs): return None run = jobs[0].run @@ -83,12 +87,22 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic username=jpd.username, port=jpd.ssh_port, ) + ssh_head_proxy: Optional[SSHConnectionParams] = None + ssh_head_proxy_private_key: Optional[str] = None + instance = get_or_error(job.instance) + if instance.remote_connection_info is not None: + rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info) + if rci.ssh_proxy is not None: + ssh_head_proxy = rci.ssh_proxy + ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private replica = Replica( id=job.id.hex, app_port=run_spec.configuration.port.container_port, ssh_destination=ssh_destination, ssh_port=ssh_port, ssh_proxy=ssh_proxy, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, ) replicas.append(replica) return Service( diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 6d3c11359..95d5f1ab5 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -75,6 +75,10 @@ def wrapper( else: proxy_identity = None + ssh_proxies = [] + if job_provisioning_data.ssh_proxy is not None: + ssh_proxies.append((job_provisioning_data.ssh_proxy, proxy_identity)) + for attempt in range(retries): last = attempt == retries - 1 # remote_host:local mapping @@ -91,8 +95,7 @@ def wrapper( port=job_provisioning_data.ssh_port, forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map), identity=identity, - ssh_proxy=job_provisioning_data.ssh_proxy, - ssh_proxy_identity=proxy_identity, + ssh_proxies=ssh_proxies, ): return func(runner_ports_map, *args, **kwargs) except SSHError: diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 8235d0ea1..352efb733 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -21,6 +21,7 @@ from dstack._internal.core.models.common import is_core_model_instance from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus +from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.runs import Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -155,7 +156,12 @@ def get_service_spec( async def register_replica( - session: AsyncSession, gateway_id: Optional[uuid.UUID], run: Run, job_model: JobModel + session: AsyncSession, + gateway_id: Optional[uuid.UUID], + run: Run, + job_model: JobModel, + ssh_head_proxy: Optional[SSHConnectionParams], + ssh_head_proxy_private_key: Optional[str], ): if gateway_id is None: # in-server proxy return @@ -167,6 +173,8 @@ async def register_replica( await client.register_replica( run=run, job_submission=job_submission, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, ) logger.info("%s: replica is registered for service %s", fmt(job_model), run.id.hex) except (httpx.RequestError, SSHError) as e: diff --git a/src/tests/_internal/core/services/ssh/test_tunnel.py b/src/tests/_internal/core/services/ssh/test_tunnel.py index 5bd5e1d8b..b97c86a1e 100644 --- a/src/tests/_internal/core/services/ssh/test_tunnel.py +++ b/src/tests/_internal/core/services/ssh/test_tunnel.py @@ -32,7 +32,9 @@ def sample_tunnel_with_all_params(self, ssh_client_info: SSHClientInfo) -> SSHTu options={"Opt1": "opt1"}, ssh_config_path="/home/user/.ssh/config", port=10022, - ssh_proxy=SSHConnectionParams(hostname="proxy", username="test", port=10022), + ssh_proxies=[ + (SSHConnectionParams(hostname="proxy", username="test", port=10022), None) + ], forwarded_sockets=[SocketPair(UnixSocket("/1"), UnixSocket("/2"))], reverse_forwarded_sockets=[SocketPair(UnixSocket("/1"), UnixSocket("/2"))], ) @@ -105,13 +107,18 @@ def test_open_command_with_temp_control_socket(self) -> None: ) @pytest.mark.usefixtures("ssh_client_info") - def test_open_command_with_proxy(self) -> None: + def test_open_command_with_one_proxy(self) -> None: tunnel = SSHTunnel( destination="ubuntu@my-server", identity=FilePath("/home/user/.ssh/id_rsa"), control_sock_path="/tmp/control.sock", options={}, - ssh_proxy=SSHConnectionParams(hostname="proxy", username="test", port=10022), + ssh_proxies=[ + ( + SSHConnectionParams(hostname="proxy", username="test", port=10022), + FilePath("/home/user/.ssh/proxy"), + ) + ], ) assert tunnel.open_command() == [ "/usr/bin/ssh", @@ -130,12 +137,57 @@ def test_open_command_with_proxy(self) -> None: "-o", ( "ProxyCommand=" - "/usr/bin/ssh -i /home/user/.ssh/id_rsa -W %h:%p -o StrictHostKeyChecking=no" + "/usr/bin/ssh -i /home/user/.ssh/proxy -W %h:%p -o StrictHostKeyChecking=no" " -o UserKnownHostsFile=/dev/null -p 10022 test@proxy" ), "ubuntu@my-server", ] + @pytest.mark.usefixtures("ssh_client_info") + def test_open_command_with_two_proxies(self) -> None: + tunnel = SSHTunnel( + destination="ubuntu@my-server", + identity=FilePath("/home/user/.ssh/id_rsa"), + control_sock_path="/tmp/control.sock", + options={}, + ssh_proxies=[ + ( + SSHConnectionParams(hostname="proxy1", username="test1", port=10022), + None, + ), + ( + SSHConnectionParams(hostname="proxy2", username="test2", port=20022), + FilePath("/home/user/.ssh/proxy2"), + ), + ], + ) + assert tunnel.open_command() == [ + "/usr/bin/ssh", + "-F", + "none", + "-i", + "/home/user/.ssh/id_rsa", + "-E", + f"{tunnel.temp_dir.name}/tunnel.log", + "-N", + "-f", + "-o", + "ControlMaster=auto", + "-S", + "/tmp/control.sock", + "-o", + ( + "ProxyCommand=" + "/usr/bin/ssh -i /home/user/.ssh/proxy2 -W %h:%p -o StrictHostKeyChecking=no" + " -o UserKnownHostsFile=/dev/null" + " -o 'ProxyCommand=/usr/bin/ssh -i /home/user/.ssh/id_rsa -W %%h:%%p" + " -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null" + " -p 10022 test1@proxy1'" + " -p 20022 test2@proxy2" + ), + "ubuntu@my-server", + ] + @pytest.mark.usefixtures("ssh_client_info") def test_open_command_with_forwarding(self) -> None: tunnel = SSHTunnel( diff --git a/src/tests/_internal/proxy/gateway/routers/test_registry.py b/src/tests/_internal/proxy/gateway/routers/test_registry.py index 706c9ec63..17ba28a08 100644 --- a/src/tests/_internal/proxy/gateway/routers/test_registry.py +++ b/src/tests/_internal/proxy/gateway/routers/test_registry.py @@ -48,6 +48,24 @@ def register_replica_payload(job_id: str = "xxx-xxx") -> dict: "ssh_host": "host.test", "ssh_port": 22, "ssh_proxy": None, + "ssh_head_proxy": None, + "ssh_head_proxy_private_key": None, + } + + +def register_replica_payload_with_head_proxy(job_id: str = "xxx-xxx") -> dict: + return { + "job_id": job_id, + "app_port": 8888, + "ssh_host": "host.test", + "ssh_port": 22, + "ssh_proxy": None, + "ssh_head_proxy": { + "hostname": "proxy.test", + "username": "debian", + "port": 222, + }, + "ssh_head_proxy_private_key": "private-key", } @@ -190,13 +208,18 @@ async def test_register(self, tmp_path: Path, system_mocks: Mocks) -> None: conf = (tmp_path / "443-test-run.gtw.test.conf").read_text() assert "upstream test-run" not in conf # register 2 replicas - for job_id in ("xxx-xxx", "yyy-yyy"): - resp = await client.post( - "/api/registry/test-proj/services/test-run/replicas/register", - json=register_replica_payload(job_id=job_id), - ) - assert resp.status_code == 200 - assert resp.json() == {"status": "ok"} + resp = await client.post( + "/api/registry/test-proj/services/test-run/replicas/register", + json=register_replica_payload(job_id="xxx-xxx"), + ) + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + resp = await client.post( + "/api/registry/test-proj/services/test-run/replicas/register", + json=register_replica_payload_with_head_proxy(job_id="yyy-yyy"), + ) + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} conf = (tmp_path / "443-test-run.gtw.test.conf").read_text() assert "upstream test-run" in conf assert (m1 := re.search(r"server unix:/(.+)/replica.sock; # replica xxx-xxx", conf))