Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ def get_gateway_user_data(authorized_key: str) -> str:
packages=[
"nginx",
"python3.10-venv",
"python3-pip", # Add pip for sglang-router installation
],
snap={"commands": [["install", "--classic", "certbot"]]},
runcmd=[
Expand All @@ -850,6 +851,8 @@ def get_gateway_user_data(authorized_key: str) -> str:
"s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/",
"/etc/nginx/nginx.conf",
],
# Install sglang-router system-wide. Can be conditionally installed in the future.
["pip", "install", "sglang-router"],
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())],
],
ssh_authorized_keys=[authorized_key],
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class GatewayConfiguration(CoreModel):
default: Annotated[bool, Field(description="Make the gateway default")] = False
backend: Annotated[BackendType, Field(description="The gateway backend")]
region: Annotated[str, Field(description="The gateway region")]
router: Annotated[Optional[str], Field(description="The router type, e.g. `sglang`")] = None
domain: Annotated[
Optional[str], Field(description="The gateway domain, e.g. `example.com`")
] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m;

{% if replicas %}
upstream {{ domain }}.upstream {
{% if router == "sglang" %}
server 127.0.0.1:3000; # SGLang router on the gateway
{% else %}
{% for replica in replicas %}
server unix:{{ replica.socket }}; # replica {{ replica.id }}
{% endfor %}
{% endif %}
}
{% else %}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{% for replica in replicas %}
# Worker {{ loop.index }}
upstream sglang_worker_{{ loop.index }}_upstream {
server unix:{{ replica.socket }};
}

server {
listen 127.0.0.1:{{ 10000 + loop.index }};
access_log off; # disable access logs for this internal endpoint

proxy_read_timeout 300s;
proxy_send_timeout 300s;

location / {
proxy_pass http://sglang_worker_{{ loop.index }}_upstream;
proxy_http_version 1.1;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Connection "";
proxy_set_header Upgrade $http_upgrade;
}
}
{% endfor %}
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def register_service(
model=body.options.openai.model if body.options.openai is not None else None,
ssh_private_key=body.ssh_private_key,
repo=repo,
router=body.router,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class RegisterServiceRequest(BaseModel):
options: Options
ssh_private_key: str
rate_limits: tuple[RateLimit, ...] = ()
router: Optional[str] = None


class RegisterReplicaRequest(BaseModel):
Expand Down
218 changes: 217 additions & 1 deletion src/dstack/_internal/proxy/gateway/services/nginx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib.resources
import json
import subprocess
import tempfile
import urllib.parse
from asyncio import Lock
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -64,6 +66,7 @@ class ServiceConfig(SiteConfig):
limit_req_zones: list[LimitReqZoneConfig]
locations: list[LocationConfig]
replicas: list[ReplicaConfig]
router: Optional[str] = None


class ModelEntrypointConfig(SiteConfig):
Expand All @@ -81,11 +84,14 @@ def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None:
async def register(self, conf: SiteConfig, acme: ACMESettings) -> None:
logger.debug("Registering %s domain %s", conf.type, conf.domain)
conf_name = self.get_config_name(conf.domain)

async with self._lock:
if conf.https:
await run_async(self.run_certbot, conf.domain, acme)
await run_async(self.write_conf, conf.render(), conf_name)
if hasattr(conf, "router") and conf.router == "sglang":
replicas = len(conf.replicas)
await run_async(self.write_sglang_workers_conf, conf)
await run_async(self.start_or_update_sglang_router, replicas)

logger.info("Registered %s domain %s", conf.type, conf.domain)

Expand All @@ -96,6 +102,10 @@ async def unregister(self, domain: str) -> None:
return
async with self._lock:
await run_async(sudo_rm, conf_path)
workers_conf_path = self._conf_dir / f"sglang-workers.{domain}.conf"
if workers_conf_path.exists():
await run_async(sudo_rm, workers_conf_path)
await run_async(self.stop_sglang_router)
await run_async(self.reload)
logger.info("Unregistered domain %s", domain)

Expand All @@ -106,6 +116,197 @@ def reload() -> None:
if r.returncode != 0:
raise UnexpectedProxyError("Failed to reload nginx")

@staticmethod
def start_or_update_sglang_router(replicas: int) -> None:
if not Nginx.is_sglang_router_running():
Nginx.start_sglang_router()
Nginx.update_sglang_router_workers(replicas)

@staticmethod
def is_sglang_router_running() -> bool:
"""Check if sglang router is running and responding to HTTP requests."""
try:
result = subprocess.run(
["curl", "-s", "http://localhost:3000/workers"], capture_output=True, timeout=5
)
return result.returncode == 0
except Exception as e:
logger.error(f"Error checking sglang router status: {e}")
return False

@staticmethod
def start_sglang_router() -> None:
try:
logger.info("Starting sglang-router...")
cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
"0.0.0.0",
"--port",
"3000",
"--log-level",
"debug",
"--log-dir",
"./router_logs",
]
subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

import time

time.sleep(2)

# Verify router is running
if not Nginx.is_sglang_router_running():
raise Exception("Failed to start sglang router")

logger.info("Sglang router started successfully")

except Exception as e:
logger.error(f"Failed to start sglang-router: {e}")
raise

@staticmethod
def get_sglang_router_workers() -> list[dict]:
try:
result = subprocess.run(
["curl", "-s", "http://localhost:3000/workers"], capture_output=True, timeout=5
)
if result.returncode == 0:
response = json.loads(result.stdout.decode())
return response.get("workers", [])
return []
except Exception as e:
logger.error(f"Error getting sglang router workers: {e}")
return []

@staticmethod
def update_sglang_router_workers(replicas: int) -> None:
"""Update sglang router workers via HTTP API"""
try:
# Get current workers
current_workers = Nginx.get_sglang_router_workers()
current_worker_urls = {worker["url"] for worker in current_workers}

# Calculate target worker URLs
target_worker_urls = {f"http://127.0.0.1:{10000 + i}" for i in range(1, replicas + 1)}

# Workers to add
workers_to_add = target_worker_urls - current_worker_urls
# Workers to remove
workers_to_remove = current_worker_urls - target_worker_urls

if workers_to_add:
logger.info("Sglang router update: adding %d workers", len(workers_to_add))
if workers_to_remove:
logger.info("Sglang router update: removing %d workers", len(workers_to_remove))

# Add workers
for worker_url in sorted(workers_to_add):
success = Nginx.add_sglang_router_worker(worker_url)
if not success:
logger.warning("Failed to add worker %s, continuing with others", worker_url)

# Remove workers
for worker_url in sorted(workers_to_remove):
success = Nginx.remove_sglang_router_worker(worker_url)
if not success:
logger.warning(
"Failed to remove worker %s, continuing with others", worker_url
)

except Exception as e:
logger.error(f"Error updating sglang router workers: {e}")
raise

@staticmethod
def add_sglang_router_worker(worker_url: str) -> bool:
try:
payload = {"url": worker_url, "worker_type": "regular"}
result = subprocess.run(
[
"curl",
"-X",
"POST",
"http://localhost:3000/workers",
"-H",
"Content-Type: application/json",
"-d",
json.dumps(payload),
],
capture_output=True,
timeout=5,
)

if result.returncode == 0:
response = json.loads(result.stdout.decode())
if response.get("status") == "accepted":
logger.info("Added worker %s to sglang router", worker_url)
return True
else:
logger.error("Failed to add worker %s: %s", worker_url, response)
return False
else:
logger.error("Failed to add worker %s: %s", worker_url, result.stderr.decode())
return False
except Exception as e:
logger.error(f"Error adding worker {worker_url}: {e}")
return False

@staticmethod
def remove_sglang_router_worker(worker_url: str) -> bool:
"""Remove a single worker from sglang router"""
try:
# URL encode the worker URL for the DELETE request
encoded_url = urllib.parse.quote(worker_url, safe="")

result = subprocess.run(
["curl", "-X", "DELETE", f"http://localhost:3000/workers/{encoded_url}"],
capture_output=True,
timeout=5,
)

if result.returncode == 0:
response = json.loads(result.stdout.decode())
if response.get("status") == "accepted":
logger.info("Removed worker %s from sglang router", worker_url)
return True
else:
logger.error("Failed to remove worker %s: %s", worker_url, response)
return False
else:
logger.error("Failed to remove worker %s: %s", worker_url, result.stderr.decode())
return False
except Exception as e:
logger.error(f"Error removing worker {worker_url}: {e}")
return False

@staticmethod
def stop_sglang_router() -> None:
try:
result = subprocess.run(
["pgrep", "-f", "sglang::router"], capture_output=True, timeout=5
)
if result.returncode == 0:
logger.info("Stopping sglang-router process...")
subprocess.run(["pkill", "-f", "sglang::router"], timeout=5)
else:
logger.debug("No sglang-router process found to stop")

log_dir = Path("./router_logs")
if log_dir.exists():
logger.debug("Cleaning up router logs...")
import shutil

shutil.rmtree(log_dir, ignore_errors=True)
else:
logger.debug("No router logs directory found to clean up")

except Exception as e:
logger.error(f"Failed to stop sglang-router: {e}")
raise

def write_conf(self, conf: str, conf_name: str) -> None:
"""Update config and reload nginx. Rollback changes on error."""
conf_path = self._conf_dir / conf_name
Expand Down Expand Up @@ -168,6 +369,21 @@ def write_global_conf(self) -> None:
conf = read_package_resource("00-log-format.conf")
self.write_conf(conf, "00-log-format.conf")

def write_sglang_workers_conf(self, conf: SiteConfig) -> None:
workers_config = generate_sglang_workers_config(conf)
workers_conf_name = f"sglang-workers.{conf.domain}.conf"
workers_conf_path = self._conf_dir / workers_conf_name
sudo_write(workers_conf_path, workers_config)
self.reload()


def generate_sglang_workers_config(conf: SiteConfig) -> str:
template = read_package_resource("sglang_workers.jinja2")
return jinja2.Template(template).render(
replicas=conf.replicas,
proxy_port=PROXY_PORT_ON_GATEWAY,
)


def read_package_resource(file: str) -> str:
return (
Expand Down
3 changes: 3 additions & 0 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def register_service(
repo: GatewayProxyRepo,
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
router: Optional[str] = None,
) -> None:
service = models.Service(
project_name=project_name,
Expand All @@ -54,6 +55,7 @@ async def register_service(
auth=auth,
client_max_body_size=client_max_body_size,
replicas=(),
router=router,
)

async with lock:
Expand Down Expand Up @@ -335,6 +337,7 @@ async def get_nginx_service_config(
limit_req_zones=limit_req_zones,
locations=locations,
replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs
router=service.router,
)


Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Service(ImmutableModel):
client_max_body_size: int # only enforced on gateways
strip_prefix: bool = True # only used in-server
replicas: tuple[Replica, ...]
router: Optional[str] = None

@property
def domain_safe(self) -> str:
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/services/gateways/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ async def register_service(
options: dict,
rate_limits: list[RateLimit],
ssh_private_key: str,
router: Optional[str] = None,
):
if "openai" in options:
entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}"
Expand All @@ -59,6 +60,7 @@ async def register_service(
"options": options,
"rate_limits": [limit.dict() for limit in rate_limits],
"ssh_private_key": ssh_private_key,
"router": router,
}
resp = await self._client.post(
self._url(f"/api/registry/{project}/services/register"), json=payload
Expand Down
Loading