Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Guyzyl/per 7303 update pydantic to v2 #488

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions packages/opal-client/opal_client/engine/options.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, List, Optional

from pydantic import BaseModel, Field, validator
from pydantic import field_validator, BaseModel, Field, validator


class LogLevel(str, Enum):
Expand Down Expand Up @@ -63,6 +63,8 @@ class OpaServerOptions(BaseModel):
description="list of built-in rego policies and data.json files that must be loaded into OPA on startup. e.g: system.authz policy when using --authorization=basic, see: https://www.openpolicyagent.org/docs/latest/security/#authentication-and-authorization",
)

# TODO[pydantic]: We couldn't refactor this class, please create the `model_config` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
class Config:
use_enum_values = True
allow_population_by_field_name = True
Expand Down Expand Up @@ -104,6 +106,8 @@ class CedarServerOptions(BaseModel):
description="list of built-in policies files that must be loaded on startup.",
)

# TODO[pydantic]: We couldn't refactor this class, please create the `model_config` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
class Config:
use_enum_values = True
allow_population_by_field_name = True
Expand All @@ -114,12 +118,15 @@ def alias_generator(cls, string: str) -> str:
file (to be used by opa cli)"""
return "--{}".format(string.replace("_", "-"))

@validator("authentication")
@field_validator("authentication")
@classmethod
def validate_authentication(cls, v: AuthenticationScheme):
if v not in [AuthenticationScheme.off, AuthenticationScheme.token]:
raise ValueError("Invalid AuthenticationScheme for Cedar.")
return v

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("authentication_token")
def validate_authentication_token(cls, v: Optional[str], values: dict[str, Any]):
if values["authentication"] == AuthenticationScheme.token and v is None:
Expand Down
10 changes: 4 additions & 6 deletions packages/opal-client/opal_client/policy_store/schemas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Optional

from pydantic import BaseModel, Field, validator
from pydantic import field_validator, ConfigDict, BaseModel, Field


class PolicyStoreTypes(Enum):
Expand Down Expand Up @@ -53,14 +53,12 @@ class PolicyStoreDetails(BaseModel):
None, description="optional OAuth server required by the policy store"
)

@validator("type")
@field_validator("type")
@classmethod
def force_enum(cls, v):
if isinstance(v, str):
return PolicyStoreTypes(v)
if isinstance(v, PolicyStoreTypes):
return v
raise ValueError(f"invalid value: {v}")

class Config:
use_enum_values = True
allow_population_by_field_name = True
model_config = ConfigDict(use_enum_values=True, populate_by_name=True)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from aiohttp import ClientResponse, ClientSession
from pydantic import validator
from pydantic import field_validator, ConfigDict

from ...http import is_http_error_response
from ...security.sslcontext import get_custom_ssl_context
Expand Down Expand Up @@ -33,16 +33,15 @@ class HttpFetcherConfig(FetcherConfig):
method: HttpMethods = HttpMethods.GET
data: Any = None

@validator("method")
@field_validator("method")
@classmethod
def force_enum(cls, v):
if isinstance(v, str):
return HttpMethods(v)
if isinstance(v, HttpMethods):
return v
raise ValueError(f"invalid value: {v}")

class Config:
use_enum_values = True
model_config = ConfigDict(use_enum_values=True)


class HttpFetchEvent(FetchEvent):
Expand Down
7 changes: 5 additions & 2 deletions packages/opal-common/opal_common/schemas/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from opal_common.fetcher.events import FetcherConfig
from opal_common.fetcher.providers.http_fetch_provider import HttpFetcherConfig
from opal_common.schemas.store import JSONPatchAction
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
from pydantic import AnyHttpUrl, BaseModel, Field, model_validator, validator

JsonableValue = Union[List[JSONPatchAction], List[Any], Dict[str, Any]]

Expand All @@ -18,6 +18,8 @@ class DataSourceEntry(BaseModel):
Data source configuration - where client's should retrieve data from and how they should store it
"""

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("data")
def validate_save_method(cls, value, values):
if values["save_method"] not in ["PUT", "PATCH"]:
Expand Down Expand Up @@ -102,7 +104,8 @@ class ServerDataSourceConfig(BaseModel):
+ " if set, the clients will be redirected to this url when requesting to fetch data sources.",
)

@root_validator
@model_validator(mode="before")
@classmethod
def check_passwords_match(cls, values):
config, redirect_url = values.get("config"), values.get("external_source_url")
if config is None and redirect_url is None:
Expand Down
7 changes: 3 additions & 4 deletions packages/opal-common/opal_common/schemas/policy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from pathlib import Path
from typing import List, Optional

from pydantic import BaseModel, Field
from pydantic import ConfigDict, BaseModel, Field


class BaseSchema(BaseModel):
class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)


class DataModule(BaseSchema):
Expand Down Expand Up @@ -38,7 +37,7 @@ class PolicyBundle(BaseSchema):
)
data_modules: List[DataModule]
policy_modules: List[RegoModule]
deleted_files: Optional[DeletedFiles]
deleted_files: Optional[DeletedFiles] = None


class PolicyUpdateMessage(BaseSchema):
Expand Down
12 changes: 5 additions & 7 deletions packages/opal-common/opal_common/schemas/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional
from uuid import UUID, uuid4

from pydantic import BaseModel, Field, validator
from pydantic import field_validator, ConfigDict, BaseModel, Field

PEER_TYPE_DESCRIPTION = (
"The peer type we generate access token for, i.e: opal client, data provider, etc."
Expand All @@ -29,17 +29,15 @@ class AccessTokenRequest(BaseModel):
ttl: timedelta = Field(timedelta(days=365), description=TTL_DESCRIPTION)
claims: dict = Field({}, description=CLAIMS_DESCRIPTION)

@validator("type")
@field_validator("type")
@classmethod
def force_enum(cls, v):
if isinstance(v, str):
return PeerType(v)
if isinstance(v, PeerType):
return v
raise ValueError(f"invalid value: {v}")

class Config:
use_enum_values = True
allow_population_by_field_name = True
model_config = ConfigDict(use_enum_values=True, populate_by_name=True)


class TokenDetails(BaseModel):
Expand All @@ -52,4 +50,4 @@ class TokenDetails(BaseModel):
class AccessToken(BaseModel):
token: str
type: str = "bearer"
details: Optional[TokenDetails]
details: Optional[TokenDetails] = None
8 changes: 4 additions & 4 deletions packages/opal-common/opal_common/schemas/store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional

from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_validator


class TransactionType(str, Enum):
Expand Down Expand Up @@ -54,7 +53,8 @@ class JSONPatchAction(BaseModel):
None, description="source location in json", alias="from"
)

@root_validator
@model_validator(mode="before")
@classmethod
def value_must_be_present(cls, values):
if values.get("op") in ["add", "replace"] and values.get("value") is None:
raise TypeError("'value' must be present when op is either add or replace")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import os
from typing import List

from fastapi_utils.tasks import repeat_every
from opal_common.logger import logger
from opal_common.schemas.data import (
DataSourceEntryWithPollingInterval,
DataUpdate,
ServerDataSourceConfig,
)
from opal_common.topics.publisher import TopicPublisher
from opal_common.utils import repeat_every

TOPIC_DELIMITER = "/"
PREFIX_DELIMITER = ":"
Expand Down
81 changes: 81 additions & 0 deletions packages/opal-server/opal_server/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Taken from - https://github.com/dmontagu/fastapi-utils/blob/3ef27a6f67ac10fae6a8b4816549c0c44567a451/fastapi_utils/tasks.py
# Was copied because fastapi-utils doesn't support Pydantic 2.0+ yet.
import asyncio
import logging
from functools import wraps
from traceback import format_exception
from typing import Callable, Coroutine, Any, Union

from starlette.concurrency import run_in_threadpool

NoArgsNoReturnFuncT = Callable[[], None]
NoArgsNoReturnAsyncFuncT = Callable[[], Coroutine[Any, Any, None]]
NoArgsNoReturnDecorator = Callable[[Union[NoArgsNoReturnFuncT, NoArgsNoReturnAsyncFuncT]], NoArgsNoReturnAsyncFuncT]


def repeat_every(
*,
seconds: float,
wait_first: bool = False,
logger: logging.Logger | None = None,
raise_exceptions: bool = False,
max_repetitions: int | None = None,
) -> NoArgsNoReturnDecorator:
"""
This function returns a decorator that modifies a function so it is periodically re-executed after its first call.

The function it decorates should accept no arguments and return nothing. If necessary, this can be accomplished
by using `functools.partial` or otherwise wrapping the target function prior to decoration.

Parameters
----------
seconds: float
The number of seconds to wait between repeated calls
wait_first: bool (default False)
If True, the function will wait for a single period before the first call
logger: Optional[logging.Logger] (default None)
The logger to use to log any exceptions raised by calls to the decorated function.
If not provided, exceptions will not be logged by this function (though they may be handled by the event loop).
raise_exceptions: bool (default False)
If True, errors raised by the decorated function will be raised to the event loop's exception handler.
Note that if an error is raised, the repeated execution will stop.
Otherwise, exceptions are just logged and the execution continues to repeat.
See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.set_exception_handler for more info.
max_repetitions: Optional[int] (default None)
The maximum number of times to call the repeated function. If `None`, the function is repeated forever.
"""

def decorator(func: NoArgsNoReturnAsyncFuncT | NoArgsNoReturnFuncT) -> NoArgsNoReturnAsyncFuncT:
"""
Converts the decorated function into a repeated, periodically-called version of itself.
"""
is_coroutine = asyncio.iscoroutinefunction(func)

@wraps(func)
async def wrapped() -> None:
repetitions = 0

async def loop() -> None:
nonlocal repetitions
if wait_first:
await asyncio.sleep(seconds)
while max_repetitions is None or repetitions < max_repetitions:
try:
if is_coroutine:
await func() # type: ignore
else:
await run_in_threadpool(func)
repetitions += 1
except Exception as exc:
if logger is not None:
formatted_exception = "".join(format_exception(type(exc), exc, exc.__traceback__))
logger.error(formatted_exception)
if raise_exceptions:
raise exc
await asyncio.sleep(seconds)

asyncio.ensure_future(loop())

return wrapped

return decorator
10 changes: 5 additions & 5 deletions packages/requires.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
charset-normalizer>=2.0.12,<3
idna>=3.3,<4
typer>=0.4.1,<1
fastapi>=0.78.0,<1
fastapi_websocket_pubsub==0.3.4
fastapi_websocket_rpc>=0.1.21,<1
fastapi>=0.101.0,<1
fastapi_websocket_pubsub @ git+https://github.com/permitio/fastapi_websocket_pubsub.git@guyzyl/per-7303-update-pydantic-to-v2 # TODO: Restore to use package from pypi
fastapi_websocket_rpc @ git+https://github.com/permitio/fastapi_websocket_rpc.git@guyzyl/per-7303-update-pydantic-to-v2 # TODO: Restore to use package
gunicorn>=20.1.0,<21
pydantic[email]>=1.9.1,<2
pydantic[email]>=2.1.1,<3
typing-extensions;python_version<'3.8'
uvicorn[standard]>=0.17.6,<1
fastapi-utils>=0.2.1,<1
starlette>=0.26.1,<1
setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
Loading