Skip to content

Commit

Permalink
refactor: swap deprecated @validator() for new @field_validator()
Browse files Browse the repository at this point in the history
lpm0073 committed Dec 20, 2023
1 parent 2ab7e65 commit 2507b26
Showing 4 changed files with 33 additions and 29 deletions.
43 changes: 17 additions & 26 deletions terraform/python/rekognition_api/conf.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
# 3rd party stuff
import boto3 # AWS SDK for Python https://boto3.amazonaws.com/v1/documentation/api/latest/index.html
from dotenv import load_dotenv
from pydantic import Field, ValidationError, validator
from pydantic import Field, ValidationError, ValidationInfo, field_validator
from pydantic_settings import BaseSettings
from rekognition_api.const import HERE, IS_USING_TFVARS, TFVARS

@@ -75,11 +75,11 @@ class SettingsDefaults:
AWS_PROFILE = TFVARS.get("aws_profile", None)
DUMP_DEFAULTS = TFVARS.get("dump_defaults", False)
AWS_REGION = TFVARS.get("aws_region", "us-east-1")
DEBUG_MODE = TFVARS.get("debug_mode", False)
DEBUG_MODE: bool = bool(TFVARS.get("debug_mode", False))
TABLE_ID = "rekognition"
COLLECTION_ID = TABLE_ID + "-collection"
FACE_DETECT_MAX_FACES_COUNT = TFVARS.get("max_faces_count", 10)
FACE_DETECT_THRESHOLD = TFVARS.get("face_detect_threshold", 80)
FACE_DETECT_MAX_FACES_COUNT: int = int(TFVARS.get("max_faces_count", 10))
FACE_DETECT_THRESHOLD: int = int(TFVARS.get("face_detect_threshold", 10))
FACE_DETECT_ATTRIBUTES = TFVARS.get("face_detect_attributes", "DEFAULT")
FACE_DETECT_QUALITY_FILTER = TFVARS.get("face_detect_quality_filter", "AUTO")
SHARED_RESOURCE_IDENTIFIER = TFVARS.get("shared_resource_identifier", "rekognition_api")
@@ -308,53 +308,54 @@ class Config:

frozen = True

@validator("shared_resource_identifier", pre=True)
@field_validator("shared_resource_identifier")
def validate_shared_resource_identifier(cls, v) -> str:
"""Validate shared_resource_identifier"""
if v in [None, ""]:
return SettingsDefaults.SHARED_RESOURCE_IDENTIFIER
return v

@validator("aws_profile", pre=True)
@field_validator("aws_profile")
# pylint: disable=no-self-argument,unused-argument
def validate_aws_profile(cls, v, values, **kwargs) -> str:
"""Validate aws_profile"""
if v in [None, ""]:
return SettingsDefaults.AWS_PROFILE
return v

@validator("aws_region", pre=True)
@field_validator("aws_region")
# pylint: disable=no-self-argument,unused-argument
def validate_aws_region(cls, v, values, **kwargs) -> str:
def validate_aws_region(cls, v, values: ValidationInfo, **kwargs) -> str:
"""Validate aws_region"""
valid_regions = values.data.get("aws_regions", [])
if v in [None, ""]:
return SettingsDefaults.AWS_REGION
if "aws_regions" in values and v not in values["aws_regions"]:
if v not in valid_regions:
raise RekognitionValueError(f"aws_region {v} not in aws_regions")
return v

@validator("table_id", pre=True)
@field_validator("table_id")
def validate_table_id(cls, v) -> str:
"""Validate table_id"""
if v in [None, ""]:
return SettingsDefaults.TABLE_ID
return v

@validator("collection_id", pre=True)
@field_validator("collection_id")
def validate_collection_id(cls, v) -> str:
"""Validate collection_id"""
if v in [None, ""]:
return SettingsDefaults.COLLECTION_ID
return v

@validator("face_detect_attributes", pre=True)
@field_validator("face_detect_attributes")
def validate_face_detect_attributes(cls, v) -> str:
"""Validate face_detect_attributes"""
if v in [None, ""]:
return SettingsDefaults.FACE_DETECT_ATTRIBUTES
return v

@validator("debug_mode", pre=True)
@field_validator("debug_mode")
def parse_debug_mode(cls, v) -> bool:
"""Parse debug_mode"""
if isinstance(v, bool):
@@ -363,7 +364,7 @@ def parse_debug_mode(cls, v) -> bool:
return SettingsDefaults.DEBUG_MODE
return v.lower() in ["true", "1", "t", "y", "yes"]

@validator("dump_defaults", pre=True)
@field_validator("dump_defaults")
def parse_dump_defaults(cls, v) -> bool:
"""Parse dump_defaults"""
if isinstance(v, bool):
@@ -372,14 +373,14 @@ def parse_dump_defaults(cls, v) -> bool:
return SettingsDefaults.DUMP_DEFAULTS
return v.lower() in ["true", "1", "t", "y", "yes"]

@validator("face_detect_max_faces_count", pre=True)
@field_validator("face_detect_max_faces_count")
def check_face_detect_max_faces_count(cls, v) -> int:
"""Check face_detect_max_faces_count"""
if v in [None, ""]:
return SettingsDefaults.FACE_DETECT_MAX_FACES_COUNT
return int(v)

@validator("face_detect_threshold", pre=True)
@field_validator("face_detect_threshold")
def check_face_detect_threshold(cls, v) -> int:
"""Check face_detect_threshold"""
if isinstance(v, int):
@@ -394,13 +395,3 @@ def check_face_detect_threshold(cls, v) -> int:
settings = Settings()
except ValidationError as e:
raise RekognitionConfigurationError("Invalid configuration: " + str(e)) from e

logger = logging.getLogger(__name__)
logger.debug("DEBUG_MODE: %s", settings.debug_mode)
logger.debug("AWS_REGION: %s", settings.aws_region)
logger.debug("TABLE_ID: %s", settings.table_id)
logger.debug("COLLECTION_ID: %s", settings.collection_id)
logger.debug("FACE_DETECT_MAX_FACES_COUNT: %s", settings.face_detect_max_faces_count)
logger.debug("FACE_DETECT_ATTRIBUTES: %s", settings.face_detect_attributes)
logger.debug("FACE_DETECT_QUALITY_FILTER: %s", settings.face_detect_quality_filter)
logger.debug("FACE_DETECT_THRESHOLD: %s", settings.face_detect_threshold)
5 changes: 5 additions & 0 deletions terraform/python/rekognition_api/tests/.env.test_legal_nulls
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
FACE_DETECT_ATTRIBUTES=
QUALITY_FILTER=
TABLE_ID=
AWS_REGION=
COLLECTION_ID=
14 changes: 11 additions & 3 deletions terraform/python/rekognition_api/tests/test_configuration.py
Original file line number Diff line number Diff line change
@@ -47,10 +47,20 @@ def test_conf_defaults(self):
self.assertEqual(mock_settings.face_detect_quality_filter, SettingsDefaults.FACE_DETECT_QUALITY_FILTER)
self.assertEqual(mock_settings.face_detect_threshold, SettingsDefaults.FACE_DETECT_THRESHOLD)

def test_env_illegal_nulls(self):
"""Test that settings handles missing .env values."""
os.environ.clear()
env_path = self.env_path(".env.test_illegal_nulls")
loaded = load_dotenv(env_path)
self.assertTrue(loaded)

with self.assertRaises(PydanticValidationError):
Settings()

def test_env_nulls(self):
"""Test that settings handles missing .env values."""
os.environ.clear()
env_path = self.env_path(".env.test_nulls")
env_path = self.env_path(".env.test_legal_nulls")
loaded = load_dotenv(env_path)
self.assertTrue(loaded)

@@ -59,10 +69,8 @@ def test_env_nulls(self):
self.assertEqual(mock_settings.aws_region, SettingsDefaults.AWS_REGION)
self.assertEqual(mock_settings.table_id, SettingsDefaults.TABLE_ID)
self.assertEqual(mock_settings.collection_id, SettingsDefaults.COLLECTION_ID)
self.assertEqual(mock_settings.face_detect_max_faces_count, SettingsDefaults.FACE_DETECT_MAX_FACES_COUNT)
self.assertEqual(mock_settings.face_detect_attributes, SettingsDefaults.FACE_DETECT_ATTRIBUTES)
self.assertEqual(mock_settings.face_detect_quality_filter, SettingsDefaults.FACE_DETECT_QUALITY_FILTER)
self.assertEqual(mock_settings.face_detect_threshold, SettingsDefaults.FACE_DETECT_THRESHOLD)

def test_env_overrides(self):
"""Test that settings takes custom .env values."""

0 comments on commit 2507b26

Please sign in to comment.