Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a class for CachedMapper-derived mappers instead of a dict #549

Merged
merged 5 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@
CachedWalkMapper,
CopyMapper,
SubsetDependencyMapper,
TransformMapperCache,
)
from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin


if TYPE_CHECKING:
from collections.abc import Hashable, Mapping
from collections.abc import Mapping

from pytato.function import FunctionDefinition, NamedCallResult
from pytato.target import Target
Expand Down Expand Up @@ -137,9 +138,10 @@ def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_function_cache: dict[Hashable, FunctionDefinition] | None = None
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition] | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
Expand Down
16 changes: 11 additions & 5 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
CachedWalkMapper,
CombineMapper,
CopyMapper,
TransformMapperCache,
_verify_is_array,
)

Expand Down Expand Up @@ -239,9 +240,11 @@ def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache:
TransformMapperCache[FunctionDefinition] | None = None,
) -> None:
super().__init__(_function_cache=_function_cache)
super().__init__(_cache=_cache, _function_cache=_function_cache)

self.recvd_ary_to_name = recvd_ary_to_name
self.sptpo_ary_to_name = sptpo_ary_to_name
Expand All @@ -255,7 +258,10 @@ def clone_for_callee(
self, function: FunctionDefinition) -> _DistributedInputReplacer:
# Function definitions aren't allowed to contain receives,
# stored arrays promoted to part outputs, or part outputs
return type(self)({}, {}, {}, _function_cache=self._function_cache)
return type(self)(
{}, {}, {},
_function_cache=cast(
"TransformMapperCache[FunctionDefinition]", self._function_cache))

def map_placeholder(self, expr: Placeholder) -> Placeholder:
self.user_input_names.add(expr.name)
Expand Down Expand Up @@ -288,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:
return new_send

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
key = self.get_cache_key(expr)
key = self._cache.get_key(expr)
try:
return self._cache[key]
return self._cache.retrieve(expr, key=key)
except KeyError:
pass

Expand Down
Loading
Loading