diff --git a/components/renku_data_services/base_api/misc.py b/components/renku_data_services/base_api/misc.py index 5053c3190..906e9087b 100644 --- a/components/renku_data_services/base_api/misc.py +++ b/components/renku_data_services/base_api/misc.py @@ -5,7 +5,7 @@ from functools import wraps from typing import Any, Concatenate, NoReturn, ParamSpec, TypeVar, cast -from pydantic import BaseModel +from pydantic import BaseModel, RootModel from sanic import Request, json from sanic.response import JSONResponse from sanic_ext import validate @@ -97,3 +97,29 @@ async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwar return decorated_function return decorator + + +def validate_body_root_model( + json: type[RootModel], +) -> Callable[ + [Callable[Concatenate[Request, _P], Awaitable[_T]]], + Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]], +]: + """Decorator for sanic json payload validation when the model is derived from RootModel. + + Should be removed once sanic fixes this error in their validation code. + Issue link: https://github.com/sanic-org/sanic-ext/issues/198 + """ + + def decorator( + f: Callable[Concatenate[Request, _P], Awaitable[_T]], + ) -> Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]]: + @wraps(f) + async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwargs) -> _T: + if request.json is not None: + request.parsed_json = {"root": request.parsed_json} # type: ignore[assignment] + return await validate(json=json)(f)(request, *args, **kwargs) + + return decorated_function + + return decorator diff --git a/components/renku_data_services/crc/blueprints.py b/components/renku_data_services/crc/blueprints.py index c71da7490..ba0e4a31c 100644 --- a/components/renku_data_services/crc/blueprints.py +++ b/components/renku_data_services/crc/blueprints.py @@ -10,7 +10,7 @@ from renku_data_services import errors from renku_data_services.base_api.auth import authenticate, only_admins from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint -from renku_data_services.base_api.misc import validate_db_ids, validate_query +from renku_data_services.base_api.misc import validate_body_root_model, validate_db_ids, validate_query from renku_data_services.base_models.validation import validated_json from renku_data_services.crc import apispec, models from renku_data_services.crc.db import ResourcePoolRepository, UserRepository @@ -161,9 +161,11 @@ def post(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @only_admins @validate_db_ids - async def _post(request: Request, user: base_models.APIUser, resource_pool_id: int) -> HTTPResponse: - users = apispec.PoolUsersWithId.model_validate(request.json) # validation - return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=users, post=True) + @validate_body_root_model(json=apispec.PoolUsersWithId) + async def _post( + _: Request, user: base_models.APIUser, resource_pool_id: int, body: apispec.PoolUsersWithId + ) -> HTTPResponse: + return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=body, post=True) return "/resource_pools//users", ["POST"], _post @@ -173,9 +175,11 @@ def put(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @only_admins @validate_db_ids - async def _put(request: Request, user: base_models.APIUser, resource_pool_id: int) -> HTTPResponse: - users = apispec.PoolUsersWithId.model_validate(request.json) # validation - return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=users, post=False) + @validate_body_root_model(json=apispec.PoolUsersWithId) + async def _put( + _: Request, user: base_models.APIUser, resource_pool_id: int, body: apispec.PoolUsersWithId + ) -> HTTPResponse: + return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=body, post=False) return "/resource_pools//users", ["PUT"], _put @@ -528,9 +532,9 @@ def post(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @only_admins - async def _post(request: Request, user: base_models.APIUser, user_id: str) -> HTTPResponse: - ids = apispec.IntegerIds.model_validate(request.json) # validation - return await self._post_put(user_id=user_id, post=True, resource_pool_ids=ids, api_user=user) + @validate_body_root_model(json=apispec.IntegerIds) + async def _post(_: Request, user: base_models.APIUser, user_id: str, body: apispec.IntegerIds) -> HTTPResponse: + return await self._post_put(user_id=user_id, post=True, resource_pool_ids=body, api_user=user) return "/users//resource_pools", ["POST"], _post @@ -539,9 +543,9 @@ def put(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @only_admins - async def _put(request: Request, user: base_models.APIUser, user_id: str) -> HTTPResponse: - ids = apispec.IntegerIds.model_validate(request.json) # validation - return await self._post_put(user_id=user_id, post=False, resource_pool_ids=ids, api_user=user) + @validate_body_root_model(json=apispec.IntegerIds) + async def _put(_: Request, user: base_models.APIUser, user_id: str, body: apispec.IntegerIds) -> HTTPResponse: + return await self._post_put(user_id=user_id, post=False, resource_pool_ids=body, api_user=user) return "/users//resource_pools", ["PUT"], _put diff --git a/components/renku_data_services/namespace/blueprints.py b/components/renku_data_services/namespace/blueprints.py index f62cfad06..a9c2deccb 100644 --- a/components/renku_data_services/namespace/blueprints.py +++ b/components/renku_data_services/namespace/blueprints.py @@ -10,7 +10,7 @@ from renku_data_services.authz.models import Role, UnsavedMember from renku_data_services.base_api.auth import authenticate, only_authenticated from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint -from renku_data_services.base_api.misc import validate_query +from renku_data_services.base_api.misc import validate_body_root_model, validate_query from renku_data_services.base_api.pagination import PaginationRequest, paginate from renku_data_services.base_models.validation import validate_and_dump, validated_json from renku_data_services.errors import errors @@ -118,11 +118,11 @@ def update_members(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @only_authenticated - async def _update_members(request: Request, user: base_models.APIUser, slug: str) -> JSONResponse: - # TODO: sanic validation does not support validating top-level json lists, switch this to @validate - # once sanic-org/sanic-ext/issues/198 is fixed - body_validated = apispec.GroupMemberPatchRequestList.model_validate(request.json) - members = [UnsavedMember(Role.from_group_role(member.role), member.id) for member in body_validated.root] + @validate_body_root_model(json=apispec.GroupMemberPatchRequestList) + async def _update_members( + _: Request, user: base_models.APIUser, slug: str, body: apispec.GroupMemberPatchRequestList + ) -> JSONResponse: + members = [UnsavedMember(Role.from_group_role(member.role), member.id) for member in body.root] res = await self.group_repo.update_group_members( user=user, slug=slug, diff --git a/components/renku_data_services/project/blueprints.py b/components/renku_data_services/project/blueprints.py index e9ccb540f..cfa136d39 100644 --- a/components/renku_data_services/project/blueprints.py +++ b/components/renku_data_services/project/blueprints.py @@ -18,7 +18,7 @@ ) from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint from renku_data_services.base_api.etag import if_match_required -from renku_data_services.base_api.misc import validate_query +from renku_data_services.base_api.misc import validate_body_root_model, validate_query from renku_data_services.base_api.pagination import PaginationRequest, paginate from renku_data_services.errors import errors from renku_data_services.project import apispec @@ -259,9 +259,11 @@ def update_members(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @validate_path_project_id - async def _update_members(request: Request, user: base_models.APIUser, project_id: str) -> HTTPResponse: - body_dump = apispec.ProjectMemberListPatchRequest.model_validate(request.json) - members = [Member(Role(i.role.value), i.id, project_id) for i in body_dump.root] + @validate_body_root_model(json=apispec.ProjectMemberListPatchRequest) + async def _update_members( + _: Request, user: base_models.APIUser, project_id: str, body: apispec.ProjectMemberListPatchRequest + ) -> HTTPResponse: + members = [Member(Role(i.role.value), i.id, project_id) for i in body.root] await self.project_member_repo.update_members(user, ULID.from_str(project_id), members) return HTTPResponse(status=200) diff --git a/components/renku_data_services/storage/blueprints.py b/components/renku_data_services/storage/blueprints.py index 5fb3c4b38..6c7c327df 100644 --- a/components/renku_data_services/storage/blueprints.py +++ b/components/renku_data_services/storage/blueprints.py @@ -12,7 +12,7 @@ from renku_data_services import errors from renku_data_services.base_api.auth import authenticate from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint -from renku_data_services.base_api.misc import validate_query +from renku_data_services.base_api.misc import validate_body_root_model, validate_query from renku_data_services.storage import apispec, models from renku_data_services.storage.db import StorageRepository, StorageV2Repository from renku_data_services.storage.rclone import RCloneValidator @@ -302,9 +302,10 @@ def upsert_secrets(self) -> BlueprintFactoryResponse: """Create/update secrets for a cloud storage.""" @authenticate(self.authenticator) - async def _upsert_secrets(request: Request, user: base_models.APIUser, storage_id: ULID) -> JSONResponse: - # TODO: use @validate once sanic supports validating json lists - body = apispec.CloudStorageSecretPostList.model_validate(request.json) + @validate_body_root_model(json=apispec.CloudStorageSecretPostList) + async def _upsert_secrets( + _: Request, user: base_models.APIUser, storage_id: ULID, body: apispec.CloudStorageSecretPostList + ) -> JSONResponse: secrets = [models.CloudStorageSecretUpsert.model_validate(s.model_dump()) for s in body.root] result = await self.storage_v2_repo.upsert_storage_secrets( storage_id=storage_id, user=user, secrets=secrets