diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index af6356a32..1665f5125 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -85,7 +85,7 @@ async def lifespan(app: FastAPI): ) else: with console.status(f"Applying [code]{server_config_dir}[/]..."): - await server_config_manager.apply_config(session=session) + await server_config_manager.apply_config(session=session, owner=admin) console.print(f"[code]✓[/] Applied [code]{server_config_dir}[/]") gateway_connections_pool.server_port = SERVER_PORT with console.status("Initializing gateways..."): diff --git a/src/dstack/_internal/server/services/config.py b/src/dstack/_internal/server/services/config.py index 90c0d0db7..58230efa2 100644 --- a/src/dstack/_internal/server/services/config.py +++ b/src/dstack/_internal/server/services/config.py @@ -18,7 +18,7 @@ from dstack._internal.core.models.backends.vastai import AnyVastAICreds from dstack._internal.core.models.common import ForbidExtra from dstack._internal.server import settings -from dstack._internal.server.models import ProjectModel +from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import projects as projects_services from dstack._internal.server.utils.common import run_async @@ -181,14 +181,21 @@ async def sync_config(self, session: AsyncSession): if self.config is not None: self._save_config(self.config) - async def apply_config(self, session: AsyncSession): + async def apply_config(self, session: AsyncSession, owner: UserModel): if self.config is None: raise ValueError("Config is not loaded") for project_config in self.config.projects: - project = await projects_services.get_project_model_by_name_or_error( + project = await projects_services.get_project_model_by_name( session=session, project_name=project_config.name, ) + if not project: + await projects_services.create_project_model( + session=session, owner=owner, project_name=project_config.name + ) + project = await projects_services.get_project_model_by_name_or_error( + session=session, project_name=project_config.name + ) backends_to_delete = backends_services.list_available_backend_types() for backend_config in project_config.backends: config_info = _config_to_internal_config(backend_config) diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 05f4e8330..24c1f1fb6 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -69,7 +69,7 @@ async def create_project(session: AsyncSession, user: UserModel, project_name: s await _check_projects_quota(session=session, user=user) project = await create_project_model( session=session, - user=user, + owner=user, project_name=project_name, ) await add_project_member( @@ -239,14 +239,14 @@ async def get_project_model_by_id_or_error( async def create_project_model( - session: AsyncSession, user: UserModel, project_name: str + session: AsyncSession, owner: UserModel, project_name: str ) -> ProjectModel: private_bytes, public_bytes = await run_async( generate_rsa_key_pair_bytes, f"{project_name}@dstack" ) project = ProjectModel( id=uuid.uuid4(), - owner_id=user.id, + owner_id=owner.id, name=project_name, ssh_private_key=private_bytes.decode(), ssh_public_key=public_bytes.decode(), diff --git a/src/tests/_internal/server/services/test_config.py b/src/tests/_internal/server/services/test_config.py index 27e55243a..efee07b13 100644 --- a/src/tests/_internal/server/services/test_config.py +++ b/src/tests/_internal/server/services/test_config.py @@ -9,10 +9,11 @@ from dstack._internal.core.models.backends.azure import AzureConfigInfoWithCreds, AzureDefaultCreds from dstack._internal.core.models.backends.base import BackendType from dstack._internal.server import settings -from dstack._internal.server.models import BackendModel +from dstack._internal.server.models import BackendModel, ProjectModel from dstack._internal.server.services.config import AzureConfig, ServerConfigManager from dstack._internal.server.testing.common import ( create_project, + create_user, ) @@ -54,7 +55,8 @@ async def test_inits_backend(self, test_db, session: AsyncSession, tmp_path: Pat class TestApplyConfig: @pytest.mark.asyncio async def test_creates_backend(self, test_db, session: AsyncSession, tmp_path: Path): - await create_project(session=session, name="main") + owner = await create_user(session=session, name="test_owner") + await create_project(session=session, owner=owner, name="main") config_filepath = tmp_path / "config.yml" config = { "projects": [ @@ -71,7 +73,21 @@ async def test_creates_backend(self, test_db, session: AsyncSession, tmp_path: P "regions": ["us-west-1"], } ], - } + }, + { + "name": "test", + "backends": [ + { + "type": "aws", + "creds": { + "type": "access_key", + "access_key": "4321", + "secret_key": "4321", + }, + "regions": ["eu-west-1"], + } + ], + }, ] } with open(config_filepath, "w+") as f: @@ -81,6 +97,10 @@ async def test_creates_backend(self, test_db, session: AsyncSession, tmp_path: P ): manager = ServerConfigManager() manager.load_config() - await manager.apply_config(session) - res = await session.execute(select(BackendModel)) - assert len(res.scalars().all()) == 1 + await manager.apply_config(session, owner) + p_res = await session.execute(select(ProjectModel)) + projects = p_res.scalars().all() + assert len(projects) == 2 + b_res = await session.execute(select(BackendModel)) + backends = b_res.scalars().all() + assert len(backends) == 2