Skip to content

Commit

Permalink
allow passing FunctionDefinition directly to Mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jan 31, 2025
1 parent 26b2223 commit 5af89bc
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
35 changes: 31 additions & 4 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TypeAlias,
TypeVar,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -262,10 +263,36 @@ def rec_function_definition(
assert method is not None
return method(expr, *args, **kwargs)

def __call__(self,
expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
@overload
def __call__(
self,
expr: ArrayOrNames,
*args: P.args,
**kwargs: P.kwargs) -> ResultT:
...

@overload
def __call__(
self,
expr: FunctionDefinition,
*args: P.args,
**kwargs: P.kwargs) -> FunctionResultT:
...

def __call__(
self,
expr: ArrayOrNames | FunctionDefinition,
*args: P.args,
**kwargs: P.kwargs) -> ResultT | FunctionResultT:
"""Handle the mapping of *expr*."""
return self.rec(expr, *args, **kwargs)
if isinstance(expr, ArrayOrNames):
return self.rec(expr, *args, **kwargs)
elif isinstance(expr, FunctionDefinition):
return self.rec_function_definition(expr, *args, **kwargs)
else:
raise ForeignObjectError(
f"{type(self).__name__} encountered invalid foreign "
f"object: {expr!r}") from None

# }}}

Expand Down Expand Up @@ -1839,7 +1866,7 @@ def __init__(self) -> None:
self.node_to_users: dict[ArrayOrNames,
set[DistributedSend | ArrayOrNames]] = {}

def __call__(self, expr: ArrayOrNames) -> None:
def __call__(self, expr: ArrayOrNames) -> None: # type: ignore[override]
# Root node has no predecessor
self.node_to_users[expr] = set()
self.rec(expr)
Expand Down
3 changes: 2 additions & 1 deletion pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,8 @@ def handle_unsupported_array(self, expr: Array) -> Array:
def rec(self, expr: Array) -> Array: # type: ignore[override]
return expr

__call__ = Mapper.rec
def __call__(self, expr: Array) -> Array: # type: ignore[override]
return self.rec(expr)


def to_index_lambda(expr: Array) -> IndexLambda:
Expand Down

0 comments on commit 5af89bc

Please sign in to comment.