Skip to content

Implements a Loop Fusion Transformation #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,616 changes: 168 additions & 4,448 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions doc/ref_transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,10 @@ TODO: Matching instruction tags

.. automodule:: loopy.match


Fusing Loops
------------

.. automodule:: loopy.transform.loop_fusion

.. vim: tw=75:spell
6 changes: 6 additions & 0 deletions loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@
simplify_indices,
tag_instructions,
)
from loopy.transform.loop_fusion import (
get_kennedy_unweighted_fusion_candidates,
rename_inames_in_batch,
)
from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call
from loopy.transform.padding import (
add_padding,
Expand Down Expand Up @@ -336,6 +340,7 @@
"get_dot_dependency_graph",
"get_global_barrier_order",
"get_iname_duplication_options",
"get_kennedy_unweighted_fusion_candidates",
"get_mem_access_map",
"get_one_linearized_kernel",
"get_one_scheduled_kernel",
Expand Down Expand Up @@ -382,6 +387,7 @@
"rename_callable",
"rename_iname",
"rename_inames",
"rename_inames_in_batch",
"replace_instruction_ids",
"save_and_reload_temporaries",
"set_argument_order",
Expand Down
12 changes: 6 additions & 6 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
check_each_kernel,
)
from loopy.type_inference import TypeReader
from loopy.typing import auto, not_none
from loopy.typing import auto, not_none, set_union


if TYPE_CHECKING:
Expand Down Expand Up @@ -1107,10 +1107,10 @@ def _check_variable_access_ordered_inner(kernel: LoopKernel) -> None:
address_space = _get_address_space(kernel, var)
eq_class = aliasing_equiv_classes[var]

readers = set.union(
*[rmap.get(eq_name, set()) for eq_name in eq_class])
writers = set.union(
*[wmap.get(eq_name, set()) for eq_name in eq_class])
readers = set_union(
rmap.get(eq_name, set()) for eq_name in eq_class)
writers = set_union(
wmap.get(eq_name, set()) for eq_name in eq_class)

for writer in writers:
required_deps = (readers | writers) - {writer}
Expand Down Expand Up @@ -1676,7 +1676,7 @@ def _get_sub_array_ref_swept_range(
return get_access_map(
domain.to_set(),
sar.swept_inames,
kernel.assumptions.to_set()).range()
kernel.assumptions).range()


def _are_sub_array_refs_equivalent(
Expand Down
10 changes: 5 additions & 5 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)
from loopy.tools import update_persistent_hash
from loopy.types import LoopyType, NumpyType
from loopy.typing import PreambleGenerator, SymbolMangler, fset_union, not_none
from loopy.typing import InsnId, PreambleGenerator, SymbolMangler, fset_union, not_none


if TYPE_CHECKING:
Expand Down Expand Up @@ -612,8 +612,8 @@ def insn_inames(self, insn: str | InstructionBase) -> frozenset[InameStr]:
return insn.within_inames

@memoize_method
def iname_to_insns(self):
result = {
def iname_to_insns(self) -> Mapping[InameStr, Set[InsnId]]:
result: dict[InameStr, set[InsnId]] = {
iname: set() for iname in self.all_inames()}
for insn in self.instructions:
for iname in insn.within_inames:
Expand Down Expand Up @@ -692,7 +692,7 @@ def compute_deps(insn_id):
# {{{ read and written variables

@memoize_method
def reader_map(self):
def reader_map(self) -> Mapping[str, Set[InsnId]]:
"""
:return: a dict that maps variable names to ids of insns that read that
variable.
Expand All @@ -710,7 +710,7 @@ def reader_map(self):
return result

@memoize_method
def writer_map(self):
def writer_map(self) -> Mapping[str, Set[InsnId]]:
"""
:return: a dict that maps variable names to ids of insns that write
to that variable.
Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
floord mod ceil floor""".split())


def _gather_isl_identifiers(s):
def _gather_isl_identifiers(s: str):
return set(_IDENTIFIER_RE.findall(s)) - _ISL_KEYWORDS


Expand Down Expand Up @@ -2461,7 +2461,7 @@ def make_function(
# does something.
knl = add_inferred_inames(knl)
from loopy.transform.parameter import fix_parameters
knl = fix_parameters(knl, **fixed_parameters)
knl = fix_parameters(knl, within=None, **fixed_parameters)

# -------------------------------------------------------------------------
# Ordering dependency:
Expand Down
7 changes: 4 additions & 3 deletions loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,15 +924,16 @@ def with_descrs(self,
for arg in subkernel.args:
kw = arg.name
if isinstance(arg, ArrayBase):
new_arg_id_to_descr[kw] = (
new_arg_descriptor = (
ArrayArgDescriptor(shape=arg.shape,
dim_tags=arg.dim_tags,
address_space=arg.address_space))
else:
assert isinstance(arg, ValueArg)
new_arg_id_to_descr[kw] = ValueArgDescriptor()
new_arg_descriptor = ValueArgDescriptor()

new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_id_to_descr[kw]
# FIXME: Should decide what the canonical arg identifiers are
new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_descriptor

# }}}

Expand Down
119 changes: 98 additions & 21 deletions loopy/kernel/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,30 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import dataclasses
import itertools
import logging
import sys
from collections.abc import Set
from functools import reduce
from sys import intern
from typing import (
TYPE_CHECKING,
Concatenate,
Generic,
ParamSpec,
TypeVar,
cast,
)

import numpy as np
from typing_extensions import deprecated
from typing_extensions import deprecated, override

import islpy as isl
import pymbolic.primitives as p
from islpy import dim_type
from pymbolic import Expression
from pytools import memoize_on_first_arg, natsorted

from loopy.diagnostic import LoopyError, warn_with_kernel
Expand All @@ -59,13 +65,13 @@
TUnitOrKernelT,
for_each_kernel,
)
from loopy.typing import fset_union, set_union


if TYPE_CHECKING:
from collections.abc import Callable, Collection, Mapping, Sequence, Set
from collections.abc import Callable, Collection, Iterable, Mapping, Sequence

import pymbolic.primitives as p
from pymbolic import ArithmeticExpression, Expression
from pymbolic import ArithmeticExpression
from pytools.tag import Tag

from loopy.types import ToLoopyTypeConvertible
Expand All @@ -75,6 +81,9 @@
logger = logging.getLogger(__name__)


T = TypeVar("T")


# {{{ add and infer argument dtypes

def add_dtypes(
Expand Down Expand Up @@ -719,7 +728,7 @@ def show_dependency_graph(*args, **kwargs):
def is_domain_dependent_on_inames(kernel: LoopKernel,
domain_index: int, inames: Set[str]) -> bool:
dom = kernel.domains[domain_index]
dom_parameters = set(dom.get_var_names(dim_type.param))
dom_parameters = set(dom.get_var_names_not_none(dim_type.param))

# {{{ check for parenthood by loop bound iname

Expand Down Expand Up @@ -1952,7 +1961,7 @@ def get_subkernel_extra_inames(kernel: LoopKernel) -> Mapping[str, frozenset[str

# {{{ find aliasing equivalence classes

class DisjointSets:
class DisjointSets(Generic[T]):
"""
.. automethod:: __getitem__
.. automethod:: find_leader_or_create_group
Expand All @@ -1963,10 +1972,10 @@ class DisjointSets:
# https://en.wikipedia.org/wiki/Disjoint-set_data_structure

def __init__(self):
self.leader_to_group = {}
self.element_to_leader = {}
self.leader_to_group: dict[T, set[T]] = {}
self.element_to_leader: dict[T, T] = {}

def __getitem__(self, item):
def __getitem__(self, item: T):
"""
:arg item: A representative of an equivalence class.
:returns: the equivalence class, given as a set of elements
Expand All @@ -1978,7 +1987,7 @@ def __getitem__(self, item):
else:
return self.leader_to_group[leader]

def find_leader_or_create_group(self, el):
def find_leader_or_create_group(self, el: T):
try:
return self.element_to_leader[el]
except KeyError:
Expand All @@ -1988,7 +1997,7 @@ def find_leader_or_create_group(self, el):
self.leader_to_group[el] = {el}
return el

def union(self, a, b):
def union(self, a: T, b: T):
leader_a = self.find_leader_or_create_group(a)
leader_b = self.find_leader_or_create_group(b)

Expand All @@ -2003,7 +2012,7 @@ def union(self, a, b):
self.leader_to_group[leader_a].update(self.leader_to_group[leader_b])
del self.leader_to_group[leader_b]

def union_many(self, relation):
def union_many(self, relation: Iterable[tuple[T, T]]):
"""
:arg relation: an iterable of 2-tuples enumerating the elements of the
relation. The relation is assumed to be an equivalence relation
Expand All @@ -2021,8 +2030,8 @@ def union_many(self, relation):
return self


def find_aliasing_equivalence_classes(kernel):
return DisjointSets().union_many(
def find_aliasing_equivalence_classes(kernel: LoopKernel):
return DisjointSets[str]().union_many(
(tv.base_storage, tv.name)
for tv in kernel.temporary_variables.values()
if tv.base_storage is not None)
Expand All @@ -2032,7 +2041,7 @@ def find_aliasing_equivalence_classes(kernel):

# {{{ direction helper tools

def infer_args_are_input_output(kernel):
def infer_args_are_input_output(kernel: LoopKernel):
"""
Returns a copy of *kernel* with the attributes ``is_input`` and
``is_output`` of the arguments set.
Expand Down Expand Up @@ -2088,22 +2097,22 @@ def infer_args_are_input_output(kernel):

# {{{ CallablesIDCollector

class CallablesIDCollector(CombineMapper):
class CallablesIDCollector(CombineMapper[frozenset[CallableId], []]):
"""
Mapper to collect function identifiers of all resolved callables in an
expression.
"""
def combine(self, values):
import operator
return reduce(operator.or_, values, frozenset())
@override
def combine(self, values: Iterable[frozenset[CallableId]]):
return fset_union(values)

def map_resolved_function(self, expr):
return frozenset([expr.name])

def map_constant(self, expr):
def map_constant(self, expr: object):
return frozenset()

def map_kernel(self, kernel):
def map_kernel(self, kernel: LoopKernel) -> frozenset[CallableId]:
callables_in_insn = frozenset()

for insn in kernel.instructions:
Expand Down Expand Up @@ -2224,4 +2233,72 @@ def get_hw_axis_base_for_codegen(kernel: LoopKernel, iname: str) -> isl.Aff:
constants_only=False)
return lower_bound


# {{{ get access map from an instruction

@dataclasses.dataclass
class _IndexCollector(CombineMapper[Set[tuple[Expression, ...]], []]):
var: str

def __post_init__(self) -> None:
super().__init__()

@override
def combine(self,
values: Iterable[Set[tuple[Expression, ...]]]
) -> Set[tuple[Expression, ...]]:
return set_union(values)

@override
def map_subscript(self, expr: p.Subscript) -> Set[tuple[Expression, ...]]:
assert isinstance(expr.aggregate, p.Variable)
if expr.aggregate.name == self.var:
return (super().map_subscript(expr) | frozenset([expr.index_tuple]))
else:
return super().map_subscript(expr)

@override
def map_algebraic_leaf(
self, expr: p.AlgebraicLeaf,
) -> frozenset[tuple[Expression, ...]]:
return frozenset()

@override
def map_constant(
self, expr: object
) -> frozenset[tuple[Expression, ...]]:
return frozenset()


def _union_amaps(amaps: Sequence[isl.Map]):
import islpy as isl
return reduce(isl.Map.union, amaps[1:], amaps[0])


def get_insn_access_map(kernel: LoopKernel, insn_id: str, var: str):
from loopy.match import Id
from loopy.symbolic import get_access_map
from loopy.transform.subst import expand_subst

insn = kernel.id_to_insn[insn_id]

kernel = expand_subst(kernel, within=Id(insn_id))
indices = tuple(
_IndexCollector(var)(
(insn.expression, insn.assignees, tuple(insn.predicates))
)
)

amaps = [
get_access_map(
kernel.get_inames_domain(insn.within_inames).to_set(),
idx, kernel.assumptions
)
for idx in indices
]

return _union_amaps(amaps)

# }}}

# vim: foldmethod=marker
Loading
Loading