diff --git a/loopy/match.py b/loopy/match.py index 63879c911..ac8d3315b 100644 --- a/loopy/match.py +++ b/loopy/match.py @@ -24,8 +24,14 @@ THE SOFTWARE. """ +from abc import abstractmethod, ABC +from dataclasses import dataclass +from typing import AbstractSet, FrozenSet, List, Sequence, Tuple, Union, Protocol from sys import intern +from loopy.kernel import LoopKernel +from loopy.kernel.instruction import InstructionBase + NoneType = type(None) @@ -33,6 +39,10 @@ import pytools.tag __doc__ = """ +.. autoclass:: Matchable +.. autoclass:: StackMatchComponent +.. autoclass:: StackMatch + .. autofunction:: parse_match .. autofunction:: parse_stack_match @@ -117,8 +127,18 @@ def re_from_glob(s): # {{{ match expression -class MatchExpressionBase: - def __call__(self, kernel, matchable): +class Matchable(Protocol): + """ + .. attribute:: tags + """ + @property + def tags(self) -> FrozenSet[pytools.tag.Tag]: + ... + + +class MatchExpressionBase(ABC): + @abstractmethod + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: raise NotImplementedError def __ne__(self, other): @@ -135,7 +155,7 @@ def __inv__(self): class All(MatchExpressionBase): - def __call__(self, kernel, matchable): + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: return True def __str__(self): @@ -154,9 +174,9 @@ def __hash__(self): return hash(type(self)) +@dataclass(frozen=True, eq=True) class MultiChildMatchExpressionBase(MatchExpressionBase): - def __init__(self, children): - self.children = children + children: Sequence[MatchExpressionBase] def __str__(self): joiner = " %s " % type(self).__name__.lower() @@ -167,33 +187,22 @@ def __repr__(self): type(self).__name__, ", ".join(repr(ch) for ch in self.children)) - def update_persistent_hash(self, key_hash, key_builder): - key_builder.rec(key_hash, type(self).__name__) - key_builder.rec(key_hash, self.children) - - def __eq__(self, other): - return (type(self) is type(other) - and self.children == other.children) - - def __hash__(self): - return hash((type(self), self.children)) - class And(MultiChildMatchExpressionBase): - def __call__(self, kernel, matchable): + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: return all(ch(kernel, matchable) for ch in self.children) class Or(MultiChildMatchExpressionBase): - def __call__(self, kernel, matchable): + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: return any(ch(kernel, matchable) for ch in self.children) +@dataclass(frozen=True, eq=True) class Not(MatchExpressionBase): - def __init__(self, child): - self.child = child + child: MatchExpressionBase - def __call__(self, kernel, matchable): + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: return not self.child(kernel, matchable) def __str__(self): @@ -202,18 +211,8 @@ def __str__(self): def __repr__(self): return "{}({!r})".format(type(self).__name__, self.child) - def update_persistent_hash(self, key_hash, key_builder): - key_builder.rec(key_hash, "not_match_expr") - key_builder.rec(key_hash, self.child) - - def __eq__(self, other): - return (type(self) is type(other) - and self.child == other.child) - - def __hash__(self): - return hash((type(self), self.child)) - +@dataclass(frozen=True, eq=True) class ObjTagged(MatchExpressionBase): """Match if the object is tagged with a given :class:`~pytools.tag.Tag`. @@ -222,19 +221,14 @@ class ObjTagged(MatchExpressionBase): These instance-based tags will, in the not-too-distant future, replace the string-based tags matched by :class:`Tagged`. """ - def __init__(self, tag: pytools.tag.Tag): - self.tag = tag + tag: pytools.tag.Tag - def __call__(self, kernel, matchable): + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: return self.tag in matchable.tags - def update_persistent_hash(self, key_hash, key_builder): - key_builder.rec(key_hash, type(self).__name__) - key_builder.rec(key_hash, self.tag) - class GlobMatchExpressionBase(MatchExpressionBase): - def __init__(self, glob): + def __init__(self, glob: str) -> None: self.glob = glob import re @@ -273,7 +267,8 @@ class Tagged(GlobMatchExpressionBase): These string-based tags will, in the not-too-distant future, be replace by instance-based tags matched by :class:`ObjTagged`. """ - def __call__(self, kernel, matchable): + + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: from loopy.kernel.instruction import LegacyStringInstructionTag if matchable.tags: return any( @@ -289,13 +284,17 @@ def __call__(self, kernel, matchable): class Writes(GlobMatchExpressionBase): - def __call__(self, kernel, matchable): + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: + if not isinstance(matchable, InstructionBase): + return False return any(self.re.match(name) for name in matchable.assignee_var_names()) class Reads(GlobMatchExpressionBase): - def __call__(self, kernel, matchable): + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: + if not isinstance(matchable, InstructionBase): + return False return any(self.re.match(name) for name in matchable.read_dependency_names()) @@ -306,7 +305,10 @@ def __call__(self, kernel, matchable): class Iname(GlobMatchExpressionBase): - def __call__(self, kernel, matchable): + def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: + if not isinstance(matchable, InstructionBase): + return False + return any(self.re.match(name) for name in matchable.within_inames) @@ -421,13 +423,21 @@ def inner_parse(pstate, min_precedence=0): # {{{ stack match objects -class StackMatchComponent: +class StackMatchComponent(ABC): + """ + .. automethod:: __call__ + """ + + @abstractmethod + def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool: + pass + def __ne__(self, other): return not self.__eq__(other) class StackAllMatchComponent(StackMatchComponent): - def __call__(self, kernel, stack): + def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool: return True def update_persistent_hash(self, key_hash, key_builder): @@ -438,7 +448,7 @@ def __eq__(self, other): class StackBottomMatchComponent(StackMatchComponent): - def __call__(self, kernel, stack): + def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool: return not stack def update_persistent_hash(self, key_hash, key_builder): @@ -448,12 +458,12 @@ def __eq__(self, other): return type(self) is type(other) +@dataclass(eq=True, frozen=True) class StackItemMatchComponent(StackMatchComponent): - def __init__(self, match_expr, inner_match): - self.match_expr = match_expr - self.inner_match = inner_match + match_expr: MatchExpressionBase + inner_match: StackMatchComponent - def __call__(self, kernel, stack): + def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool: if not stack: return False @@ -463,22 +473,12 @@ def __call__(self, kernel, stack): return self.inner_match(kernel, stack[1:]) - def update_persistent_hash(self, key_hash, key_builder): - key_builder.rec(key_hash, "item_match") - key_builder.rec(key_hash, self.match_expr) - key_builder.rec(key_hash, self.inner_match) - - def __eq__(self, other): - return (type(self) is type(other) - and self.match_expr == other.match_expr - and self.inner_match == other.inner_match) - +@dataclass(eq=True, frozen=True) class StackWildcardMatchComponent(StackMatchComponent): - def __init__(self, inner_match): - self.inner_match = inner_match + inner_match: StackMatchComponent - def __call__(self, kernel, stack): + def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool: for i in range(0, len(stack)): if self.inner_match(kernel, stack[i:]): return True @@ -490,10 +490,10 @@ def __call__(self, kernel, stack): # {{{ stack matcher +@dataclass(eq=True, frozen=True) class RuleInvocationMatchable: - def __init__(self, id, tags): - self.id = id - self.tags = tags + id: str + tags: FrozenSet[pytools.tag.Tag] def write_dependency_names(self): raise TypeError("writes: query may not be applied to rule invocations") @@ -505,27 +505,21 @@ def inames(self, kernel): raise TypeError("inames: query may not be applied to rule invocations") +@dataclass(eq=True, frozen=True) class StackMatch: - def __init__(self, root_component): - self.root_component = root_component - - def update_persistent_hash(self, key_hash, key_builder): - key_builder.rec(key_hash, self.root_component) - - def __eq__(self, other): - return ( - type(self) is type(other) - and - self.root_component == other.root_component) + """ + .. autofunction:: __call__ + """ - def __ne__(self, other): - return not self.__eq__(other) + root_component: StackMatchComponent - def __call__(self, kernel, insn, rule_stack): + def __call__( + self, kernel: LoopKernel, insn: InstructionBase, + rule_stack: Sequence[Tuple[str, FrozenSet[pytools.tag.Tag]]]) -> bool: """ :arg rule_stack: a tuple of (name, tags) rule invocation, outermost first """ - stack_of_matchables = [insn] + stack_of_matchables: List[Matchable] = [insn] for id, tags in rule_stack: stack_of_matchables.append(RuleInvocationMatchable(id, tags)) @@ -536,7 +530,10 @@ def __call__(self, kernel, insn, rule_stack): # {{{ stack match parsing -def parse_stack_match(smatch): +ToStackMatchCovertible = Union[StackMatch, str, None] + + +def parse_stack_match(smatch: ToStackMatchCovertible) -> StackMatch: """Syntax example:: ... > outer > ... > next > innermost $ @@ -561,7 +558,7 @@ def parse_stack_match(smatch): smatch = smatch.strip() - match = StackAllMatchComponent() + match: StackMatchComponent = StackAllMatchComponent() if smatch[-1] == "$": match = StackBottomMatchComponent() smatch = smatch[:-1]