diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index e6367f2b..d7133392 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -53,6 +53,7 @@ from .impl.jax import EagerJAXArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext +from .impl.pytato.batched_einsum import BatchedEinsumPytatoPyOpenCLArrayContext from .impl.pytato.split_actx import SplitPytatoPyOpenCLArrayContext from .loopy import make_loopy_program # deprecated, remove in 2022. @@ -100,6 +101,7 @@ "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", "SplitPytatoPyOpenCLArrayContext", + "BatchedEinsumPytatoPyOpenCLArrayContext", "PytatoJAXArrayContext", "EagerJAXArrayContext", diff --git a/arraycontext/impl/pytato/batched_einsum/__init__.py b/arraycontext/impl/pytato/batched_einsum/__init__.py new file mode 100644 index 00000000..87f6273e --- /dev/null +++ b/arraycontext/impl/pytato/batched_einsum/__init__.py @@ -0,0 +1,382 @@ +""" +.. autoclass:: BatchedEinsumPytatoPyOpenCLArrayContext """ + +__copyright__ = """ +Copyright (C) 2023 Kaushik Kulkarni +Copyright (C) 2022 Andreas Kloeckner +Copyright (C) 2022 Matthias Diener +Copyright (C) 2022 Matt Smith +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import logging +import sys +from typing import TYPE_CHECKING, Any, Callable, Optional, Type +from warnings import warn + +import numpy as np + +import loopy as lp +from pytools import ProcessLogger +from pytools.tag import Tag + +from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext + + +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", False): + import pyopencl as cl + import pytato + + +class BatchedEinsumPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): + r""" + .. attribute:: loop_fusion_axis_tag_t + + A subtype of :class:`pytato.tag.Tag` that are attached to the + :class:`~pytato.array.Array`\ 's axes in an expression graph. Loops that + iterate over axes tagged with instances of same such tag types will form the + candidate loops for Kennedy's unweighted Loop Fusion algorithm. + + .. attribute:: fallback_to_no_fusion + + If *True*, during the compilation of an array expression graph for which + loop fusion fails (see note) transformation routines from + :class:`arraycontext.SplitPytatoPyOpenCLArrayContext` are invoked. + + .. attribute:: feinsum_db + + An instance of :class:`str` corresponding to the database of tuned batched + einsums. If *None*, then a static transformation strategy is applied to the + batched einsums kernels. + + .. attribute:: log_loopy_statistics + + If *True*, statistics of compiled :class:`loopy.TranslationUnit` will be + logged. If enable, we log the FLOPS and global memory access footprint for + each of the programs. If *False*, nothing is done. + + .. note:: + + The conditions under which we fallback (or raise) are: + + #. There exists an array that is to be materialized but at least one of its + axes is not tagged with tags of :attr:`loop_fusion_axis_tag_t`. + """ + def __init__( + self, + queue: "cl.CommandQueue", allocator=None, + *, + loop_fusion_axis_tag_t: Type[Tag], + fallback_to_no_fusion: bool = True, + assume_all_indirection_maps_as_non_negative: bool = False, + compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None, + feinsum_db: Optional[str] = None, + log_loopy_statistics: bool = False, + fused_loop_name_prefix_getter: Optional[Callable[[Tag], str]] = None, + ) -> None: + super().__init__(queue, + allocator, + compile_trace_callback=compile_trace_callback) + + self.loop_fusion_axis_tag_t = loop_fusion_axis_tag_t + self.fallback_to_no_fusion = fallback_to_no_fusion + self.feinsum_db = feinsum_db + self.assume_all_indirection_maps_as_non_negative = ( + assume_all_indirection_maps_as_non_negative) + self.log_loopy_statistics = log_loopy_statistics + if fused_loop_name_prefix_getter: + self.fused_loop_name_prefix_getter = fused_loop_name_prefix_getter + else: + self.fused_loop_name_prefix_getter = lambda tag_t: "ifused" + + def transform_dag(self, + dag: "pytato.DictOfNamedArrays") -> "pytato.DictOfNamedArrays": + import pytato as pt + + from .utils import ( + _make_passthrough_arg, get_indirection_maps, + get_inputs_and_outputs_of_reduction_nodes) + from arraycontext.impl.pytato.split_actx.utils import ( + get_inputs_and_outputs_of_einsum) + + # Step 1. Collapse equivalent nodes in DAG. + # ----------------------------------------- + # type-ignore-reason: mypy is right pytato provides imprecise types. + dag = pt.transform.deduplicate_data_wrappers(dag) # type: ignore[assignment] + + # Step 2. Materialize einsum/reduction outputs. + # --------------------------------------------- + _, einsum_outputs = get_inputs_and_outputs_of_einsum(dag) + _, reduction_outputs = get_inputs_and_outputs_of_reduction_nodes(dag) + + def materialize_all_einsums_or_reduces(expr): + if (expr in einsum_outputs + or expr in reduction_outputs): + return expr.tagged(pt.tags.ImplStored()) + else: + return expr + + # type-ignore-reason: mypy is right pytato provides imprecise types. + dag = pt.transform.map_and_copy(dag, # type: ignore[assignment] + materialize_all_einsums_or_reduces) + + # Step 3. Materialize with MPMS + # ----------------------------- + dag = pt.transform.materialize_with_mpms(dag) + + # Step 4. Mark all indirection maps as non-negative + # ------------------------------------------------- + if self.assume_all_indirection_maps_as_non_negative: + indirection_maps = get_indirection_maps(dag) + + def tag_indices_as_non_negative(ary): + if ary in indirection_maps: + return ary.tagged(pt.tags.AssumeNonNegative()) + else: + return ary + + # type-ignore-reason: mypy is right pytato provides imprecise types. + dag = pt.transform.map_and_copy(dag, # type: ignore[assignment] + tag_indices_as_non_negative) + + # Step 5. Get rid of broadcasts in einsum expressions (helps feinsum) + # ------------------------------------------------------------------- + dag = pt.rewrite_einsums_with_no_broadcasts(dag) + + # Step 6. Infer axis tags + # ----------------------- + # type-ignore-reason: mypy is right pytato provides imprecise types. + dag = pt.unify_axes_tags(dag) # type: ignore[assignment] + + # Step 7. Make all pt.einsum/pt.reduction inputs as substitutions + # --------------------------------------------------------------- + def implement_einsum_reduction_inputs_as_substs(expr): + from immutables import Map + + from pytato.target.loopy import ImplSubstitution + if isinstance(expr, pt.Einsum): + # make the arguments passthrough to make use of already stored + # values. + # pylint and 'attrs' have poor compatibility + # pylint: disable=too-many-function-args,redundant-keyword-arg + # pylint: disable=unexpected-keyword-arg + return pt.Einsum( + expr.access_descriptors, + tuple(_make_passthrough_arg(arg, ImplSubstitution()) + for arg in expr.args), + expr.redn_axis_to_redn_descr, + expr.index_to_access_descr, + tags=expr.tags, + axes=expr.axes, + ) + elif isinstance(expr, pt.IndexLambda) and expr.var_to_reduction_descr: + # make the arguments passthrough to make use of already stored + # values. + # pylint: disable=too-many-function-args,redundant-keyword-arg + # pylint: disable=unexpected-keyword-arg + return pt.IndexLambda( + expr.expr, + expr.shape, + expr.dtype, + Map({name: _make_passthrough_arg(bnd, ImplSubstitution()) + for name, bnd in expr.bindings.items()}), + expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes, + ) + else: + return expr + + # type-ignore-reason: mypy is right pytato provides imprecise types. + dag = pt.transform.map_and_copy(dag, # type: ignore[assignment] + implement_einsum_reduction_inputs_as_substs) + + return dag + + def transform_loopy_program(self, + t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + knl_name = t_unit.default_entrypoint.name + + logger.info(f"[{self.__class__}.transform_loopy_program]:" + f" Transforming kernel '{knl_name}' with" + f" {len(t_unit.default_entrypoint.instructions)} statements.") + + # Step 0. Fallback if cannot t_unit cannot be transformed + # ------------------------------------------------------- + for iname in t_unit.default_entrypoint.all_inames(): + if not t_unit.default_entrypoint.iname_tags_of_type( + iname, self.loop_fusion_axis_tag_t): + if self.fallback_to_no_fusion: + warn(f"[{knl_name}]: Falling back to a slower transformation" + " strategy as some loops are uninferred which mesh entity" + " they belong to.", + stacklevel=2) + from arraycontext.impl.pytato.split_actx import ( + SplitPytatoPyOpenCLArrayContext) + + # type-ignore-reason: mypy is right, we are passing incorrect + # types, but knowing the implementation of + # SplitPytatoPyOpenCLArrayContext this should be fine. + return SplitPytatoPyOpenCLArrayContext.transform_loopy_program( + self, t_unit) # type: ignore[arg-type] + else: + raise RuntimeError(f"Iname '{iname}' is not tagged with tags" + f" of type '{self.loop_fusion_axis_tag_t}'" + " => Not allowed since Kennedy's Loop fusion" + " cannot be applied.") + + # Step 0.5. Make offsets as 0. (FIXME: move this to loopy knl invocation) + # ----------------------------------------------------------------------- + knl = t_unit.default_entrypoint + knl = knl.copy(args=[arg.copy(offset=0) for arg in knl.args]) + t_unit = t_unit.with_kernel(knl) + del knl + + # Step 1. Fuse loops indexing over the same tag + # --------------------------------------------- + with ProcessLogger(logger, f"[{knl_name}] Loop Fusion"): + from .utils import apply_kennedy_fusion_with_batched_einsum_extension + t_unit = apply_kennedy_fusion_with_batched_einsum_extension( + t_unit, self.loop_fusion_axis_tag_t, + self.fused_loop_name_prefix_getter) + + # Step 2. Combine the domains of individual loop nests into individual + # BasicSets + # -------------------------------------------------------------------- + from .utils import combine_domains_of_perfect_loop_nests + t_unit = combine_domains_of_perfect_loop_nests(t_unit) + + # Step 3. Remove dead temporaries + # ------------------------------- + from .utils import remove_dead_temporaries + t_unit = remove_dead_temporaries(t_unit) + + # Step 4. Contract arrays + # ----------------------- + with ProcessLogger(logger, f"[{knl_name}] Array Contraction"): + from .utils import contract_arrays + t_unit = contract_arrays(t_unit) + + # Step 5. Collect statistics + # -------------------------- + + # {{{ compute stats + + if self.log_loopy_statistics: + + with ProcessLogger(logger, f"[{knl_name}] Count kernel metrics"): + from loopy.kernel.array import ArrayBase + from pytools import product + knl = t_unit.default_entrypoint + knl = knl.copy( + silenced_warnings=(knl.silenced_warnings + + ["insn_count_subgroups_upper_bound", + "summing_if_branches_ops"])) + + t_unit = t_unit.with_kernel(knl) + del knl + + op_map = lp.get_op_map(t_unit, subgroup_size=32) + + c64_ops = {op_type: (op_map.filter_by(dtype=[np.complex64], + name=op_type, + kernel_name=knl_name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + c128_ops = {op_type: (op_map.filter_by(dtype=[np.complex128], + name=op_type, + kernel_name=knl_name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + f32_ops = ((op_map.filter_by(dtype=[np.float32], + kernel_name=knl_name) + .eval_and_sum({})) + + (2 * c64_ops["add"] + + 6 * c64_ops["mul"] + + (6 + 3 + 2) * c64_ops["div"])) + f64_ops = ((op_map.filter_by(dtype=[np.float64], + kernel_name="_pt_kernel") + .eval_and_sum({})) + + (2 * c128_ops["add"] + + 6 * c128_ops["mul"] + + (6 + 3 + 2) * c128_ops["div"])) + + # {{{ footprint gathering + + nfootprint_bytes = 0 + + for ary in knl.args: + if (isinstance(ary, ArrayBase) + and ary.address_space == lp.AddressSpace.GLOBAL): + nfootprint_bytes += (product(ary.shape) + * ary.dtype.itemsize) + + for ary in knl.temporary_variables.values(): + if ary.address_space == lp.AddressSpace.GLOBAL: + # global temps would be written once and read once + nfootprint_bytes += (2 * product(ary.shape) + * ary.dtype.itemsize) + + # }}} + + if f32_ops: + logger.info(f"Single-prec. GFlOps: {f32_ops * 1e-9}") + if f64_ops: + logger.info(f"Double-prec. GFlOps: {f64_ops * 1e-9}") + logger.info(f"Footprint GBs: {nfootprint_bytes * 1e-9}") + + # }}} + + # Step 6. Draw kernel boundaries between batched einsum kernels + # ------------------------------------------------------------- + from arraycontext.impl.pytato.split_actx.utils import ( + add_gbarrier_between_disjoint_loop_nests) + + t_unit = add_gbarrier_between_disjoint_loop_nests(t_unit) + + # Step 7. Alias global temporaries with disjoint live intervals + # ------------------------------------------------------------- + from arraycontext.impl.pytato.split_actx.utils import ( + alias_global_temporaries) + t_unit = alias_global_temporaries(t_unit) + + # Step 8. Macro-kernel optimizations + # ---------------------------------- + if self.feinsum_db: + from .utils import apply_feinsum_transformations + t_unit = apply_feinsum_transformations( + t_unit, self.feinsum_db, self.queue.device) + else: + from arraycontext.impl.pytato.split_actx.utils import ( + parallelize_reduce_to_scalars, + split_iteration_domain_across_work_items) + t_unit = split_iteration_domain_across_work_items(t_unit, + self.queue.device) + t_unit = parallelize_reduce_to_scalars(t_unit, self.queue.device) + + return t_unit + +# vim: fdm=marker diff --git a/arraycontext/impl/pytato/batched_einsum/utils.py b/arraycontext/impl/pytato/batched_einsum/utils.py new file mode 100644 index 00000000..4bbe275f --- /dev/null +++ b/arraycontext/impl/pytato/batched_einsum/utils.py @@ -0,0 +1,453 @@ +__copyright__ = """ +Copyright (C) 2023 Kaushik Kulkarni +Copyright (C) 2022 Andreas Kloeckner +Copyright (C) 2022 Matthias Diener +Copyright (C) 2022 Matt Smith +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, Any, Callable, Dict, FrozenSet, List, Mapping, Set, Tuple, Type) + +from immutables import Map + +import loopy as lp +import loopy.match as lp_match +import loopy.symbolic as lp_symbolic +import pymbolic.primitives as prim +import pytato as pt +from loopy.translation_unit import for_each_kernel +from pytools import memoize_on_first_arg +from pytools.tag import Tag, ToTagSetConvertible + + +if TYPE_CHECKING: + import feinsum + + import pyopencl + + +# {{{ IndirectionMapsCollector + +class IndirectionMapsCollector(pt.transform.CachedWalkMapper): + """ + .. note:: + + We deliberately avoid using :class:`pytato.transform.CombineMapper` since + the mapper's caching structure would still lead to recomputing + the union of sets for the results of a revisited node. + """ + def __init__(self) -> None: + self.collected_indirection_maps: Set[pt.Array] = set() + super().__init__() + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, # type: ignore[override] + expr: pt.transform.ArrayOrNames) -> int: + return id(expr) + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def post_visit(self, expr: Any) -> None: # type: ignore[override] + if isinstance(expr, pt.IndexBase): + for idx in expr.indices: + if isinstance(idx, pt.Array): + self.collected_indirection_maps.add(idx) + + +def get_indirection_maps(expr: pt.DictOfNamedArrays) -> FrozenSet[pt.Array]: + mapper = IndirectionMapsCollector() + mapper(expr) + return frozenset(mapper.collected_indirection_maps) + +# }}} + + +# {{{ EinsumInputOutputCollector + +class ReductionInputOutputCollector(pt.transform.CachedWalkMapper): + """ + .. note:: + + We deliberately avoid using :class:`pytato.transform.CombineMapper` since + the mapper's caching structure would still lead to recomputing + the union of sets for the results of a revisited node. + """ + def __init__(self) -> None: + self.collected_outputs: Set[pt.Array] = set() + self.collected_inputs: Set[pt.Array] = set() + super().__init__() + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, # type: ignore[override] + expr: pt.transform.ArrayOrNames) -> int: + return id(expr) + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def post_visit(self, expr: Any) -> None: # type: ignore[override] + if isinstance(expr, pt.IndexLambda) and expr.var_to_reduction_descr: + self.collected_outputs.add(expr) + self.collected_inputs.update(expr.bindings.values()) + + +def get_inputs_and_outputs_of_reduction_nodes( + expr: pt.DictOfNamedArrays) -> Tuple[FrozenSet[pt.Array], + FrozenSet[pt.Array]]: + mapper = ReductionInputOutputCollector() + mapper(expr) + return frozenset(mapper.collected_inputs), frozenset(mapper.collected_outputs) + +# }}} + + +def _make_passthrough_arg(ary: pt.Array, + tags: ToTagSetConvertible = frozenset()) -> pt.Array: + from pytato.array import make_index_lambda + return make_index_lambda( + prim.Variable("in")[tuple(prim.Variable(f"_{idim}") + for idim in range(ary.ndim))], + bindings={"in": ary}, + shape=ary.shape, + dtype=ary.dtype, + ).tagged(tags) + + +@dataclass(frozen=True) +class EinsumWithAxesTagged: + einsum: "feinsum.FusedEinsum" + index_tags: Mapping["feinsum.EinsumAxisAccess", Tag] + + def __post_init__(self): + assert (frozenset(self.einsum.index_to_dim_length()) + == frozenset(self.index_tags)) + + +def get_n_callable_kernels(t_unit: lp.TranslationUnit) -> int: + from loopy.kernel.function_interface import CallableKernel + return len([name + for name, clbl in t_unit.callables_table.items() + if isinstance(clbl, CallableKernel)]) + + +def apply_kennedy_fusion_with_batched_einsum_extension( + t_unit: lp.TranslationUnit, + tag_t: Type[Tag], + fused_loop_name_prefix_getter: Callable[[Tag], str]) -> lp.TranslationUnit: + + import feinsum as fnsm + + if get_n_callable_kernels(t_unit) > 1: + # We accept 't_unit' (instead of a kernel) to comply with feinsum's API. + raise NotImplementedError( + "'apply_kennedy_fusion_with_batched_einsum_extension'" + " does not support multiple callee kernels (yet).") + + kernel = t_unit.default_entrypoint + + assert all(len(kernel.iname_to_insns()[iname]) <= 1 + for iname in kernel.all_inames()) + + # A bucket by a tagged einsum and it's position in the einsum. + bucket_to_inames: Dict[Tuple[EinsumWithAxesTagged, fnsm.EinsumAxisAccess], + Set[str]] = {} + + for insn in kernel.instructions: + if isinstance(insn, lp.Assignment): + # {{{ get matching einsum/subst_map + + if insn.reduction_inames(): + einsum, _ = fnsm.get_a_matched_einsum( + t_unit, insn_match=lp_match.Id(insn.id)) + einsum = fnsm.canonicalize_einsum(einsum) + subst_map = fnsm.match_t_unit_to_einsum( + t_unit, einsum, insn_match=lp_match.Id(insn.id)) + else: + # we treat any non-reduction einsum as a copy-einsum + assignee = insn.assignee + if isinstance(assignee, prim.Variable): + lpy_dim_names = [] + else: + assert isinstance(assignee, prim.Subscript) + lpy_dim_names = [index.name for index in assignee.index_tuple] + + dim_lengths = [kernel.get_constant_iname_length(dim_name) + for dim_name in lpy_dim_names] + if len(lpy_dim_names) > 26: + raise ValueError("Batched Einsum Actx supports upto" + "26-dimensions") + einsum_dim_names = [chr(97+idim) + for idim in range(len(lpy_dim_names))] + einsum = fnsm.einsum( + f"{''.join(einsum_dim_names)}->{''.join(einsum_dim_names)}", + # purposefully fix dtype=F64, since for such expression we are + # imprecise on purpose. + fnsm.array(shape=dim_lengths, dtype="float64"), + ) + einsum = fnsm.canonicalize_einsum(einsum) + subst_map = { + einsum.index_names[fnsm.FreeAxis(idim)]: lpy_dim_names[idim] + for idim in range(len(einsum_dim_names))} + # }}} + + idx_tags: Dict[fnsm.EinsumAxisAccess, Tag] = {} + for acc_descr, name_in_einsum in einsum.index_names.items(): + lpy_iname = subst_map[name_in_einsum] + lpy_iname_tag, = kernel.iname_tags_of_type(lpy_iname, tag_t) + idx_tags[acc_descr] = lpy_iname_tag + + tagged_einsum = EinsumWithAxesTagged(einsum, Map(idx_tags)) + + for acc_descr, name_in_einsum in einsum.index_names.items(): + bucket = (tagged_einsum, acc_descr) + lpy_iname = subst_map[name_in_einsum] + bucket_to_inames.setdefault(bucket, set()).add(lpy_iname) + + else: + # TODO: should this be a ValueError? + raise NotImplementedError + + for inames in bucket_to_inames.values(): + inames_tag, = kernel.iname_tags_of_type(next(iter(inames)), + tag_t) + # TODO: Enable pylint once these routines have been upstreamed to loopy + kernel = lp.rename_inames_in_batch( # pylint: disable = no-member + kernel, + lp.get_kennedy_unweighted_fusion_candidates( # pylint: disable=no-member + kernel, inames, + prefix=fused_loop_name_prefix_getter(inames_tag), + ), + ) + + return t_unit.with_kernel(kernel) + + +@for_each_kernel +def remove_dead_temporaries(kernel: lp.LoopKernel) -> lp.LoopKernel: + wmap = kernel.writer_map() + rmap = kernel.reader_map() + + new_tvs: Dict[str, lp.TemporaryVariable] = {} + + for name, tv in kernel.temporary_variables.items(): + writer_ids: FrozenSet[str] = wmap.get(name, frozenset()) + reader_ids: FrozenSet[str] = rmap.get(name, frozenset()) + + if len(writer_ids) != 0 or len(reader_ids) != 0: + new_tvs[name] = tv + + return kernel.copy(temporary_variables=new_tvs) + + +class IndexingTupleCollector(lp_symbolic.WalkMapper): + def __init__(self, subscript_name: str) -> None: + self.subscript_name = subscript_name + super().__init__() + + # mutable state: + self.collected_index_tuples: Set[Tuple[prim.Expression, ...]] = set() + + def post_visit(self, expr: prim.Expression) -> None: + if (isinstance(expr, prim.Subscript) + and expr.aggregate == prim.Variable(self.subscript_name)): + self.collected_index_tuples.add(expr.index_tuple) + + +@memoize_on_first_arg +def _expand_substs(kernel): + # memoized wrapper for lp.expand_substs + return lp.expand_subst(kernel) + + +@memoize_on_first_arg +def can_temp_var_be_contracted(kernel: lp.LoopKernel, name: str) -> bool: + kernel = _expand_substs(kernel) + wmap = kernel.writer_map() + rmap = kernel.reader_map() + + writer_ids: FrozenSet[str] = wmap.get(name, frozenset()) + reader_ids: FrozenSet[str] = rmap.get(name, frozenset()) + + if kernel.temporary_variables[name].initializer: + # this is a constant literal => cannot be contracted + return False + + if len(writer_ids) == 0: + assert len(reader_ids) == 0 + return True + else: + mapper = IndexingTupleCollector(name) + for insn_id in writer_ids | reader_ids: + insn = kernel.id_to_insn[insn_id] + mapper((insn.expression, + insn.assignees, + tuple(insn.predicates))) + + return len(mapper.collected_index_tuples) == 1 + + +class ArrayContracter(lp_symbolic.RuleAwareIdentityMapper): + def __init__(self, + rule_mapping_context: lp_symbolic.SubstitutionRuleMappingContext, + arrays_to_contract: FrozenSet[str]): + self.arrays_to_contract = arrays_to_contract + super().__init__(rule_mapping_context) + + def map_subscript(self, expr, expn_state) -> prim.Expression: + if (isinstance(expr.aggregate, prim.Variable) + and expr.aggregate.name in self.arrays_to_contract): + return expr.aggregate + else: + return super().map_subscript(expr, expn_state) + + +@for_each_kernel +def contract_arrays(kernel: lp.LoopKernel): + # Note: We could have used lp.precompute here, but that would be unnecessarily + # expensive. + new_tvs: Dict[str, lp.TemporaryVariable] = {} + + rule_mapping_context = lp_symbolic.SubstitutionRuleMappingContext( + kernel.substitutions, + kernel.get_var_name_generator() + ) + temps_to_contract = frozenset({ + name for name, tv in kernel.temporary_variables.items() + if can_temp_var_be_contracted(kernel, name)}) + array_contracter = ArrayContracter(rule_mapping_context, + temps_to_contract) + + kernel = rule_mapping_context.finish_kernel( + array_contracter.map_kernel( + kernel, map_tvs=False, map_args=False) + ) + + for name, tv in kernel.temporary_variables.items(): + if name in temps_to_contract: + tv = tv.copy(shape=(), + dim_tags=(), + address_space=lp.AddressSpace.PRIVATE) + + new_tvs[name] = tv + + return kernel.copy(temporary_variables=new_tvs) + + +@for_each_kernel +def combine_domains_of_perfect_loop_nests(kernel: lp.LoopKernel) -> lp.LoopKernel: + import islpy as isl + + from arraycontext.impl.pytato.split_actx.utils import _is_a_perfect_loop_nest + + new_domains: List[isl.BasicSet] = [] + + seen_loop_nests: Set[FrozenSet[str]] = set() + + for insn in kernel.instructions: + assert _is_a_perfect_loop_nest(kernel, insn.within_inames) + loop_nest = insn.within_inames | insn.reduction_inames() + + if loop_nest in seen_loop_nests: + continue + + domain = kernel.get_inames_domain(loop_nest) + new_domains.append(domain.project_out_except(sorted(loop_nest), + [isl.dim_type.set])) + seen_loop_nests.add(loop_nest) + + return kernel.copy(domains=new_domains) + + +def _apply_feinsum_transformations_to_single_kernel( + t_unit: lp.TranslationUnit, kernel_name: str, feinsum_db: str, + cl_device: "pyopencl.Device", +) -> lp.TranslationUnit: + import feinsum as fnsm + + from arraycontext.impl.pytato.split_actx.utils import ( + InsnIds, _get_call_kernel_insn_ids, _LoopNest, + _split_loop_nest_across_work_items, get_iname_length) + call_kernel_insn_ids = _get_call_kernel_insn_ids(t_unit[kernel_name]) + iname_to_length = {iname: get_iname_length(t_unit[kernel_name], iname) + for iname in t_unit[kernel_name].all_inames()} + + for insn_ids in call_kernel_insn_ids: + within_inames, = {t_unit[kernel_name].id_to_insn[insn_id].within_inames + for insn_id in insn_ids} + redn_inames, = {t_unit[kernel_name].id_to_insn[insn_id].reduction_inames() + for insn_id in insn_ids} + if redn_inames: + einsum, _ = fnsm.get_a_matched_einsum(t_unit, + insn_match=InsnIds(insn_ids), + kernel_name=kernel_name, + long_dim_length=128, + ) + available_facts = fnsm.query(einsum, + fnsm.make_fake_cl_context([cl_device.name]), + database=feinsum_db) + if available_facts: + best_query = max( + available_facts, + key=lambda q: sum(q.giga_op_info.values())/q.runtime_in_sec) + # type-ignore reason: mypy is right here, the callable returned + # by feinsum is imprecisely typed. + t_unit = best_query.transform( + t_unit, + insn_match=InsnIds(insn_ids), # type: ignore[call-arg] + kernel_name=kernel_name) + else: + from warnings import warn + warn(f"The database at '{feinsum_db}' has no tuned instances" + f" for {einsum}") + t_unit = t_unit.with_kernel( + _split_loop_nest_across_work_items(t_unit[kernel_name], + _LoopNest( + within_inames, + insn_ids), + iname_to_length, + cl_device)) + else: + # TODO: read the grid/block size from the database. + t_unit = t_unit.with_kernel( + _split_loop_nest_across_work_items(t_unit[kernel_name], + _LoopNest( + within_inames, + insn_ids), + iname_to_length, + cl_device)) + + return t_unit + + +def apply_feinsum_transformations(t_unit: lp.TranslationUnit, + feinsum_db: str, + cl_device: "pyopencl.Device" + ) -> lp.TranslationUnit: + from loopy.kernel.function_interface import CallableKernel + kernel_names = {name + for name, clbl in t_unit.callables_table.items() + if isinstance(clbl, CallableKernel)} + for kernel_name in kernel_names: + t_unit = _apply_feinsum_transformations_to_single_kernel( + t_unit, kernel_name, feinsum_db, cl_device) + return t_unit + +# vim: fdm=marker