Skip to content

Commit

Permalink
refactor: introduce @validate_body_root_model (#423)
Browse files Browse the repository at this point in the history
Add `@validate_body_root_model` which is a drop-in replacement for Sanic's `@validate` when validating a `RootModel` or a model deriving from it.

The decorator should be replaced by `@validate` once sanic-org/sanic-ext#198 is fixed.
  • Loading branch information
leafty authored Sep 27, 2024
1 parent 0136831 commit 63b22b2
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 28 deletions.
28 changes: 27 additions & 1 deletion components/renku_data_services/base_api/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import wraps
from typing import Any, Concatenate, NoReturn, ParamSpec, TypeVar, cast

from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from sanic import Request, json
from sanic.response import JSONResponse
from sanic_ext import validate
Expand Down Expand Up @@ -97,3 +97,29 @@ async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwar
return decorated_function

return decorator


def validate_body_root_model(
json: type[RootModel],
) -> Callable[
[Callable[Concatenate[Request, _P], Awaitable[_T]]],
Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]],
]:
"""Decorator for sanic json payload validation when the model is derived from RootModel.
Should be removed once sanic fixes this error in their validation code.
Issue link: https://github.com/sanic-org/sanic-ext/issues/198
"""

def decorator(
f: Callable[Concatenate[Request, _P], Awaitable[_T]],
) -> Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]]:
@wraps(f)
async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwargs) -> _T:
if request.json is not None:
request.parsed_json = {"root": request.parsed_json} # type: ignore[assignment]
return await validate(json=json)(f)(request, *args, **kwargs)

return decorated_function

return decorator
30 changes: 17 additions & 13 deletions components/renku_data_services/crc/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from renku_data_services import errors
from renku_data_services.base_api.auth import authenticate, only_admins
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_db_ids, validate_query
from renku_data_services.base_api.misc import validate_body_root_model, validate_db_ids, validate_query
from renku_data_services.base_models.validation import validated_json
from renku_data_services.crc import apispec, models
from renku_data_services.crc.db import ResourcePoolRepository, UserRepository
Expand Down Expand Up @@ -161,9 +161,11 @@ def post(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@only_admins
@validate_db_ids
async def _post(request: Request, user: base_models.APIUser, resource_pool_id: int) -> HTTPResponse:
users = apispec.PoolUsersWithId.model_validate(request.json) # validation
return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=users, post=True)
@validate_body_root_model(json=apispec.PoolUsersWithId)
async def _post(
_: Request, user: base_models.APIUser, resource_pool_id: int, body: apispec.PoolUsersWithId
) -> HTTPResponse:
return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=body, post=True)

return "/resource_pools/<resource_pool_id>/users", ["POST"], _post

Expand All @@ -173,9 +175,11 @@ def put(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@only_admins
@validate_db_ids
async def _put(request: Request, user: base_models.APIUser, resource_pool_id: int) -> HTTPResponse:
users = apispec.PoolUsersWithId.model_validate(request.json) # validation
return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=users, post=False)
@validate_body_root_model(json=apispec.PoolUsersWithId)
async def _put(
_: Request, user: base_models.APIUser, resource_pool_id: int, body: apispec.PoolUsersWithId
) -> HTTPResponse:
return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=body, post=False)

return "/resource_pools/<resource_pool_id>/users", ["PUT"], _put

Expand Down Expand Up @@ -528,9 +532,9 @@ def post(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_admins
async def _post(request: Request, user: base_models.APIUser, user_id: str) -> HTTPResponse:
ids = apispec.IntegerIds.model_validate(request.json) # validation
return await self._post_put(user_id=user_id, post=True, resource_pool_ids=ids, api_user=user)
@validate_body_root_model(json=apispec.IntegerIds)
async def _post(_: Request, user: base_models.APIUser, user_id: str, body: apispec.IntegerIds) -> HTTPResponse:
return await self._post_put(user_id=user_id, post=True, resource_pool_ids=body, api_user=user)

return "/users/<user_id>/resource_pools", ["POST"], _post

Expand All @@ -539,9 +543,9 @@ def put(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_admins
async def _put(request: Request, user: base_models.APIUser, user_id: str) -> HTTPResponse:
ids = apispec.IntegerIds.model_validate(request.json) # validation
return await self._post_put(user_id=user_id, post=False, resource_pool_ids=ids, api_user=user)
@validate_body_root_model(json=apispec.IntegerIds)
async def _put(_: Request, user: base_models.APIUser, user_id: str, body: apispec.IntegerIds) -> HTTPResponse:
return await self._post_put(user_id=user_id, post=False, resource_pool_ids=body, api_user=user)

return "/users/<user_id>/resource_pools", ["PUT"], _put

Expand Down
12 changes: 6 additions & 6 deletions components/renku_data_services/namespace/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from renku_data_services.authz.models import Role, UnsavedMember
from renku_data_services.base_api.auth import authenticate, only_authenticated
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_api.misc import validate_body_root_model, validate_query
from renku_data_services.base_api.pagination import PaginationRequest, paginate
from renku_data_services.base_models.validation import validate_and_dump, validated_json
from renku_data_services.errors import errors
Expand Down Expand Up @@ -118,11 +118,11 @@ def update_members(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_authenticated
async def _update_members(request: Request, user: base_models.APIUser, slug: str) -> JSONResponse:
# TODO: sanic validation does not support validating top-level json lists, switch this to @validate
# once sanic-org/sanic-ext/issues/198 is fixed
body_validated = apispec.GroupMemberPatchRequestList.model_validate(request.json)
members = [UnsavedMember(Role.from_group_role(member.role), member.id) for member in body_validated.root]
@validate_body_root_model(json=apispec.GroupMemberPatchRequestList)
async def _update_members(
_: Request, user: base_models.APIUser, slug: str, body: apispec.GroupMemberPatchRequestList
) -> JSONResponse:
members = [UnsavedMember(Role.from_group_role(member.role), member.id) for member in body.root]
res = await self.group_repo.update_group_members(
user=user,
slug=slug,
Expand Down
10 changes: 6 additions & 4 deletions components/renku_data_services/project/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.etag import if_match_required
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_api.misc import validate_body_root_model, validate_query
from renku_data_services.base_api.pagination import PaginationRequest, paginate
from renku_data_services.errors import errors
from renku_data_services.project import apispec
Expand Down Expand Up @@ -259,9 +259,11 @@ def update_members(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@validate_path_project_id
async def _update_members(request: Request, user: base_models.APIUser, project_id: str) -> HTTPResponse:
body_dump = apispec.ProjectMemberListPatchRequest.model_validate(request.json)
members = [Member(Role(i.role.value), i.id, project_id) for i in body_dump.root]
@validate_body_root_model(json=apispec.ProjectMemberListPatchRequest)
async def _update_members(
_: Request, user: base_models.APIUser, project_id: str, body: apispec.ProjectMemberListPatchRequest
) -> HTTPResponse:
members = [Member(Role(i.role.value), i.id, project_id) for i in body.root]
await self.project_member_repo.update_members(user, ULID.from_str(project_id), members)
return HTTPResponse(status=200)

Expand Down
9 changes: 5 additions & 4 deletions components/renku_data_services/storage/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from renku_data_services import errors
from renku_data_services.base_api.auth import authenticate
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_api.misc import validate_body_root_model, validate_query
from renku_data_services.storage import apispec, models
from renku_data_services.storage.db import StorageRepository, StorageV2Repository
from renku_data_services.storage.rclone import RCloneValidator
Expand Down Expand Up @@ -302,9 +302,10 @@ def upsert_secrets(self) -> BlueprintFactoryResponse:
"""Create/update secrets for a cloud storage."""

@authenticate(self.authenticator)
async def _upsert_secrets(request: Request, user: base_models.APIUser, storage_id: ULID) -> JSONResponse:
# TODO: use @validate once sanic supports validating json lists
body = apispec.CloudStorageSecretPostList.model_validate(request.json)
@validate_body_root_model(json=apispec.CloudStorageSecretPostList)
async def _upsert_secrets(
_: Request, user: base_models.APIUser, storage_id: ULID, body: apispec.CloudStorageSecretPostList
) -> JSONResponse:
secrets = [models.CloudStorageSecretUpsert.model_validate(s.model_dump()) for s in body.root]
result = await self.storage_v2_repo.upsert_storage_secrets(
storage_id=storage_id, user=user, secrets=secrets
Expand Down

0 comments on commit 63b22b2

Please sign in to comment.