diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 192e5270..d0194b57 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -11653,14 +11653,6 @@ "lineCount": 1 } }, - { - "code": "reportPrivateUsage", - "range": { - "startColumn": 53, - "endColumn": 83, - "lineCount": 1 - } - }, { "code": "reportPrivateUsage", "range": { @@ -11765,30 +11757,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 21, - "endColumn": 60, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 49, - "endColumn": 67, - "lineCount": 2 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 53, - "endColumn": 68, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -11829,14 +11797,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 16, - "endColumn": 58, - "lineCount": 1 - } - }, { "code": "reportPrivateUsage", "range": { @@ -11885,22 +11845,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 12, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 45, - "endColumn": 63, - "lineCount": 2 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -13045,14 +12989,6 @@ "lineCount": 1 } }, - { - "code": "reportPrivateUsage", - "range": { - "startColumn": 53, - "endColumn": 83, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -17283,14 +17219,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, { "code": "reportPrivateUsage", "range": { @@ -17299,22 +17227,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -17331,14 +17243,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -17347,14 +17251,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, { "code": "reportPrivateUsage", "range": { diff --git a/arraycontext/context.py b/arraycontext/context.py index f751413c..9fbe7ab1 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -167,7 +167,7 @@ """ from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping +from collections.abc import Callable, Hashable, Mapping from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload from warnings import warn @@ -334,6 +334,7 @@ class ArrayContext(ABC): .. automethod:: tag .. automethod:: tag_axis .. automethod:: compile + .. automethod:: outline """ array_types: tuple[type, ...] = () @@ -583,6 +584,25 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: """ return f + def outline(self, + f: Callable[..., Any], + *, + id: Hashable | None = None) -> Callable[..., Any]: # pyright: ignore[reportUnusedParameter] + """ + Returns a drop-in-replacement for *f*. The behavior of the returned + callable is specific to the derived class. + + The reason for the existence of such a routine is mainly for + arraycontexts that allow a lazy mode of execution. In such + arraycontexts, the computations within *f* maybe staged to potentially + enable additional compiler transformations. See + :func:`pytato.trace_call` or :func:`jax.named_call` for examples. + + :arg f: the function executing the computation to be staged. + :return: a function with the same signature as *f*. + """ + return f + # undocumented for now @property @abstractmethod diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index d55e86da..62471231 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -53,11 +53,12 @@ import abc import sys -from collections.abc import Callable +from collections.abc import Callable, Hashable from dataclasses import dataclass from typing import TYPE_CHECKING, Any import numpy as np +from typing_extensions import override from pytools import memoize_method from pytools.tag import Tag, ToTagSetConvertible, normalize_tags @@ -154,7 +155,7 @@ def __init__( self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} self._dag_transform_cache: dict[ pt.DictOfNamedArrays, - tuple[pt.DictOfNamedArrays, str]] = {} + tuple[pt.AbstractResultWithNamedArrays, str]] = {} if compile_trace_callback is None: def _compile_trace_callback(what, stage, ir): @@ -176,8 +177,8 @@ def _frozen_array_types(self) -> tuple[type, ...]: # {{{ compilation - def transform_dag(self, dag: pytato.DictOfNamedArrays - ) -> pytato.DictOfNamedArrays: + def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays + ) -> pytato.AbstractResultWithNamedArrays: """ Returns a transformed version of *dag*. Sub-classes are supposed to override this method to implement context-specific transformations on @@ -232,6 +233,22 @@ def get_target(self): # }}} + @override + def outline(self, + f: Callable[..., Any], + *, + id: Hashable | None = None, + tags: frozenset[Tag] = frozenset() # pyright: ignore[reportCallInDefaultInitializer] + ) -> Callable[..., Any]: + from pytato.tags import FunctionIdentifier + + from .outline import OutlinedCall + id = id or getattr(f, "__name__", None) + if id is not None: + tags = tags | {FunctionIdentifier(id)} + + return OutlinedCall(self, f, tags) + # }}} @@ -514,8 +531,8 @@ def freeze(self, array): TaggableCLArray, to_tagged_cl_array, ) - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier from arraycontext.impl.pytato.utils import ( + _ary_container_key_stringifier, _normalize_pt_expr, get_cl_axes_from_pt_axes, ) @@ -576,10 +593,19 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray: rec_keyed_map_array_container(_to_frozen, array), actx=None) - pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( + dag = pt.make_dict_of_named_arrays( key_to_pt_arrays) + + from pytato.transform import Deduplicator + dag = Deduplicator()(dag) + + # FIXME: Remove this if/when _normalize_pt_expr gets support for functions + dag = pt.tag_all_calls_to_be_inlined( + dag) + dag = pt.inline_calls(dag) + normalized_expr, bound_arguments = _normalize_pt_expr( - pt_dict_of_named_arrays) + dag) try: pt_prg = self._freeze_prg_cache[normalized_expr] @@ -731,11 +757,13 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: from .compile import LazilyPyOpenCLCompilingFunctionCaller return LazilyPyOpenCLCompilingFunctionCaller(self, f) - def transform_dag(self, dag: pytato.DictOfNamedArrays - ) -> pytato.DictOfNamedArrays: + def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays + ) -> pytato.AbstractResultWithNamedArrays: import pytato as pt - dag = pt.transform.materialize_with_mpms(dag) - return dag + tdag = pt.tag_all_calls_to_be_inlined(dag) + tdag = pt.inline_calls(tdag) + tdag = pt.transform.materialize_with_mpms(tdag) + return tdag def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt @@ -769,7 +797,7 @@ def preprocess_arg(name, arg): # multiple placeholders with the same name that are not # also the same object are not allowed, and this would produce # a different Placeholder object of the same name. - if (not isinstance(ary, pt.Placeholder) + if (not isinstance(ary, pt.Placeholder | pt.NamedArray) and not ary.tags_of_type(NameHint)): ary = ary.tagged(NameHint(name)) @@ -795,6 +823,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): An arraycontext that uses :mod:`pytato` to represent the thawed state of the arrays and compiles the expressions using :class:`pytato.target.python.JAXPythonTarget`. + + + .. automethod:: transform_dag """ def __init__(self, @@ -874,7 +905,7 @@ def freeze(self, array): import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier array_as_dict: dict[str, jnp.ndarray | pt.Array] = {} key_to_frozen_subary: dict[str, jnp.ndarray] = {} @@ -946,6 +977,15 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: from .compile import LazilyJAXCompilingFunctionCaller return LazilyJAXCompilingFunctionCaller(self, f) + @override + def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays + ) -> pytato.AbstractResultWithNamedArrays: + import pytato as pt + + dag = pt.tag_all_calls_to_be_inlined(dag) + dag = pt.inline_calls(dag) + return dag + def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): import jax.numpy as jnp diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e78c4e62..7285c22e 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -7,6 +7,9 @@ """ from __future__ import annotations +from pytato.array import AxesT +from pytato.transform import Deduplicator + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -113,28 +116,6 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # {{{ utilities -def _ary_container_key_stringifier(keys: tuple[object, ...]) -> str: - """ - Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an - array-container's component's key. Goals of this routine: - - * No two different keys should have the same stringification - * Stringified key must a valid identifier according to :meth:`str.isidentifier` - * (informal) Shorter identifiers are preferred - """ - def _rec_str(key: object) -> str: - if isinstance(key, str | int): - return str(key) - elif isinstance(key, tuple): - # t in '_actx_t': stands for tuple - return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] - else: - raise NotImplementedError("Key-stringication unimplemented for " - f"'{type(key).__name__}'.") - - return "_".join(_rec_str(key) for key in keys) - - def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...], kwargs: Mapping[str, Any] ) -> \ @@ -211,7 +192,7 @@ def _to_input_for_compiled( """ from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array if isinstance(ary, pt.Array): - dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) + dag = Deduplicator()(pt.make_dict_of_named_arrays({"_actx_out": ary})) # Transform the DAG to give metadata inference a chance to do its job return actx.transform_dag(dag)["_actx_out"].expr elif isinstance(ary, TaggableCLArray): @@ -341,6 +322,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( args, kwargs) @@ -427,12 +409,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): self.actx._compile_trace_callback( prg_id, "post_transform_dag", pt_dict_of_named_arrays) - name_in_program_to_tags = { - name: out.tags - for name, out in pt_dict_of_named_arrays._data.items()} - name_in_program_to_axes = { - name: out.axes - for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_tags: dict[str, frozenset[Tag]] = {} + name_in_program_to_axes: dict[str, AxesT] = {} + if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays): + name_in_program_to_tags.update({ + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()}) + + name_in_program_to_axes.update({ + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()}) self.actx._compile_trace_callback( prg_id, "pre_generate_loopy", pt_dict_of_named_arrays) @@ -524,12 +510,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): self.actx._compile_trace_callback( prg_id, "post_transform_dag", pt_dict_of_named_arrays) - name_in_program_to_tags = { - name: out.tags - for name, out in pt_dict_of_named_arrays._data.items()} - name_in_program_to_axes = { - name: out.axes - for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_tags: dict[str, frozenset[Tag]] = {} + name_in_program_to_axes: dict[str, AxesT] = {} + if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays): + name_in_program_to_tags.update({ + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()}) + + name_in_program_to_axes.update({ + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()}) self.actx._compile_trace_callback( prg_id, "pre_generate_jax", pt_dict_of_named_arrays) diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py new file mode 100644 index 00000000..a8135ae3 --- /dev/null +++ b/arraycontext/impl/pytato/outline.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +from pytato.transform import Deduplicator + + +__doc__ = """ +.. autoclass:: OutlinedCall +""" +__copyright__ = """ +Copyright (C) 2023-5 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import itertools +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from typing import cast + +import numpy as np +from immutabledict import immutabledict + +import pytato as pt +from pymbolic import Scalar +from pytools.tag import Tag + +from arraycontext.container import is_array_container_type +from arraycontext.container.traversal import rec_keyed_map_array_container +from arraycontext.context import ( + Array, + ArrayOrContainer, + ArrayT, +) +from arraycontext.impl.pytato import _BasePytatoArrayContext + + +def _get_arg_id_to_arg(args: tuple[object, ...], + kwargs: Mapping[str, object] + ) -> immutabledict[tuple[object, ...], pt.Array]: + """ + Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id + to argument values. See + :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's + representation. + """ + arg_id_to_arg: dict[tuple[object, ...], object] = {} + + for kw, arg in itertools.chain(enumerate(args), + kwargs.items()): + if arg is None: + pass + elif np.isscalar(arg): + # do not make scalars as placeholders since we inline them. + pass + elif is_array_container_type(arg.__class__): + def id_collector(keys: tuple[object, ...], ary: ArrayT) -> ArrayT: + if np.isscalar(ary): + pass + else: + arg_id = (kw, *keys) # noqa: B023 + arg_id_to_arg[arg_id] = ary + return ary + + rec_keyed_map_array_container(id_collector, arg) + elif isinstance(arg, pt.Array): + arg_id = (kw,) + arg_id_to_arg[arg_id] = arg + else: + raise ValueError("Argument to a compiled operator should be" + " either a scalar, pt.Array or an array container. Got" + f" '{arg}'.") + + return immutabledict(arg_id_to_arg) + + +def _get_input_arg_id_str( + arg_id: tuple[object, ...], prefix: str | None = None) -> str: + if prefix is None: + prefix = "" + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier + return f"_actx_{prefix}_in_{_ary_container_key_stringifier(arg_id)}" + + +def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str: + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier + return f"_actx_out_{_ary_container_key_stringifier(arg_id)}" + + +def _get_arg_id_to_placeholder( + arg_id_to_arg: Mapping[tuple[object, ...], pt.Array], + prefix: str | None = None) -> immutabledict[tuple[object, ...], pt.Placeholder]: + """ + Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder` + for each argument in *arg_id_to_arg*. See + :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's + representation. + """ + return immutabledict({ + arg_id: pt.make_placeholder( + _get_input_arg_id_str(arg_id, prefix=prefix), + arg.shape, + arg.dtype) + for arg_id, arg in arg_id_to_arg.items()}) + + +def _call_with_placeholders( + f: Callable[..., object], + args: tuple[object, ...], + kwargs: Mapping[str, object], + arg_id_to_placeholder: Mapping[tuple[object, ...], pt.Placeholder]) -> object: + """ + Construct placeholders analogous to *args* and *kwargs* and call *f*. + """ + def get_placeholder_replacement( + arg: ArrayOrContainer | Scalar | None, key: tuple[object, ...] + ) -> ArrayOrContainer | Scalar | None: + if arg is None: + return None + elif np.isscalar(arg): + return cast(Scalar, arg) + elif isinstance(arg, pt.Array): + return arg_id_to_placeholder[key] + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder( + keys: tuple[object, ...], ary: Array) -> Array: + return cast("Array", get_placeholder_replacement(ary, key + keys)) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) + + pl_args = [get_placeholder_replacement(arg, (iarg,)) + for iarg, arg in enumerate(args)] + pl_kwargs = {kw: get_placeholder_replacement(arg, (kw,)) + for kw, arg in kwargs.items()} + + return f(*pl_args, **pl_kwargs) + + +def _unpack_output( + output: ArrayOrContainer) -> immutabledict[str, pt.Array]: + """Unpack any array containers in *output*.""" + if isinstance(output, pt.Array): + return immutabledict({"_": output}) + elif is_array_container_type(output.__class__): + unpacked_output = {} + + def _unpack_container(key: tuple[object, ...], ary: ArrayT) -> ArrayT: + key_str = _get_output_arg_id_str(key) + unpacked_output[key_str] = ary + return ary + + rec_keyed_map_array_container(_unpack_container, output) + + return immutabledict(unpacked_output) + else: + raise NotImplementedError(type(output)) + + +def _pack_output( + output_template: ArrayOrContainer, + unpacked_output: pt.Array | immutabledict[str, pt.Array] + ) -> ArrayOrContainer: + """ + Pack *unpacked_output* into array containers according to *output_template*. + """ + if isinstance(output_template, pt.Array): + assert isinstance(unpacked_output, pt.Array) + return unpacked_output + elif is_array_container_type(output_template.__class__): + assert isinstance(unpacked_output, immutabledict) + + def _pack_into_container(key: tuple[object, ...], ary: Array) -> Array: # pyright: ignore[reportUnusedParameter] + key_str = _get_output_arg_id_str(key) + return unpacked_output[key_str] + + return rec_keyed_map_array_container(_pack_into_container, output_template) + else: + raise NotImplementedError(type(output_template)) + + +@dataclass(frozen=True) +class OutlinedCall: + actx: _BasePytatoArrayContext + f: Callable[..., object] + tags: frozenset[Tag] + + def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer: + arg_id_to_arg = _get_arg_id_to_arg(args, kwargs) + + if __debug__: + # Add a prefix to the names to distinguish them from any existing + # placeholders + arg_id_to_prefixed_placeholder = _get_arg_id_to_placeholder( + arg_id_to_arg, prefix="outlined_call") + + prefixed_output = _call_with_placeholders( + self.f, args, kwargs, arg_id_to_prefixed_placeholder) + unpacked_prefixed_output = Deduplicator()( + pt.make_dict_of_named_arrays(_unpack_output(prefixed_output))) + + prefixed_placeholders = frozenset( + arg_id_to_prefixed_placeholder.values()) + + found_placeholders = frozenset({ + arg for arg in pt.transform.InputGatherer()(unpacked_prefixed_output) + if isinstance(arg, pt.Placeholder)}) + + extra_placeholders = found_placeholders - prefixed_placeholders + assert not extra_placeholders, \ + "Found non-argument placeholder " \ + f"'{next(iter(extra_placeholders)).name}' in outlined function." + + arg_id_to_placeholder = _get_arg_id_to_placeholder(arg_id_to_arg) + + output = _call_with_placeholders(self.f, args, kwargs, arg_id_to_placeholder) + unpacked_output = Deduplicator()( + pt.make_dict_of_named_arrays(_unpack_output(output))) + if len(unpacked_output) == 1 and "_" in unpacked_output: + ret_type = pt.function.ReturnType.ARRAY + else: + ret_type = pt.function.ReturnType.DICT_OF_ARRAYS + + used_placeholders = frozenset({ + arg for arg in pt.transform.InputGatherer()( + unpacked_output) + if isinstance(arg, pt.Placeholder)}) + + call_bindings = { + placeholder.name: arg_id_to_arg[arg_id] + for arg_id, placeholder in arg_id_to_placeholder.items() + if placeholder in used_placeholders} + + func_def = pt.function.FunctionDefinition( + parameters=frozenset(call_bindings.keys()), + return_type=ret_type, + returns=immutabledict(unpacked_output._data), + tags=self.tags, + ) + + call_site_output = func_def(**call_bindings) + + assert isinstance(call_site_output, pt.Array | immutabledict) + return _pack_output(output, call_site_output) + +# vim: foldmethod=marker diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 8c6e7f5c..78167efd 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -41,9 +41,12 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any, cast +from typing_extensions import override + import pytools +from pytato import AbstractResultWithNamedArrays +from pytato.analysis import get_num_call_sites from pytato.array import ( - AbstractResultWithNamedArrays, Array, Axis as PtAxis, DataWrapper, @@ -52,8 +55,15 @@ SizeParam, make_placeholder, ) +from pytato.function import FunctionDefinition from pytato.target.loopy import LoopyPyOpenCLTarget -from pytato.transform import ArrayOrNames, CopyMapper +from pytato.transform import ( + ArrayOrNames, + CopyMapper, + Deduplicator, + MappedT, + TransformMapperCache, +) from pytools import UniqueNameGenerator, memoize_method from arraycontext import ArrayContext @@ -71,12 +81,24 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): :class:`pytato.DataWrapper` is replaced with a deterministic copy of :class:`Placeholder`. """ - def __init__(self) -> None: - super().__init__() + def __init__( + self, + err_on_collision: bool = True, + err_on_created_duplicate: bool = True, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate, + _cache=_cache, + _function_cache=_function_cache) + self.bound_arguments: dict[str, Any] = {} self.vng = UniqueNameGenerator() self.seen_inputs: set[str] = set() + @override def map_data_wrapper(self, expr: DataWrapper) -> Array: if expr.name is not None: if expr.name in self.seen_inputs: @@ -96,17 +118,27 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: axes=expr.axes, tags=expr.tags) + @override def map_size_param(self, expr: SizeParam) -> Array: raise NotImplementedError + @override def map_placeholder(self, expr: Placeholder) -> Array: raise ValueError("Placeholders cannot appear in" " DatawrapperToBoundPlaceholderMapper.") + @override + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + raise ValueError("Function definitions cannot appear in" + " DatawrapperToBoundPlaceholderMapper.") + +# FIXME: This strategy doesn't work if the DAG has functions, since function +# definitions can't contain non-argument placeholders def _normalize_pt_expr( - expr: DictOfNamedArrays - ) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]: + expr: AbstractResultWithNamedArrays + ) -> tuple[DictOfNamedArrays, Mapping[str, Any]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of @@ -116,9 +148,16 @@ def _normalize_pt_expr( Deterministic naming of placeholders permits more effective caching of equivalent graphs. """ + expr = Deduplicator()(expr) + + if get_num_call_sites(expr): + raise NotImplementedError( + "_normalize_pt_expr is not compatible with expressions that " + "contain function calls.") + normalize_mapper = _DatawrapperToBoundPlaceholderMapper() normalized_expr = normalize_mapper(expr) - assert isinstance(normalized_expr, AbstractResultWithNamedArrays) + assert isinstance(normalized_expr, DictOfNamedArrays) return normalized_expr, normalize_mapper.bound_arguments @@ -156,6 +195,7 @@ def __init__(self, actx: ArrayContext) -> None: super().__init__() self.actx = actx + @override def map_data_wrapper(self, expr: DataWrapper) -> Array: import numpy as np @@ -188,6 +228,7 @@ def __init__(self, actx: ArrayContext) -> None: super().__init__() self.actx = actx + @override def map_data_wrapper(self, expr: DataWrapper) -> Array: import numpy as np @@ -207,7 +248,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: non_equality_tags=expr.non_equality_tags) -def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: +def transfer_from_numpy(expr: MappedT, actx: ArrayContext) -> MappedT: """Transfer arrays contained in :class:`~pytato.array.DataWrapper` instances to be device arrays, using :meth:`~arraycontext.ArrayContext.from_numpy`. @@ -215,7 +256,7 @@ def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: return TransferFromNumpyMapper(actx)(expr) -def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: +def transfer_to_numpy(expr: MappedT, actx: ArrayContext) -> MappedT: """Transfer arrays contained in :class:`~pytato.array.DataWrapper` instances to be :class:`numpy.ndarray` instances, using :meth:`~arraycontext.ArrayContext.to_numpy`. @@ -262,4 +303,30 @@ def tabulate_profiling_data(actx: PytatoPyOpenCLArrayContext) -> pytools.Table: # }}} + +# {{{ compile/outline helpers + +def _ary_container_key_stringifier(keys: tuple[object, ...]) -> str: + """ + Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an + array-container's component's key. Goals of this routine: + + * No two different keys should have the same stringification + * Stringified key must a valid identifier according to :meth:`str.isidentifier` + * (informal) Shorter identifiers are preferred + """ + def _rec_str(key: object) -> str: + if isinstance(key, str | int): + return str(key) + elif isinstance(key, tuple): + # t in '_actx_t': stands for tuple + return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] + else: + raise NotImplementedError("Key-stringication unimplemented for " + f"'{type(key).__name__}'.") + + return "_".join(_rec_str(key) for key in keys) + +# }}} + # vim: foldmethod=marker diff --git a/examples/how_to_outline.py b/examples/how_to_outline.py new file mode 100644 index 00000000..bb04968d --- /dev/null +++ b/examples/how_to_outline.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import dataclasses as dc + +import numpy as np + +import pytato as pt +from pytools.obj_array import make_obj_array + +from arraycontext import ( + Array, + PytatoJAXArrayContext as BasePytatoJAXArrayContext, + dataclass_array_container, + with_container_arithmetic, +) + + +Ncalls = 300 + + +class PytatoJAXArrayContext(BasePytatoJAXArrayContext): + def transform_dag(self, dag): + # Test 1: Test that the number of untransformed call sites are as + # expected + assert pt.analysis.get_num_call_sites(dag) == Ncalls + + dag = pt.tag_all_calls_to_be_inlined(dag) + # FIXME: Re-enable this when concatenation is added to pytato + # print("[Pre-concatenation] Number of nodes =", + # pt.analysis.get_num_nodes(pt.inline_calls(dag))) + # dag = pt.concatenate_calls( + # dag, + # lambda cs: pt.tags.FunctionIdentifier("foo") in cs.call.function.tags + # ) + # + # # Test 2: Test that only one call-sites is left post concatenation + # assert pt.analysis.get_num_call_sites(dag) == 1 + # + # dag = pt.inline_calls(dag) + # print("[Post-concatenation] Number of nodes =", + # pt.analysis.get_num_nodes(dag)) + dag = pt.inline_calls(dag) + + return dag + + +actx = PytatoJAXArrayContext() + + +@with_container_arithmetic( + bcast_obj_array=True, + eq_comparison=False, + rel_comparison=False, +) +@dataclass_array_container +@dc.dataclass(frozen=True) +class State: + mass: Array | np.ndarray + vel: np.ndarray # np array of Arrays or numpy arrays + + +@actx.outline +def foo(x1, x2): + return (2*x1 + 3*x2 + x1**3 + x2**4 + + actx.np.minimum(2*x1, 4*x2) + + actx.np.maximum(7*x1, 8*x2)) + + +rng = np.random.default_rng(0) +Ndof = 10 +Ndim = 3 + +results = [] + +for _ in range(Ncalls): + Nel = rng.integers(low=4, high=17) + state1_np = State( + mass=rng.random((Nel, Ndof)), + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + ) + state2_np = State( + mass=rng.random((Nel, Ndof)), + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + ) + + state1 = actx.from_numpy(state1_np) + state2 = actx.from_numpy(state2_np) + results.append(foo(state1, state2)) + +actx.to_numpy(make_obj_array(results)) diff --git a/requirements.txt b/requirements.txt index a4cb4025..d25b9dcc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/loopy.git#egg=loopy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/majosm/pytato.git@fix-transform-result-types#egg=pytato diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 31fa9e79..7e4949f3 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -24,8 +24,10 @@ """ import logging +from collections.abc import Callable from dataclasses import dataclass from functools import partial +from typing import TypeAlias import numpy as np import pytest @@ -34,6 +36,7 @@ from pytools.tag import Tag from arraycontext import ( + ArrayContext, BcastUntilActxArray, EagerJAXArrayContext, NumpyArrayContext, @@ -58,6 +61,9 @@ logger = logging.getLogger(__name__) +ArrayContextFactory: TypeAlias = Callable[[], ArrayContext] + + # {{{ array context fixture class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext): @@ -1165,6 +1171,40 @@ def my_rhs(scale, vel): np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) + +def test_actx_compile_with_outlined_function(actx_factory: ArrayContextFactory): + actx = actx_factory() + rng = np.random.default_rng() + + @actx.outline + def outlined_scale_and_orthogonalize(alpha: float, vel: Velocity2D) -> Velocity2D: + return scale_and_orthogonalize(alpha, vel) + + def multi_scale_and_orthogonalize( + alpha: float, vel1: Velocity2D, vel2: Velocity2D) -> np.ndarray: + return make_obj_array([ + outlined_scale_and_orthogonalize(alpha, vel1), + outlined_scale_and_orthogonalize(alpha, vel2)]) + + compiled_rhs = actx.compile(multi_scale_and_orthogonalize) + + v1_x = rng.uniform(size=10) + v1_y = rng.uniform(size=10) + v2_x = rng.uniform(size=10) + v2_y = rng.uniform(size=10) + + vel1 = actx.from_numpy(Velocity2D(v1_x, v1_y, actx)) + vel2 = actx.from_numpy(Velocity2D(v2_x, v2_y, actx)) + + scaled_speed1, scaled_speed2 = compiled_rhs(np.float64(3.14), vel1, vel2) + + result1 = actx.to_numpy(scaled_speed1) + result2 = actx.to_numpy(scaled_speed2) + np.testing.assert_allclose(result1.u, -3.14*v1_y) # pyright: ignore[reportAttributeAccessIssue] + np.testing.assert_allclose(result1.v, 3.14*v1_x) # pyright: ignore[reportAttributeAccessIssue] + np.testing.assert_allclose(result2.u, -3.14*v2_y) # pyright: ignore[reportAttributeAccessIssue] + np.testing.assert_allclose(result2.v, 3.14*v2_x) # pyright: ignore[reportAttributeAccessIssue] + # }}} diff --git a/test/test_utils.py b/test/test_utils.py index eeef7723..a44c6035 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -39,7 +39,7 @@ # {{{ test_pt_actx_key_stringification_uniqueness def test_pt_actx_key_stringification_uniqueness(): - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier assert (_ary_container_key_stringifier(((3, 2), 3)) != _ary_container_key_stringifier((3, (2, 3))))