diff --git a/mandala/storage.py b/mandala/storage.py index ac2413b..a337108 100644 --- a/mandala/storage.py +++ b/mandala/storage.py @@ -4,7 +4,7 @@ from .model import * import sqlite3 from .model import __make_list__, __list_getitem__, __make_dict__, __dict_getitem__, _Ignore, _NewArgDefault -from .utils import dataframe_to_prettytable, parse_returns +from .utils import dataframe_to_prettytable, parse_returns, _conservative_equality_check from .viz import _get_colorized_diff from .deps.versioner import Versioner, CodeState from .deps.utils import get_dep_key_from_func, extract_func_obj @@ -51,6 +51,10 @@ def __init__(self, else: self.overflow_storage = None + # storage for ref values: {cid -> serialized object} + # storing the serialized object prevents accidental modification of the + # object in memory, but also means that we have to deserialize it when + # we want to use it. self.atoms = CachedDictStorage( persistent=SQLiteDictStorage(self.db, table="atoms", overflow_storage=self.overflow_storage, @@ -122,7 +126,7 @@ def vacuum(self): ############################################################################ ### managing the caches ############################################################################ - def clear_cache(self, allow_uncommitted: False): + def clear_cache(self, allow_uncommitted: bool = False): self.atoms.clear(allow_uncommited=allow_uncommitted) self.shapes.clear(allow_uncommited=allow_uncommitted) self.ops.clear(allow_uncommited=allow_uncommitted) @@ -223,6 +227,12 @@ def save_ref(self, ref: Ref): raise NotImplementedError def load_ref(self, hid: str, in_memory: bool = False) -> Ref: + """ + Loads the Ref with the given `hid`, *and* caches it in the `.atoms` and + `.shapes`. + + TODO: add option to disable automatic caching of atoms. + """ shape = self.shapes[hid] if isinstance(shape, AtomRef): if in_memory: @@ -435,10 +445,13 @@ def get_unreferenced_cids(self) -> Set[str]: ############################################################################ ### ############################################################################ - def _unwrap_atom(self, obj: Any) -> Any: + def _unwrap_atom(self, obj: Any, cache: bool = True) -> Any: + # TODO: implement `cache = False` in `load_ref` assert isinstance(obj, AtomRef) if not obj.in_memory: ref = self.load_ref(hid=obj.hid, in_memory=False) + if cache: + self.atoms[obj.cid] = serialize(ref.obj) return ref.obj else: return obj.obj @@ -458,7 +471,7 @@ def _attach_atom(self, ref: AtomRef, inplace: bool = False) -> Optional[AtomRef] else: return ref.attached(obj=deserialize(self.atoms[ref.cid])) - def unwrap(self, obj: Any) -> Any: + def unwrap(self, obj: Any, cache: bool = True) -> Any: """ Given a `Ref` or a nested python collection containing `Ref`s, return the "unwrapped" object, where all `Ref`s are replaced by the objects @@ -467,7 +480,7 @@ def unwrap(self, obj: Any) -> Any: NOTE: will trigger a load from the storage backend when some of the objects are not in memory. """ - return recurse_on_ref_collections(self._unwrap_atom, obj) + return recurse_on_ref_collections(self._unwrap_atom, obj, **{"cache": cache}) def attach(self, obj: T, inplace: bool = False) -> Optional[T]: """ @@ -740,7 +753,8 @@ def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool, if isinstance(v, Ref) and self.unwrap(v) == default_values[k].value: # the value is wrapped bound_arguments.arguments[k] = default_values[k].value - elif v == default_values[k].value: + # elif v == default_values[k].value: + elif _conservative_equality_check(safe_value=default_values[k].value, unknown_value=v): # the value is unwrapped bound_arguments.arguments[k] = default_values[k].value else: diff --git a/mandala/tests/test_memoization.py b/mandala/tests/test_memoization.py index b8d8362..0ff0e18 100644 --- a/mandala/tests/test_memoization.py +++ b/mandala/tests/test_memoization.py @@ -1,4 +1,5 @@ from mandala.imports import * +import numpy as np def test_storage(): @@ -176,4 +177,50 @@ def inc(x): # now clear the atoms cache after committing storage.commit() - storage.atoms.clear() \ No newline at end of file + storage.atoms.clear() + + + +def test_newargdefault(): + storage = Storage() + + @op + def add(x,): + return x + 1 + + with storage: + add(1) + + @op + def add(x, y=NewArgDefault(1)): + return x + y + + with storage: + add(1) + # check that we didn't make a new call + assert len(storage.cf(add).calls) == 1 + + with storage: + add(1, 1) + # check that we didn't make a new call + assert len(storage.cf(add).calls) == 1 + + with storage: + add(1, 2) + # now this should have made a new call! + assert len(storage.cf(add).calls) == 2 + +def test_newargdefault_compound_types(): + storage = Storage() + + @op + def add_array(x:np.ndarray): + return x + with storage: + add_array(np.array([1, 2, 3])) + + @op + def add_array(x:np.ndarray, y=NewArgDefault(None)): + return x + y + with storage: + add_array(np.array([1, 2, 3]), y=np.array([4, 5, 6])) \ No newline at end of file diff --git a/mandala/utils.py b/mandala/utils.py index 1589d6a..86aeeb1 100644 --- a/mandala/utils.py +++ b/mandala/utils.py @@ -47,6 +47,26 @@ def deserialize(value: bytes) -> Any: return joblib.load(buffer) +def _conservative_equality_check(safe_value: Any, unknown_value: Any) -> bool: + """ + An equality checker that treats `safe_value` as a "simple" type, but is + conservative about how __eq__ can be applied to `unknown_value`. This is + necessary when comparing against e.g. numpy arrays. + """ + if type(safe_value) != type(unknown_value): + return False + if isinstance(unknown_value, (int, float, str, bytes, bool, type(None))): + return safe_value == unknown_value + # handle some common cases + if isinstance(unknown_value, np.ndarray): + return np.array_equal(safe_value, unknown_value) + elif isinstance(unknown_value, pd.DataFrame): + return safe_value.equals(unknown_value) + else: + # fall back to the default equality check + return safe_value == unknown_value + + def get_content_hash(obj: Any) -> str: if hasattr(obj, "__get_mandala_dict__"): obj = obj.__get_mandala_dict__()