diff --git a/packages/opal-client/opal_client/data/updater.py b/packages/opal-client/opal_client/data/updater.py index 5841c77d..607df28f 100644 --- a/packages/opal-client/opal_client/data/updater.py +++ b/packages/opal-client/opal_client/data/updater.py @@ -409,13 +409,17 @@ async def update_policy_data( and isinstance(policy_data, dict) ): await self._set_split_policy_data( - store_transaction, url=url, data=policy_data + store_transaction, + url=url, + save_method=entry.save_method, + data=policy_data, ) else: await self._set_policy_data( store_transaction, url=url, path=policy_store_path, + save_method=entry.save_method, data=policy_data, ) # No exception we we're able to save to the policy-store @@ -447,19 +451,27 @@ async def update_policy_data( ) ) - async def _set_split_policy_data(self, tx, url: str, data: Dict[str, Any]): + async def _set_split_policy_data( + self, tx, url: str, save_method: str, data: Dict[str, Any] + ): """Split data writes to root ("/") path, so they won't overwrite other sources.""" logger.info("Splitting root data to {n} keys", n=len(data)) for prefix, obj in data.items(): - await self._set_policy_data(tx, url=url, path=f"/{prefix}", data=obj) + await self._set_policy_data( + tx, url=url, path=f"/{prefix}", save_method=save_method, data=obj + ) - async def _set_policy_data(self, tx, url: str, path: str, data: JsonableValue): + async def _set_policy_data( + self, tx, url: str, path: str, save_method: str, data: JsonableValue + ): logger.info( "Saving fetched data to policy-store: source url='{url}', destination path='{path}'", url=url, path=path or "/", ) - - await tx.set_policy_data(data, path=path) + if save_method == "PUT": + await tx.set_policy_data(data, path=path) + else: + await tx.patch_policy_data(data, path=path) diff --git a/packages/opal-client/opal_client/policy_store/opa_client.py b/packages/opal-client/opal_client/policy_store/opa_client.py index b59e2f19..6db531c0 100644 --- a/packages/opal-client/opal_client/policy_store/opa_client.py +++ b/packages/opal-client/opal_client/policy_store/opa_client.py @@ -8,6 +8,7 @@ import aiohttp import dpath +import jsonpatch from aiofiles.threadpool.text import AsyncTextIOWrapper from fastapi import Response, status from opal_client.config import opal_client_config @@ -21,6 +22,7 @@ from opal_common.engine.parsing import get_rego_package from opal_common.git.bundle_utils import BundleUtils from opal_common.paths import PathUtils +from opal_common.schemas.data import custom_encoder from opal_common.schemas.policy import DataModule, PolicyBundle, RegoModule from opal_common.schemas.store import JSONPatchAction, StoreTransaction, TransactionType from pydantic import BaseModel @@ -262,6 +264,16 @@ def set(self, path, data): # This would overwrite already existing paths dpath.new(self._root_data, path, data) + def patch(self, path, data: List[JSONPatchAction]): + for i, _ in enumerate(data): + if not path == "/": + data[i].path = path + data[i].path + data_str = json.dumps( + data, default=custom_encoder(by_alias=True, exclude_none=True) + ) + patch = jsonpatch.JsonPatch.from_string(data_str) + patch.apply(self._root_data, in_place=True) + def delete(self, path): if not path or path == "/": self._root_data = {} @@ -741,10 +753,60 @@ async def set_policy_data( async with aiohttp.ClientSession() as session: try: headers = await self._get_auth_headers() + data = json.dumps( + policy_data, + default=custom_encoder(by_alias=True, exclude_none=True), + ) async with session.put( f"{self._opa_url}/data{path}", - data=json.dumps(policy_data, default=str), + data=data, + headers=headers, + **self._ssl_context_kwargs, + ) as opa_response: + response = await proxy_response_unless_invalid( + opa_response, + accepted_status_codes=[ + status.HTTP_204_NO_CONTENT, + status.HTTP_304_NOT_MODIFIED, + ], + ) + if self._policy_data_cache: + self._policy_data_cache.set(path, json.loads(data)) + return response + except aiohttp.ClientError as e: + logger.warning("Opa connection error: {err}", err=repr(e)) + raise + + @affects_transaction + @retry(**RETRY_CONFIG) + async def patch_policy_data( + self, + policy_data: List[JSONPatchAction], + path: str = "", + transaction_id: Optional[str] = None, + ): + path = self._safe_data_module_path(path) + + # in OPA, the root document must be an object, so we must wrap list values + if not path and isinstance(policy_data, list): + logger.warning( + "OPAL client was instructed to put a list on OPA's root document. In OPA the root document must be an object so the original value was wrapped." + ) + policy_data = {"items": policy_data} + + async with aiohttp.ClientSession() as session: + try: + headers = await self._get_auth_headers() + headers["Content-Type"] = "application/json-patch+json" + data = json.dumps( + policy_data, + default=custom_encoder(by_alias=True, exclude_none=True), + ) + + async with session.patch( + f"{self._opa_url}/data{path}", + data=data, headers=headers, **self._ssl_context_kwargs, ) as opa_response: @@ -756,7 +818,7 @@ async def set_policy_data( ], ) if self._policy_data_cache: - self._policy_data_cache.set(path, policy_data) + self._policy_data_cache.patch(path, policy_data) return response except aiohttp.ClientError as e: logger.warning("Opa connection error: {err}", err=repr(e)) diff --git a/packages/opal-client/requires.txt b/packages/opal-client/requires.txt index 0c7bc332..0f0cbcf0 100644 --- a/packages/opal-client/requires.txt +++ b/packages/opal-client/requires.txt @@ -4,3 +4,4 @@ psutil>=5.9.1,<6 tenacity>=8.0.1,<9 websockets>=10.3,<11 dpath>=2.1.5,<3 +jsonpatch>=1.33,<2 diff --git a/packages/opal-common/opal_common/schemas/data.py b/packages/opal-common/opal_common/schemas/data.py index f5ef20cb..cc968bb8 100644 --- a/packages/opal-common/opal_common/schemas/data.py +++ b/packages/opal-common/opal_common/schemas/data.py @@ -4,25 +4,46 @@ from opal_common.fetcher.events import FetcherConfig from opal_common.fetcher.providers.http_fetch_provider import HttpFetcherConfig -from pydantic import AnyHttpUrl, BaseModel, Field, root_validator +from opal_common.schemas.store import JSONPatchAction +from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator +from pydantic.json import pydantic_encoder + +JsonableValue = Union[List[JSONPatchAction], List[Any], Dict[str, Any]] -JsonableValue = Union[Dict[str, Any], List[Any]] DEFAULT_DATA_TOPIC = "policy_data" +# custom encoder for doing a json.dumps on JsonableValue to pass in additional +# kwargs like by_alias, exclude_none etc to avoid field name being sent to OPA instead of alias +# and to exclude default fields from the JSON being sent to OPA +def custom_encoder(**kwargs): + def base_encoder(obj): + if isinstance(obj, BaseModel): + return obj.dict(**kwargs) + else: + return pydantic_encoder(obj) + + return base_encoder + class DataSourceEntry(BaseModel): """ Data source configuration - where client's should retrieve data from and how they should store it """ + @validator("data") + def name_must_contain_space(cls, value, values): + if values["save_method"] == "PATCH" and ( + not isinstance(value, list) + or not all(isinstance(elem, JSONPatchAction) for elem in value) + ): + raise TypeError( + "'data' must be of type JSON patch request when save_method is PATCH" + ) + return value + # How to obtain the data url: str = Field(..., description="Url source to query for data") - data: Optional[JsonableValue] = Field( - None, - description="Data payload to embed within the data update (instead of having " - "the client fetch it from the url).", - ) config: dict = Field( None, description="Suggested fetcher configuration (e.g. auth or method) to fetch data with", @@ -37,6 +58,11 @@ class DataSourceEntry(BaseModel): save_method: str = Field( "PUT", description="Method used to write into OPA - PUT/PATCH" ) + data: Optional[JsonableValue] = Field( + None, + description="Data payload to embed within the data update (instead of having " + "the client fetch it from the url).", + ) class DataSourceEntryWithPollingInterval(DataSourceEntry): diff --git a/packages/opal-common/opal_common/schemas/store.py b/packages/opal-common/opal_common/schemas/store.py index 5ed9a9a9..1bfeca7e 100644 --- a/packages/opal-common/opal_common/schemas/store.py +++ b/packages/opal-common/opal_common/schemas/store.py @@ -1,8 +1,8 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, root_validator class TransactionType(str, Enum): @@ -47,9 +47,18 @@ class JSONPatchAction(BaseModel): op: str = Field(..., description="patch action to perform") path: str = Field(..., description="target location in modified json") - value: Dict[str, Any] = Field( - ..., description="json document, the operand of the action" + value: Optional[Any] = Field( + None, description="json document, the operand of the action" ) + from_field: Optional[str] = Field( + None, description="source location in json", alias="from" + ) + + @root_validator + def value_must_be_present(cls, values): + if values.get("op") in ["add", "replace"] and values.get("value") is None: + raise TypeError("'value' must be present when op is either add or replace") + return values class ArrayAppendAction(JSONPatchAction): diff --git a/packages/opal-server/opal_server/data/data_update_publisher.py b/packages/opal-server/opal_server/data/data_update_publisher.py index 572d1d21..f0f8df1d 100644 --- a/packages/opal-server/opal_server/data/data_update_publisher.py +++ b/packages/opal-server/opal_server/data/data_update_publisher.py @@ -109,7 +109,9 @@ async def publish_data_updates(self, update: DataUpdate): ) async with self._publisher: - await self._publisher.publish(list(all_topic_combos), update) + await self._publisher.publish( + list(all_topic_combos), update.dict(by_alias=True) + ) async def _periodic_update_callback( self, update: DataSourceEntryWithPollingInterval