Skip to content

Commit

Permalink
use a class for CachedMapper caches instead of using a dict directly
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jan 21, 2025
1 parent 453ef7b commit ae86606
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 64 deletions.
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,6 @@
["py:class", r"P\.kwargs"],
["py:class", r"lp\.LoopKernel"],
["py:class", r"_dtype_any"],
["py:class", r"(.+)\._CacheT"],
["py:class", r"(.+)\._FunctionCacheT"],
]
4 changes: 2 additions & 2 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@


if TYPE_CHECKING:
from collections.abc import Hashable, Mapping
from collections.abc import Mapping

from pytato.function import FunctionDefinition, NamedCallResult
from pytato.target import Target
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_function_cache: dict[Hashable, FunctionDefinition] | None = None
_function_cache: CodeGenPreprocessor._FunctionCacheT | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
Expand Down
8 changes: 4 additions & 4 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,12 @@ class _DistributedInputReplacer(CopyMapper):
instances for their assigned names. Also gathers names for
user-supplied inputs needed by the part
"""

def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
_function_cache:
_DistributedInputReplacer._FunctionCacheT | None = None,
) -> None:
super().__init__(_function_cache=_function_cache)

Expand Down Expand Up @@ -344,9 +344,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:
return new_send

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
key = self.get_cache_key(expr)
key = self._cache.get_key(expr)
try:
return self._cache[key]
return self._cache.retrieve(expr, key=key)
except KeyError:
pass

Expand Down
225 changes: 175 additions & 50 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@

__doc__ = """
.. autoclass:: Mapper
.. autoclass:: CachedMapperCache
.. autoclass:: CachedMapper
.. autoclass:: TransformMapper
.. autoclass:: TransformMapperWithExtraArgs
Expand Down Expand Up @@ -150,9 +151,41 @@
A type variable representing the result type of a :class:`Mapper` when mapping
a :class:`pytato.function.FunctionDefinition`.
.. class:: CacheExprT
A type variable representing an input from which to compute a cache key in order
to cache a result.
.. class:: CacheKeyT
A type variable representing a key computed from an input expression.
.. class:: CacheResultT
A type variable representing a result to be cached.
.. class:: Scalar
See :data:`pymbolic.Scalar`.
.. class:: P
A :class:`typing.ParamSpec` used to annotate `*args` and `**kwargs`.
.. class:: _OtherResultT
Duplicate of :class:`pytato.transform.ResultT`, used for defining class-local
type aliases.
.. class:: _OtherFunctionResultT
Duplicate of :class:`pytato.transform.FunctionResultT`, used for defining
class-local type aliases.
.. class:: _OtherP
Duplicate of :class:`P`, used for defining class-local type aliases.
"""

transform_logger = logging.getLogger(__file__)
Expand All @@ -172,6 +205,12 @@ class ForeignObjectError(ValueError):
FunctionResultT = TypeVar("FunctionResultT")
P = ParamSpec("P")

# Duplicates of type variables, mainly used for defining aliases of parameterized
# types inside mapper classes
_OtherResultT = TypeVar("_OtherResultT")
_OtherFunctionResultT = TypeVar("_OtherFunctionResultT")
_OtherP = ParamSpec("_OtherP")


def _verify_is_array(expr: ArrayOrNames) -> Array:
assert isinstance(expr, Array)
Expand Down Expand Up @@ -252,57 +291,130 @@ def __call__(self,

# {{{ CachedMapper

CacheExprT = TypeVar("CacheExprT")
CacheKeyT = TypeVar("CacheKeyT")
CacheResultT = TypeVar("CacheResultT")


class CachedMapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT, P]):
"""
Cache for :class:`CachedMapper`.
.. automethod:: __init__
.. automethod:: get_key
.. automethod:: add
.. automethod:: retrieve
"""
def __init__(
self,
# FIXME: Figure out the right way to type annotate this
key_func: Callable[..., CacheKeyT]) -> None:
"""
Initialize the cache.
:arg key_func: Function to compute a hashable cache key from an input
expression and any extra arguments.
"""
self._key_func = key_func
self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {}

# FIXME: Can this be inlined?
def get_key(
self, expr: CacheExprT, *args: P.args, **kwargs: P.kwargs) -> CacheKeyT:
"""Compute the key for an input expression."""
return self._key_func(expr, *args, **kwargs)

def add(
self,
key_inputs:
CacheExprT
# FIXME: Figure out the right way to type annotate these
| tuple[CacheExprT, tuple[Any, ...], dict[str, Any]],
result: CacheResultT,
key: CacheKeyT | None = None) -> CacheResultT:
"""Cache a mapping result."""
if key is None:
if isinstance(key_inputs, tuple):
expr, key_args, key_kwargs = key_inputs
key = self._key_func(expr, *key_args, **key_kwargs)
else:
key = self._key_func(key_inputs)

self._expr_key_to_result[key] = result

return result

def retrieve(
self,
key_inputs:
CacheExprT
# FIXME: Figure out the right way to type annotate these
| tuple[CacheExprT, tuple[Any, ...], dict[str, Any]],
key: CacheKeyT | None = None) -> CacheResultT:
"""Retrieve the cached mapping result."""
if key is None:
if isinstance(key_inputs, tuple):
expr, key_args, key_kwargs = key_inputs
key = self._key_func(expr, *key_args, **key_kwargs)
else:
key = self._key_func(key_inputs)

return self._expr_key_to_result[key]


class CachedMapper(Mapper[ResultT, FunctionResultT, P]):
"""Mapper class that maps each node in the DAG exactly once. This loses some
information compared to :class:`Mapper` as a node is visited only from
one of its predecessors.
.. automethod:: get_cache_key
.. automethod:: get_function_definition_cache_key
.. automethod:: clone_for_callee
"""
_CacheT: TypeAlias = CachedMapperCache[
ArrayOrNames, Hashable, _OtherResultT, _OtherP]
_FunctionCacheT: TypeAlias = CachedMapperCache[
FunctionDefinition, Hashable, _OtherFunctionResultT, _OtherP]

def __init__(
self,
# Arrays are cached separately for each call stack frame, but
# functions are cached globally
_function_cache: dict[Hashable, FunctionResultT] | None = None
_function_cache:
CachedMapper._FunctionCacheT[FunctionResultT, P] | None = None
) -> None:
super().__init__()
self._cache: dict[Hashable, ResultT] = {}

self._function_cache: dict[Hashable, FunctionResultT] = \
_function_cache if _function_cache is not None else {}
def key_func(
expr: ArrayOrNames | FunctionDefinition,
*args: P.args, **kwargs: P.kwargs) -> Hashable:
return (expr, args, tuple(sorted(kwargs.items())))

def get_cache_key(
self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs
) -> Hashable:
return (expr, *args, tuple(sorted(kwargs.items())))
self._cache: CachedMapper._CacheT[ResultT, P] = CachedMapperCache(key_func)

def get_function_definition_cache_key(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> Hashable:
return (expr, *args, tuple(sorted(kwargs.items())))
self._function_cache: CachedMapper._FunctionCacheT[FunctionResultT, P] = (
_function_cache if _function_cache is not None
else CachedMapperCache(key_func))

def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
key = self.get_cache_key(expr, *args, **kwargs)
key = self._cache.get_key(expr, *args, **kwargs)
try:
return self._cache[key]
return self._cache.retrieve((expr, args, kwargs), key=key)
except KeyError:
result = super().rec(expr, *args, **kwargs)
self._cache[key] = result
return result
return self._cache.add(
(expr, args, kwargs),
super().rec(expr, *args, **kwargs),
key=key)

def rec_function_definition(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> FunctionResultT:
key = self.get_function_definition_cache_key(expr, *args, **kwargs)
key = self._function_cache.get_key(expr, *args, **kwargs)
try:
return self._function_cache[key]
return self._function_cache.retrieve((expr, args, kwargs), key=key)
except KeyError:
result = super().rec_function_definition(expr, *args, **kwargs)
self._function_cache[key] = result
return result
return self._function_cache.add(
(expr, args, kwargs),
super().rec_function_definition(expr, *args, **kwargs),
key=key)

def clone_for_callee(
self, function: FunctionDefinition) -> Self:
Expand All @@ -322,11 +434,11 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]):
other :class:`pytato.array.Array`\\ s.
Enables certain operations that can only be done if the mapping results are also
arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not
implement default mapper methods; for that, see :class:`CopyMapper`.
arrays (e.g., computing a cache key from them). Does not implement default
mapper methods; for that, see :class:`CopyMapper`.
"""
pass
_CacheT: TypeAlias = CachedMapper._CacheT[ArrayOrNames, []]
_FunctionCacheT: TypeAlias = CachedMapper._FunctionCacheT[FunctionDefinition, []]

# }}}

Expand All @@ -343,7 +455,9 @@ class TransformMapperWithExtraArgs(
The logic in :class:`TransformMapper` purposely does not take the extra
arguments to keep the cost of its each call frame low.
"""
pass
_CacheT: TypeAlias = CachedMapper._CacheT[ArrayOrNames, _OtherP]
_FunctionCacheT: TypeAlias = CachedMapper._FunctionCacheT[
FunctionDefinition, _OtherP]

# }}}

Expand Down Expand Up @@ -768,40 +882,51 @@ def map_named_call_result(self, expr: NamedCallResult,

# {{{ CombineMapper

# FIXME: Can this just inherit from CachedMapper?
class CombineMapper(Mapper[ResultT, FunctionResultT, []]):
"""
Abstract mapper that recursively combines the results of user nodes
of a given expression.
.. automethod:: combine
"""
_CacheT: TypeAlias = CachedMapperCache[ArrayOrNames, Hashable, _OtherResultT, []]
_FunctionCacheT: TypeAlias = CachedMapperCache[
FunctionDefinition, Hashable, _OtherFunctionResultT, []]

def __init__(
self,
_function_cache: dict[FunctionDefinition, FunctionResultT] | None = None
_function_cache:
CombineMapper._FunctionCacheT[FunctionResultT] | None = None
) -> None:
super().__init__()
self.cache: dict[ArrayOrNames, ResultT] = {}
self.function_cache: dict[FunctionDefinition, FunctionResultT] = \
_function_cache if _function_cache is not None else {}

self.cache: CombineMapper._CacheT[ResultT] = CachedMapperCache(
lambda expr: expr)

self.function_cache: CombineMapper._FunctionCacheT[FunctionResultT] = (
_function_cache if _function_cache is not None
else CachedMapperCache(lambda expr: expr))

def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...]
) -> tuple[ResultT, ...]:
return tuple(self.rec(s) for s in situp if isinstance(s, Array))

def rec(self, expr: ArrayOrNames) -> ResultT:
if expr in self.cache:
return self.cache[expr]
result: ResultT = super().rec(expr)
self.cache[expr] = result
return result
key = self.cache.get_key(expr)
try:
return self.cache.retrieve(expr, key=key)
except KeyError:
return self.cache.add(expr, super().rec(expr), key=key)

def rec_function_definition(
self, expr: FunctionDefinition) -> FunctionResultT:
if expr in self.function_cache:
return self.function_cache[expr]
result: FunctionResultT = super().rec_function_definition(expr)
self.function_cache[expr] = result
return result
key = self.function_cache.get_key(expr)
try:
return self.function_cache.retrieve(expr, key=key)
except KeyError:
return self.function_cache.add(
expr, super().rec_function_definition(expr), key=key)

def __call__(self, expr: ArrayOrNames) -> ResultT:
return self.rec(expr)
Expand Down Expand Up @@ -1390,7 +1515,7 @@ class CachedMapAndCopyMapper(CopyMapper):
def __init__(
self,
map_fn: Callable[[ArrayOrNames], ArrayOrNames],
_function_cache: dict[Hashable, FunctionDefinition] | None = None
_function_cache: CachedMapAndCopyMapper._FunctionCacheT | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn
Expand All @@ -1400,12 +1525,12 @@ def clone_for_callee(
return type(self)(self.map_fn, _function_cache=self._function_cache)

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
if expr in self._cache:
return self._cache[expr]

result = super().rec(self.map_fn(expr))
self._cache[expr] = result
return result
key = self._cache.get_key(expr)
try:
return self._cache.retrieve(expr, key=key)
except KeyError:
return self._cache.add(
expr, super().rec(self.map_fn(expr)), key=key)

# }}}

Expand Down
Loading

0 comments on commit ae86606

Please sign in to comment.