Skip to content

Commit

Permalink
another fix relevant to #28
Browse files Browse the repository at this point in the history
  • Loading branch information
amakelov committed Sep 9, 2024
1 parent 1f98b1d commit e3787f9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,8 @@ def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool,
# regardless of defaults, any _Ignore instance should be ignored
bound_arguments.arguments[k] = v.value
elif k in default_values and isinstance(default_values[k], _NewArgDefault):
if isinstance(v, Ref) and self.unwrap(v) == default_values[k].value:
# if isinstance(v, Ref) and self.unwrap(v) == default_values[k].value:
if isinstance(v, Ref) and _conservative_equality_check(safe_value=default_values[k].value, unknown_value=self.unwrap(v)):
# the value is wrapped
bound_arguments.arguments[k] = default_values[k].value
# elif v == default_values[k].value:
Expand Down
7 changes: 6 additions & 1 deletion mandala/tests/test_memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,10 @@ def add_array(x:np.ndarray):
@op
def add_array(x:np.ndarray, y=NewArgDefault(None)):
return x + y
# test passing a raw value
with storage:
add_array(np.array([1, 2, 3]), y=np.array([4, 5, 6]))
add_array(np.array([1, 2, 3]), y=np.array([4, 5, 6]))

# now test passing a wrapped value
with storage:
add_array(np.array([1, 2, 3]), y=wrap_atom(np.array([7, 8, 9])))

0 comments on commit e3787f9

Please sign in to comment.