Skip to content

Commit

Permalink
Normalize Storage wrapper interfaces and allow extra options
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Kanakarakis <[email protected]>
  • Loading branch information
c00kiemon5ter committed Sep 3, 2021
1 parent 0fbb18c commit 04fee31
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
49 changes: 35 additions & 14 deletions src/pyop/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,23 @@ def pop(self, key, default=None):
return data

@classmethod
def from_uri(cls, db_uri, collection, db_name=None, ttl=None):
def from_uri(cls, db_uri, collection, db_name=None, ttl=None, **kwargs):
if db_uri.startswith("mongodb"):
return MongoWrapper(
db_uri=db_uri, db_name=db_name, collection=collection, ttl=ttl
db_uri=db_uri,
db_name=db_name,
collection=collection,
ttl=ttl,
extra_options=kwargs,
)
elif db_uri.startswith("redis") or db_uri.startswith("unix"):
return RedisWrapper(db_uri=db_uri, collection=collection, ttl=ttl)
return RedisWrapper(
db_uri=db_uri,
db_name=db_name,
collection=collection,
ttl=ttl,
extra_options=kwargs,
)

return ValueError(f"Invalid DB URI: {db_uri}")

Expand All @@ -68,12 +78,18 @@ def ttl(self):


class MongoWrapper(StorageBase):
def __init__(self, db_uri, db_name, collection, ttl=None):
def __init__(self, db_uri, db_name, collection, ttl=None, extra_options=None):
if not _has_pymongo:
raise ImportError("pymongo module is required but it is not available")

if not extra_options:
extra_options = {}

mongo_options = extra_options.pop("mongo_kwargs", None) or {}

self._db_uri = db_uri
self._coll_name = collection
self._db = MongoDB(db_uri, db_name=db_name)
self._db = MongoDB(db_uri, db_name=db_name, **mongo_options)
self._coll = self._db.get_collection(collection)
self._coll.create_index('lookup_key', unique=True)

Expand Down Expand Up @@ -120,13 +136,21 @@ class RedisWrapper(StorageBase):
Supports JSON-serializable data types.
"""

def __init__(self, db_uri, collection, ttl=None, options={}):
if options is None:
options = {}

def __init__(
self, db_uri, *, db_name=None, collection, ttl=None, extra_options=None
):
if not _has_redis:
raise ImportError("redis module is required but it is not available")
self._db = Redis.from_url(db_uri, decode_responses=True, **options.get('redis_kwargs', {}))

if not extra_options:
extra_options = {}

redis_kwargs = extra_options.pop("redis_kwargs", None) or {}
redis_options = {
"decode_responses": True, "db": db_name, **redis_kwargs
}

self._db = Redis.from_url(db_uri, **redis_options)
self._collection = collection
if ttl is None or (isinstance(ttl, int) and ttl >= 0):
self._ttl = ttl
Expand Down Expand Up @@ -173,14 +197,11 @@ def items(self):
class MongoDB(object):
"""Simple wrapper to get pymongo real objects from the settings uri"""

def __init__(self, db_uri, db_name=None,
connection_factory=None, **kwargs):

def __init__(self, db_uri, db_name=None, connection_factory=None, **kwargs):
if db_uri is None:
raise ValueError('db_uri not supplied')

self._sanitized_uri = None

self._parsed_uri = pymongo.uri_parser.parse_uri(db_uri)

db_name = self._parsed_uri.get('database') or db_name
Expand Down
37 changes: 18 additions & 19 deletions tests/pyop/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
__author__ = 'lundberg'


uri_list = ["mongodb://localhost:1234/pyop", "redis://localhost/0"]
db_specs_list = [
{"uri": "mongodb://localhost:1234/pyop", "name": "pyop"},
{"uri": "redis://localhost/0", "name": 0},
]


@pytest.fixture(autouse=True)
def mock_redis(monkeypatch):
def mockreturn(*args, **kwargs):
return fakeredis.FakeStrictRedis(decode_responses=True)
return fakeredis.FakeStrictRedis(*args, **kwargs)
monkeypatch.setattr(Redis, "from_url", mockreturn)

@pytest.fixture(autouse=True)
Expand All @@ -30,10 +34,10 @@ def mock_mongo():


class TestStorage(object):
@pytest.fixture(params=uri_list)
@pytest.fixture(params=db_specs_list)
def db(self, request):
return pyop.storage.StorageBase.from_uri(
request.param, db_name="pyop", collection="test"
request.param["uri"], db_name=request.param["name"], collection="test"
)

def test_write(self, db):
Expand Down Expand Up @@ -69,15 +73,15 @@ def test_items(self, db):
@pytest.mark.parametrize(
"args,kwargs",
[
(["redis://localhost"], {"collection": "test"}),
(["redis://localhost", "test"], {}),
(["unix://localhost/0"], {"collection": "test", "ttl": 3}),
(["mongodb://localhost/pyop"], {"collection": "test", "ttl": 3}),
(["mongodb://localhost"], {"db_name": "pyop", "collection": "test"}),
(["mongodb://localhost", "test", "pyop"], {}),
(["mongodb://localhost/pyop", "test"], {}),
(["mongodb://localhost/pyop"], {"db_name": "other", "collection": "test"}),
(["redis://localhost/0"], {"db_name": "pyop", "collection": "test"}),
(["redis://localhost"], {"collection": "test"}),
(["redis://localhost", "test"], {}),
(["redis://localhost"], {"db_name": 2, "collection": "test"}),
(["unix://localhost/0"], {"collection": "test", "ttl": 3}),
],
)
def test_from_uri(self, args, kwargs):
Expand All @@ -88,11 +92,7 @@ def test_from_uri(self, args, kwargs):
@pytest.mark.parametrize(
"error,args,kwargs",
[
(
TypeError,
["redis://localhost", "ouch"],
{"db_name": 3, "collection": "test", "ttl": None},
),
(ValueError, ["mongodb://localhost"], {"collection": "test", "ttl": None}),
(
TypeError,
["mongodb://localhost", "ouch"],
Expand All @@ -110,12 +110,11 @@ def test_from_uri(self, args, kwargs):
),
(
TypeError,
["mongodb://localhost"],
{"db_name": "pyop", "collection": "test", "ttl": None, "extra": True},
["redis://localhost", "ouch"],
{"db_name": 3, "collection": "test", "ttl": None},
),
(TypeError, ["redis://localhost/0"], {}),
(TypeError, ["redis://localhost/0"], {"db_name": "pyop"}),
(ValueError, ["mongodb://localhost"], {"collection": "test", "ttl": None}),
],
)
def test_from_uri_invalid_parameters(self, error, args, kwargs):
Expand Down Expand Up @@ -153,11 +152,11 @@ def execute_ttl_test(self, uri, ttl):
with pytest.raises(KeyError):
self.db["foo"]

@pytest.mark.parametrize("uri", uri_list)
@pytest.mark.parametrize("spec", db_specs_list)
@pytest.mark.parametrize("ttl", ["invalid", -1, 2.3, {}])
def test_invalid_ttl(self, uri, ttl):
def test_invalid_ttl(self, spec, ttl):
with pytest.raises(ValueError):
self.prepare_db(uri, ttl)
self.prepare_db(spec["uri"], ttl)


class TestRedisTTL(StorageTTLTest):
Expand Down

0 comments on commit 04fee31

Please sign in to comment.