Skip to content

Commit be021fe

Browse files
committed
add _verify_is_dag to fix pyright errors about assigning ArrayOrNames to DictOfNamedArrays
1 parent 2bca755 commit be021fe

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,13 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
745745
def transform_dag(self, dag: pytato.DictOfNamedArrays
746746
) -> pytato.DictOfNamedArrays:
747747
import pytato as pt
748-
dag = pt.tag_all_calls_to_be_inlined(dag)
749-
dag = pt.inline_calls(dag)
750-
dag = pt.transform.materialize_with_mpms(dag)
748+
749+
# FIXME: Having to use _verify_is_dag seems clunky, but I'm not sure how to
750+
# avoid it
751+
from .utils import _verify_is_dag
752+
dag = _verify_is_dag(pt.tag_all_calls_to_be_inlined(dag))
753+
dag = _verify_is_dag(pt.inline_calls(dag))
754+
dag = _verify_is_dag(pt.transform.materialize_with_mpms(dag))
751755
return dag
752756

753757
def einsum(self, spec, *args, arg_names=None, tagged=()):

arraycontext/impl/pytato/outline.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ class OutlinedCall:
193193
def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
194194
arg_id_to_arg = _get_arg_id_to_arg(args, kwargs)
195195

196+
from .utils import _verify_is_dag
197+
196198
if __debug__:
197199
# Add a prefix to the names to distinguish them from any existing
198200
# placeholders
@@ -201,9 +203,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
201203

202204
prefixed_output = _call_with_placeholders(
203205
self.f, args, kwargs, arg_id_to_prefixed_placeholder)
204-
unpacked_prefixed_output = pt.transform.Deduplicator()(
205-
pt.make_dict_of_named_arrays(
206-
_unpack_output(prefixed_output)))
206+
unpacked_prefixed_output = _verify_is_dag(
207+
pt.transform.Deduplicator()(
208+
pt.make_dict_of_named_arrays(
209+
_unpack_output(prefixed_output))))
207210

208211
prefixed_placeholders = frozenset(
209212
arg_id_to_prefixed_placeholder.values())
@@ -220,9 +223,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
220223
arg_id_to_placeholder = _get_arg_id_to_placeholder(arg_id_to_arg)
221224

222225
output = _call_with_placeholders(self.f, args, kwargs, arg_id_to_placeholder)
223-
unpacked_output = pt.transform.Deduplicator()(
224-
pt.make_dict_of_named_arrays(
225-
_unpack_output(output)))
226+
unpacked_output = _verify_is_dag(
227+
pt.transform.Deduplicator()(
228+
pt.make_dict_of_named_arrays(
229+
_unpack_output(output))))
226230
if len(unpacked_output) == 1 and "_" in unpacked_output:
227231
ret_type = pt.function.ReturnType.ARRAY
228232
else:

arraycontext/impl/pytato/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@
7272
import loopy as lp
7373

7474

75+
def _verify_is_dag(dag: ArrayOrNames) -> DictOfNamedArrays:
76+
assert isinstance(dag, DictOfNamedArrays)
77+
return dag
78+
79+
7580
class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
7681
"""
7782
Helper mapper for :func:`normalize_pt_expr`. Every
@@ -141,7 +146,7 @@ def _normalize_pt_expr(
141146
Deterministic naming of placeholders permits more effective caching of
142147
equivalent graphs.
143148
"""
144-
expr = Deduplicator()(expr)
149+
expr = _verify_is_dag(Deduplicator()(expr))
145150

146151
if get_num_call_sites(expr):
147152
raise NotImplementedError(

0 commit comments

Comments
 (0)