Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auth: true services in dstack-proxy #1885

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions src/dstack/_internal/proxy/deps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional

from fastapi import Depends, Request
from fastapi import Depends, HTTPException, Request, Security, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from typing_extensions import Annotated

from dstack._internal.proxy.repos.base import BaseProxyRepo
Expand Down Expand Up @@ -34,3 +35,41 @@ async def get_proxy_repo(
) -> AsyncGenerator[BaseProxyRepo, None]:
async for repo in injector.get_repo():
yield repo


class ProxyAuthContext:
def __init__(self, project_name: str, token: Optional[str], repo: BaseProxyRepo):
self._project_name = project_name
self._token = token
self._repo = repo

async def enforce(self) -> None:
if self._token is None or not await self._repo.is_project_member(
self._project_name, self._token
):
raise HTTPException(
status.HTTP_403_FORBIDDEN,
f"Unauthenticated or unauthorized to access project {self._project_name}",
)


class ProxyAuth:
def __init__(self, auto_enforce: bool):
self._auto_enforce = auto_enforce

async def __call__(
self,
project_name: str,
token: Annotated[
Optional[HTTPAuthorizationCredentials], Security(HTTPBearer(auto_error=False))
],
repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)],
) -> ProxyAuthContext:
context = ProxyAuthContext(
project_name=project_name,
token=token.credentials if token is not None else None,
repo=repo,
)
if self._auto_enforce:
await context.enforce()
return context
4 changes: 4 additions & 0 deletions src/dstack/_internal/proxy/repos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,7 @@ async def get_project(self, name: str) -> Optional[Project]:
@abstractmethod
async def add_project(self, project: Project) -> None:
pass

@abstractmethod
async def is_project_member(self, project_name: str, token: str) -> bool:
pass
5 changes: 5 additions & 0 deletions src/dstack/_internal/proxy/repos/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ async def get_project(self, name: str) -> Optional[Project]:

async def add_project(self, project: Project) -> None:
self.projects[project.name] = project

async def is_project_member(self, project_name: str, token: str) -> bool:
# TODO(#1595): when this class is used for gateways,
# implement a network request to dstack-server to check authorization
raise NotImplementedError
5 changes: 3 additions & 2 deletions src/dstack/_internal/proxy/routers/service_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi.responses import RedirectResponse, Response
from typing_extensions import Annotated

from dstack._internal.proxy.deps import get_proxy_repo
from dstack._internal.proxy.deps import ProxyAuth, ProxyAuthContext, get_proxy_repo
from dstack._internal.proxy.repos.base import BaseProxyRepo
from dstack._internal.proxy.services import service_proxy

Expand All @@ -27,9 +27,10 @@ async def service_reverse_proxy(
run_name: str,
path: str,
request: Request,
auth: Annotated[ProxyAuthContext, Depends(ProxyAuth(auto_enforce=False))],
repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)],
) -> Response:
return await service_proxy.proxy(project_name, run_name, path, request, repo)
return await service_proxy.proxy(project_name, run_name, path, request, auth, repo)


# TODO(#1595): support websockets
8 changes: 3 additions & 5 deletions src/dstack/_internal/proxy/services/service_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import httpx
from starlette.requests import ClientDisconnect

from dstack._internal.proxy.deps import ProxyAuthContext
from dstack._internal.proxy.repos.base import BaseProxyRepo, Replica, Service
from dstack._internal.proxy.services.service_connection import service_replica_connection_pool
from dstack._internal.utils.logging import get_logger
Expand All @@ -17,6 +18,7 @@ async def proxy(
run_name: str,
path: str,
request: fastapi.Request,
auth: ProxyAuthContext,
repo: BaseProxyRepo,
) -> fastapi.responses.Response:
if "Upgrade" in request.headers:
Expand All @@ -31,11 +33,7 @@ async def proxy(
f"Service {project_name}/{run_name} not found",
)
if service.auth:
# TODO(#1595): support auth
raise fastapi.HTTPException(
fastapi.status.HTTP_400_BAD_REQUEST,
f"Service {project_name}/{run_name} requires auth, which is not yet supported",
)
await auth.enforce()

replica = random.choice(service.replicas)
client = await get_replica_client(project_name, service, replica, repo)
Expand Down
Empty file.
12 changes: 12 additions & 0 deletions src/dstack/_internal/proxy/testing/repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Container, Dict, Optional

from dstack._internal.proxy.repos.memory import InMemoryProxyRepo


class ProxyTestRepo(InMemoryProxyRepo):
def __init__(self, project_to_tokens: Optional[Dict[str, Container[str]]] = None) -> None:
super().__init__()
self._project_to_tokens = project_to_tokens or {}

async def is_project_member(self, project_name: str, token: str) -> bool:
return token in self._project_to_tokens.get(project_name, set())
40 changes: 27 additions & 13 deletions src/dstack/_internal/server/security/permissions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple

from fastapi import Depends, Security
from fastapi import Depends, HTTPException, Security
from fastapi.security import HTTPBearer
from fastapi.security.http import HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -96,15 +96,29 @@ async def __call__(
project_name: str,
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
if project is None:
raise error_not_found()
if user.global_role == GlobalRole.ADMIN:
return user, project
project_role = get_user_project_role(user=user, project=project)
if project_role is not None:
return user, project
raise error_forbidden()
return await get_project_member(session, project_name, token.credentials)


async def get_project_member(
session: AsyncSession, project_name: str, token: str
) -> Tuple[UserModel, ProjectModel]:
user = await log_in_with_token(session=session, token=token)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
if project is None:
raise error_not_found()
if user.global_role == GlobalRole.ADMIN:
return user, project
project_role = get_user_project_role(user=user, project=project)
if project_role is not None:
return user, project
raise error_forbidden()


async def is_project_member(session: AsyncSession, project_name: str, token: str) -> bool:
try:
await get_project_member(session, project_name, token)
return True
except HTTPException:
return False
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/services/proxy/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dstack._internal.core.models.runs import JobProvisioningData, JobStatus, RunSpec
from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Replica, Service
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
from dstack._internal.server.security.permissions import is_project_member


class DBProxyRepo(BaseProxyRepo):
Expand Down Expand Up @@ -91,3 +92,6 @@ async def get_project(self, name: str) -> Optional[Project]:

async def add_project(self, project: Project) -> None:
pass

async def is_project_member(self, project_name: str, token: str) -> bool:
return await is_project_member(self.session, project_name, token)
41 changes: 29 additions & 12 deletions src/tests/_internal/proxy/routers/test_service_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from dstack._internal.proxy.deps import BaseProxyDependencyInjector
from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Replica, Service
from dstack._internal.proxy.repos.memory import InMemoryProxyRepo
from dstack._internal.proxy.routers.service_proxy import router
from dstack._internal.proxy.testing.repo import ProxyTestRepo


def make_app(repo: BaseProxyRepo) -> FastAPI:
Expand Down Expand Up @@ -36,11 +36,11 @@ def make_project(name: str) -> Project:
return Project(name=name, ssh_private_key="secret")


def make_service(run_name: str) -> Service:
def make_service(run_name: str, auth: bool = False) -> Service:
return Service(
id="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
run_name=run_name,
auth=False,
auth=auth,
app_port=80,
replicas=[
Replica(
Expand Down Expand Up @@ -71,7 +71,7 @@ def mock_replica_client_httpbin(httpbin) -> Generator[None, None, None]:
@pytest.mark.parametrize("method", ["get", "post", "put", "patch", "delete"])
async def test_proxy(mock_replica_client_httpbin, method: str) -> None:
methods_without_body = "get", "delete"
repo = InMemoryProxyRepo()
repo = ProxyTestRepo()
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
_, client = make_app_client(repo)
Expand All @@ -96,7 +96,7 @@ async def test_proxy(mock_replica_client_httpbin, method: str) -> None:

@pytest.mark.asyncio
async def test_proxy_method_head(mock_replica_client_httpbin) -> None:
repo = InMemoryProxyRepo()
repo = ProxyTestRepo()
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
_, client = make_app_client(repo)
Expand All @@ -111,7 +111,7 @@ async def test_proxy_method_head(mock_replica_client_httpbin) -> None:

@pytest.mark.asyncio
async def test_proxy_method_options(mock_replica_client_httpbin) -> None:
repo = InMemoryProxyRepo()
repo = ProxyTestRepo()
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
_, client = make_app_client(repo)
Expand All @@ -124,7 +124,7 @@ async def test_proxy_method_options(mock_replica_client_httpbin) -> None:
@pytest.mark.asyncio
@pytest.mark.parametrize("code", [204, 304, 418, 503])
async def test_proxy_status_codes(mock_replica_client_httpbin, code: int) -> None:
repo = InMemoryProxyRepo()
repo = ProxyTestRepo()
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
_, client = make_app_client(repo)
Expand All @@ -134,7 +134,7 @@ async def test_proxy_status_codes(mock_replica_client_httpbin, code: int) -> Non

@pytest.mark.asyncio
async def test_proxy_not_leaks_cookies(mock_replica_client_httpbin) -> None:
repo = InMemoryProxyRepo()
repo = ProxyTestRepo()
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
app = make_app(repo)
Expand All @@ -152,7 +152,7 @@ async def test_proxy_not_leaks_cookies(mock_replica_client_httpbin) -> None:

@pytest.mark.asyncio
async def test_proxy_gateway_timeout(mock_replica_client_httpbin) -> None:
repo = InMemoryProxyRepo()
repo = ProxyTestRepo()
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
_, client = make_app_client(repo)
Expand All @@ -164,7 +164,7 @@ async def test_proxy_gateway_timeout(mock_replica_client_httpbin) -> None:

@pytest.mark.asyncio
async def test_proxy_run_not_found(mock_replica_client_httpbin) -> None:
repo = InMemoryProxyRepo()
repo = ProxyTestRepo()
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("test-run"))
_, client = make_app_client(repo)
Expand All @@ -175,15 +175,15 @@ async def test_proxy_run_not_found(mock_replica_client_httpbin) -> None:

@pytest.mark.asyncio
async def test_proxy_project_not_found(mock_replica_client_httpbin) -> None:
_, client = make_app_client(InMemoryProxyRepo())
_, client = make_app_client(ProxyTestRepo())
resp = await client.get("http://test-host/services/unknown/test-run/")
assert resp.status_code == 404
assert resp.json()["detail"] == "Service unknown/test-run not found"


@pytest.mark.asyncio
async def test_redirect_to_service_root(mock_replica_client_httpbin) -> None:
repo = InMemoryProxyRepo()
repo = ProxyTestRepo()
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
_, client = make_app_client(repo)
Expand All @@ -194,3 +194,20 @@ async def test_redirect_to_service_root(mock_replica_client_httpbin) -> None:
resp = await client.get(url, follow_redirects=True)
assert resp.status_code == 200
assert resp.request.url == url + "/"


@pytest.mark.asyncio
@pytest.mark.parametrize(
("token", "status"), [("correct-token", 200), ("incorrect-token", 403), ("", 403), (None, 403)]
)
async def test_auth(mock_replica_client_httpbin, token: str, status: int) -> None:
repo = ProxyTestRepo(project_to_tokens={"test-proj": {"correct-token"}})
await repo.add_project(make_project("test-proj"))
await repo.add_service(project_name="test-proj", service=make_service("httpbin", auth=True))
_, client = make_app_client(repo)
url = "http://test-host/services/test-proj/httpbin/"
headers = {}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
resp = await client.get(url, headers=headers)
assert resp.status_code == status