Skip to content

Commit

Permalink
migrate to pydantic>=2
Browse files Browse the repository at this point in the history
  • Loading branch information
johanlundberg committed Nov 22, 2023
1 parent f4ce0e8 commit e32a58f
Show file tree
Hide file tree
Showing 19 changed files with 1,873 additions and 1,639 deletions.
3 changes: 1 addition & 2 deletions constraints.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pydantic>=1.8.2,<2
aiohttp>=3.7.4
anyio>=3.7.1,<4.0.0
1,138 changes: 611 additions & 527 deletions dev_requirements.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ xmltodict
loguru
motor
pysaml2
pki-tools
pydantic-settings
uvicorn[standard]
gunicorn
# jwcrypto needs six
Expand Down
1,056 changes: 578 additions & 478 deletions requirements.txt

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions src/auth_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from typing import Any, Dict, List, Optional, Union

import yaml
from pydantic import AnyUrl, BaseModel, BaseSettings, Field, ValidationError, validator
from pydantic import AnyUrl, BaseModel, Field, ValidationError, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

from auth_server.models.gnap import Proof
from auth_server.models.jose import ECJWK, RSAJWK, SymmetricJWK
Expand Down Expand Up @@ -78,15 +79,16 @@ class AuthServerConfig(BaseSettings):
pysaml2_config_name: str = "SAML_CONFIG"
saml2_discovery_service_url: Optional[AnyUrl] = None
saml2_single_idp: Optional[str] = None
ca_certs_path: Optional[Path] = None # all files ending with .crt will be loaded recursively. PEM and DER supported

@validator("application_root")
@field_validator("application_root")
@classmethod
def application_root_must_not_end_with_slash(cls, v: str):
if v.endswith("/"):
v = v.removesuffix("/")
return v

class Config:
frozen = True # make hashable
model_config = SettingsConfigDict(frozen=True)


def read_config_file(config_file: str, config_ns: str = "") -> Dict:
Expand Down
12 changes: 5 additions & 7 deletions src/auth_server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
from fastapi import Request, Response
from fastapi.routing import APIRoute
from jwcrypto import jws
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

__author__ = "lundberg"


class Context(BaseModel):
jws_verified: bool = False
client_cert: Optional[str]
jws_obj: Optional[jws.JWS]
detached_jws: Optional[str]

class Config:
arbitrary_types_allowed = True
client_cert: Optional[str] = None
jws_obj: Optional[jws.JWS] = None
detached_jws: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

def to_dict(self):
return self.dict()
Expand Down
16 changes: 8 additions & 8 deletions src/auth_server/db/transaction_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,30 @@ def from_dict(cls: Type[T], state: Mapping[str, Any]) -> T:
return cls(**state)

def to_dict(self) -> Dict[str, Any]:
return self.dict(exclude_none=True)
return self.model_dump(exclude_none=True)


class TestState(TransactionState):
auth_source = AuthSource.TEST
auth_source: AuthSource = AuthSource.TEST


class InteractionState(TransactionState):
auth_source = AuthSource.INTERACTION
auth_source: AuthSource = AuthSource.INTERACTION


class ConfigState(TransactionState):
auth_source = AuthSource.CONFIG
auth_source: AuthSource = AuthSource.CONFIG
config_claims: Dict[str, Any] = Field(default_factory=dict)


class MDQState(TransactionState):
auth_source = AuthSource.MDQ
mdq_data: Optional[MDQData]
auth_source: AuthSource = AuthSource.MDQ
mdq_data: Optional[MDQData] = None


class TLSFEDState(TransactionState):
auth_source = AuthSource.TLSFED
entity: Optional[MetadataEntity]
auth_source: AuthSource = AuthSource.TLSFED
entity: Optional[MetadataEntity] = None


class TransactionStateDB(BaseDB):
Expand Down
44 changes: 11 additions & 33 deletions src/auth_server/mdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
from base64 import b64encode
from collections import OrderedDict as _OrderedDict
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, OrderedDict, Union
from typing import Any, List, Optional, OrderedDict

import aiohttp
import xmltodict
from cryptography.hazmat.primitives.hashes import SHA1, SHA256
from cryptography.x509 import Certificate
from loguru import logger
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serializer
from pyexpat import ExpatError

if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny, DictStrAny

from auth_server.models.gnap import Key, Proof, ProofMethod
from auth_server.utils import get_values, hash_with, load_cert_from_str, serialize_certificate

Expand All @@ -27,45 +24,26 @@ class KeyUse(str, Enum):


class MDQBase(BaseModel):
class Config:
allow_mutation = False # should not change after load
arbitrary_types_allowed = True # needed for x509.Certificate
json_encoders = {Certificate: serialize_certificate}
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)


class MDQCert(MDQBase):
use: KeyUse
cert: Certificate

@validator("cert", pre=True)
@field_validator("cert", mode="before")
@classmethod
def deserialize_cert(cls, v: str) -> Certificate:
if isinstance(v, Certificate):
return v
return load_cert_from_str(v)

def dict(
self,
*,
include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> "DictStrAny":
# serialize Certificate on dict use
d = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
d["cert"] = serialize_certificate(d["cert"])
return d
@model_serializer
def serialize_mdq_cert(self) -> dict[str, Any]:
"""
serialize Certificate on model_dump
"""
return {"use": self.use.value, "cert": serialize_certificate(self.cert)}


class MDQData(MDQBase):
Expand Down
10 changes: 4 additions & 6 deletions src/auth_server/models/claims.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-
from typing import Any, List, Optional, Union
from typing import List, Optional, Union

from pydantic import Extra
from pydantic import ConfigDict

from auth_server.config import FlowName
from auth_server.models.gnap import Access
from auth_server.models.jose import RegisteredClaims
from auth_server.saml2 import AuthnInfo

__author__ = "lundberg"

Expand All @@ -20,8 +18,8 @@ class Claims(RegisteredClaims):


class ConfigClaims(Claims):
class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")



class MDQClaims(Claims):
Expand Down
16 changes: 7 additions & 9 deletions src/auth_server/models/gnap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from auth_server.models.jose import (
ECJWK,
Expand All @@ -23,8 +23,7 @@


class GnapBaseModel(BaseModel):
class Config:
allow_population_by_field_name = True
model_config = ConfigDict(populate_by_name=True)


class ProofMethod(str, Enum):
Expand All @@ -47,7 +46,8 @@ class Key(GnapBaseModel):
cert: Optional[str] = None
cert_S256: Optional[str] = Field(default=None, alias="cert#S256")

@validator("proof", pre=True)
@field_validator("proof", mode="before")
@classmethod
def expand_proof(cls, v: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
# If additional parameters are not required or used for a specific method,
# the method MAY be passed as a string instead of an object.
Expand Down Expand Up @@ -139,9 +139,7 @@ class SubjectIdentifier(GnapBaseModel):
# {"format": "email", "email": "[email protected]"}
# see ietf-secevent-subject-identifiers
format: SubjectIdentifierFormat

class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")


class SubjectAssertion(GnapBaseModel):
Expand Down Expand Up @@ -190,8 +188,8 @@ class Hints(GnapBaseModel):

class InteractionRequest(GnapBaseModel):
start: List[StartInteractionMethod]
finish: Optional[FinishInteraction]
hints: Optional[Hints]
finish: Optional[FinishInteraction] = None
hints: Optional[Hints] = None


class GrantRequest(GnapBaseModel):
Expand Down
10 changes: 2 additions & 8 deletions src/auth_server/models/jose.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class RegisteredClaims(BaseModel):
iss: Optional[str] = None # Issuer
sub: Optional[str] = None # Subject
aud: Optional[str] = None # Audience
exp: Optional[timedelta] # Expiration Time
exp: Optional[timedelta] = None # Expiration Time
nbf: Optional[datetime] = Field(default_factory=utc_now) # Not Before
iat: Optional[datetime] = Field(default_factory=utc_now) # Issued At
jti: Optional[str] = None # JWT ID
Expand Down Expand Up @@ -106,12 +106,6 @@ class SymmetricJWK(JWK):
k: Optional[str] = None


# Workaround for mypy not liking Union[ECJWK, RSAJWK, SymmetricJWK] as response_model. It should work.
# https://github.com/tiangolo/fastapi/issues/2279
class JWKTypes(BaseModel):
__root__: Union[ECJWK, RSAJWK, SymmetricJWK]


class JWKS(BaseModel):
keys: List[Union[ECJWK, RSAJWK, SymmetricJWK]]

Expand All @@ -129,7 +123,7 @@ class JOSEHeader(BaseModel):
x5u: Optional[str] = None
x5c: Optional[str] = None
x5t: Optional[str] = None
x5tS256: Optional[str] = Field(alias="x5t#S256")
x5tS256: Optional[str] = Field(default=None, alias="x5t#S256")
typ: Optional[str] = None
cty: Optional[str] = None
crit: Optional[List] = None
46 changes: 24 additions & 22 deletions src/auth_server/models/tls_fed_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import Enum
from typing import List, Optional

from pydantic import AnyUrl, BaseModel, Extra, Field, conint, constr
from pydantic import AnyUrl, BaseModel, ConfigDict, Field, PositiveInt

from auth_server.models.jose import JOSEHeader

Expand All @@ -28,9 +28,7 @@ class SAMLScopeExtension(BaseModel):


class Extensions(BaseModel):
class Config:
extra = Extra.allow
allow_population_by_field_name = True # allow registered extension to also be set by name, not only by alias
model_config = ConfigDict(extra="allow", populate_by_name=True)

saml_scope: Optional[SAMLScopeExtension] = Field(default=None, alias=RegisteredExtensions.SAML_SCOPE.value) # type: ignore[literal-required]

Expand All @@ -44,42 +42,42 @@ class Alg(str, Enum):


class PinDirective(BaseModel):
alg: Alg = Field(..., example="sha256", title="Directive name")
digest: constr(regex=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$") = Field( # type: ignore
alg: Alg = Field(..., examples=["sha256"], title="Directive name")
digest: str = Field(
...,
example="HiMkrb4phPSP+OvGqmZd6sGvy7AUn4k3XEe8OMBrzt8=",
examples=["HiMkrb4phPSP+OvGqmZd6sGvy7AUn4k3XEe8OMBrzt8="],
title="Directive value (Base64)",
pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$",
)


class Endpoint(BaseModel):
class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")

description: Optional[str] = Field(None, example="SCIM Server 1", title="Endpoint description")
tags: Optional[List[constr(regex=r"^[a-z0-9]{1,64}$")]] = Field( # type: ignore
description: Optional[str] = Field(None, examples=["SCIM Server 1"], title="Endpoint description")
tags: Optional[List[str]] = Field(
None,
description="A list of strings that describe the endpoint's capabilities.\n",
title="Endpoint tags",
pattern=r"^[a-z0-9]{1,64}$",
)
base_uri: Optional[AnyUrl] = Field(None, example="https://scim.example.com", title="Endpoint base URI")
base_uri: Optional[AnyUrl] = Field(None, examples=["https://scim.example.com"], title="Endpoint base URI")
pins: List[PinDirective] = Field(..., title="Certificate pin set")


class Entity(BaseModel):
class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")

entity_id: AnyUrl = Field(
entity_id: str = Field(
...,
description="Globally unique identifier for the entity.",
example="https://example.com",
examples=["https://example.com"],
title="Entity identifier",
)
organization: Optional[str] = Field(
None,
description="Name identifying the organization that the entity’s\nmetadata represents.\n",
example="Example Org",
examples=["Example Org"],
title="Name of entity organization",
)
issuers: List[CertIssuers] = Field(
Expand All @@ -95,14 +93,18 @@ class Config:


class Model(BaseModel):
class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")

version: constr(regex=r"^\d+\.\d+\.\d+$") = Field(..., example="1.0.0", title="Metadata schema version") # type: ignore
cache_ttl: Optional[conint(ge=0)] = Field( # type: ignore
version: str = Field(
...,
examples=["1.0.0"],
title="Metadata schema version",
pattern=r"^\d+\.\d+\.\d+$",
)
cache_ttl: Optional[PositiveInt] = Field(
None,
description="How long (in seconds) to cache metadata.\nEffective maximum TTL is the minimum of HTTP Expire and TTL\n",
example=3600,
examples=[3600],
title="Metadata cache TTL",
)
entities: List[Entity]
Loading

0 comments on commit e32a58f

Please sign in to comment.