From 25915ab105f6c410adf3570cbd1b83de3a6b6c11 Mon Sep 17 00:00:00 2001 From: Aleksandar Makelov Date: Sat, 24 Aug 2024 19:47:21 +0300 Subject: [PATCH] add support for arguments that should always be ignored when hashing --- mandala/model.py | 12 ++++++ mandala/storage.py | 62 ++++++++++++++++--------------- mandala/tests/test_memoization.py | 23 ++++++++++++ 3 files changed, 68 insertions(+), 29 deletions(-) diff --git a/mandala/model.py b/mandala/model.py index ad00341..94f1661 100644 --- a/mandala/model.py +++ b/mandala/model.py @@ -100,6 +100,7 @@ def __init__( nout: Union[Literal["var", "auto"], int] = "auto", output_names: Optional[List[str]] = None, version: Optional[int] = 0, + ignore_args: Optional[Tuple[str,...]] = None, # ignore these arguments when hashing __structural__: bool = False, __allow_side_effects__: bool = False, ) -> None: @@ -107,6 +108,7 @@ def __init__( self.nout = nout self.version = version self.output_names = output_names + self.ignore_args = ignore_args self.__structural__ = __structural__ self.__allow_side_effects__ = __allow_side_effects__ self.f = f @@ -405,15 +407,25 @@ def make_ref_set(resf: Iterable[Ref]) -> SetRef: def op( output_names: Union[Optional[List[str]], Callable] = None, nout: Union[Literal["var", "auto"], int] = "auto", + ignore_args: Optional[Tuple[str,...]] = None, __structural__: bool = False, __allow_side_effects__: bool = False, ): + """ + Decorator used to make a function memoized by the storage. Some options: + + - `ignore_args` is a tuple of argument names (keyword or positional) that + should be ignored when hashing the function. This is useful when the + function has arguments that are not relevant to the output, like a batch + size. + """ def decorator(f: Callable, output_names = None) -> 'f': # some IDE magic to make it recognize that @op(f) has the same type as f res = Op( f.__name__, f, output_names=output_names, nout=nout, + ignore_args=ignore_args, __structural__=__structural__, __allow_side_effects__=__allow_side_effects__, ) diff --git a/mandala/storage.py b/mandala/storage.py index 206186e..c13a2e2 100644 --- a/mandala/storage.py +++ b/mandala/storage.py @@ -636,36 +636,37 @@ def get_defaults(self, f: Callable) -> Dict[str, Any]: if v.default is not inspect.Parameter.empty } - def preprocess_args_kwargs(self, - f: Callable, - args: Tuple[Any, ...], - kwargs: Dict[str, Any]): - """ - Find which inputs should be ignored by the storage, and replace them - with their underlying objects. - """ - sig = inspect.signature(f) - defaults = self.get_defaults(f) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - ignored_inputs = set() - for k, v in bound_args.arguments.items(): - if isinstance(v, _Ignore): - ignored_inputs.add(k) - # now, check for defaults we should ignore - for k, v in defaults.items(): - given_value = bound_args.arguments[v] - if isinstance(given_value, _Ignore): - ignored_inputs.add(k) - elif isinstance(v, _NewArgDefault): - if isinstance(given_value, Ref): - if self.unwrap(given_value) == v.value: - ignored_inputs.add(k) - else: - if given_value == v.value: - ignored_inputs.add(k) + # def preprocess_args_kwargs(self, + # f: Callable, + # args: Tuple[Any, ...], + # kwargs: Dict[str, Any]): + # """ + # Find which inputs should be ignored by the storage, and replace them + # with their underlying objects. + # """ + # sig = inspect.signature(f) + # defaults = self.get_defaults(f) + # bound_args = sig.bind(*args, **kwargs) + # bound_args.apply_defaults() + # ignored_inputs = set() + # for k, v in bound_args.arguments.items(): + # if isinstance(v, _Ignore): + # ignored_inputs.add(k) + # # now, check for defaults we should ignore + # for k, v in defaults.items(): + # given_value = bound_args.arguments[v] + # if isinstance(given_value, _Ignore): + # ignored_inputs.add(k) + # elif isinstance(v, _NewArgDefault): + # if isinstance(given_value, Ref): + # if self.unwrap(given_value) == v.value: + # ignored_inputs.add(k) + # else: + # if given_value == v.value: + # ignored_inputs.add(k) - def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool) -> Tuple[inspect.BoundArguments, Dict[str, Any], Dict[str, Any]]: + def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool, + ignore_args: Optional[Tuple[str,...]] = None) -> Tuple[inspect.BoundArguments, Dict[str, Any], Dict[str, Any]]: """ Given the inputs passed to an @op call (could be wrapped or unwrapped), figure out the inputs we should pass to storage functions, their type @@ -708,6 +709,8 @@ def parse_args(self, sig: inspect.Signature, args, kwargs, apply_defaults: bool) storage_inputs[name] = val storage_annotations[name] = var_keyword.annotation else: + if ignore_args is not None and k in ignore_args: + v = Ignore(v) # could have a default if isinstance(v, _Ignore): # regardless of defaults, any _Ignore instance should be ignored @@ -1130,6 +1133,7 @@ def call( args=args, kwargs=kwargs, apply_defaults=True, + ignore_args=__op__.ignore_args, ) storage_tps = { diff --git a/mandala/tests/test_memoization.py b/mandala/tests/test_memoization.py index 623c7b6..7f2cd23 100644 --- a/mandala/tests/test_memoization.py +++ b/mandala/tests/test_memoization.py @@ -131,3 +131,26 @@ def chunked_square(elts: MList[int]) -> MList[int]: with storage: elts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] squares = chunked_square(elts) + + + +def test_ignore(): + + storage = Storage() + + @op(ignore_args=('irrelevant',)) + def inc(x, irrelevant): + return x + 1 + + + with storage: + inc(23, 0) + + df = storage.cf(inc).df() + assert len(df) == 1 + + with storage: + inc(23, 1) + + df = storage.cf(inc).df() + assert len(df) == 1 \ No newline at end of file