diff --git a/cacholote/config.py b/cacholote/config.py index 978ddf3..2274fdf 100644 --- a/cacholote/config.py +++ b/cacholote/config.py @@ -88,23 +88,20 @@ def make_cache_dir(self) -> "Settings": fs.mkdirs(urlpath, exist_ok=True) return self - @pydantic.model_validator(mode="after") - def check_mutually_exclusive(self) -> "Settings": - mutually_exclusive = (self.sessionmaker, self.cache_db_urlpath) - if all(mutually_exclusive) or not any(mutually_exclusive): - raise ValueError( - "Provide either `sessionmaker` or `cache_db_urlpath` (mutually exclusive)." - ) - return self - @property def instantiated_sessionmaker(self) -> sa.orm.sessionmaker: # type: ignore[type-arg] if self.sessionmaker is None: + if self.cache_db_urlpath is None: + raise ValueError("Provide either `sessionmaker` or `cache_db_urlpath`.") self.sessionmaker = database.cached_sessionmaker( self.cache_db_urlpath, **self.create_engine_kwargs ) self.cache_db_urlpath = None self.create_engine_kwargs = {} + elif self.cache_db_urlpath is not None: + raise ValueError( + "`sessionmaker` and `cache_db_urlpath` are mutually exclusive." + ) return self.sessionmaker @property diff --git a/cacholote/database.py b/cacholote/database.py index 3c1ae6e..52b5128 100644 --- a/cacholote/database.py +++ b/cacholote/database.py @@ -18,7 +18,7 @@ import functools import json import warnings -from typing import Any +from typing import Any, Dict import sqlalchemy as sa import sqlalchemy.orm @@ -77,8 +77,32 @@ def _commit_or_rollback(session: sa.orm.Session) -> None: session.rollback() +def _encode_kwargs(**kwargs: Any) -> Dict[str, Any]: + encoded_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, dict): + encoded_kwargs["_encoded_" + key] = json.dumps(value) + else: + encoded_kwargs[key] = value + return encoded_kwargs + + +def _decode_kwargs(**kwargs: Any) -> Dict[str, Any]: + decoded_kwargs = {} + for key, value in kwargs.items(): + if key.startswith("_encoded_"): + decoded_kwargs[key.replace("_encoded_", "", 1)] = json.loads(value) + else: + decoded_kwargs[key] = value + return decoded_kwargs + + @functools.lru_cache() -def cached_sessionmaker(url: str, **kwargs: Any) -> sa.orm.sessionmaker: # type: ignore[type-arg] - engine = sa.create_engine(url, **kwargs) +def _cached_sessionmaker(url: str, **kwargs: Any) -> sa.orm.sessionmaker: # type: ignore[type-arg] + engine = sa.create_engine(url, **_decode_kwargs(**kwargs)) Base.metadata.create_all(engine) return sa.orm.sessionmaker(engine) + + +def cached_sessionmaker(url: str, **kwargs: Any) -> sa.orm.sessionmaker: # type: ignore[type-arg] + return _cached_sessionmaker(url, **_encode_kwargs(**kwargs)) diff --git a/tests/conftest.py b/tests/conftest.py index 0e913c0..1d31f38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,7 @@ def set_cache( ) -> Iterator[str]: param = getattr(request, "param", "file") if param.lower() == "cads": - database.cached_sessionmaker.cache_clear() + database._cached_sessionmaker.cache_clear() test_bucket_name = "test-bucket" client_kwargs = create_test_bucket(s3_server, test_bucket_name) with config.set( diff --git a/tests/test_01_settings.py b/tests/test_01_settings.py index f7f42e8..1066b1f 100644 --- a/tests/test_01_settings.py +++ b/tests/test_01_settings.py @@ -137,3 +137,9 @@ def test_set_expiration( ) -> None: with raises: config.set(expiration=expiration) + + +def test_create_engine_dict_kwargs() -> None: + old_session_maker = config.get().instantiated_sessionmaker + config.set(create_engine_kwargs={"connect_args": {"timeout": 30}}) + assert config.get().instantiated_sessionmaker is not old_session_maker