Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors mongodb persister a little #472

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions burr/integrations/persisters/b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
logger = logging.getLogger(__name__)


class MongoDBPersister(persistence.BaseStatePersister):
class MongoDBBasePersister(persistence.BaseStatePersister):
"""A class used to represent a MongoDB Persister.

Example usage:

.. code-block:: python

persister = MongoDBPersister(uri='mongodb://user:pass@localhost:27017', db_name='mydatabase', collection_name='mystates')
persister = MongoDBBasePersister.from_values(uri='mongodb://user:pass@localhost:27017',
db_name='mydatabase',
collection_name='mystates')
persister.save(
partition_key='example_partition',
app_id='example_app',
Expand All @@ -28,20 +30,46 @@ class MongoDBPersister(persistence.BaseStatePersister):
)
loaded_state = persister.load(partition_key='example_partition', app_id='example_app', sequence_id=1)
print(loaded_state)

Note: this is called MongoDBBasePersister because we had to change the constructor and wanted to make
this change backwards compatible.
"""

def __init__(
self,
@classmethod
def from_values(
cls,
uri="mongodb://localhost:27017",
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
mongo_client_kwargs: dict = None,
):
"""Initializes the MongoDBPersister class."""
) -> "MongoDBBasePersister":
"""Initializes the MongoDBBasePersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
self.client = MongoClient(uri, **mongo_client_kwargs)
client = MongoClient(uri, **mongo_client_kwargs)
return cls(
client=client,
db_name=db_name,
collection_name=collection_name,
serde_kwargs=serde_kwargs,
)

def __init__(
self,
client,
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
):
"""Initializes the MongoDBBasePersister class.

:param client: the mongodb client to use
:param db_name: the name of the database to use
:param collection_name: the name of the collection to use
:param serde_kwargs: serializer/deserializer keyword arguments to pass to the state object
"""
self.client = client
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.serde_kwargs = serde_kwargs or {}
Expand Down Expand Up @@ -101,3 +129,29 @@ def save(

def __del__(self):
self.client.close()


class MongoDBPersister(MongoDBBasePersister):
"""A class used to represent a MongoDB Persister.

This class is deprecated. Please use MongoDBBasePersister instead.
"""

def __init__(
self,
uri="mongodb://localhost:27017",
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
mongo_client_kwargs: dict = None,
):
"""Initializes the MongoDBPersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
client = MongoClient(uri, **mongo_client_kwargs)
super(MongoDBPersister, self).__init__(
client=client,
db_name=db_name,
collection_name=collection_name,
serde_kwargs=serde_kwargs,
)
2 changes: 1 addition & 1 deletion docs/reference/persister.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Currently we support the following, although we highly recommend you contribute

.. automethod:: __init__

.. autoclass:: burr.integrations.persisters.b_mongodb.MongoDBPersister
.. autoclass:: burr.integrations.persisters.b_mongodb.MongoDBBasePersister
:members:

.. automethod:: __init__
Expand Down
15 changes: 13 additions & 2 deletions tests/integrations/persisters/test_b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import pytest

from burr.core import state
from burr.integrations.persisters.b_mongodb import MongoDBPersister
from burr.integrations.persisters.b_mongodb import MongoDBBasePersister, MongoDBPersister

if not os.environ.get("BURR_CI_INTEGRATION_TESTS") == "true":
pytest.skip("Skipping integration tests", allow_module_level=True)


@pytest.fixture
def mongodb_persister():
persister = MongoDBPersister(
persister = MongoDBBasePersister.from_values(
uri="mongodb://localhost:27017", db_name="testdb", collection_name="testcollection"
)
yield persister
Expand All @@ -35,3 +35,14 @@ def test_list_app_ids(mongodb_persister):
def test_load_nonexistent_key(mongodb_persister):
state_data = mongodb_persister.load("pk", "nonexistent_key")
assert state_data is None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def test_backwards_compatible_persister():
persister = MongoDBPersister(
uri="mongodb://localhost:27017", db_name="testdb", collection_name="backwardscompatible"
)
persister.save("pk", "app_id", 5, "pos", state.State({"a": 5, "b": 5}), "completed")
data = persister.load("pk", "app_id", 5)
assert data["state"].get_all() == {"a": 5, "b": 5}

persister.collection.drop()
Loading