Skip to content

Commit

Permalink
partially fix #28
Browse files Browse the repository at this point in the history
  • Loading branch information
amakelov committed Sep 8, 2024
1 parent 1311267 commit 1f98b1d
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 7 deletions.
26 changes: 20 additions & 6 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .model import *
import sqlite3
from .model import __make_list__, __list_getitem__, __make_dict__, __dict_getitem__, _Ignore, _NewArgDefault
from .utils import dataframe_to_prettytable, parse_returns
from .utils import dataframe_to_prettytable, parse_returns, _conservative_equality_check
from .viz import _get_colorized_diff
from .deps.versioner import Versioner, CodeState
from .deps.utils import get_dep_key_from_func, extract_func_obj
Expand Down Expand Up @@ -51,6 +51,10 @@ def __init__(self,
else:
self.overflow_storage = None

# storage for ref values: {cid -> serialized object}
# storing the serialized object prevents accidental modification of the
# object in memory, but also means that we have to deserialize it when
# we want to use it.
self.atoms = CachedDictStorage(
persistent=SQLiteDictStorage(self.db, table="atoms",
overflow_storage=self.overflow_storage,
Expand Down Expand Up @@ -122,7 +126,7 @@ def vacuum(self):
############################################################################
### managing the caches
############################################################################
def clear_cache(self, allow_uncommitted: False):
def clear_cache(self, allow_uncommitted: bool = False):
self.atoms.clear(allow_uncommited=allow_uncommitted)
self.shapes.clear(allow_uncommited=allow_uncommitted)
self.ops.clear(allow_uncommited=allow_uncommitted)
Expand Down Expand Up @@ -223,6 +227,12 @@ def save_ref(self, ref: Ref):
raise NotImplementedError

def load_ref(self, hid: str, in_memory: bool = False) -> Ref:
"""
Loads the Ref with the given `hid`, *and* caches it in the `.atoms` and
`.shapes`.
TODO: add option to disable automatic caching of atoms.
"""
shape = self.shapes[hid]
if isinstance(shape, AtomRef):
if in_memory:
Expand Down Expand Up @@ -435,10 +445,13 @@ def get_unreferenced_cids(self) -> Set[str]:
############################################################################
###
############################################################################
def _unwrap_atom(self, obj: Any) -> Any:
def _unwrap_atom(self, obj: Any, cache: bool = True) -> Any:
# TODO: implement `cache = False` in `load_ref`
assert isinstance(obj, AtomRef)
if not obj.in_memory:
ref = self.load_ref(hid=obj.hid, in_memory=False)
if cache:
self.atoms[obj.cid] = serialize(ref.obj)
return ref.obj
else:
return obj.obj
Expand All @@ -458,7 +471,7 @@ def _attach_atom(self, ref: AtomRef, inplace: bool = False) -> Optional[AtomRef]
else:
return ref.attached(obj=deserialize(self.atoms[ref.cid]))

def unwrap(self, obj: Any) -> Any:
def unwrap(self, obj: Any, cache: bool = True) -> Any:
"""
Given a `Ref` or a nested python collection containing `Ref`s, return
the "unwrapped" object, where all `Ref`s are replaced by the objects
Expand All @@ -467,7 +480,7 @@ def unwrap(self, obj: Any) -> Any:
NOTE: will trigger a load from the storage backend when some of the
objects are not in memory.
"""
return recurse_on_ref_collections(self._unwrap_atom, obj)
return recurse_on_ref_collections(self._unwrap_atom, obj, **{"cache": cache})

def attach(self, obj: T, inplace: bool = False) -> Optional[T]:
"""
Expand Down Expand Up @@ -740,7 +753,8 @@ def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool,
if isinstance(v, Ref) and self.unwrap(v) == default_values[k].value:
# the value is wrapped
bound_arguments.arguments[k] = default_values[k].value
elif v == default_values[k].value:
# elif v == default_values[k].value:
elif _conservative_equality_check(safe_value=default_values[k].value, unknown_value=v):
# the value is unwrapped
bound_arguments.arguments[k] = default_values[k].value
else:
Expand Down
49 changes: 48 additions & 1 deletion mandala/tests/test_memoization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mandala.imports import *
import numpy as np


def test_storage():
Expand Down Expand Up @@ -176,4 +177,50 @@ def inc(x):

# now clear the atoms cache after committing
storage.commit()
storage.atoms.clear()
storage.atoms.clear()



def test_newargdefault():
storage = Storage()

@op
def add(x,):
return x + 1

with storage:
add(1)

@op
def add(x, y=NewArgDefault(1)):
return x + y

with storage:
add(1)
# check that we didn't make a new call
assert len(storage.cf(add).calls) == 1

with storage:
add(1, 1)
# check that we didn't make a new call
assert len(storage.cf(add).calls) == 1

with storage:
add(1, 2)
# now this should have made a new call!
assert len(storage.cf(add).calls) == 2

def test_newargdefault_compound_types():
storage = Storage()

@op
def add_array(x:np.ndarray):
return x
with storage:
add_array(np.array([1, 2, 3]))

@op
def add_array(x:np.ndarray, y=NewArgDefault(None)):
return x + y
with storage:
add_array(np.array([1, 2, 3]), y=np.array([4, 5, 6]))
20 changes: 20 additions & 0 deletions mandala/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ def deserialize(value: bytes) -> Any:
return joblib.load(buffer)


def _conservative_equality_check(safe_value: Any, unknown_value: Any) -> bool:
"""
An equality checker that treats `safe_value` as a "simple" type, but is
conservative about how __eq__ can be applied to `unknown_value`. This is
necessary when comparing against e.g. numpy arrays.
"""
if type(safe_value) != type(unknown_value):
return False
if isinstance(unknown_value, (int, float, str, bytes, bool, type(None))):
return safe_value == unknown_value
# handle some common cases
if isinstance(unknown_value, np.ndarray):
return np.array_equal(safe_value, unknown_value)
elif isinstance(unknown_value, pd.DataFrame):
return safe_value.equals(unknown_value)
else:
# fall back to the default equality check
return safe_value == unknown_value


def get_content_hash(obj: Any) -> str:
if hasattr(obj, "__get_mandala_dict__"):
obj = obj.__get_mandala_dict__()
Expand Down

0 comments on commit 1f98b1d

Please sign in to comment.