From 1e30fd131a56f7e8a448db99aba0c6c2cefdf531 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 23 Jan 2025 10:25:23 -0600 Subject: [PATCH 1/5] make CombineMapper inherit from CachedMapper --- pytato/transform/__init__.py | 33 ++++++--------------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 40cb2b6e2..43d050c65 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -768,44 +768,23 @@ def map_named_call_result(self, expr: NamedCallResult, # {{{ CombineMapper -class CombineMapper(Mapper[ResultT, FunctionResultT, []]): +class CombineMapper(CachedMapper[ResultT, FunctionResultT, []]): """ Abstract mapper that recursively combines the results of user nodes of a given expression. .. automethod:: combine """ - def __init__( - self, - _function_cache: dict[FunctionDefinition, FunctionResultT] | None = None - ) -> None: - super().__init__() - self.cache: dict[ArrayOrNames, ResultT] = {} - self.function_cache: dict[FunctionDefinition, FunctionResultT] = \ - _function_cache if _function_cache is not None else {} + def get_cache_key(self, expr: ArrayOrNames) -> Hashable: + return expr + + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> Hashable: + return expr def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[ResultT, ...]: return tuple(self.rec(s) for s in situp if isinstance(s, Array)) - def rec(self, expr: ArrayOrNames) -> ResultT: - if expr in self.cache: - return self.cache[expr] - result: ResultT = super().rec(expr) - self.cache[expr] = result - return result - - def rec_function_definition( - self, expr: FunctionDefinition) -> FunctionResultT: - if expr in self.function_cache: - return self.function_cache[expr] - result: FunctionResultT = super().rec_function_definition(expr) - self.function_cache[expr] = result - return result - - def __call__(self, expr: ArrayOrNames) -> ResultT: - return self.rec(expr) - def combine(self, *args: ResultT) -> ResultT: """Combine the arguments.""" raise NotImplementedError From 59d8d0433ca1a0ff81c520db5e9e07b1c5159e42 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 2 Jul 2024 07:57:09 -0500 Subject: [PATCH 2/5] move axis tag attaching code into a separate method in AxisTagAttacher --- pytato/transform/metadata.py | 103 ++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 49 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 49190c76e..635a478d8 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -601,56 +601,61 @@ def __init__(self, self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr + def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: + assert rec_expr.ndim == expr.ndim + + result = rec_expr + + for iaxis in range(expr.ndim): + result = result.with_tagged_axis( + iaxis, self.axis_to_tags.get((expr, iaxis), [])) + + # {{{ tag reduction descrs + + if self.tag_corresponding_redn_descr: + if isinstance(expr, Einsum): + assert isinstance(result, Einsum) + for arg, access_descrs in zip(expr.args, + expr.access_descriptors, + strict=True): + for iaxis, access_descr in enumerate(access_descrs): + if isinstance(access_descr, EinsumReductionAxis): + result = result.with_tagged_reduction( + access_descr, + self.axis_to_tags.get((arg, iaxis), []) + ) + + if isinstance(expr, IndexLambda): + assert isinstance(result, IndexLambda) + try: + hlo = index_lambda_to_high_level_op(expr) + except UnknownIndexLambdaExpr: + pass + else: + if isinstance(hlo, ReduceOp): + for iaxis, redn_var in hlo.axes.items(): + result = result.with_tagged_reduction( + redn_var, + self.axis_to_tags.get((hlo.x, iaxis), []) + ) + + # }}} + + return result + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - if isinstance(expr, AbstractResultWithNamedArrays | DistributedSendRefHolder): - return super().rec(expr) - else: - assert isinstance(expr, Array) - key = self.get_cache_key(expr) - try: - return self._cache[key] - except KeyError: - expr_copy = Mapper.rec(self, expr) - assert isinstance(expr_copy, Array) - assert expr_copy.ndim == expr.ndim - - for iaxis in range(expr.ndim): - expr_copy = expr_copy.with_tagged_axis( - iaxis, self.axis_to_tags.get((expr, iaxis), [])) - - # {{{ tag reduction descrs - - if self.tag_corresponding_redn_descr: - if isinstance(expr, Einsum): - assert isinstance(expr_copy, Einsum) - for arg, access_descrs in zip(expr.args, - expr.access_descriptors, - strict=True): - for iaxis, access_descr in enumerate(access_descrs): - if isinstance(access_descr, EinsumReductionAxis): - expr_copy = expr_copy.with_tagged_reduction( - access_descr, - self.axis_to_tags.get((arg, iaxis), []) - ) - - if isinstance(expr, IndexLambda): - assert isinstance(expr_copy, IndexLambda) - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - pass - else: - if isinstance(hlo, ReduceOp): - for iaxis, redn_var in hlo.axes.items(): - expr_copy = expr_copy.with_tagged_reduction( - redn_var, - self.axis_to_tags.get((hlo.x, iaxis), []) - ) - - # }}} - - self._cache[key] = expr_copy - return expr_copy + key = self.get_cache_key(expr) + try: + return self._cache[key] + except KeyError: + result = Mapper.rec(self, expr) + if not isinstance( + expr, AbstractResultWithNamedArrays | DistributedSendRefHolder): + assert isinstance(expr, Array) + # type-ignore reason: passed "ArrayOrNames"; expected "Array" + result = self._attach_tags(expr, result) # type: ignore[arg-type] + self._cache[key] = result + return result def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( From 530f7e1671a7e183da8b3fc48050815ebee8d61e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 20 Dec 2024 15:24:12 -0600 Subject: [PATCH 3/5] use a class for CachedMapper caches instead of using a dict directly --- pytato/codegen.py | 8 +- pytato/distributed/partition.py | 16 ++- pytato/transform/__init__.py | 187 ++++++++++++++++++++++++++------ pytato/transform/metadata.py | 17 +-- test/test_apps.py | 2 +- 5 files changed, 178 insertions(+), 52 deletions(-) diff --git a/pytato/codegen.py b/pytato/codegen.py index eae7ea286..cb957f076 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -63,12 +63,13 @@ CachedWalkMapper, CopyMapper, SubsetDependencyMapper, + TransformMapperCache, ) from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin if TYPE_CHECKING: - from collections.abc import Hashable, Mapping + from collections.abc import Mapping from pytato.function import FunctionDefinition, NamedCallResult from pytato.target import Target @@ -137,9 +138,10 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _function_cache: dict[Hashable, FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_function_cache=_function_cache) + super().__init__(_cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 7662c7205..2ab1163a1 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -94,6 +94,7 @@ CachedWalkMapper, CombineMapper, CopyMapper, + TransformMapperCache, _verify_is_array, ) @@ -239,9 +240,11 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _function_cache: dict[Hashable, FunctionDefinition] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: + TransformMapperCache[FunctionDefinition, []] | None = None, ) -> None: - super().__init__(_function_cache=_function_cache) + super().__init__(_cache=_cache, _function_cache=_function_cache) self.recvd_ary_to_name = recvd_ary_to_name self.sptpo_ary_to_name = sptpo_ary_to_name @@ -255,7 +258,10 @@ def clone_for_callee( self, function: FunctionDefinition) -> _DistributedInputReplacer: # Function definitions aren't allowed to contain receives, # stored arrays promoted to part outputs, or part outputs - return type(self)({}, {}, {}, _function_cache=self._function_cache) + return type(self)( + {}, {}, {}, + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) @@ -288,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: return new_send def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self.get_cache_key(expr) + key = self._cache.get_key(expr) try: - return self._cache[key] + return self._cache.retrieve(expr, key=key) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 43d050c65..811c625b6 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -92,7 +92,9 @@ __doc__ = """ .. autoclass:: Mapper +.. autoclass:: CachedMapperCache .. autoclass:: CachedMapper +.. autoclass:: TransformMapperCache .. autoclass:: TransformMapper .. autoclass:: TransformMapperWithExtraArgs .. autoclass:: CopyMapper @@ -150,9 +152,27 @@ A type variable representing the result type of a :class:`Mapper` when mapping a :class:`pytato.function.FunctionDefinition`. +.. class:: CacheExprT + + A type variable representing an input from which to compute a cache key in order + to cache a result. + +.. class:: CacheKeyT + + A type variable representing a key computed from an input expression. + +.. class:: CacheResultT + + A type variable representing a result to be cached. + .. class:: Scalar See :data:`pymbolic.Scalar`. + +.. class:: P + + A :class:`typing.ParamSpec` used to annotate `*args` and `**kwargs`. + """ transform_logger = logging.getLogger(__file__) @@ -252,6 +272,77 @@ def __call__(self, # {{{ CachedMapper +CacheExprT = TypeVar("CacheExprT") +CacheResultT = TypeVar("CacheResultT") +CacheKeyT: TypeAlias = Hashable + + +class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): + """ + Cache for mappers. + + .. automethod:: __init__ + .. method:: get_key + + Compute the key for an input expression. + + .. automethod:: add + .. automethod:: retrieve + .. automethod:: clear + """ + def __init__( + self, + key_func: Callable[..., CacheKeyT]) -> None: + """ + Initialize the cache. + + :arg key_func: Function to compute a hashable cache key from an input + expression and any extra arguments. + """ + self.get_key = key_func + + self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + + def add( + self, + key_inputs: + CacheExprT + | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], + result: CacheResultT, + key: CacheKeyT | None = None) -> CacheResultT: + """Cache a mapping result.""" + if key is None: + if isinstance(key_inputs, tuple): + expr, key_args, key_kwargs = key_inputs + key = self.get_key(expr, *key_args, **key_kwargs) + else: + key = self.get_key(key_inputs) + + self._expr_key_to_result[key] = result + + return result + + def retrieve( + self, + key_inputs: + CacheExprT + | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], + key: CacheKeyT | None = None) -> CacheResultT: + """Retrieve the cached mapping result.""" + if key is None: + if isinstance(key_inputs, tuple): + expr, key_args, key_kwargs = key_inputs + key = self.get_key(expr, *key_args, **key_kwargs) + else: + key = self.get_key(key_inputs) + + return self._expr_key_to_result[key] + + def clear(self) -> None: + """Reset the cache.""" + self._expr_key_to_result = {} + + class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """Mapper class that maps each node in the DAG exactly once. This loses some information compared to :class:`Mapper` as a node is visited only from @@ -261,18 +352,23 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): .. automethod:: get_function_definition_cache_key .. automethod:: clone_for_callee """ - def __init__( self, - # Arrays are cached separately for each call stack frame, but - # functions are cached globally - _function_cache: dict[Hashable, FunctionResultT] | None = None + _cache: + CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, + _function_cache: + CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None ) -> None: super().__init__() - self._cache: dict[Hashable, ResultT] = {} - self._function_cache: dict[Hashable, FunctionResultT] = \ - _function_cache if _function_cache is not None else {} + self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( + _cache if _cache is not None + else CachedMapperCache(self.get_cache_key)) + + self._function_cache: CachedMapperCache[ + FunctionDefinition, FunctionResultT, P] = ( + _function_cache if _function_cache is not None + else CachedMapperCache(self.get_function_definition_cache_key)) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -285,24 +381,26 @@ def get_function_definition_cache_key( return (expr, *args, tuple(sorted(kwargs.items()))) def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: - key = self.get_cache_key(expr, *args, **kwargs) + key = self._cache.get_key(expr, *args, **kwargs) try: - return self._cache[key] + return self._cache.retrieve((expr, args, kwargs), key=key) except KeyError: - result = super().rec(expr, *args, **kwargs) - self._cache[key] = result - return result + return self._cache.add( + (expr, args, kwargs), + super().rec(expr, *args, **kwargs), + key=key) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: - key = self.get_function_definition_cache_key(expr, *args, **kwargs) + key = self._function_cache.get_key(expr, *args, **kwargs) try: - return self._function_cache[key] + return self._function_cache.retrieve((expr, args, kwargs), key=key) except KeyError: - result = super().rec_function_definition(expr, *args, **kwargs) - self._function_cache[key] = result - return result + return self._function_cache.add( + (expr, args, kwargs), + super().rec_function_definition(expr, *args, **kwargs), + key=key) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -310,6 +408,7 @@ def clone_for_callee( Called to clone *self* before starting traversal of a :class:`pytato.function.FunctionDefinition`. """ + # Functions are cached globally, but arrays aren't return type(self)(_function_cache=self._function_cache) # }}} @@ -317,16 +416,24 @@ def clone_for_callee( # {{{ TransformMapper +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): + pass + + class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """Base class for mappers that transform :class:`pytato.array.Array`\\ s into other :class:`pytato.array.Array`\\ s. Enables certain operations that can only be done if the mapping results are also - arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not - implement default mapper methods; for that, see :class:`CopyMapper`. - + arrays (e.g., computing a cache key from them). Does not implement default + mapper methods; for that, see :class:`CopyMapper`. """ - pass + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) # }}} @@ -343,7 +450,13 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. """ - pass + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames, P] | None = None, + _function_cache: + TransformMapperCache[FunctionDefinition, P] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) # }}} @@ -775,10 +888,10 @@ class CombineMapper(CachedMapper[ResultT, FunctionResultT, []]): .. automethod:: combine """ - def get_cache_key(self, expr: ArrayOrNames) -> Hashable: + def get_cache_key(self, expr: ArrayOrNames) -> CacheKeyT: return expr - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> Hashable: + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> CacheKeyT: return expr def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] @@ -1369,22 +1482,26 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _function_cache: dict[Hashable, FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_function_cache=_function_cache) + super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn def clone_for_callee( self, function: FunctionDefinition) -> Self: - return type(self)(self.map_fn, _function_cache=self._function_cache) + return type(self)( + self.map_fn, + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - if expr in self._cache: - return self._cache[expr] - - result = super().rec(self.map_fn(expr)) - self._cache[expr] = result - return result + key = self._cache.get_key(expr) + try: + return self._cache.retrieve(expr, key=key) + except KeyError: + return self._cache.add( + expr, super().rec(self.map_fn(expr)), key=key) # }}} @@ -1894,7 +2011,7 @@ def rec_get_user_nodes(expr: ArrayOrNames, # {{{ deduplicate_data_wrappers -def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: +def _get_data_dedup_cache_key(ary: DataInterface) -> CacheKeyT: import sys if "pyopencl" in sys.modules: from pyopencl import MemoryObjectHolder @@ -1953,7 +2070,7 @@ def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: job of deduplication. """ - data_wrapper_cache: dict[Hashable, DataWrapper] = {} + data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} data_wrappers_encountered = 0 def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 635a478d8..8fdc5b4ea 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -82,7 +82,7 @@ index_lambda_to_high_level_op, ) from pytato.scalar_expr import SCALAR_CLASSES -from pytato.transform import ArrayOrNames, CopyMapper, Mapper +from pytato.transform import ArrayOrNames, CopyMapper, Mapper, TransformMapperCache from pytato.utils import are_shape_components_equal, are_shapes_equal @@ -90,7 +90,7 @@ if TYPE_CHECKING: - from collections.abc import Collection, Hashable, Mapping + from collections.abc import Collection, Mapping from pytato.function import FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall @@ -596,8 +596,10 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], tag_corresponding_redn_descr: bool, - _function_cache: dict[Hashable, FunctionDefinition] | None = None): - super().__init__(_function_cache=_function_cache) + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: + TransformMapperCache[FunctionDefinition, []] | None = None): + super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr @@ -644,9 +646,9 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: return result def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self.get_cache_key(expr) + key = self._cache.get_key(expr) try: - return self._cache[key] + return self._cache.retrieve(expr, key=key) except KeyError: result = Mapper.rec(self, expr) if not isinstance( @@ -654,8 +656,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - self._cache[key] = result - return result + return self._cache.add(expr, result, key=key) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( diff --git a/test/test_apps.py b/test/test_apps.py index fe1ba18bb..f39be848c 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -94,7 +94,7 @@ def __init__(self, fft_vec_gatherer): arrays = fft_vec_gatherer.level_to_arrays[lev] rec_arrays = [self.rec(ary) for ary in arrays] # reset cache so that the partial subs are not stored - self._cache = {} + self._cache.clear() lev_array = pt.concatenate(rec_arrays, axis=0) assert lev_array.shape == (fft_vec_gatherer.n,) From 56b162ad03781504b8fa5d81e618492508d386fa Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 27 Jan 2025 15:14:39 -0600 Subject: [PATCH 4/5] drop P from mapper caches, since it's not being used --- pytato/codegen.py | 4 ++-- pytato/distributed/partition.py | 6 +++--- pytato/transform/__init__.py | 26 +++++++++++++------------- pytato/transform/metadata.py | 4 ++-- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pytato/codegen.py b/pytato/codegen.py index cb957f076..86a328929 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -138,8 +138,8 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _cache: TransformMapperCache[ArrayOrNames, []] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 2ab1163a1..27b1e2cee 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -240,9 +240,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _cache: TransformMapperCache[ArrayOrNames] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition, []] | None = None, + TransformMapperCache[FunctionDefinition] | None = None, ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -261,7 +261,7 @@ def clone_for_callee( return type(self)( {}, {}, {}, _function_cache=cast( - "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + "TransformMapperCache[FunctionDefinition]", self._function_cache)) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 811c625b6..1cb5dbb74 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -277,7 +277,7 @@ def __call__(self, CacheKeyT: TypeAlias = Hashable -class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): +class CachedMapperCache(Generic[CacheExprT, CacheResultT]): """ Cache for mappers. @@ -355,18 +355,18 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): def __init__( self, _cache: - CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, + CachedMapperCache[ArrayOrNames, ResultT] | None = None, _function_cache: - CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None + CachedMapperCache[FunctionDefinition, FunctionResultT] | None = None ) -> None: super().__init__() - self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( + self._cache: CachedMapperCache[ArrayOrNames, ResultT] = ( _cache if _cache is not None else CachedMapperCache(self.get_cache_key)) self._function_cache: CachedMapperCache[ - FunctionDefinition, FunctionResultT, P] = ( + FunctionDefinition, FunctionResultT] = ( _function_cache if _function_cache is not None else CachedMapperCache(self.get_function_definition_cache_key)) @@ -416,7 +416,7 @@ def clone_for_callee( # {{{ TransformMapper -class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT]): pass @@ -430,8 +430,8 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames, []] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -452,9 +452,9 @@ class TransformMapperWithExtraArgs( """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames, P] | None = None, + _cache: TransformMapperCache[ArrayOrNames] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition, P] | None = None + TransformMapperCache[FunctionDefinition] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -1482,8 +1482,8 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _cache: TransformMapperCache[ArrayOrNames, []] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn @@ -1493,7 +1493,7 @@ def clone_for_callee( return type(self)( self.map_fn, _function_cache=cast( - "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + "TransformMapperCache[FunctionDefinition]", self._function_cache)) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: key = self._cache.get_key(expr) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 8fdc5b4ea..e96cd1ee4 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -596,9 +596,9 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], tag_corresponding_redn_descr: bool, - _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _cache: TransformMapperCache[ArrayOrNames] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition, []] | None = None): + TransformMapperCache[FunctionDefinition] | None = None): super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr From 0bfb2b2254d2affbeb48851212db88d82f9bc87e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 27 Jan 2025 18:23:07 -0600 Subject: [PATCH 5/5] add comment about type annotations in CachedMapperCache --- pytato/transform/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 1cb5dbb74..dbab9148c 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -307,6 +307,10 @@ def add( self, key_inputs: CacheExprT + # Currently, Python's type system doesn't have a way to annotate + # containers of args/kwargs (ParamSpec won't work here). So we have + # to fall back to using Any. More details here: + # https://github.com/python/typing/issues/1252 | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], result: CacheResultT, key: CacheKeyT | None = None) -> CacheResultT: