Skip to content

Commit

Permalink
add support for arguments that should always be ignored when hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
amakelov committed Aug 24, 2024
1 parent d628910 commit 25915ab
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 29 deletions.
12 changes: 12 additions & 0 deletions mandala/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ def __init__(
nout: Union[Literal["var", "auto"], int] = "auto",
output_names: Optional[List[str]] = None,
version: Optional[int] = 0,
ignore_args: Optional[Tuple[str,...]] = None, # ignore these arguments when hashing
__structural__: bool = False,
__allow_side_effects__: bool = False,
) -> None:
self.name = name
self.nout = nout
self.version = version
self.output_names = output_names
self.ignore_args = ignore_args
self.__structural__ = __structural__
self.__allow_side_effects__ = __allow_side_effects__
self.f = f
Expand Down Expand Up @@ -405,15 +407,25 @@ def make_ref_set(resf: Iterable[Ref]) -> SetRef:
def op(
output_names: Union[Optional[List[str]], Callable] = None,
nout: Union[Literal["var", "auto"], int] = "auto",
ignore_args: Optional[Tuple[str,...]] = None,
__structural__: bool = False,
__allow_side_effects__: bool = False,
):
"""
Decorator used to make a function memoized by the storage. Some options:
- `ignore_args` is a tuple of argument names (keyword or positional) that
should be ignored when hashing the function. This is useful when the
function has arguments that are not relevant to the output, like a batch
size.
"""
def decorator(f: Callable, output_names = None) -> 'f': # some IDE magic to make it recognize that @op(f) has the same type as f
res = Op(
f.__name__,
f,
output_names=output_names,
nout=nout,
ignore_args=ignore_args,
__structural__=__structural__,
__allow_side_effects__=__allow_side_effects__,
)
Expand Down
62 changes: 33 additions & 29 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,36 +636,37 @@ def get_defaults(self, f: Callable) -> Dict[str, Any]:
if v.default is not inspect.Parameter.empty
}

def preprocess_args_kwargs(self,
f: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any]):
"""
Find which inputs should be ignored by the storage, and replace them
with their underlying objects.
"""
sig = inspect.signature(f)
defaults = self.get_defaults(f)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
ignored_inputs = set()
for k, v in bound_args.arguments.items():
if isinstance(v, _Ignore):
ignored_inputs.add(k)
# now, check for defaults we should ignore
for k, v in defaults.items():
given_value = bound_args.arguments[v]
if isinstance(given_value, _Ignore):
ignored_inputs.add(k)
elif isinstance(v, _NewArgDefault):
if isinstance(given_value, Ref):
if self.unwrap(given_value) == v.value:
ignored_inputs.add(k)
else:
if given_value == v.value:
ignored_inputs.add(k)
# def preprocess_args_kwargs(self,
# f: Callable,
# args: Tuple[Any, ...],
# kwargs: Dict[str, Any]):
# """
# Find which inputs should be ignored by the storage, and replace them
# with their underlying objects.
# """
# sig = inspect.signature(f)
# defaults = self.get_defaults(f)
# bound_args = sig.bind(*args, **kwargs)
# bound_args.apply_defaults()
# ignored_inputs = set()
# for k, v in bound_args.arguments.items():
# if isinstance(v, _Ignore):
# ignored_inputs.add(k)
# # now, check for defaults we should ignore
# for k, v in defaults.items():
# given_value = bound_args.arguments[v]
# if isinstance(given_value, _Ignore):
# ignored_inputs.add(k)
# elif isinstance(v, _NewArgDefault):
# if isinstance(given_value, Ref):
# if self.unwrap(given_value) == v.value:
# ignored_inputs.add(k)
# else:
# if given_value == v.value:
# ignored_inputs.add(k)

def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool) -> Tuple[inspect.BoundArguments, Dict[str, Any], Dict[str, Any]]:
def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool,
ignore_args: Optional[Tuple[str,...]] = None) -> Tuple[inspect.BoundArguments, Dict[str, Any], Dict[str, Any]]:
"""
Given the inputs passed to an @op call (could be wrapped or unwrapped),
figure out the inputs we should pass to storage functions, their type
Expand Down Expand Up @@ -708,6 +709,8 @@ def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool)
storage_inputs[name] = val
storage_annotations[name] = var_keyword.annotation
else:
if ignore_args is not None and k in ignore_args:
v = Ignore(v)
# could have a default
if isinstance(v, _Ignore):
# regardless of defaults, any _Ignore instance should be ignored
Expand Down Expand Up @@ -1130,6 +1133,7 @@ def call(
args=args,
kwargs=kwargs,
apply_defaults=True,
ignore_args=__op__.ignore_args,
)

storage_tps = {
Expand Down
23 changes: 23 additions & 0 deletions mandala/tests/test_memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,26 @@ def chunked_square(elts: MList[int]) -> MList[int]:
with storage:
elts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
squares = chunked_square(elts)



def test_ignore():

storage = Storage()

@op(ignore_args=('irrelevant',))
def inc(x, irrelevant):
return x + 1


with storage:
inc(23, 0)

df = storage.cf(inc).df()
assert len(df) == 1

with storage:
inc(23, 1)

df = storage.cf(inc).df()
assert len(df) == 1

0 comments on commit 25915ab

Please sign in to comment.