From e3787f95b6e1057e992499eeda6052f9a6a0a36e Mon Sep 17 00:00:00 2001 From: Aleksandar Makelov Date: Mon, 9 Sep 2024 03:24:29 +0300 Subject: [PATCH] another fix relevant to #28 --- mandala/storage.py | 3 ++- mandala/tests/test_memoization.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mandala/storage.py b/mandala/storage.py index a337108..cd383e6 100644 --- a/mandala/storage.py +++ b/mandala/storage.py @@ -750,7 +750,8 @@ def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool, # regardless of defaults, any _Ignore instance should be ignored bound_arguments.arguments[k] = v.value elif k in default_values and isinstance(default_values[k], _NewArgDefault): - if isinstance(v, Ref) and self.unwrap(v) == default_values[k].value: + # if isinstance(v, Ref) and self.unwrap(v) == default_values[k].value: + if isinstance(v, Ref) and _conservative_equality_check(safe_value=default_values[k].value, unknown_value=self.unwrap(v)): # the value is wrapped bound_arguments.arguments[k] = default_values[k].value # elif v == default_values[k].value: diff --git a/mandala/tests/test_memoization.py b/mandala/tests/test_memoization.py index 0ff0e18..727c057 100644 --- a/mandala/tests/test_memoization.py +++ b/mandala/tests/test_memoization.py @@ -222,5 +222,10 @@ def add_array(x:np.ndarray): @op def add_array(x:np.ndarray, y=NewArgDefault(None)): return x + y + # test passing a raw value with storage: - add_array(np.array([1, 2, 3]), y=np.array([4, 5, 6])) \ No newline at end of file + add_array(np.array([1, 2, 3]), y=np.array([4, 5, 6])) + + # now test passing a wrapped value + with storage: + add_array(np.array([1, 2, 3]), y=wrap_atom(np.array([7, 8, 9]))) \ No newline at end of file