Skip to content

Commit

Permalink
Add AwsSystemsManagerParameterStoreSettingsSource
Browse files Browse the repository at this point in the history
  • Loading branch information
alukach committed Sep 10, 2024
1 parent 818d56e commit af06b53
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 0 deletions.
57 changes: 57 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,63 @@ class AzureKeyVaultSettings(BaseSettings):
)
```

## AWS Systems Manager Parameter Store

You must set the following parameters:

- `ssm_client`: An initialized [`boto3` SSM Client](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#client).

Optionally, you may specify the following parameters:

- `ssm_path`: The hierarchy for the parameter. Hierarchies start with a forward slash (/). The hierarchy is the parameter name except the last part of the parameter. Under the hood, we make use of the [`get_parameters_by_path` method](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm/client/get_parameters_by_path.html) to recursively retrieve all parameters within the a specified path hierarchy.

```py
import os
from typing import Tuple, Type

import boto3
from pydantic import BaseModel

from pydantic_settings import (
AwsSystemsManagerParameterStoreSettingsSource,
BaseSettings,
PydanticBaseSettingsSource,
)


class SubModel(BaseModel):
a: str


class AzureKeyVaultSettings(BaseSettings):
foo: str
bar: int
sub: SubModel

@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
client = boto3.client('ssm')
ssm_param_store_settings = AwsSystemsManagerParameterStoreSettingsSource(
settings_cls,
ssm_client=client,
ssm_path=os.environ.get('SSM_PREFIX', '/api/dev/')
)
return (
init_settings,
env_settings,
dotenv_settings,
file_secret_settings,
ssm_param_store_settings,
)
``` -->

## Other settings source

Other settings sources are available for common configuration files:
Expand Down
56 changes: 56 additions & 0 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,62 @@ def __repr__(self) -> str:
return f'AzureKeyVaultSettingsSource(url={self._url!r}, ' f'env_nested_delimiter={self.env_nested_delimiter!r})'


class AwsSystemsManagerParameterStoreSettingsSource(EnvSettingsSource):
_ssm_client: "SSMClient" # type: ignore
_ssm_path: str

def __init__(
self,
settings_cls: type[BaseSettings],
ssm_client: "SSMClient", # type: ignore
ssm_path: str = "/",
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str = "/",
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
self._ssm_client = ssm_client
self._ssm_path = ssm_path
super().__init__(
settings_cls,
case_sensitive,
env_prefix,
env_nested_delimiter,
env_ignore_empty,
env_parse_none_str,
env_parse_enums,
)

def _load_env_vars(self) -> Mapping[str, Optional[str]]:
paginator = self._ssm_client.get_paginator("get_parameters_by_path")
response_iterator = paginator.paginate(
Path=self._ssm_path, WithDecryption=True, Recursive=True
)

output = {}
try:
for page in response_iterator:
for parameter in page["Parameters"]:
name = Path(parameter["Name"])
key = name.relative_to(self._ssm_path).as_posix()

if not self.case_sensitive:
first_key, *rest = key.split(self.env_nested_delimiter)
key = self.env_nested_delimiter.join([first_key.lower(), *rest])

output[key] = parameter["Value"]

except self._ssm_client.exceptions.ClientError as e:
warnings.warn(f"Unable to get parameters from {self._ssm_path!r}: {e}")

return output

def __repr__(self) -> str:
return f"AwsSystemsManagerParameterStoreSettingsSource(ssm_path={self._ssm_path!r})"


def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
return key if case_sensitive else key.lower()

Expand Down
194 changes: 194 additions & 0 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pydantic_settings.main import BaseSettings, SettingsConfigDict
from pydantic_settings.sources import (
AwsSystemsManagerParameterStoreSettingsSource,
AzureKeyVaultSettingsSource,
PydanticBaseSettingsSource,
PyprojectTomlConfigSettingsSource,
Expand All @@ -31,6 +32,12 @@
except ImportError:
azure_key_vault = False

try:
aws = True
import boto3
except ImportError:
aws = False

if TYPE_CHECKING:
from pathlib import Path

Expand Down Expand Up @@ -210,3 +217,190 @@ def _raise_resource_not_found_when_getting_parent_secret_name(self, secret_name:
raise ResourceNotFoundError()

return key_vault_secret


@pytest.mark.skipif(not aws, reason="boto3 is not installed")
class TestAwsSystemsManagerParameterStoreSettingsSource:
"""Test AwsSystemsManagerParameterStoreSettingsSource."""

def test___init__(self, mocker: MockerFixture) -> None:
"""Test __init__."""

class AwsSettings(BaseSettings):
"""AWS settings."""

mock_parameters = []
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{"Parameters": mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings, ssm_client=client_mock, ssm_path="/my/path"
)

def test___call__case_sensitive(self, mocker: MockerFixture) -> None:
"""Test __call__."""

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

mock_parameters = [
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

obj = AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path='/my/path',
case_sensitive=True,
)

settings = obj()

assert settings['SqlServerUser'] == 'SecretValue'
assert settings['SqlServer']['Password'] == 'SecretValue'

def test___call__case_insensitive(self, mocker: MockerFixture) -> None:
"""Test __call__."""

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

mock_parameters = [
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{"Parameters": mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

obj = AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path="/my/path",
case_sensitive=False,
)
settings = obj()

assert settings['SqlServerUser'] == 'SecretValue'
assert settings['SqlServer']['Password'] == 'SecretValue'

def test_aws_ssm_settings_source(self, mocker: MockerFixture) -> None:
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
mock_parameters = [
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{"Parameters": mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path="/my/path",
),
)

settings = AwsSettings() # type: ignore

assert settings.SqlServerUser == 'SecretValue'
assert settings.sql_server_user == 'SecretValue'
assert settings.sql_server.password == 'SecretValue'

def test_aws_ssm_settings_source__delimiter(self, mocker: MockerFixture) -> None:
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
mock_parameters = [
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer__Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path='/my/path',
env_nested_delimiter='__',
),
)

settings = AwsSettings() # type: ignore

assert settings.SqlServerUser == 'SecretValue'
assert settings.sql_server_user == 'SecretValue'
assert settings.sql_server.password == 'SecretValue'

0 comments on commit af06b53

Please sign in to comment.