Skip to content

Add a pass to outline computations in a function #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b7e2d63
move _ary_container_key_stringifier to utils.py
kaushikcfd Mar 14, 2023
0f24581
Add outlining pass to array expression
kaushikcfd Mar 14, 2023
3fc1859
adds an outlining example
kaushikcfd Mar 14, 2023
1fc80f4
cosmetic-ish refactor of OutlinedCall.__call__
majosm Jun 14, 2024
c56d019
check for non-argument placeholders in outlined function
majosm Jun 14, 2024
59c93c6
drop unused function arguments
majosm Mar 8, 2024
e7d4a36
pass hashable instead of string as id to outline
majosm Jan 17, 2025
7e66296
handle optional arguments that are passed as None explicitly
majosm Mar 14, 2024
04e7175
don't tag NamedArray (they inherit tags from their corresponding _con…
majosm Jun 13, 2024
b1b096b
change Map -> immutabledict in outlining
majosm Jun 6, 2024
e3f3749
forbid calling _normalize_pt_expr on a DAG with function calls for now
majosm Sep 10, 2024
ea7d1a4
inline calls in freeze before _normalize_pt_expr
majosm Oct 24, 2024
31dbefb
disable concatenation in outlining example for now
majosm Jan 17, 2025
bc7902e
remove duplicates in _normalize_pt_expr
majosm Jun 13, 2024
8e8164e
remove duplicates in _to_input_for_compiled
majosm Jun 13, 2024
d1ac9c1
remove duplicates when creating FunctionDefinition
majosm Jun 26, 2024
a6bd5ff
deduplicate in freeze
majosm Nov 14, 2024
df254fb
update _DatawrapperToBoundPlaceholderMapper with CopyMapper changes
majosm Sep 5, 2024
a6d7eda
add function inlining step to transform_dag
majosm May 16, 2025
3e6886d
add test for outline
majosm May 16, 2025
f121329
add transform_dag implementation to pytato JAX array context in order…
majosm Jun 2, 2025
6bf6ad9
fix some pyright errors
majosm Jun 3, 2025
636639d
Towards typing of outlining
inducer Jun 4, 2025
74675c7
more pyright fixes
majosm Jun 4, 2025
befb2d3
temporarily change pytato branch
majosm Jun 4, 2025
62e3586
Better transform typing
inducer Jun 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 0 additions & 104 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -11653,14 +11653,6 @@
"lineCount": 1
}
},
{
"code": "reportPrivateUsage",
"range": {
"startColumn": 53,
"endColumn": 83,
"lineCount": 1
}
},
{
"code": "reportPrivateUsage",
"range": {
Expand Down Expand Up @@ -11765,30 +11757,6 @@
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 21,
"endColumn": 60,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 49,
"endColumn": 67,
"lineCount": 2
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 53,
"endColumn": 68,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down Expand Up @@ -11829,14 +11797,6 @@
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 16,
"endColumn": 58,
"lineCount": 1
}
},
{
"code": "reportPrivateUsage",
"range": {
Expand Down Expand Up @@ -11885,22 +11845,6 @@
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 12,
"endColumn": 51,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 45,
"endColumn": 63,
"lineCount": 2
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down Expand Up @@ -13045,14 +12989,6 @@
"lineCount": 1
}
},
{
"code": "reportPrivateUsage",
"range": {
"startColumn": 53,
"endColumn": 83,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
Expand Down Expand Up @@ -17283,14 +17219,6 @@
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportPrivateUsage",
"range": {
Expand All @@ -17299,22 +17227,6 @@
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 22,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
Expand All @@ -17331,14 +17243,6 @@
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
Expand All @@ -17347,14 +17251,6 @@
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportPrivateUsage",
"range": {
Expand Down
22 changes: 21 additions & 1 deletion arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
"""

from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from collections.abc import Callable, Hashable, Mapping
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload
from warnings import warn

Expand Down Expand Up @@ -334,6 +334,7 @@ class ArrayContext(ABC):
.. automethod:: tag
.. automethod:: tag_axis
.. automethod:: compile
.. automethod:: outline
"""

array_types: tuple[type, ...] = ()
Expand Down Expand Up @@ -583,6 +584,25 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
"""
return f

def outline(self,
f: Callable[..., Any],
*,
id: Hashable | None = None) -> Callable[..., Any]: # pyright: ignore[reportUnusedParameter]
"""
Returns a drop-in-replacement for *f*. The behavior of the returned
callable is specific to the derived class.

The reason for the existence of such a routine is mainly for
arraycontexts that allow a lazy mode of execution. In such
arraycontexts, the computations within *f* maybe staged to potentially
enable additional compiler transformations. See
:func:`pytato.trace_call` or :func:`jax.named_call` for examples.

:arg f: the function executing the computation to be staged.
:return: a function with the same signature as *f*.
"""
return f

# undocumented for now
@property
@abstractmethod
Expand Down
66 changes: 53 additions & 13 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@

import abc
import sys
from collections.abc import Callable
from collections.abc import Callable, Hashable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import numpy as np
from typing_extensions import override

from pytools import memoize_method
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags
Expand Down Expand Up @@ -154,7 +155,7 @@
self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
self._dag_transform_cache: dict[
pt.DictOfNamedArrays,
tuple[pt.DictOfNamedArrays, str]] = {}
tuple[pt.AbstractResultWithNamedArrays, str]] = {}

if compile_trace_callback is None:
def _compile_trace_callback(what, stage, ir):
Expand All @@ -176,8 +177,8 @@

# {{{ compilation

def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
Expand Down Expand Up @@ -232,6 +233,22 @@

# }}}

@override
def outline(self,
f: Callable[..., Any],
*,
id: Hashable | None = None,
tags: frozenset[Tag] = frozenset() # pyright: ignore[reportCallInDefaultInitializer]
) -> Callable[..., Any]:
from pytato.tags import FunctionIdentifier

from .outline import OutlinedCall
id = id or getattr(f, "__name__", None)
if id is not None:
tags = tags | {FunctionIdentifier(id)}

return OutlinedCall(self, f, tags)

# }}}


Expand Down Expand Up @@ -514,8 +531,8 @@
TaggableCLArray,
to_tagged_cl_array,
)
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
from arraycontext.impl.pytato.utils import (
_ary_container_key_stringifier,
_normalize_pt_expr,
get_cl_axes_from_pt_axes,
)
Expand Down Expand Up @@ -576,10 +593,19 @@
rec_keyed_map_array_container(_to_frozen, array),
actx=None)

pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
dag = pt.make_dict_of_named_arrays(
key_to_pt_arrays)

from pytato.transform import Deduplicator
dag = Deduplicator()(dag)

# FIXME: Remove this if/when _normalize_pt_expr gets support for functions
dag = pt.tag_all_calls_to_be_inlined(
dag)
dag = pt.inline_calls(dag)

normalized_expr, bound_arguments = _normalize_pt_expr(
pt_dict_of_named_arrays)
dag)

Check failure on line 608 in arraycontext/impl/pytato/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument of type "Array | AbstractResultWithNamedArrays" cannot be assigned to parameter "expr" of type "AbstractResultWithNamedArrays" in function "_normalize_pt_expr"   Type "Array | AbstractResultWithNamedArrays" is not assignable to type "AbstractResultWithNamedArrays"     "Array" is not assignable to "AbstractResultWithNamedArrays" (reportArgumentType)

try:
pt_prg = self._freeze_prg_cache[normalized_expr]
Expand Down Expand Up @@ -608,7 +634,7 @@
opts = _DEFAULT_LOOPY_OPTIONS
assert opts.return_dict

pt_prg = pt.generate_loopy(transformed_dag,

Check failure on line 637 in arraycontext/impl/pytato/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument of type "AbstractResultWithNamedArrays" cannot be assigned to parameter "result" of type "Array | DictOfNamedArrays | dict[str, Array]" in function "generate_loopy"   Type "AbstractResultWithNamedArrays" is not assignable to type "Array | DictOfNamedArrays | dict[str, Array]"     "AbstractResultWithNamedArrays" is not assignable to "Array"     "AbstractResultWithNamedArrays" is not assignable to "DictOfNamedArrays"     "AbstractResultWithNamedArrays" is not assignable to "dict[str, Array]" (reportArgumentType)
options=opts,
function_name=function_name,
target=self.get_target()
Expand Down Expand Up @@ -731,11 +757,13 @@
from .compile import LazilyPyOpenCLCompilingFunctionCaller
return LazilyPyOpenCLCompilingFunctionCaller(self, f)

def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
import pytato as pt
dag = pt.transform.materialize_with_mpms(dag)
return dag
tdag = pt.tag_all_calls_to_be_inlined(dag)
tdag = pt.inline_calls(tdag)
tdag = pt.transform.materialize_with_mpms(tdag)

Check failure on line 765 in arraycontext/impl/pytato/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument of type "AbstractResultWithNamedArrays" cannot be assigned to parameter "expr" of type "DictOfNamedArrays" in function "materialize_with_mpms"   "AbstractResultWithNamedArrays" is not assignable to "DictOfNamedArrays" (reportArgumentType)

Check failure on line 765 in arraycontext/impl/pytato/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

No overloads for "materialize_with_mpms" match the provided arguments (reportCallIssue)

Check warning on line 765 in arraycontext/impl/pytato/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "tdag" is unknown (reportUnknownVariableType)
return tdag

Check warning on line 766 in arraycontext/impl/pytato/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Return type is unknown (reportUnknownVariableType)

def einsum(self, spec, *args, arg_names=None, tagged=()):
import pytato as pt
Expand Down Expand Up @@ -769,7 +797,7 @@
# multiple placeholders with the same name that are not
# also the same object are not allowed, and this would produce
# a different Placeholder object of the same name.
if (not isinstance(ary, pt.Placeholder)
if (not isinstance(ary, pt.Placeholder | pt.NamedArray)
and not ary.tags_of_type(NameHint)):
ary = ary.tagged(NameHint(name))

Expand All @@ -795,6 +823,9 @@
An arraycontext that uses :mod:`pytato` to represent the thawed state of
the arrays and compiles the expressions using
:class:`pytato.target.python.JAXPythonTarget`.


.. automethod:: transform_dag
"""

def __init__(self,
Expand Down Expand Up @@ -874,7 +905,7 @@
import pytato as pt

from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
from arraycontext.impl.pytato.utils import _ary_container_key_stringifier

array_as_dict: dict[str, jnp.ndarray | pt.Array] = {}
key_to_frozen_subary: dict[str, jnp.ndarray] = {}
Expand Down Expand Up @@ -946,6 +977,15 @@
from .compile import LazilyJAXCompilingFunctionCaller
return LazilyJAXCompilingFunctionCaller(self, f)

@override
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
import pytato as pt

dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)
return dag

def tag(self, tags: ToTagSetConvertible, array):
def _tag(ary):
import jax.numpy as jnp
Expand Down
Loading
Loading