Skip to content

Commit

Permalink
add support for save_method PATCH
Browse files Browse the repository at this point in the history
  • Loading branch information
thilak reddy committed Jul 3, 2023
1 parent 3afcd1a commit f5a12b6
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 20 deletions.
24 changes: 18 additions & 6 deletions packages/opal-client/opal_client/data/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,17 @@ async def update_policy_data(
and isinstance(policy_data, dict)
):
await self._set_split_policy_data(
store_transaction, url=url, data=policy_data
store_transaction,
url=url,
save_method=entry.save_method,
data=policy_data,
)
else:
await self._set_policy_data(
store_transaction,
url=url,
path=policy_store_path,
save_method=entry.save_method,
data=policy_data,
)
# No exception we we're able to save to the policy-store
Expand Down Expand Up @@ -447,19 +451,27 @@ async def update_policy_data(
)
)

async def _set_split_policy_data(self, tx, url: str, data: Dict[str, Any]):
async def _set_split_policy_data(
self, tx, url: str, save_method: str, data: Dict[str, Any]
):
"""Split data writes to root ("/") path, so they won't overwrite other
sources."""
logger.info("Splitting root data to {n} keys", n=len(data))

for prefix, obj in data.items():
await self._set_policy_data(tx, url=url, path=f"/{prefix}", data=obj)
await self._set_policy_data(
tx, url=url, path=f"/{prefix}", save_method=save_method, data=obj
)

async def _set_policy_data(self, tx, url: str, path: str, data: JsonableValue):
async def _set_policy_data(
self, tx, url: str, path: str, save_method: str, data: JsonableValue
):
logger.info(
"Saving fetched data to policy-store: source url='{url}', destination path='{path}'",
url=url,
path=path or "/",
)

await tx.set_policy_data(data, path=path)
if save_method == "PUT":
await tx.set_policy_data(data, path=path)
else:
await tx.patch_policy_data(data, path=path)
66 changes: 64 additions & 2 deletions packages/opal-client/opal_client/policy_store/opa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import aiohttp
import dpath
import jsonpatch
from aiofiles.threadpool.text import AsyncTextIOWrapper
from fastapi import Response, status
from opal_client.config import opal_client_config
Expand All @@ -21,6 +22,7 @@
from opal_common.engine.parsing import get_rego_package
from opal_common.git.bundle_utils import BundleUtils
from opal_common.paths import PathUtils
from opal_common.schemas.data import custom_encoder
from opal_common.schemas.policy import DataModule, PolicyBundle, RegoModule
from opal_common.schemas.store import JSONPatchAction, StoreTransaction, TransactionType
from pydantic import BaseModel
Expand Down Expand Up @@ -262,6 +264,16 @@ def set(self, path, data):
# This would overwrite already existing paths
dpath.new(self._root_data, path, data)

def patch(self, path, data: List[JSONPatchAction]):
for i, _ in enumerate(data):
if not path == "/":
data[i].path = path + data[i].path
data_str = json.dumps(
data, default=custom_encoder(by_alias=True, exclude_none=True)
)
patch = jsonpatch.JsonPatch.from_string(data_str)
patch.apply(self._root_data, in_place=True)

def delete(self, path):
if not path or path == "/":
self._root_data = {}
Expand Down Expand Up @@ -741,10 +753,60 @@ async def set_policy_data(
async with aiohttp.ClientSession() as session:
try:
headers = await self._get_auth_headers()
data = json.dumps(
policy_data,
default=custom_encoder(by_alias=True, exclude_none=True),
)

async with session.put(
f"{self._opa_url}/data{path}",
data=json.dumps(policy_data, default=str),
data=data,
headers=headers,
**self._ssl_context_kwargs,
) as opa_response:
response = await proxy_response_unless_invalid(
opa_response,
accepted_status_codes=[
status.HTTP_204_NO_CONTENT,
status.HTTP_304_NOT_MODIFIED,
],
)
if self._policy_data_cache:
self._policy_data_cache.set(path, json.loads(data))
return response
except aiohttp.ClientError as e:
logger.warning("Opa connection error: {err}", err=repr(e))
raise

@affects_transaction
@retry(**RETRY_CONFIG)
async def patch_policy_data(
self,
policy_data: List[JSONPatchAction],
path: str = "",
transaction_id: Optional[str] = None,
):
path = self._safe_data_module_path(path)

# in OPA, the root document must be an object, so we must wrap list values
if not path and isinstance(policy_data, list):
logger.warning(
"OPAL client was instructed to put a list on OPA's root document. In OPA the root document must be an object so the original value was wrapped."
)
policy_data = {"items": policy_data}

async with aiohttp.ClientSession() as session:
try:
headers = await self._get_auth_headers()
headers["Content-Type"] = "application/json-patch+json"
data = json.dumps(
policy_data,
default=custom_encoder(by_alias=True, exclude_none=True),
)

async with session.patch(
f"{self._opa_url}/data{path}",
data=data,
headers=headers,
**self._ssl_context_kwargs,
) as opa_response:
Expand All @@ -756,7 +818,7 @@ async def set_policy_data(
],
)
if self._policy_data_cache:
self._policy_data_cache.set(path, policy_data)
self._policy_data_cache.patch(path, policy_data)
return response
except aiohttp.ClientError as e:
logger.warning("Opa connection error: {err}", err=repr(e))
Expand Down
1 change: 1 addition & 0 deletions packages/opal-client/requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ psutil>=5.9.1,<6
tenacity>=8.0.1,<9
websockets>=10.3,<11
dpath>=2.1.5,<3
jsonpatch>=1.33,<2
40 changes: 33 additions & 7 deletions packages/opal-common/opal_common/schemas/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,46 @@

from opal_common.fetcher.events import FetcherConfig
from opal_common.fetcher.providers.http_fetch_provider import HttpFetcherConfig
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator
from opal_common.schemas.store import JSONPatchAction
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
from pydantic.json import pydantic_encoder

JsonableValue = Union[List[JSONPatchAction], List[Any], Dict[str, Any]]

JsonableValue = Union[Dict[str, Any], List[Any]]

DEFAULT_DATA_TOPIC = "policy_data"

# custom encoder for doing a json.dumps on JsonableValue to pass in additional
# kwargs like by_alias, exclude_none etc to avoid field name being sent to OPA instead of alias
# and to exclude default fields from the JSON being sent to OPA
def custom_encoder(**kwargs):
def base_encoder(obj):
if isinstance(obj, BaseModel):
return obj.dict(**kwargs)
else:
return pydantic_encoder(obj)

return base_encoder


class DataSourceEntry(BaseModel):
"""
Data source configuration - where client's should retrieve data from and how they should store it
"""

@validator("data")
def name_must_contain_space(cls, value, values):
if values["save_method"] == "PATCH" and (
not isinstance(value, list)
or not all(isinstance(elem, JSONPatchAction) for elem in value)
):
raise TypeError(
"'data' must be of type JSON patch request when save_method is PATCH"
)
return value

# How to obtain the data
url: str = Field(..., description="Url source to query for data")
data: Optional[JsonableValue] = Field(
None,
description="Data payload to embed within the data update (instead of having "
"the client fetch it from the url).",
)
config: dict = Field(
None,
description="Suggested fetcher configuration (e.g. auth or method) to fetch data with",
Expand All @@ -37,6 +58,11 @@ class DataSourceEntry(BaseModel):
save_method: str = Field(
"PUT", description="Method used to write into OPA - PUT/PATCH"
)
data: Optional[JsonableValue] = Field(
None,
description="Data payload to embed within the data update (instead of having "
"the client fetch it from the url).",
)


class DataSourceEntryWithPollingInterval(DataSourceEntry):
Expand Down
17 changes: 13 additions & 4 deletions packages/opal-common/opal_common/schemas/store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, root_validator


class TransactionType(str, Enum):
Expand Down Expand Up @@ -47,9 +47,18 @@ class JSONPatchAction(BaseModel):

op: str = Field(..., description="patch action to perform")
path: str = Field(..., description="target location in modified json")
value: Dict[str, Any] = Field(
..., description="json document, the operand of the action"
value: Optional[Any] = Field(
None, description="json document, the operand of the action"
)
from_field: Optional[str] = Field(
None, description="source location in json", alias="from"
)

@root_validator
def value_must_be_present(cls, values):
if values.get("op") in ["add", "replace"] and values.get("value") is None:
raise TypeError("'value' must be present when op is either add or replace")
return values


class ArrayAppendAction(JSONPatchAction):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ async def publish_data_updates(self, update: DataUpdate):
)

async with self._publisher:
await self._publisher.publish(list(all_topic_combos), update)
await self._publisher.publish(
list(all_topic_combos), update.dict(by_alias=True)
)

async def _periodic_update_callback(
self, update: DataSourceEntryWithPollingInterval
Expand Down

0 comments on commit f5a12b6

Please sign in to comment.