diff --git a/src/banks/config.py b/src/banks/config.py index 7fe5e83..0f4686c 100644 --- a/src/banks/config.py +++ b/src/banks/config.py @@ -1,14 +1,38 @@ +import json import os from pathlib import Path +from typing import Any from platformdirs import user_data_path from .utils import strtobool -class BanksConfig: - ASYNC_ENABLED: bool = strtobool(os.environ.get("BANKS_ASYNC_ENABLED", "false")) - USER_DATA_PATH: Path = Path(os.environ.get("BANKS_USER_DATA_PATH", "")) or user_data_path("banks") +class _BanksConfig: + ASYNC_ENABLED: bool = False + USER_DATA_PATH: Path = user_data_path("banks") + def __init__(self, env_var_prefix: str = "BANKS_"): + self._env_var_prefix = env_var_prefix -config = BanksConfig() + def __getattribute__(self, name: str) -> Any: + # Raise an attribute error if the name of the config is unknown + original_value = super().__getattribute__(name) + + # Env var takes precedence + prefix = super().__getattribute__("_env_var_prefix") + value = os.environ.get(f"{prefix}{name}") + if value is None: + return original_value + + # Convert string from env var to the actual type + t = super().__getattribute__("__annotations__")[name] + if t == bool: + value = strtobool(value) + else: + value = t(value) + + return value + + +config = _BanksConfig() diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..7ada35c --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,37 @@ +from pathlib import Path + +from platformdirs import user_data_path + +from banks.config import _BanksConfig + + +def test_config_defaults(): + c = _BanksConfig() + assert c.ASYNC_ENABLED == False + assert c.USER_DATA_PATH == user_data_path("banks") + + +def test_config_env_override(monkeypatch): + c = _BanksConfig() + monkeypatch.setenv("BANKS_ASYNC_ENABLED", "true") + assert c.ASYNC_ENABLED == True + monkeypatch.setenv("BANKS_ASYNC_ENABLED", "false") + assert c.ASYNC_ENABLED == False + monkeypatch.setenv("BANKS_USER_DATA_PATH", "/") + assert c.USER_DATA_PATH == Path("/") + + class TestConfig(_BanksConfig): + FOO: int = 0 + + c = TestConfig() + assert c.FOO == 0 + monkeypatch.setenv("BANKS_FOO", "42") + assert c.FOO == 42 + + +def test_config_env_prefix(monkeypatch): + c = _BanksConfig("BANKS_TEST_") + monkeypatch.setenv("BANKS_ASYNC_ENABLED", "true") + assert c.ASYNC_ENABLED == False + monkeypatch.setenv("BANKS_TEST_ASYNC_ENABLED", "true") + assert c.ASYNC_ENABLED == True