diff --git a/src/pyop/storage.py b/src/pyop/storage.py index 071c50b..d3906b0 100644 --- a/src/pyop/storage.py +++ b/src/pyop/storage.py @@ -52,13 +52,23 @@ def pop(self, key, default=None): return data @classmethod - def from_uri(cls, db_uri, collection, db_name=None, ttl=None): + def from_uri(cls, db_uri, collection, db_name=None, ttl=None, **kwargs): if db_uri.startswith("mongodb"): return MongoWrapper( - db_uri=db_uri, db_name=db_name, collection=collection, ttl=ttl + db_uri=db_uri, + db_name=db_name, + collection=collection, + ttl=ttl, + extra_options=kwargs, ) elif db_uri.startswith("redis") or db_uri.startswith("unix"): - return RedisWrapper(db_uri=db_uri, collection=collection, ttl=ttl) + return RedisWrapper( + db_uri=db_uri, + db_name=db_name, + collection=collection, + ttl=ttl, + extra_options=kwargs, + ) return ValueError(f"Invalid DB URI: {db_uri}") @@ -68,12 +78,18 @@ def ttl(self): class MongoWrapper(StorageBase): - def __init__(self, db_uri, db_name, collection, ttl=None): + def __init__(self, db_uri, db_name, collection, ttl=None, extra_options=None): if not _has_pymongo: raise ImportError("pymongo module is required but it is not available") + + if not extra_options: + extra_options = {} + + mongo_options = extra_options.pop("mongo_kwargs", None) or {} + self._db_uri = db_uri self._coll_name = collection - self._db = MongoDB(db_uri, db_name=db_name) + self._db = MongoDB(db_uri, db_name=db_name, **mongo_options) self._coll = self._db.get_collection(collection) self._coll.create_index('lookup_key', unique=True) @@ -120,10 +136,21 @@ class RedisWrapper(StorageBase): Supports JSON-serializable data types. """ - def __init__(self, db_uri, collection, ttl=None): + def __init__( + self, db_uri, *, db_name=None, collection, ttl=None, extra_options=None + ): if not _has_redis: raise ImportError("redis module is required but it is not available") - self._db = Redis.from_url(db_uri, decode_responses=True) + + if not extra_options: + extra_options = {} + + redis_kwargs = extra_options.pop("redis_kwargs", None) or {} + redis_options = { + "decode_responses": True, "db": db_name, **redis_kwargs + } + + self._db = Redis.from_url(db_uri, **redis_options) self._collection = collection if ttl is None or (isinstance(ttl, int) and ttl >= 0): self._ttl = ttl @@ -170,14 +197,11 @@ def items(self): class MongoDB(object): """Simple wrapper to get pymongo real objects from the settings uri""" - def __init__(self, db_uri, db_name=None, - connection_factory=None, **kwargs): - + def __init__(self, db_uri, db_name=None, connection_factory=None, **kwargs): if db_uri is None: raise ValueError('db_uri not supplied') self._sanitized_uri = None - self._parsed_uri = pymongo.uri_parser.parse_uri(db_uri) db_name = self._parsed_uri.get('database') or db_name diff --git a/tests/pyop/test_storage.py b/tests/pyop/test_storage.py index 62b8a41..c3d72a5 100644 --- a/tests/pyop/test_storage.py +++ b/tests/pyop/test_storage.py @@ -16,12 +16,16 @@ __author__ = 'lundberg' -uri_list = ["mongodb://localhost:1234/pyop", "redis://localhost/0"] +db_specs_list = [ + {"uri": "mongodb://localhost:1234/pyop", "name": "pyop"}, + {"uri": "redis://localhost/0", "name": 0}, +] + @pytest.fixture(autouse=True) def mock_redis(monkeypatch): def mockreturn(*args, **kwargs): - return fakeredis.FakeStrictRedis(decode_responses=True) + return fakeredis.FakeStrictRedis(*args, **kwargs) monkeypatch.setattr(Redis, "from_url", mockreturn) @pytest.fixture(autouse=True) @@ -30,10 +34,10 @@ def mock_mongo(): class TestStorage(object): - @pytest.fixture(params=uri_list) + @pytest.fixture(params=db_specs_list) def db(self, request): return pyop.storage.StorageBase.from_uri( - request.param, db_name="pyop", collection="test" + request.param["uri"], db_name=request.param["name"], collection="test" ) def test_write(self, db): @@ -69,15 +73,15 @@ def test_items(self, db): @pytest.mark.parametrize( "args,kwargs", [ - (["redis://localhost"], {"collection": "test"}), - (["redis://localhost", "test"], {}), - (["unix://localhost/0"], {"collection": "test", "ttl": 3}), (["mongodb://localhost/pyop"], {"collection": "test", "ttl": 3}), (["mongodb://localhost"], {"db_name": "pyop", "collection": "test"}), (["mongodb://localhost", "test", "pyop"], {}), (["mongodb://localhost/pyop", "test"], {}), (["mongodb://localhost/pyop"], {"db_name": "other", "collection": "test"}), - (["redis://localhost/0"], {"db_name": "pyop", "collection": "test"}), + (["redis://localhost"], {"collection": "test"}), + (["redis://localhost", "test"], {}), + (["redis://localhost"], {"db_name": 2, "collection": "test"}), + (["unix://localhost/0"], {"collection": "test", "ttl": 3}), ], ) def test_from_uri(self, args, kwargs): @@ -88,11 +92,7 @@ def test_from_uri(self, args, kwargs): @pytest.mark.parametrize( "error,args,kwargs", [ - ( - TypeError, - ["redis://localhost", "ouch"], - {"db_name": 3, "collection": "test", "ttl": None}, - ), + (ValueError, ["mongodb://localhost"], {"collection": "test", "ttl": None}), ( TypeError, ["mongodb://localhost", "ouch"], @@ -110,12 +110,11 @@ def test_from_uri(self, args, kwargs): ), ( TypeError, - ["mongodb://localhost"], - {"db_name": "pyop", "collection": "test", "ttl": None, "extra": True}, + ["redis://localhost", "ouch"], + {"db_name": 3, "collection": "test", "ttl": None}, ), (TypeError, ["redis://localhost/0"], {}), (TypeError, ["redis://localhost/0"], {"db_name": "pyop"}), - (ValueError, ["mongodb://localhost"], {"collection": "test", "ttl": None}), ], ) def test_from_uri_invalid_parameters(self, error, args, kwargs): @@ -153,11 +152,11 @@ def execute_ttl_test(self, uri, ttl): with pytest.raises(KeyError): self.db["foo"] - @pytest.mark.parametrize("uri", uri_list) + @pytest.mark.parametrize("spec", db_specs_list) @pytest.mark.parametrize("ttl", ["invalid", -1, 2.3, {}]) - def test_invalid_ttl(self, uri, ttl): + def test_invalid_ttl(self, spec, ttl): with pytest.raises(ValueError): - self.prepare_db(uri, ttl) + self.prepare_db(spec["uri"], ttl) class TestRedisTTL(StorageTTLTest):