diff --git a/src/dstack/_internal/proxy/deps.py b/src/dstack/_internal/proxy/deps.py index 36bd17b14..0d49be8ee 100644 --- a/src/dstack/_internal/proxy/deps.py +++ b/src/dstack/_internal/proxy/deps.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional -from fastapi import Depends, Request +from fastapi import Depends, HTTPException, Request, Security, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from typing_extensions import Annotated from dstack._internal.proxy.repos.base import BaseProxyRepo @@ -34,3 +35,41 @@ async def get_proxy_repo( ) -> AsyncGenerator[BaseProxyRepo, None]: async for repo in injector.get_repo(): yield repo + + +class ProxyAuthContext: + def __init__(self, project_name: str, token: Optional[str], repo: BaseProxyRepo): + self._project_name = project_name + self._token = token + self._repo = repo + + async def enforce(self) -> None: + if self._token is None or not await self._repo.is_project_member( + self._project_name, self._token + ): + raise HTTPException( + status.HTTP_403_FORBIDDEN, + f"Unauthenticated or unauthorized to access project {self._project_name}", + ) + + +class ProxyAuth: + def __init__(self, auto_enforce: bool): + self._auto_enforce = auto_enforce + + async def __call__( + self, + project_name: str, + token: Annotated[ + Optional[HTTPAuthorizationCredentials], Security(HTTPBearer(auto_error=False)) + ], + repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)], + ) -> ProxyAuthContext: + context = ProxyAuthContext( + project_name=project_name, + token=token.credentials if token is not None else None, + repo=repo, + ) + if self._auto_enforce: + await context.enforce() + return context diff --git a/src/dstack/_internal/proxy/repos/base.py b/src/dstack/_internal/proxy/repos/base.py index 2da002f2c..1f40e8a2f 100644 --- a/src/dstack/_internal/proxy/repos/base.py +++ b/src/dstack/_internal/proxy/repos/base.py @@ -42,3 +42,7 @@ async def get_project(self, name: str) -> Optional[Project]: @abstractmethod async def add_project(self, project: Project) -> None: pass + + @abstractmethod + async def is_project_member(self, project_name: str, token: str) -> bool: + pass diff --git a/src/dstack/_internal/proxy/repos/memory.py b/src/dstack/_internal/proxy/repos/memory.py index 7b36ad881..cb60d6223 100644 --- a/src/dstack/_internal/proxy/repos/memory.py +++ b/src/dstack/_internal/proxy/repos/memory.py @@ -19,3 +19,8 @@ async def get_project(self, name: str) -> Optional[Project]: async def add_project(self, project: Project) -> None: self.projects[project.name] = project + + async def is_project_member(self, project_name: str, token: str) -> bool: + # TODO(#1595): when this class is used for gateways, + # implement a network request to dstack-server to check authorization + raise NotImplementedError diff --git a/src/dstack/_internal/proxy/routers/service_proxy.py b/src/dstack/_internal/proxy/routers/service_proxy.py index 6d9cf0b04..74fa5af23 100644 --- a/src/dstack/_internal/proxy/routers/service_proxy.py +++ b/src/dstack/_internal/proxy/routers/service_proxy.py @@ -3,7 +3,7 @@ from fastapi.responses import RedirectResponse, Response from typing_extensions import Annotated -from dstack._internal.proxy.deps import get_proxy_repo +from dstack._internal.proxy.deps import ProxyAuth, ProxyAuthContext, get_proxy_repo from dstack._internal.proxy.repos.base import BaseProxyRepo from dstack._internal.proxy.services import service_proxy @@ -27,9 +27,10 @@ async def service_reverse_proxy( run_name: str, path: str, request: Request, + auth: Annotated[ProxyAuthContext, Depends(ProxyAuth(auto_enforce=False))], repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)], ) -> Response: - return await service_proxy.proxy(project_name, run_name, path, request, repo) + return await service_proxy.proxy(project_name, run_name, path, request, auth, repo) # TODO(#1595): support websockets diff --git a/src/dstack/_internal/proxy/services/service_proxy.py b/src/dstack/_internal/proxy/services/service_proxy.py index 9910370b9..1a332df79 100644 --- a/src/dstack/_internal/proxy/services/service_proxy.py +++ b/src/dstack/_internal/proxy/services/service_proxy.py @@ -5,6 +5,7 @@ import httpx from starlette.requests import ClientDisconnect +from dstack._internal.proxy.deps import ProxyAuthContext from dstack._internal.proxy.repos.base import BaseProxyRepo, Replica, Service from dstack._internal.proxy.services.service_connection import service_replica_connection_pool from dstack._internal.utils.logging import get_logger @@ -17,6 +18,7 @@ async def proxy( run_name: str, path: str, request: fastapi.Request, + auth: ProxyAuthContext, repo: BaseProxyRepo, ) -> fastapi.responses.Response: if "Upgrade" in request.headers: @@ -31,11 +33,7 @@ async def proxy( f"Service {project_name}/{run_name} not found", ) if service.auth: - # TODO(#1595): support auth - raise fastapi.HTTPException( - fastapi.status.HTTP_400_BAD_REQUEST, - f"Service {project_name}/{run_name} requires auth, which is not yet supported", - ) + await auth.enforce() replica = random.choice(service.replicas) client = await get_replica_client(project_name, service, replica, repo) diff --git a/src/dstack/_internal/proxy/testing/__init__.py b/src/dstack/_internal/proxy/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/dstack/_internal/proxy/testing/repo.py b/src/dstack/_internal/proxy/testing/repo.py new file mode 100644 index 000000000..852ecfbef --- /dev/null +++ b/src/dstack/_internal/proxy/testing/repo.py @@ -0,0 +1,12 @@ +from typing import Container, Dict, Optional + +from dstack._internal.proxy.repos.memory import InMemoryProxyRepo + + +class ProxyTestRepo(InMemoryProxyRepo): + def __init__(self, project_to_tokens: Optional[Dict[str, Container[str]]] = None) -> None: + super().__init__() + self._project_to_tokens = project_to_tokens or {} + + async def is_project_member(self, project_name: str, token: str) -> bool: + return token in self._project_to_tokens.get(project_name, set()) diff --git a/src/dstack/_internal/server/security/permissions.py b/src/dstack/_internal/server/security/permissions.py index 9954e0c34..1fa938c0d 100644 --- a/src/dstack/_internal/server/security/permissions.py +++ b/src/dstack/_internal/server/security/permissions.py @@ -1,6 +1,6 @@ from typing import Tuple -from fastapi import Depends, Security +from fastapi import Depends, HTTPException, Security from fastapi.security import HTTPBearer from fastapi.security.http import HTTPAuthorizationCredentials from sqlalchemy.ext.asyncio import AsyncSession @@ -96,15 +96,29 @@ async def __call__( project_name: str, token: HTTPAuthorizationCredentials = Security(HTTPBearer()), ) -> Tuple[UserModel, ProjectModel]: - user = await log_in_with_token(session=session, token=token.credentials) - if user is None: - raise error_invalid_token() - project = await get_project_model_by_name(session=session, project_name=project_name) - if project is None: - raise error_not_found() - if user.global_role == GlobalRole.ADMIN: - return user, project - project_role = get_user_project_role(user=user, project=project) - if project_role is not None: - return user, project - raise error_forbidden() + return await get_project_member(session, project_name, token.credentials) + + +async def get_project_member( + session: AsyncSession, project_name: str, token: str +) -> Tuple[UserModel, ProjectModel]: + user = await log_in_with_token(session=session, token=token) + if user is None: + raise error_invalid_token() + project = await get_project_model_by_name(session=session, project_name=project_name) + if project is None: + raise error_not_found() + if user.global_role == GlobalRole.ADMIN: + return user, project + project_role = get_user_project_role(user=user, project=project) + if project_role is not None: + return user, project + raise error_forbidden() + + +async def is_project_member(session: AsyncSession, project_name: str, token: str) -> bool: + try: + await get_project_member(session, project_name, token) + return True + except HTTPException: + return False diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index e5d792855..e2bd160c3 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -10,6 +10,7 @@ from dstack._internal.core.models.runs import JobProvisioningData, JobStatus, RunSpec from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Replica, Service from dstack._internal.server.models import JobModel, ProjectModel, RunModel +from dstack._internal.server.security.permissions import is_project_member class DBProxyRepo(BaseProxyRepo): @@ -91,3 +92,6 @@ async def get_project(self, name: str) -> Optional[Project]: async def add_project(self, project: Project) -> None: pass + + async def is_project_member(self, project_name: str, token: str) -> bool: + return await is_project_member(self.session, project_name, token) diff --git a/src/tests/_internal/proxy/routers/test_service_proxy.py b/src/tests/_internal/proxy/routers/test_service_proxy.py index 3299c59f3..e4051d723 100644 --- a/src/tests/_internal/proxy/routers/test_service_proxy.py +++ b/src/tests/_internal/proxy/routers/test_service_proxy.py @@ -7,8 +7,8 @@ from dstack._internal.proxy.deps import BaseProxyDependencyInjector from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Replica, Service -from dstack._internal.proxy.repos.memory import InMemoryProxyRepo from dstack._internal.proxy.routers.service_proxy import router +from dstack._internal.proxy.testing.repo import ProxyTestRepo def make_app(repo: BaseProxyRepo) -> FastAPI: @@ -36,11 +36,11 @@ def make_project(name: str) -> Project: return Project(name=name, ssh_private_key="secret") -def make_service(run_name: str) -> Service: +def make_service(run_name: str, auth: bool = False) -> Service: return Service( id="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", run_name=run_name, - auth=False, + auth=auth, app_port=80, replicas=[ Replica( @@ -71,7 +71,7 @@ def mock_replica_client_httpbin(httpbin) -> Generator[None, None, None]: @pytest.mark.parametrize("method", ["get", "post", "put", "patch", "delete"]) async def test_proxy(mock_replica_client_httpbin, method: str) -> None: methods_without_body = "get", "delete" - repo = InMemoryProxyRepo() + repo = ProxyTestRepo() await repo.add_project(make_project("test-proj")) await repo.add_service(project_name="test-proj", service=make_service("httpbin")) _, client = make_app_client(repo) @@ -96,7 +96,7 @@ async def test_proxy(mock_replica_client_httpbin, method: str) -> None: @pytest.mark.asyncio async def test_proxy_method_head(mock_replica_client_httpbin) -> None: - repo = InMemoryProxyRepo() + repo = ProxyTestRepo() await repo.add_project(make_project("test-proj")) await repo.add_service(project_name="test-proj", service=make_service("httpbin")) _, client = make_app_client(repo) @@ -111,7 +111,7 @@ async def test_proxy_method_head(mock_replica_client_httpbin) -> None: @pytest.mark.asyncio async def test_proxy_method_options(mock_replica_client_httpbin) -> None: - repo = InMemoryProxyRepo() + repo = ProxyTestRepo() await repo.add_project(make_project("test-proj")) await repo.add_service(project_name="test-proj", service=make_service("httpbin")) _, client = make_app_client(repo) @@ -124,7 +124,7 @@ async def test_proxy_method_options(mock_replica_client_httpbin) -> None: @pytest.mark.asyncio @pytest.mark.parametrize("code", [204, 304, 418, 503]) async def test_proxy_status_codes(mock_replica_client_httpbin, code: int) -> None: - repo = InMemoryProxyRepo() + repo = ProxyTestRepo() await repo.add_project(make_project("test-proj")) await repo.add_service(project_name="test-proj", service=make_service("httpbin")) _, client = make_app_client(repo) @@ -134,7 +134,7 @@ async def test_proxy_status_codes(mock_replica_client_httpbin, code: int) -> Non @pytest.mark.asyncio async def test_proxy_not_leaks_cookies(mock_replica_client_httpbin) -> None: - repo = InMemoryProxyRepo() + repo = ProxyTestRepo() await repo.add_project(make_project("test-proj")) await repo.add_service(project_name="test-proj", service=make_service("httpbin")) app = make_app(repo) @@ -152,7 +152,7 @@ async def test_proxy_not_leaks_cookies(mock_replica_client_httpbin) -> None: @pytest.mark.asyncio async def test_proxy_gateway_timeout(mock_replica_client_httpbin) -> None: - repo = InMemoryProxyRepo() + repo = ProxyTestRepo() await repo.add_project(make_project("test-proj")) await repo.add_service(project_name="test-proj", service=make_service("httpbin")) _, client = make_app_client(repo) @@ -164,7 +164,7 @@ async def test_proxy_gateway_timeout(mock_replica_client_httpbin) -> None: @pytest.mark.asyncio async def test_proxy_run_not_found(mock_replica_client_httpbin) -> None: - repo = InMemoryProxyRepo() + repo = ProxyTestRepo() await repo.add_project(make_project("test-proj")) await repo.add_service(project_name="test-proj", service=make_service("test-run")) _, client = make_app_client(repo) @@ -175,7 +175,7 @@ async def test_proxy_run_not_found(mock_replica_client_httpbin) -> None: @pytest.mark.asyncio async def test_proxy_project_not_found(mock_replica_client_httpbin) -> None: - _, client = make_app_client(InMemoryProxyRepo()) + _, client = make_app_client(ProxyTestRepo()) resp = await client.get("http://test-host/services/unknown/test-run/") assert resp.status_code == 404 assert resp.json()["detail"] == "Service unknown/test-run not found" @@ -183,7 +183,7 @@ async def test_proxy_project_not_found(mock_replica_client_httpbin) -> None: @pytest.mark.asyncio async def test_redirect_to_service_root(mock_replica_client_httpbin) -> None: - repo = InMemoryProxyRepo() + repo = ProxyTestRepo() await repo.add_project(make_project("test-proj")) await repo.add_service(project_name="test-proj", service=make_service("httpbin")) _, client = make_app_client(repo) @@ -194,3 +194,20 @@ async def test_redirect_to_service_root(mock_replica_client_httpbin) -> None: resp = await client.get(url, follow_redirects=True) assert resp.status_code == 200 assert resp.request.url == url + "/" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("token", "status"), [("correct-token", 200), ("incorrect-token", 403), ("", 403), (None, 403)] +) +async def test_auth(mock_replica_client_httpbin, token: str, status: int) -> None: + repo = ProxyTestRepo(project_to_tokens={"test-proj": {"correct-token"}}) + await repo.add_project(make_project("test-proj")) + await repo.add_service(project_name="test-proj", service=make_service("httpbin", auth=True)) + _, client = make_app_client(repo) + url = "http://test-host/services/test-proj/httpbin/" + headers = {} + if token is not None: + headers["Authorization"] = f"Bearer {token}" + resp = await client.get(url, headers=headers) + assert resp.status_code == status