Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add aws key pair to conf.Settings() #59

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion terraform/python/rekognition_api/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Managed via automated CI/CD in .github/workflows/semanticVersionBump.yml.
__version__ = "0.2.3-next.1"
__version__ = "0.2.4-next.1"
88 changes: 80 additions & 8 deletions terraform/python/rekognition_api/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# 3rd party stuff
import boto3 # AWS SDK for Python https://boto3.amazonaws.com/v1/documentation/api/latest/index.html
from dotenv import load_dotenv
from pydantic import Field, ValidationError, ValidationInfo, field_validator
from pydantic import Field, SecretStr, ValidationError, ValidationInfo, field_validator
from pydantic_settings import BaseSettings
from rekognition_api.const import HERE, IS_USING_TFVARS, TFVARS

Expand All @@ -31,6 +31,8 @@
)


logger = logging.getLogger(__name__)
TFVARS = TFVARS or {}
DOT_ENV_LOADED = load_dotenv()


Expand Down Expand Up @@ -73,8 +75,11 @@ class SettingsDefaults:
"""Default values for Settings"""

AWS_PROFILE = TFVARS.get("aws_profile", None)
DUMP_DEFAULTS = TFVARS.get("dump_defaults", False)
AWS_ACCESS_KEY_ID = SecretStr(None)
AWS_SECRET_ACCESS_KEY = SecretStr(None)
AWS_REGION = TFVARS.get("aws_region", "us-east-1")

DUMP_DEFAULTS = TFVARS.get("dump_defaults", False)
DEBUG_MODE: bool = bool(TFVARS.get("debug_mode", False))
SHARED_RESOURCE_IDENTIFIER = TFVARS.get("shared_resource_identifier", "rekognition_api")

Expand Down Expand Up @@ -123,10 +128,23 @@ class Settings(BaseSettings):
"""Settings for Lambda functions"""

_aws_session: boto3.Session = None
_aws_access_key_id_source: str = "unset"
_aws_secret_access_key_source: str = "unset"
_dump: dict = None
_initialized: bool = False

def __init__(self, **data: Any):
super().__init__(**data)
aws_profile = str(os.environ.get("AWS_PROFILE", "")).strip()
if len(aws_profile) > 0:
logger.debug("Using AWS_PROFILE: %s", aws_profile)
self._aws_access_key_id_source = "aws_profile"
self._aws_secret_access_key_source = "aws_profile"
else:
if "AWS_ACCESS_KEY_ID" in os.environ:
self._aws_access_key_id_source = "environ"
if "AWS_SECRET_ACCESS_KEY" in os.environ:
self._aws_secret_access_key_source = "environ"
self._initialized = True

debug_mode: Optional[bool] = Field(
Expand All @@ -145,6 +163,14 @@ def __init__(self, **data: Any):
SettingsDefaults.AWS_PROFILE,
env="AWS_PROFILE",
)
aws_access_key_id: Optional[SecretStr] = Field(
SettingsDefaults.AWS_ACCESS_KEY_ID,
env="AWS_ACCESS_KEY_ID",
)
aws_secret_access_key: Optional[SecretStr] = Field(
SettingsDefaults.AWS_SECRET_ACCESS_KEY,
env="AWS_SECRET_ACCESS_KEY",
)
aws_regions: Optional[List[str]] = Field(AWS_REGIONS, description="The list of AWS regions")
aws_region: Optional[str] = Field(
SettingsDefaults.AWS_REGION,
Expand Down Expand Up @@ -185,14 +211,32 @@ def __init__(self, **data: Any):
SettingsDefaults.SHARED_RESOURCE_IDENTIFIER, env="SHARED_RESOURCE_IDENTIFIER"
)

@property
def aws_access_key_id_source(self):
"""Source of aws_access_key_id"""
return self._aws_access_key_id_source

@property
def aws_secret_access_key_source(self):
"""Source of aws_secret_access_key"""
return self._aws_secret_access_key_source

@property
def aws_session(self):
"""AWS session"""
if not self._aws_session:
if self.aws_profile:
self._aws_session = boto3.Session(profile_name=self.aws_profile, region_name=self.aws_region)
else:
self._aws_session = boto3.Session(region_name=self.aws_region)
if self.aws_access_key_id_source == "unset" or self.aws_secret_access_key_source == "unset":
raise RekognitionConfigurationError(
"aws_access_key_id and aws_secret_access_key must be set when aws_profile is not set."
)
self._aws_session = boto3.Session(
region_name=self.aws_region,
aws_access_key_id=self.aws_access_key_id.get_secret_value(),
aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
)
return self._aws_session

@property
Expand Down Expand Up @@ -261,7 +305,6 @@ def recursive_sort_dict(d):
return self._dump

self._dump = {
"secrets": {},
"environment": {
"is_using_tfvars_file": self.is_using_tfvars_file,
"is_using_dotenv_file": self.is_using_dotenv_file,
Expand All @@ -277,19 +320,20 @@ def recursive_sort_dict(d):
"version": self.version,
},
"aws": {
"profile": self.aws_profile,
"region": self.aws_region,
"aws_profile": self.aws_profile,
"aws_access_key_id_source": self.aws_access_key_id_source,
"aws_secret_access_key_source": self.aws_secret_access_key_source,
"aws_region": self.aws_region,
},
"rekognition": {
"aws_rekognition_collection_id": self.aws_rekognition_collection_id,
"aws_dynamodb_table_id": self.aws_dynamodb_table_id,
"aws_rekognition_face_detect_max_faces_count": self.aws_rekognition_face_detect_max_faces_count,
"aws_rekognition_face_detect_attributes": self.aws_rekognition_face_detect_attributes,
"aws_rekognition_face_detect_quality_filter": self.aws_rekognition_face_detect_quality_filter,
"aws_rekognition_face_detect_threshold": self.aws_rekognition_face_detect_threshold,
},
"dynamodb": {
"table": self.aws_dynamodb_table_id,
"aws_dynamodb_table_id": self.aws_dynamodb_table_id,
},
}
if self.dump_defaults:
Expand Down Expand Up @@ -325,6 +369,34 @@ def validate_aws_profile(cls, v, values, **kwargs) -> str:
return SettingsDefaults.AWS_PROFILE
return v

@field_validator("aws_access_key_id")
def validate_aws_access_key_id(cls, v, values: ValidationInfo) -> str:
"""Validate aws_access_key_id"""
if not isinstance(v, SecretStr):
v = SecretStr(v)
if v.get_secret_value() in [None, ""]:
return SettingsDefaults.AWS_ACCESS_KEY_ID
if "aws_profile" in values.data and values.data["aws_profile"] != SettingsDefaults.AWS_PROFILE:
logger.warning("aws_access_key_id is ignored when aws_profile is set")
return SettingsDefaults.AWS_ACCESS_KEY_ID
if cls.aws_access_key_id_source == "unset":
cls._aws_access_key_id_source = "constructor"
return v

@field_validator("aws_secret_access_key")
def validate_aws_secret_access_key(cls, v, values: ValidationInfo) -> str:
"""Validate aws_secret_access_key"""
if not isinstance(v, SecretStr):
v = SecretStr(v)
if v.get_secret_value() in [None, ""]:
return SettingsDefaults.AWS_SECRET_ACCESS_KEY
if "aws_profile" in values.data and values.data["aws_profile"] != SettingsDefaults.AWS_PROFILE:
logger.warning("aws_secret_access_key is ignored when aws_profile is set")
return SettingsDefaults.AWS_SECRET_ACCESS_KEY
if cls.aws_secret_access_key_source == "unset":
cls._aws_secret_access_key_source = "constructor"
return v

@field_validator("aws_region")
# pylint: disable=no-self-argument,unused-argument
def validate_aws_region(cls, v, values: ValidationInfo, **kwargs) -> str:
Expand Down
43 changes: 41 additions & 2 deletions terraform/python/rekognition_api/tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

# our stuff
from rekognition_api.conf import Settings, SettingsDefaults # noqa: E402
from rekognition_api.exceptions import RekognitionValueError # noqa: E402
from rekognition_api.exceptions import ( # noqa: E402
RekognitionConfigurationError,
RekognitionValueError,
)


class TestConfiguration(unittest.TestCase):
Expand Down Expand Up @@ -107,6 +110,43 @@ def test_env_overrides(self):
self.assertEqual(mock_settings.aws_rekognition_face_detect_threshold, 100)
self.assertEqual(mock_settings.debug_mode, True)

@patch.dict(
os.environ,
{"AWS_PROFILE": "TEST_PROFILE", "AWS_ACCESS_KEY_ID": "TEST_KEY", "AWS_SECRET_ACCESS_KEY": "TEST_SECRET"},
)
def test_aws_credentials_with_profile(self):
"""Test that key and secret are unset when using profile."""

mock_settings = Settings()
self.assertEqual(mock_settings.aws_access_key_id_source, "aws_profile")
self.assertEqual(mock_settings.aws_secret_access_key_source, "aws_profile")

@patch.dict(
os.environ, {"AWS_PROFILE": "", "AWS_ACCESS_KEY_ID": "TEST_KEY", "AWS_SECRET_ACCESS_KEY": "TEST_SECRET"}
)
def test_aws_credentials_without_profile(self):
"""Test that key and secret are unset when using profile."""

mock_settings = Settings()
# pylint: disable=no-member
self.assertEqual(mock_settings.aws_access_key_id.get_secret_value(), "TEST_KEY")
# pylint: disable=no-member
self.assertEqual(mock_settings.aws_secret_access_key.get_secret_value(), "TEST_SECRET")
self.assertEqual(mock_settings.aws_access_key_id_source, "environ")
self.assertEqual(mock_settings.aws_secret_access_key_source, "environ")

def test_aws_credentials_noinfo(self):
"""Test that key and secret are unset when using profile."""
os.environ.clear()
mock_settings = Settings()
self.assertEqual(mock_settings.aws_profile, None)
# pylint: disable=no-member
self.assertEqual(mock_settings.aws_access_key_id.get_secret_value(), None)
# pylint: disable=no-member
self.assertEqual(mock_settings.aws_secret_access_key.get_secret_value(), None)
self.assertEqual(mock_settings.aws_access_key_id_source, "unset")
self.assertEqual(mock_settings.aws_secret_access_key_source, "unset")

@patch.dict(os.environ, {"AWS_REGION": "invalid-region"})
def test_invalid_aws_region_code(self):
"""Test that Pydantic raises a validation error for environment variable with non-existent aws region code."""
Expand Down Expand Up @@ -198,7 +238,6 @@ def test_dump_keys(self):
"""Test that dump contains the expected keys."""

dump = Settings().dump
self.assertIn("secrets", dump)
self.assertIn("environment", dump)
self.assertIn("aws", dump)
self.assertIn("rekognition", dump)
Expand Down
6 changes: 4 additions & 2 deletions terraform/python/rekognition_api/tests/test_lambda_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
"""Test Search Lambda function."""

# python stuff
import logging
import os
import sys
import unittest


logger = logging.getLogger(__file__)
HERE = os.path.abspath(os.path.dirname(__file__))
PYTHON_ROOT = os.path.dirname(os.path.dirname(HERE))
sys.path.append(PYTHON_ROOT) # noqa: E402
Expand Down Expand Up @@ -44,7 +46,7 @@ def test_get_image_from_event(self):
# image_from_event = get_image_from_event(self.search_event)

# self.assertEqual(image_from_event, self.image)
print("Not implemented")
logger.debug("Not implemented")
assert True

def test_get_faces(self):
Expand All @@ -57,5 +59,5 @@ def test_get_faces(self):
# QualityFilter=settings.aws_rekognition_face_detect_quality_filter,
# )
# faces = get_faces(self.image_packed)
print("Not implemented")
logger.debug("Not implemented")
assert True