From 63c29f33976c01f7caf2cfe4ef19beaf3cfae853 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 29 Aug 2024 16:57:13 -0500 Subject: [PATCH] add collision/duplication checks to CachedMapper/TransformMapper/TransformMapperWithExtraArgs --- pytato/distributed/partition.py | 2 +- pytato/transform/__init__.py | 383 +++++++++++++++++++++++++++++--- pytato/transform/metadata.py | 4 +- test/test_apps.py | 5 +- 4 files changed, 362 insertions(+), 32 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 9911e8606..c6730b67a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -346,7 +346,7 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override] key = self._cache.get_key(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(expr, key=key) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 304a1f2e1..cd33183dd 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -171,6 +171,14 @@ class UnsupportedArrayError(ValueError): pass +class CacheCollisionError(ValueError): + pass + + +class CacheNoOpDuplicationError(ValueError): + pass + + # {{{ mapper base class class Mapper: @@ -262,15 +270,21 @@ class CachedMapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT]): """ def __init__( self, - key_func: Callable[[CacheExprT], CacheKeyT]) -> None: + key_func: Callable[[CacheExprT], CacheKeyT], + err_on_collision: bool) -> None: """ Initialize the cache. :arg key_func: Function to compute a hashable cache key from an input expression. + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. """ + self.err_on_collision = err_on_collision self._key_func = key_func self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + if self.err_on_collision: + self._expr_key_to_expr: dict[CacheKeyT, CacheExprT] = {} # FIXME: Can this be inlined? def get_key(self, expr: CacheExprT) -> CacheKeyT: @@ -291,6 +305,8 @@ def add( key = self._key_func(expr) self._expr_key_to_result[key] = result + if self.err_on_collision: + self._expr_key_to_expr[key] = expr return result @@ -302,7 +318,13 @@ def retrieve( if key is None: key = self._key_func(expr) - return self._expr_key_to_result[key] + result = self._expr_key_to_result[key] + + if self.err_on_collision: + if expr is not self._expr_key_to_expr[key]: + raise CacheCollisionError + + return result class CachedMapper(Mapper, Generic[CachedMapperT, CachedMapperFunctionT]): @@ -335,34 +357,73 @@ class CachedMapper(Mapper, Generic[CachedMapperT, CachedMapperFunctionT]): def __init__( self, + err_on_collision: bool = False, # Arrays are cached separately for each call stack frame, but # functions are cached globally _function_cache: _FunctionCacheT[CachedMapperFunctionT] | None = None ) -> None: super().__init__() self._cache: CachedMapper._CacheT[CachedMapperT] = \ - CachedMapper._CacheType(lambda expr: expr) + CachedMapper._CacheType( + lambda expr: expr, + err_on_collision=err_on_collision) if _function_cache is None: - _function_cache = CachedMapper._FunctionCacheType(lambda expr: expr) + _function_cache = CachedMapper._FunctionCacheType( + lambda expr: expr, + err_on_collision=err_on_collision) self._function_cache: CachedMapper._FunctionCacheT[CachedMapperFunctionT] = \ _function_cache + def _cache_add( + self, + expr: ArrayOrNames, + result: CachedMapperT, + key: Hashable | None = None) -> CachedMapperT: + return self._cache.add(expr, result, key=key) + + def _function_cache_add( + self, + expr: FunctionDefinition, + result: CachedMapperFunctionT, + key: Hashable | None = None) -> CachedMapperFunctionT: + return self._function_cache.add(expr, result, key=key) + + def _cache_retrieve( + self, + expr: ArrayOrNames, + key: Hashable | None = None) -> CachedMapperT: + try: + return self._cache.retrieve(expr, key=key) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(expr)} in {type(self)}.") from e + + def _function_cache_retrieve( + self, + expr: FunctionDefinition, + key: Hashable | None = None) -> CachedMapperFunctionT: + try: + return self._function_cache.retrieve(expr, key=key) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(expr)} in {type(self)}.") from e + def rec(self, expr: ArrayOrNames) -> CachedMapperT: key = self._cache.get_key(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(expr, key=key) except KeyError: - return self._cache.add(expr, super().rec(expr), key=key) + return self._cache_add(expr, super().rec(expr), key=key) def rec_function_definition( self, expr: FunctionDefinition) -> CachedMapperFunctionT: key = self._function_cache.get_key(expr) try: - return self._function_cache.retrieve(expr, key=key) + return self._function_cache_retrieve(expr, key=key) except KeyError: - return self._function_cache.add( + return self._function_cache_add( expr, super().rec_function_definition(expr), key=key) if TYPE_CHECKING: @@ -378,6 +439,7 @@ def clone_for_callee( # type-ignore-reason: self.__init__ has a different function signature # than Mapper.__init__ return type(self)( # type: ignore[call-arg] + err_on_collision=self._cache.err_on_collision, # type: ignore[attr-defined] _function_cache=self._function_cache) # type: ignore[attr-defined] # }}} @@ -385,6 +447,74 @@ def clone_for_callee( # {{{ TransformMapper +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheKeyT, CacheExprT]): + """ + Cache for :class:`TransformMapper`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + key_func: Callable[[CacheExprT], CacheKeyT], + err_on_collision: bool, + err_on_no_op_duplication: bool) -> None: + """ + Initialize the cache. + + :arg key_func: Function to compute a hashable cache key from an input + expression. + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(key_func, err_on_collision=err_on_collision) + + self.err_on_no_op_duplication = err_on_no_op_duplication + + def add( + self, + expr: CacheExprT, + result: CacheExprT, + key: CacheKeyT | None = None, + result_key: CacheKeyT | None = None) -> CacheExprT: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + if key is None: + key = self._key_func(expr) + if result_key is None: + result_key = self._key_func(result) + + if ( + self.err_on_no_op_duplication + and hash(result_key) == hash(key) + and result_key == key + and result is not expr + # This is questionable in two ways: + # 1) It will not detect duplication of things that are not + # considered direct predecessors (e.g. a Call's + # FunctionDefinition). Not sure how to handle such cases + # 2) DirectPredecessorsGetter doesn't accept FunctionDefinitions, + # but CacheExprT is allowed to be one + and all( + result_pred is pred + for pred, result_pred in zip( + DirectPredecessorsGetter()(expr), + DirectPredecessorsGetter()(result)))): + raise CacheNoOpDuplicationError from None + + self._expr_key_to_result[key] = result + if self.err_on_collision: + self._expr_key_to_expr[key] = expr + + return result + + class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition]): """Base class for mappers that transform :class:`pytato.array.Array`\\ s into other :class:`pytato.array.Array`\\ s. @@ -392,14 +522,17 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition]): Enables certain operations that can only be done if the mapping results are also arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ - _CacheType: type[Any] = CachedMapperCache[ArrayOrNames, Hashable, ArrayOrNames] - _CacheT: TypeAlias = CachedMapperCache[ArrayOrNames, Hashable, ArrayOrNames] + _CacheType: type[Any] = TransformMapperCache[ArrayOrNames, Hashable] + _CacheT: TypeAlias = TransformMapperCache[ArrayOrNames, Hashable] - _FunctionCacheType: type[Any] = CachedMapperCache[ - FunctionDefinition, Hashable, FunctionDefinition] - _FunctionCacheT: TypeAlias = CachedMapperCache[ - FunctionDefinition, Hashable, FunctionDefinition] + _FunctionCacheType: type[Any] = TransformMapperCache[ + FunctionDefinition, Hashable] + _FunctionCacheT: TypeAlias = TransformMapperCache[ + FunctionDefinition, Hashable] if TYPE_CHECKING: def rec(self, expr: TransformMapperResultT) -> TransformMapperResultT: @@ -408,6 +541,72 @@ def rec(self, expr: TransformMapperResultT) -> TransformMapperResultT: def __call__(self, expr: TransformMapperResultT) -> TransformMapperResultT: return self.rec(expr) + def __init__( + self, + err_on_collision: bool = False, + err_on_no_op_duplication: bool = False, + _function_cache: _FunctionCacheT | None = None + ) -> None: + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _function_cache is None: + _function_cache = TransformMapper._FunctionCacheType( + lambda expr: expr, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + super().__init__( + err_on_collision=err_on_collision, + _function_cache=_function_cache) + + self._cache: TransformMapper._CacheT = TransformMapper._CacheType( + lambda expr: expr, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + self._function_cache: TransformMapper._FunctionCacheT = self._function_cache + + def _cache_add( + self, + expr: TransformMapperResultT, + result: TransformMapperResultT, + key: Hashable | None = None) -> TransformMapperResultT: + try: + return self._cache.add(expr, result, key=key) # type: ignore[return-value] + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + expr: FunctionDefinition, + result: FunctionDefinition, + key: Hashable | None = None) -> FunctionDefinition: + try: + return self._function_cache.add(expr, result, key=key) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(expr)} in " + f"{type(self)}.") from e + + def clone_for_callee( + self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + # type-ignore-reason: self.__init__ has a different function signature + # than Mapper.__init__ + return type(self)( # type: ignore[call-arg] + err_on_collision=self._cache.err_on_collision, # type: ignore[attr-defined] + err_on_no_op_duplication=self._cache.err_on_no_op_duplication, # type: ignore[attr-defined] + _function_cache=self._function_cache) # type: ignore[attr-defined] + # }}} @@ -425,14 +624,22 @@ class TransformMapperWithExtraArgsCache( """ def __init__( self, - key_func: Callable[..., CacheKeyT]) -> None: + key_func: Callable[..., CacheKeyT], + err_on_collision: bool, + err_on_no_op_duplication: bool) -> None: """ Initialize the cache. :arg key_func: Function to compute a hashable cache key from an input expression and extra arguments. + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. """ - super().__init__(key_func) + super().__init__(key_func, err_on_collision=err_on_collision) + + self.err_on_no_op_duplication = err_on_no_op_duplication def get_key(self, expr: CacheExprT, *args: Any, **kwargs: Any) -> CacheKeyT: """Compute the key for an input expression.""" @@ -444,12 +651,40 @@ def add( # type: ignore[override] key_args: tuple[Any, ...], key_kwargs: dict[str, Any], result: CacheExprT, - key: CacheKeyT | None = None) -> CacheExprT: - """Cache a mapping result.""" + key: CacheKeyT | None = None, + result_key: CacheKeyT | None = None) -> CacheExprT: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ if key is None: key = self._key_func(expr, *key_args, **key_kwargs) + if result_key is None: + result_key = self._key_func(result, *key_args, **key_kwargs) + + if ( + self.err_on_no_op_duplication + and hash(result_key) == hash(key) + and result_key == key + and result is not expr + # This is questionable in two ways: + # 1) It will not detect duplication of things that are not + # considered direct predecessors (e.g. a Call's + # FunctionDefinition). Not sure how to handle such cases + # 2) DirectPredecessorsGetter doesn't accept FunctionDefinitions, + # but CacheExprT is allowed to be one + and all( + result_pred is pred + for pred, result_pred in zip( + DirectPredecessorsGetter()(expr), + DirectPredecessorsGetter()(result)))): + raise CacheNoOpDuplicationError from None self._expr_key_to_result[key] = result + if self.err_on_collision: + self._expr_key_to_expr[key] = expr return result @@ -463,7 +698,13 @@ def retrieve( # type: ignore[override] if key is None: key = self._key_func(expr, *key_args, **key_kwargs) - return self._expr_key_to_result[key] + result = self._expr_key_to_result[key] + + if self.err_on_collision: + if expr is not self._expr_key_to_expr[key]: + raise CacheCollisionError + + return result class TransformMapperWithExtraArgs(CachedMapper[ArrayOrNames, FunctionDefinition]): @@ -473,6 +714,9 @@ class TransformMapperWithExtraArgs(CachedMapper[ArrayOrNames, FunctionDefinition The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ _CacheType: type[Any] = TransformMapperWithExtraArgsCache[ ArrayOrNames, Hashable] @@ -492,8 +736,16 @@ def __call__( def __init__( self, + err_on_collision: bool = False, + err_on_no_op_duplication: bool = False, _function_cache: _FunctionCacheT | None = None ) -> None: + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ def key_func( expr: ArrayOrNames | FunctionDefinition, *args: Any, **kwargs: Any) -> Hashable: @@ -501,25 +753,87 @@ def key_func( if _function_cache is None: _function_cache = TransformMapperWithExtraArgs._FunctionCacheType( - key_func) + key_func, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) - super().__init__(_function_cache=_function_cache) + super().__init__( + err_on_collision=err_on_collision, + _function_cache=_function_cache) self._cache: TransformMapperWithExtraArgs._CacheT = \ - TransformMapperWithExtraArgs._CacheType(key_func) + TransformMapperWithExtraArgs._CacheType( + key_func, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) self._function_cache: TransformMapperWithExtraArgs._FunctionCacheT = \ self._function_cache + def _cache_add( # type: ignore[override] + self, + expr: TransformMapperResultT, + key_args: tuple[Any, ...], + key_kwargs: dict[str, Any], + result: TransformMapperResultT, + key: Hashable | None = None) -> TransformMapperResultT: + try: + return self._cache.add(expr, key_args, key_kwargs, result, key=key) # type: ignore[return-value] + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( # type: ignore[override] + self, + expr: FunctionDefinition, + key_args: tuple[Any, ...], + key_kwargs: dict[str, Any], + result: FunctionDefinition, + key: Hashable | None = None) -> FunctionDefinition: + try: + return self._function_cache.add( + expr, key_args, key_kwargs, result, key=key) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(expr)} in " + f"{type(self)}.") from e + + def _cache_retrieve( # type: ignore[override] + self, + expr: TransformMapperResultT, + key_args: tuple[Any, ...], + key_kwargs: dict[str, Any], + key: Hashable | None = None) -> TransformMapperResultT: + try: + return self._cache.retrieve( # type: ignore[return-value] + expr, key_args, key_kwargs, key=key) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(expr)} in {type(self)}.") from e + + def _function_cache_retrieve( # type: ignore[override] + self, + expr: FunctionDefinition, + key_args: tuple[Any, ...], + key_kwargs: dict[str, Any], + key: Hashable | None = None) -> FunctionDefinition: + try: + return self._function_cache.retrieve( + expr, key_args, key_kwargs, key=key) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(expr)} in {type(self)}.") from e + def rec( self, expr: TransformMapperResultT, *args: Any, **kwargs: Any) -> TransformMapperResultT: key = self._cache.get_key(expr, *args, **kwargs) try: - return self._cache.retrieve(expr, args, kwargs, key=key) # type: ignore[return-value] + return self._cache_retrieve(expr, args, kwargs, key=key) except KeyError: - return self._cache.add( # type: ignore[return-value] + return self._cache_add( expr, args, kwargs, Mapper.rec(self, expr, *args, **kwargs), key=key) def rec_function_definition( @@ -528,12 +842,25 @@ def rec_function_definition( *args: Any, **kwargs: Any) -> FunctionDefinition: key = self._function_cache.get_key(expr, *args, **kwargs) try: - return self._function_cache.retrieve(expr, args, kwargs, key=key) + return self._function_cache_retrieve(expr, args, kwargs, key=key) except KeyError: - return self._function_cache.add( + return self._function_cache_add( expr, args, kwargs, Mapper.rec_function_definition(self, expr, *args, **kwargs), key=key) + def clone_for_callee( + self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + # type-ignore-reason: self.__init__ has a different function signature + # than Mapper.__init__ + return type(self)( # type: ignore[call-arg] + err_on_collision=self._cache.err_on_collision, # type: ignore[attr-defined] + err_on_no_op_duplication=self._cache.err_on_no_op_duplication, # type: ignore[attr-defined] + _function_cache=self._function_cache) # type: ignore[attr-defined] + # }}} @@ -1655,9 +1982,9 @@ def clone_for_callee( def rec(self, expr: MappedT) -> MappedT: key = self._cache.get_key(expr) try: - return self._cache.retrieve(expr, key=key) # type: ignore[return-value] + return self._cache_retrieve(expr, key=key) # type: ignore[return-value] except KeyError: - return self._cache.add( # type: ignore[return-value] + return self._cache_add( expr, super().rec(self.map_fn(expr)), key=key) if TYPE_CHECKING: diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 6e8ccf19b..c86db81e4 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -645,14 +645,14 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: def rec(self, expr: ArrayOrNames) -> Any: key = self._cache.get_key(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(expr, key=key) except KeyError: result = Mapper.rec(self, expr) if not isinstance(expr, (AbstractResultWithNamedArrays, DistributedSendRefHolder)): assert isinstance(expr, Array) result = self._attach_tags(expr, result) - return self._cache.add(expr, result, key=key) + return self._cache_add(expr, result, key=key) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( diff --git a/test/test_apps.py b/test/test_apps.py index 89d52218e..9cc29733d 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -94,7 +94,10 @@ def __init__(self, fft_vec_gatherer): arrays = fft_vec_gatherer.level_to_arrays[lev] rec_arrays = [self.rec(ary) for ary in arrays] # reset cache so that the partial subs are not stored - self._cache = type(self._cache)(lambda expr: expr) + self._cache = type(self._cache)( + lambda expr: expr, + err_on_collision=self._cache.err_on_collision, + err_on_no_op_duplication=self._cache.err_on_no_op_duplication) lev_array = pt.concatenate(rec_arrays, axis=0) assert lev_array.shape == (fft_vec_gatherer.n,)