From 1311267c1482cf835447f2f718168566d6b2d584 Mon Sep 17 00:00:00 2001 From: Aleksandar Makelov Date: Tue, 3 Sep 2024 16:03:58 +0300 Subject: [PATCH] prevent clearing uncommitted data by default --- mandala/storage.py | 7 +++++++ mandala/storage_utils.py | 16 ++++++++++++++-- mandala/tests/test_memoization.py | 25 ++++++++++++++++++++++++- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/mandala/storage.py b/mandala/storage.py index 5c457db..ac2413b 100644 --- a/mandala/storage.py +++ b/mandala/storage.py @@ -122,6 +122,13 @@ def vacuum(self): ############################################################################ ### managing the caches ############################################################################ + def clear_cache(self, allow_uncommitted: False): + self.atoms.clear(allow_uncommited=allow_uncommitted) + self.shapes.clear(allow_uncommited=allow_uncommitted) + self.ops.clear(allow_uncommited=allow_uncommitted) + self.calls.clear(allow_uncommited=allow_uncommitted) + print("Cleared all caches.") + def cache_info(self) -> str: """ Display information about the contents of the cache in a pretty table. diff --git a/mandala/storage_utils.py b/mandala/storage_utils.py index 2d8acfa..2f9528a 100644 --- a/mandala/storage_utils.py +++ b/mandala/storage_utils.py @@ -242,7 +242,9 @@ def values(self, conn: Optional[sqlite3.Connection] = None) -> List[Any]: class CachedDictStorage(DictStorage): def __init__(self, persistent: DictStorage): self.persistent = persistent + # keep a cache of the values for faster lookups self.cache: Dict[str, Any] = {} + # keep track of keys that have been added but not yet persisted self.dirty_keys: Set[str] = set() def load_all(self) -> Dict[str, Any]: @@ -271,7 +273,13 @@ def commit(self, conn: Optional[sqlite3.Connection] = None) -> None: self.persistent.set(key, self.cache[key], conn=conn) self.dirty_keys.clear() - def clear(self) -> None: + def clear(self, allow_uncommited: bool = False) -> None: + if len(self.dirty_keys) > 0 and not allow_uncommited: + # we add this as a precaution to avoid data loss. Otherwise, it's + # easy to shoot yourself in the foot by calling `clear()` before + # `commit()` + msg = "Cannot clear cache with uncommitted changes; call `commit()` first, or use `allow_uncommited=True`" + raise ValueError(msg) self.cache.clear() self.dirty_keys.clear() @@ -739,6 +747,10 @@ def commit(self, conn: Optional[sqlite3.Connection] = None): self.persistent.save(self.cache.get_data(hid), conn=conn) self.dirty_hids.clear() - def clear(self): + def clear(self, allow_uncommited: bool = False): + if len(self.dirty_hids) > 0 and not allow_uncommited: + # see `CachedDictStorage.clear` for an explanation + msg = "Cannot clear cache with uncommitted changes; call `commit()` first, or use `allow_uncommited=True`" + raise ValueError(msg) self.cache = InMemCallStorage() self.dirty_hids.clear() diff --git a/mandala/tests/test_memoization.py b/mandala/tests/test_memoization.py index 7f2cd23..b8d8362 100644 --- a/mandala/tests/test_memoization.py +++ b/mandala/tests/test_memoization.py @@ -153,4 +153,27 @@ def inc(x, irrelevant): inc(23, 1) df = storage.cf(inc).df() - assert len(df) == 1 \ No newline at end of file + assert len(df) == 1 + + +def test_clear_uncommitted(): + storage = Storage() + + @op + def inc(x): + return x + 1 + + with storage: + for i in range(10): + inc(i) + # attempt to clear the atoms cache without having committed; this should + # fail by default + try: + storage.atoms.clear() + assert False + except ValueError: + pass + + # now clear the atoms cache after committing + storage.commit() + storage.atoms.clear() \ No newline at end of file