diff --git a/packages/opal-client/opal_client/policy_store/mock_policy_store_client.py b/packages/opal-client/opal_client/policy_store/mock_policy_store_client.py index e19b5b46..bae93222 100644 --- a/packages/opal-client/opal_client/policy_store/mock_policy_store_client.py +++ b/packages/opal-client/opal_client/policy_store/mock_policy_store_client.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import jsonpatch -from opal_common.schemas.data import custom_encoder +from opal_client.utils import exclude_none_fields from opal_common.schemas.policy import PolicyBundle from opal_common.schemas.store import JSONPatchAction, StoreTransaction from pydantic import BaseModel @@ -65,10 +65,9 @@ async def patch_policy_data( for i, _ in enumerate(policy_data): if not path == "/": policy_data[i].path = path + policy_data[i].path - data_str = json.dumps( - policy_data, default=custom_encoder(by_alias=True, exclude_none=True) + patch = jsonpatch.JsonPatch.from_string( + json.dumps(exclude_none_fields(policy_data)) ) - patch = jsonpatch.JsonPatch.from_string(data_str) patch.apply(self._data, in_place=True) self.has_data_event.set() diff --git a/packages/opal-client/opal_client/policy_store/opa_client.py b/packages/opal-client/opal_client/policy_store/opa_client.py index 6db531c0..0fa390f6 100644 --- a/packages/opal-client/opal_client/policy_store/opa_client.py +++ b/packages/opal-client/opal_client/policy_store/opa_client.py @@ -18,11 +18,10 @@ JsonableValue, ) from opal_client.policy_store.schemas import PolicyStoreAuth -from opal_client.utils import proxy_response +from opal_client.utils import exclude_none_fields, proxy_response from opal_common.engine.parsing import get_rego_package from opal_common.git.bundle_utils import BundleUtils from opal_common.paths import PathUtils -from opal_common.schemas.data import custom_encoder from opal_common.schemas.policy import DataModule, PolicyBundle, RegoModule from opal_common.schemas.store import JSONPatchAction, StoreTransaction, TransactionType from pydantic import BaseModel @@ -268,10 +267,7 @@ def patch(self, path, data: List[JSONPatchAction]): for i, _ in enumerate(data): if not path == "/": data[i].path = path + data[i].path - data_str = json.dumps( - data, default=custom_encoder(by_alias=True, exclude_none=True) - ) - patch = jsonpatch.JsonPatch.from_string(data_str) + patch = jsonpatch.JsonPatch.from_string(json.dumps(exclude_none_fields(data))) patch.apply(self._root_data, in_place=True) def delete(self, path): @@ -753,11 +749,7 @@ async def set_policy_data( async with aiohttp.ClientSession() as session: try: headers = await self._get_auth_headers() - data = json.dumps( - policy_data, - default=custom_encoder(by_alias=True, exclude_none=True), - ) - + data = json.dumps(exclude_none_fields(policy_data)) async with session.put( f"{self._opa_url}/data{path}", data=data, @@ -799,14 +791,10 @@ async def patch_policy_data( try: headers = await self._get_auth_headers() headers["Content-Type"] = "application/json-patch+json" - data = json.dumps( - policy_data, - default=custom_encoder(by_alias=True, exclude_none=True), - ) async with session.patch( f"{self._opa_url}/data{path}", - data=data, + data=json.dumps(exclude_none_fields(policy_data)), headers=headers, **self._ssl_context_kwargs, ) as opa_response: diff --git a/packages/opal-client/opal_client/utils.py b/packages/opal-client/opal_client/utils.py index 41d62c40..c6d0de39 100644 --- a/packages/opal-client/opal_client/utils.py +++ b/packages/opal-client/opal_client/utils.py @@ -1,5 +1,6 @@ import aiohttp from fastapi import Response +from fastapi.encoders import jsonable_encoder async def proxy_response(response: aiohttp.ClientResponse) -> Response: @@ -10,3 +11,9 @@ async def proxy_response(response: aiohttp.ClientResponse) -> Response: headers=dict(response.headers), media_type="application/json", ) + + +def exclude_none_fields(data): + # remove default values from the pydatic model with a None value and also + # convert the model to a valid JSON serializable type using jsonable_encoder + return jsonable_encoder(data, exclude_none=True) diff --git a/packages/opal-common/opal_common/schemas/data.py b/packages/opal-common/opal_common/schemas/data.py index 5e20e267..ba7bedff 100644 --- a/packages/opal-common/opal_common/schemas/data.py +++ b/packages/opal-common/opal_common/schemas/data.py @@ -6,25 +6,12 @@ from opal_common.fetcher.providers.http_fetch_provider import HttpFetcherConfig from opal_common.schemas.store import JSONPatchAction from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator -from pydantic.json import pydantic_encoder JsonableValue = Union[List[JSONPatchAction], List[Any], Dict[str, Any]] DEFAULT_DATA_TOPIC = "policy_data" -# custom encoder for doing a json.dumps on JsonableValue to pass in additional -# kwargs like by_alias, exclude_none etc to avoid field name being sent to OPA instead of alias -# and to exclude default fields from the JSON being sent to OPA -def custom_encoder(**kwargs): - def base_encoder(obj): - if isinstance(obj, BaseModel): - return obj.dict(**kwargs) - else: - return pydantic_encoder(obj) - - return base_encoder - class DataSourceEntry(BaseModel): """ @@ -33,6 +20,8 @@ class DataSourceEntry(BaseModel): @validator("data") def name_must_contain_space(cls, value, values): + if values["save_method"] not in ["PUT", "PATCH"]: + raise ValueError("'save_method' must be either PUT or PATCH") if values["save_method"] == "PATCH" and ( not isinstance(value, list) or not all(isinstance(elem, JSONPatchAction) for elem in value)