Skip to content

Commit

Permalink
Type-annotate loopy.match (at least partially)
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Sep 8, 2023
1 parent 67506c8 commit 5d68dc8
Showing 1 changed file with 81 additions and 84 deletions.
165 changes: 81 additions & 84 deletions loopy/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,25 @@
THE SOFTWARE.
"""

from abc import abstractmethod, ABC
from dataclasses import dataclass
from typing import AbstractSet, FrozenSet, List, Sequence, Tuple, Union, Protocol
from sys import intern

from loopy.kernel import LoopKernel
from loopy.kernel.instruction import InstructionBase


NoneType = type(None)

from pytools.lex import RE
import pytools.tag

__doc__ = """
.. autoclass:: Matchable
.. autoclass:: StackMatchComponent
.. autoclass:: StackMatch
.. autofunction:: parse_match
.. autofunction:: parse_stack_match
Expand Down Expand Up @@ -117,8 +127,18 @@ def re_from_glob(s):

# {{{ match expression

class MatchExpressionBase:
def __call__(self, kernel, matchable):
class Matchable(Protocol):
"""
.. attribute:: tags
"""
@property
def tags(self) -> FrozenSet[pytools.tag.Tag]:
...


class MatchExpressionBase(ABC):
@abstractmethod
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
raise NotImplementedError

def __ne__(self, other):
Expand All @@ -135,7 +155,7 @@ def __inv__(self):


class All(MatchExpressionBase):
def __call__(self, kernel, matchable):
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
return True

def __str__(self):
Expand All @@ -154,9 +174,9 @@ def __hash__(self):
return hash(type(self))


@dataclass(frozen=True, eq=True)
class MultiChildMatchExpressionBase(MatchExpressionBase):
def __init__(self, children):
self.children = children
children: Sequence[MatchExpressionBase]

def __str__(self):
joiner = " %s " % type(self).__name__.lower()
Expand All @@ -167,33 +187,22 @@ def __repr__(self):
type(self).__name__,
", ".join(repr(ch) for ch in self.children))

def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, type(self).__name__)
key_builder.rec(key_hash, self.children)

def __eq__(self, other):
return (type(self) is type(other)
and self.children == other.children)

def __hash__(self):
return hash((type(self), self.children))


class And(MultiChildMatchExpressionBase):
def __call__(self, kernel, matchable):
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
return all(ch(kernel, matchable) for ch in self.children)


class Or(MultiChildMatchExpressionBase):
def __call__(self, kernel, matchable):
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
return any(ch(kernel, matchable) for ch in self.children)


@dataclass(frozen=True, eq=True)
class Not(MatchExpressionBase):
def __init__(self, child):
self.child = child
child: MatchExpressionBase

def __call__(self, kernel, matchable):
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
return not self.child(kernel, matchable)

def __str__(self):
Expand All @@ -202,18 +211,8 @@ def __str__(self):
def __repr__(self):
return "{}({!r})".format(type(self).__name__, self.child)

def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "not_match_expr")
key_builder.rec(key_hash, self.child)

def __eq__(self, other):
return (type(self) is type(other)
and self.child == other.child)

def __hash__(self):
return hash((type(self), self.child))


@dataclass(frozen=True, eq=True)
class ObjTagged(MatchExpressionBase):
"""Match if the object is tagged with a given :class:`~pytools.tag.Tag`.
Expand All @@ -222,19 +221,14 @@ class ObjTagged(MatchExpressionBase):
These instance-based tags will, in the not-too-distant future, replace
the string-based tags matched by :class:`Tagged`.
"""
def __init__(self, tag: pytools.tag.Tag):
self.tag = tag
tag: pytools.tag.Tag

def __call__(self, kernel, matchable):
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
return self.tag in matchable.tags

def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, type(self).__name__)
key_builder.rec(key_hash, self.tag)


class GlobMatchExpressionBase(MatchExpressionBase):
def __init__(self, glob):
def __init__(self, glob: str) -> None:
self.glob = glob

import re
Expand Down Expand Up @@ -273,7 +267,8 @@ class Tagged(GlobMatchExpressionBase):
These string-based tags will, in the not-too-distant future, be replace
by instance-based tags matched by :class:`ObjTagged`.
"""
def __call__(self, kernel, matchable):

def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
from loopy.kernel.instruction import LegacyStringInstructionTag
if matchable.tags:
return any(
Expand All @@ -289,13 +284,17 @@ def __call__(self, kernel, matchable):


class Writes(GlobMatchExpressionBase):
def __call__(self, kernel, matchable):
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
if not isinstance(matchable, InstructionBase):
return False
return any(self.re.match(name)
for name in matchable.assignee_var_names())


class Reads(GlobMatchExpressionBase):
def __call__(self, kernel, matchable):
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
if not isinstance(matchable, InstructionBase):
return False
return any(self.re.match(name)
for name in matchable.read_dependency_names())

Expand All @@ -306,7 +305,10 @@ def __call__(self, kernel, matchable):


class Iname(GlobMatchExpressionBase):
def __call__(self, kernel, matchable):
def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool:
if not isinstance(matchable, InstructionBase):
return False

return any(self.re.match(name)
for name in matchable.within_inames)

Expand Down Expand Up @@ -421,13 +423,21 @@ def inner_parse(pstate, min_precedence=0):

# {{{ stack match objects

class StackMatchComponent:
class StackMatchComponent(ABC):
"""
.. automethod:: __call__
"""

@abstractmethod
def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool:
pass

def __ne__(self, other):
return not self.__eq__(other)


class StackAllMatchComponent(StackMatchComponent):
def __call__(self, kernel, stack):
def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool:
return True

def update_persistent_hash(self, key_hash, key_builder):
Expand All @@ -438,7 +448,7 @@ def __eq__(self, other):


class StackBottomMatchComponent(StackMatchComponent):
def __call__(self, kernel, stack):
def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool:
return not stack

def update_persistent_hash(self, key_hash, key_builder):
Expand All @@ -448,12 +458,12 @@ def __eq__(self, other):
return type(self) is type(other)


@dataclass(eq=True, frozen=True)
class StackItemMatchComponent(StackMatchComponent):
def __init__(self, match_expr, inner_match):
self.match_expr = match_expr
self.inner_match = inner_match
match_expr: MatchExpressionBase
inner_match: StackMatchComponent

def __call__(self, kernel, stack):
def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool:
if not stack:
return False

Expand All @@ -463,22 +473,12 @@ def __call__(self, kernel, stack):

return self.inner_match(kernel, stack[1:])

def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "item_match")
key_builder.rec(key_hash, self.match_expr)
key_builder.rec(key_hash, self.inner_match)

def __eq__(self, other):
return (type(self) is type(other)
and self.match_expr == other.match_expr
and self.inner_match == other.inner_match)


@dataclass(eq=True, frozen=True)
class StackWildcardMatchComponent(StackMatchComponent):
def __init__(self, inner_match):
self.inner_match = inner_match
inner_match: StackMatchComponent

def __call__(self, kernel, stack):
def __call__(self, kernel: LoopKernel, stack: Sequence[Matchable]) -> bool:
for i in range(0, len(stack)):
if self.inner_match(kernel, stack[i:]):
return True
Expand All @@ -490,10 +490,10 @@ def __call__(self, kernel, stack):

# {{{ stack matcher

@dataclass(eq=True, frozen=True)
class RuleInvocationMatchable:
def __init__(self, id, tags):
self.id = id
self.tags = tags
id: str
tags: FrozenSet[pytools.tag.Tag]

def write_dependency_names(self):
raise TypeError("writes: query may not be applied to rule invocations")
Expand All @@ -505,27 +505,21 @@ def inames(self, kernel):
raise TypeError("inames: query may not be applied to rule invocations")


@dataclass(eq=True, frozen=True)
class StackMatch:
def __init__(self, root_component):
self.root_component = root_component

def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, self.root_component)

def __eq__(self, other):
return (
type(self) is type(other)
and
self.root_component == other.root_component)
"""
.. autofunction:: __call__
"""

def __ne__(self, other):
return not self.__eq__(other)
root_component: StackMatchComponent

def __call__(self, kernel, insn, rule_stack):
def __call__(
self, kernel: LoopKernel, insn: InstructionBase,
rule_stack: Sequence[Tuple[str, FrozenSet[pytools.tag.Tag]]]) -> bool:
"""
:arg rule_stack: a tuple of (name, tags) rule invocation, outermost first
"""
stack_of_matchables = [insn]
stack_of_matchables: List[Matchable] = [insn]
for id, tags in rule_stack:
stack_of_matchables.append(RuleInvocationMatchable(id, tags))

Expand All @@ -536,7 +530,10 @@ def __call__(self, kernel, insn, rule_stack):

# {{{ stack match parsing

def parse_stack_match(smatch):
ToStackMatchCovertible = Union[StackMatch, str, None]


def parse_stack_match(smatch: ToStackMatchCovertible) -> StackMatch:
"""Syntax example::
... > outer > ... > next > innermost $
Expand All @@ -561,7 +558,7 @@ def parse_stack_match(smatch):

smatch = smatch.strip()

match = StackAllMatchComponent()
match: StackMatchComponent = StackAllMatchComponent()
if smatch[-1] == "$":
match = StackBottomMatchComponent()
smatch = smatch[:-1]
Expand Down

0 comments on commit 5d68dc8

Please sign in to comment.