From 46bda513df4fc8c8eb64d2d8ad3e33f2a12a4e97 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 17 Jun 2025 17:21:19 -0500 Subject: [PATCH 01/12] Remove conceptually unsound {set,fset}_intersection --- loopy/typing.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/loopy/typing.py b/loopy/typing.py index 88ec7692e..447794da5 100644 --- a/loopy/typing.py +++ b/loopy/typing.py @@ -132,11 +132,3 @@ def set_union(iterable: Iterable[Iterable[T]]): def fset_union(iterable: Iterable[Iterable[T]]): return cast("frozenset[T]", frozenset()).union(*iterable) - - -def set_intersection(iterable: Iterable[Iterable[T]]): - return cast("set[T]", set()).intersection(*iterable) - - -def fset_intersection(iterable: Iterable[Iterable[T]]): - return cast("frozenset[T]", frozenset()).intersection(*iterable) From e9ae8006bfc3565181ae73dcdc183a171845d372 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 17 Jun 2025 21:19:33 -0500 Subject: [PATCH 02/12] loopy.match: Distinguish StackMatch and ConcreteStackMatch --- loopy/kernel/creation.py | 4 ++-- loopy/match.py | 42 +++++++++++++++++++++++++++--------- loopy/symbolic.py | 38 +++++++++++++++++++++++--------- loopy/transform/iname.py | 9 +++++--- loopy/transform/parameter.py | 25 ++++++++++++++++----- loopy/transform/subst.py | 8 +++++-- 6 files changed, 94 insertions(+), 32 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 7dab5a401..f62ee4b30 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -91,7 +91,7 @@ floord mod ceil floor""".split()) -def _gather_isl_identifiers(s): +def _gather_isl_identifiers(s: str): return set(_IDENTIFIER_RE.findall(s)) - _ISL_KEYWORDS @@ -2461,7 +2461,7 @@ def make_function( # does something. knl = add_inferred_inames(knl) from loopy.transform.parameter import fix_parameters - knl = fix_parameters(knl, **fixed_parameters) + knl = fix_parameters(knl, within=None, **fixed_parameters) # ------------------------------------------------------------------------- # Ordering dependency: diff --git a/loopy/match.py b/loopy/match.py index 8cddb7f4c..2bd1ead1d 100644 --- a/loopy/match.py +++ b/loopy/match.py @@ -1,7 +1,15 @@ """ .. autoclass:: Matchable +.. autoclass:: ConcreteMatchable +.. autodata:: RuleStack + :noindex: + +.. class:: RuleStack + + See above. .. autoclass:: StackMatchComponent .. autoclass:: StackMatch +.. autoclass:: ConcreteStackMatch .. autofunction:: parse_match @@ -68,7 +76,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from sys import intern -from typing import TYPE_CHECKING, Protocol, TypeAlias, cast +from typing import TYPE_CHECKING, NamedTuple, Protocol, TypeAlias, cast from typing_extensions import override @@ -81,7 +89,7 @@ if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence import pytools.tag @@ -149,6 +157,11 @@ def re_from_glob(s: str) -> re.Pattern[str]: # {{{ match expression +class ConcreteMatchable(NamedTuple): + id: str + tags: frozenset[pytools.tag.Tag] + + class Matchable(Protocol): """ .. attribute:: tags @@ -162,6 +175,13 @@ def tags(self) -> frozenset[pytools.tag.Tag]: ... +RuleStack: TypeAlias = "Sequence[ConcreteMatchable]" +StackMatch: TypeAlias = """Callable[ + [LoopKernel, InstructionBase, RuleStack], + bool + ]""" + + class MatchExpressionBase(ABC): @abstractmethod def __call__(self, kernel: LoopKernel, matchable: Matchable) -> bool: @@ -561,7 +581,7 @@ def inames(self, kernel): @dataclass(eq=True, frozen=True) -class StackMatch: +class ConcreteStackMatch: """ .. automethod:: __call__ """ @@ -570,7 +590,7 @@ class StackMatch: def __call__( self, kernel: LoopKernel, insn: InstructionBase, - rule_stack: Sequence[tuple[str, frozenset[pytools.tag.Tag]]]) -> bool: + rule_stack: RuleStack) -> bool: """ :arg rule_stack: a tuple of (name, tags) rule invocation, outermost first """ @@ -585,10 +605,12 @@ def __call__( # {{{ stack match parsing -ToStackMatchConvertible: TypeAlias = MatchExpressionBase | StackMatch | str | None +ToStackMatchConvertible: TypeAlias = ( + MatchExpressionBase | ConcreteStackMatch | str | None + ) -def parse_stack_match(smatch: ToStackMatchConvertible) -> StackMatch: +def parse_stack_match(smatch: ToStackMatchConvertible) -> ConcreteStackMatch: """Syntax example:: ... > outer > ... > next > innermost $ @@ -601,15 +623,15 @@ def parse_stack_match(smatch: ToStackMatchConvertible) -> StackMatch: :func:`parse_match`. """ - if isinstance(smatch, StackMatch): + if isinstance(smatch, ConcreteStackMatch): return smatch if isinstance(smatch, MatchExpressionBase): - return StackMatch( + return ConcreteStackMatch( StackItemMatchComponent( smatch, StackAllMatchComponent())) if smatch is None: - return StackMatch(StackAllMatchComponent()) + return ConcreteStackMatch(StackAllMatchComponent()) smatch = smatch.strip() @@ -629,7 +651,7 @@ def parse_stack_match(smatch: ToStackMatchConvertible) -> StackMatch: else: match = StackItemMatchComponent(parse_match(comp), match) - return StackMatch(match) + return ConcreteStackMatch(match) # }}} diff --git a/loopy/symbolic.py b/loopy/symbolic.py index f370e2a97..539c123d2 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -99,6 +99,7 @@ from loopy.kernel.data import KernelArgument, SubstitutionRule, TemporaryVariable from loopy.kernel.instruction import InstructionBase from loopy.library.reduction import ReductionOperation, ReductionOpFunction + from loopy.match import ConcreteMatchable, RuleStack, StackMatch from loopy.types import LoopyType, NumpyType, ToLoopyTypeConvertible @@ -1147,7 +1148,7 @@ class ExpansionState: """ kernel: LoopKernel instruction: InstructionBase - stack: tuple[tuple[str, Tag], ...] + stack: tuple[ConcreteMatchable, ...] arg_context: Mapping[str, Expression] def __post_init__(self) -> None: @@ -1321,7 +1322,7 @@ def _get_new_substitutions_and_renames(self): return renamed_result, renames - def finish_kernel(self, kernel): + def finish_kernel(self, kernel: LoopKernel): new_substs, renames = self._get_new_substitutions_and_renames() if not renames: return kernel.copy(substitutions=new_substs) @@ -1401,8 +1402,9 @@ def map_subst_rule( rec_arguments = self.rec(arguments, expn_state, *args, **kwargs) assert isinstance(rec_arguments, tuple) + from loopy.match import ConcreteMatchable new_expn_state = expn_state.copy( - stack=(*expn_state.stack, (name, tags)), + stack=(*expn_state.stack, ConcreteMatchable(name, tags)), arg_context=self.make_new_arg_context( name, rule.arguments, rec_arguments, expn_state.arg_context)) @@ -1439,8 +1441,12 @@ def __call__(self, expr, kernel, insn): def map_instruction(self, kernel, insn): return insn - def map_kernel(self, kernel: LoopKernel, within=lambda *args: True, - map_args: bool = True, map_tvs: bool = True) -> LoopKernel: + def map_kernel(self, + kernel: LoopKernel, + within: StackMatch = lambda knl, insn, stack: True, + map_args: bool = True, + map_tvs: bool = True + ) -> LoopKernel: new_insns = [ # While subst rules are not allowed in assignees, the mapper # may perform tasks entirely unrelated to subst rules, so @@ -1509,13 +1515,19 @@ class RuleAwareSubstitutionMapper(RuleAwareIdentityMapper[[]]): :meth:`SubstitutionRuleMappingContext.finish_kernel` to perform any renaming mandated by the rule expression divergences. """ - def __init__(self, rule_mapping_context, subst_func, within): + def __init__(self, + rule_mapping_context: SubstitutionRuleMappingContext, + subst_func, + within: StackMatch): super().__init__(rule_mapping_context) self.subst_func = subst_func - self._within = within + self._within: StackMatch = within - def within(self, kernel, instruction, stack): + def within(self, + kernel: LoopKernel, + instruction: InstructionBase | None, + stack: RuleStack): if instruction is None: # always perform substitutions on expressions not coming from # instructions. @@ -1523,6 +1535,7 @@ def within(self, kernel, instruction, stack): else: return self._within(kernel, instruction, stack) + @override def map_variable(self, expr: Variable, expn_state: ExpansionState) -> Expression: if (expr.name in expn_state.arg_context or not self.within( @@ -1540,16 +1553,21 @@ def map_variable(self, expr: Variable, expn_state: ExpansionState) -> Expression class RuleAwareSubstitutionRuleExpander(RuleAwareIdentityMapper[[]]): - def __init__(self, rule_mapping_context, rules, within): + def __init__(self, + rule_mapping_context: SubstitutionRuleMappingContext, + rules, + within: StackMatch): super().__init__(rule_mapping_context) self.rules = rules self.within = within + @override def map_subst_rule( self, name: str, tags, arguments, expn_state: ExpansionState ) -> Expression: - new_stack = (*expn_state.stack, (name, tags)) + from loopy.match import ConcreteMatchable + new_stack = (*expn_state.stack, ConcreteMatchable(name, tags)) if self.within(expn_state.kernel, expn_state.instruction, new_stack): # expand diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 407b2b6da..933b7ec4e 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -50,7 +50,7 @@ if TYPE_CHECKING: from loopy.kernel.data import GroupInameTag, LocalInameTag, ToInameTagConvertible from loopy.kernel.instruction import InstructionBase - from loopy.match import ToMatchConvertible, ToStackMatchConvertible + from loopy.match import RuleStack, ToMatchConvertible, ToStackMatchConvertible from loopy.typing import InameStr @@ -2542,8 +2542,11 @@ def rename_inames( from loopy.kernel.instruction import MultiAssignmentBase - def does_insn_involve_iname(kernel, insn, *args): - return (not isinstance(insn, MultiAssignmentBase) + def does_insn_involve_iname( + kernel: LoopKernel, + insn: InstructionBase, + stack: RuleStack): + return bool(not isinstance(insn, MultiAssignmentBase) or frozenset(old_inames) & insn.dependency_names() or frozenset(old_inames) & insn.reduction_inames()) diff --git a/loopy/transform/parameter.py b/loopy/transform/parameter.py index a29fb4a3f..9f19e5729 100644 --- a/loopy/transform/parameter.py +++ b/loopy/transform/parameter.py @@ -24,6 +24,8 @@ """ +from typing import TYPE_CHECKING + import islpy as isl from loopy.kernel import LoopKernel @@ -31,6 +33,10 @@ from loopy.translation_unit import for_each_kernel +if TYPE_CHECKING: + from loopy.match import ToStackMatchConvertible + + __doc__ = """ .. currentmodule:: loopy @@ -71,8 +77,12 @@ def assume(kernel: LoopKernel, assumptions: str | isl.Set | isl.BasicSet) -> Loo # {{{ fix_parameter -def _fix_parameter(kernel, name, value, within=None): - def process_set(s): +def _fix_parameter( + kernel: LoopKernel, + name: str, + value: int | float, + within: ToStackMatchConvertible = None): + def process_set(s: isl.BasicSet): var_dict = s.get_var_dict() try: @@ -80,6 +90,9 @@ def process_set(s): except KeyError: return s + if not isinstance(value, int): + raise ValueError(f"parameter '{name}' is used in a set, must be an integer") + value_aff = isl.Aff.zero_on_domain(s.space) + value from loopy.isl_helpers import iname_rel_aff @@ -141,7 +154,11 @@ def map_expr(expr): @for_each_kernel -def fix_parameters(kernel, **value_dict): +def fix_parameters( + kernel: LoopKernel, + *, within: ToStackMatchConvertible = None, + **value_dict: int | float, + ): """Fix the values of the arguments to specific constants. *value_dict* consists of *name*/*value* pairs, where *name* will be fixed @@ -150,8 +167,6 @@ def fix_parameters(kernel, **value_dict): """ assert isinstance(kernel, LoopKernel) - within = value_dict.pop("within", None) - for name, value in value_dict.items(): kernel = _fix_parameter(kernel, name, value, within) diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index 26e1ebbaa..4d196de61 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -37,8 +37,9 @@ if TYPE_CHECKING: + from loopy.kernel import LoopKernel from loopy.kernel.instruction import InstructionBase - from loopy.match import ToMatchConvertible + from loopy.match import RuleStack, ToMatchConvertible logger = logging.getLogger(__name__) @@ -413,7 +414,10 @@ def get_relevant_definition_insn_id(usage_insn_id): lhs_name, assigning_insn_ids, usage_to_definition, extra_arguments, within) - def _accesses_lhs(kernel, insn, *args): + def _accesses_lhs( + kernel: LoopKernel, + insn: InstructionBase, + stack: RuleStack): return lhs_name in insn.read_dependency_names() kernel = rule_mapping_context.finish_kernel( From 78138a8efba1646aff8cb99a7d125c28e56627cd Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 17 Jun 2025 17:20:58 -0500 Subject: [PATCH 03/12] Type reader_map, writer_map, iname_to_insns --- loopy/kernel/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 16a5f3538..fb62a79b5 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -72,7 +72,7 @@ ) from loopy.tools import update_persistent_hash from loopy.types import LoopyType, NumpyType -from loopy.typing import PreambleGenerator, SymbolMangler, fset_union, not_none +from loopy.typing import InsnId, PreambleGenerator, SymbolMangler, fset_union, not_none if TYPE_CHECKING: @@ -612,8 +612,8 @@ def insn_inames(self, insn: str | InstructionBase) -> frozenset[InameStr]: return insn.within_inames @memoize_method - def iname_to_insns(self): - result = { + def iname_to_insns(self) -> Mapping[InameStr, Set[InsnId]]: + result: dict[InameStr, set[InsnId]] = { iname: set() for iname in self.all_inames()} for insn in self.instructions: for iname in insn.within_inames: @@ -692,7 +692,7 @@ def compute_deps(insn_id): # {{{ read and written variables @memoize_method - def reader_map(self): + def reader_map(self) -> Mapping[str, Set[InsnId]]: """ :return: a dict that maps variable names to ids of insns that read that variable. @@ -710,7 +710,7 @@ def reader_map(self): return result @memoize_method - def writer_map(self): + def writer_map(self) -> Mapping[str, Set[InsnId]]: """ :return: a dict that maps variable names to ids of insns that write to that variable. From 0dfe034eaf35b271c0448f6a08338f828d157abe Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 31 Oct 2021 16:54:49 -0500 Subject: [PATCH 04/12] Add a utility to get the instruction access map --- loopy/kernel/tools.py | 79 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 2 deletions(-) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 2342b9c57..c57e11233 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -23,9 +23,13 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + + +import dataclasses import itertools import logging import sys +from collections.abc import Set from functools import reduce from sys import intern from typing import ( @@ -37,10 +41,12 @@ ) import numpy as np -from typing_extensions import deprecated +from typing_extensions import deprecated, override import islpy as isl +import pymbolic.primitives as p from islpy import dim_type +from pymbolic import Expression from pytools import memoize_on_first_arg, natsorted from loopy.diagnostic import LoopyError, warn_with_kernel @@ -62,7 +68,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Collection, Mapping, Sequence, Set + from collections.abc import Callable, Collection, Iterable, Mapping, Sequence, Set import pymbolic.primitives as p from pymbolic import ArithmeticExpression, Expression @@ -2224,4 +2230,73 @@ def get_hw_axis_base_for_codegen(kernel: LoopKernel, iname: str) -> isl.Aff: constants_only=False) return lower_bound + +# {{{ get access map from an instruction + +@dataclasses.dataclass +class _IndexCollector(CombineMapper[Set[tuple[Expression, ...]], []]): + var: str + + def __post_init__(self) -> None: + super().__init__() + + @override + def combine(self, + values: Iterable[Set[tuple[Expression, ...]]] + ) -> Set[tuple[Expression, ...]]: + import operator + from functools import reduce + return reduce(operator.or_, values, set()) + + @override + def map_subscript(self, expr: p.Subscript) -> Set[tuple[Expression, ...]]: + assert isinstance(expr.aggregate, p.Variable) + if expr.aggregate.name == self.var: + return (super().map_subscript(expr) | frozenset([expr.index_tuple])) + else: + return super().map_subscript(expr) + + @override + def map_algebraic_leaf( + self, expr: p.AlgebraicLeaf, + ) -> frozenset[tuple[Expression, ...]]: + return frozenset() + + @override + def map_constant( + self, expr: object + ) -> frozenset[tuple[Expression, ...]]: + return frozenset() + + +def _union_amaps(amaps): + import islpy as isl + return reduce(isl.Map.union, amaps[1:], amaps[0]) + + +def get_insn_access_map(kernel: LoopKernel, insn_id: str, var: str): + from loopy.match import Id + from loopy.symbolic import get_access_map + from loopy.transform.subst import expand_subst + + insn = kernel.id_to_insn[insn_id] + + kernel = expand_subst(kernel, within=Id(insn_id)) + indices = tuple( + _IndexCollector(var)( + (insn.expression, insn.assignees, tuple(insn.predicates)) + ) + ) + + amaps = [ + get_access_map( + kernel.get_inames_domain(insn.within_inames), idx, kernel.assumptions + ) + for idx in indices + ] + + return _union_amaps(amaps) + +# }}} + # vim: foldmethod=marker From 637a8f17c4bf63a1812a5fd79e3749f9569aa7ce Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 17 Jun 2025 17:49:25 -0500 Subject: [PATCH 05/12] Fix a few more misc type issues --- loopy/check.py | 12 ++++++------ loopy/kernel/tools.py | 31 +++++++++++++++---------------- loopy/schedule/__init__.py | 2 +- loopy/schedule/tools.py | 12 +++--------- loopy/symbolic.py | 4 ++-- 5 files changed, 27 insertions(+), 34 deletions(-) diff --git a/loopy/check.py b/loopy/check.py index 3acf8f08a..17dd13270 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -72,7 +72,7 @@ check_each_kernel, ) from loopy.type_inference import TypeReader -from loopy.typing import auto, not_none +from loopy.typing import auto, not_none, set_union if TYPE_CHECKING: @@ -1107,10 +1107,10 @@ def _check_variable_access_ordered_inner(kernel: LoopKernel) -> None: address_space = _get_address_space(kernel, var) eq_class = aliasing_equiv_classes[var] - readers = set.union( - *[rmap.get(eq_name, set()) for eq_name in eq_class]) - writers = set.union( - *[wmap.get(eq_name, set()) for eq_name in eq_class]) + readers = set_union( + rmap.get(eq_name, set()) for eq_name in eq_class) + writers = set_union( + wmap.get(eq_name, set()) for eq_name in eq_class) for writer in writers: required_deps = (readers | writers) - {writer} @@ -1676,7 +1676,7 @@ def _get_sub_array_ref_swept_range( return get_access_map( domain.to_set(), sar.swept_inames, - kernel.assumptions.to_set()).range() + kernel.assumptions).range() def _are_sub_array_refs_equivalent( diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index c57e11233..937fcc81e 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -65,13 +65,13 @@ TUnitOrKernelT, for_each_kernel, ) +from loopy.typing import fset_union, set_union if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterable, Mapping, Sequence, Set + from collections.abc import Callable, Collection, Iterable, Mapping, Sequence - import pymbolic.primitives as p - from pymbolic import ArithmeticExpression, Expression + from pymbolic import ArithmeticExpression from pytools.tag import Tag from loopy.types import ToLoopyTypeConvertible @@ -725,7 +725,7 @@ def show_dependency_graph(*args, **kwargs): def is_domain_dependent_on_inames(kernel: LoopKernel, domain_index: int, inames: Set[str]) -> bool: dom = kernel.domains[domain_index] - dom_parameters = set(dom.get_var_names(dim_type.param)) + dom_parameters = set(dom.get_var_names_not_none(dim_type.param)) # {{{ check for parenthood by loop bound iname @@ -2038,7 +2038,7 @@ def find_aliasing_equivalence_classes(kernel): # {{{ direction helper tools -def infer_args_are_input_output(kernel): +def infer_args_are_input_output(kernel: LoopKernel): """ Returns a copy of *kernel* with the attributes ``is_input`` and ``is_output`` of the arguments set. @@ -2094,22 +2094,22 @@ def infer_args_are_input_output(kernel): # {{{ CallablesIDCollector -class CallablesIDCollector(CombineMapper): +class CallablesIDCollector(CombineMapper[frozenset[CallableId], []]): """ Mapper to collect function identifiers of all resolved callables in an expression. """ - def combine(self, values): - import operator - return reduce(operator.or_, values, frozenset()) + @override + def combine(self, values: Iterable[frozenset[CallableId]]): + return fset_union(values) def map_resolved_function(self, expr): return frozenset([expr.name]) - def map_constant(self, expr): + def map_constant(self, expr: object): return frozenset() - def map_kernel(self, kernel): + def map_kernel(self, kernel: LoopKernel) -> frozenset[CallableId]: callables_in_insn = frozenset() for insn in kernel.instructions: @@ -2244,9 +2244,7 @@ def __post_init__(self) -> None: def combine(self, values: Iterable[Set[tuple[Expression, ...]]] ) -> Set[tuple[Expression, ...]]: - import operator - from functools import reduce - return reduce(operator.or_, values, set()) + return set_union(values) @override def map_subscript(self, expr: p.Subscript) -> Set[tuple[Expression, ...]]: @@ -2269,7 +2267,7 @@ def map_constant( return frozenset() -def _union_amaps(amaps): +def _union_amaps(amaps: Sequence[isl.Map]): import islpy as isl return reduce(isl.Map.union, amaps[1:], amaps[0]) @@ -2290,7 +2288,8 @@ def get_insn_access_map(kernel: LoopKernel, insn_id: str, var: str): amaps = [ get_access_map( - kernel.get_inames_domain(insn.within_inames), idx, kernel.assumptions + kernel.get_inames_domain(insn.within_inames).to_set(), + idx, kernel.assumptions ) for idx in indices ] diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index 4bb8b879f..87970e04c 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -1490,7 +1490,7 @@ def insn_sort_key(insn_id: InsnId): iname_home_domain = kernel.domains[kernel.get_home_domain_index(iname)] from islpy import dim_type iname_home_domain_params = set( - iname_home_domain.get_var_names(dim_type.param)) + iname_home_domain.get_var_names_not_none(dim_type.param)) # The previous check should have ensured this is true, because # the loop_nest_around_map takes the domain dependency graph into diff --git a/loopy/schedule/tools.py b/loopy/schedule/tools.py index 5b4420be1..b51f1b43d 100644 --- a/loopy/schedule/tools.py +++ b/loopy/schedule/tools.py @@ -979,7 +979,7 @@ def get_partial_loop_nest_tree(kernel: LoopKernel) -> LoopNestTree: for insn in kernel.instructions} root: InameStrSet = frozenset() - tree = Tree.from_root(root) + tree = Tree[InameStrSet].from_root(root) # mapping from iname to the innermost loop nest they are part of in *tree*. iname_to_tree_node_id: dict[InameStr, InameStrSet] = {} @@ -1085,14 +1085,8 @@ def get_loop_tree(kernel: LoopKernel) -> LoopTree: # {{{ impose constraints by the domain tree - # FIXME: These three could be one statement if it weren't for - # - https://github.com/python/mypy/issues/17693 - # - https://github.com/python/mypy/issues/17694 - emptyset: InameStrSet = frozenset() - loop_inames = reduce(frozenset.union, - (insn.within_inames - for insn in kernel.instructions), - emptyset) + loop_inames = fset_union( + insn.within_inames for insn in kernel.instructions) loop_inames = loop_inames - _get_parallel_inames(kernel) for dom in kernel.domains: diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 539c123d2..f0f43bda6 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -2564,7 +2564,7 @@ def map_tagged_variable(self, expr): def get_access_map( domain: isl.Set, subscript: tuple[Expression, ...], - assumptions: isl.Set | None = None, + assumptions: isl.BasicSet | None = None, shape: ShapeType | None = None, allowed_constant_names: Collection[str] | None = None ) -> isl.Map: @@ -2730,7 +2730,7 @@ def map_subscript(self, expr: p.Subscript, inames: Set[str]) -> None: try: access_map = get_access_map( - domain.to_set(), subscript, self.kernel.assumptions.to_set(), + domain.to_set(), subscript, self.kernel.assumptions, shape=cast("ShapeType | None", descriptor.shape) if self._overestimate else None, allowed_constant_names=self.kernel.get_unwritten_value_args()) From 066c26c339be703c2d5734ea547916473c7248db Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 17 Jun 2025 17:46:40 -0500 Subject: [PATCH 06/12] Type DisjointSets, find_aliasing_equivalence_classes --- loopy/kernel/tools.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 937fcc81e..ef46cec2a 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -24,7 +24,6 @@ THE SOFTWARE. """ - import dataclasses import itertools import logging @@ -35,6 +34,7 @@ from typing import ( TYPE_CHECKING, Concatenate, + Generic, ParamSpec, TypeVar, cast, @@ -81,6 +81,9 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") + + # {{{ add and infer argument dtypes def add_dtypes( @@ -1958,7 +1961,7 @@ def get_subkernel_extra_inames(kernel: LoopKernel) -> Mapping[str, frozenset[str # {{{ find aliasing equivalence classes -class DisjointSets: +class DisjointSets(Generic[T]): """ .. automethod:: __getitem__ .. automethod:: find_leader_or_create_group @@ -1969,10 +1972,10 @@ class DisjointSets: # https://en.wikipedia.org/wiki/Disjoint-set_data_structure def __init__(self): - self.leader_to_group = {} - self.element_to_leader = {} + self.leader_to_group: dict[T, set[T]] = {} + self.element_to_leader: dict[T, T] = {} - def __getitem__(self, item): + def __getitem__(self, item: T): """ :arg item: A representative of an equivalence class. :returns: the equivalence class, given as a set of elements @@ -1984,7 +1987,7 @@ def __getitem__(self, item): else: return self.leader_to_group[leader] - def find_leader_or_create_group(self, el): + def find_leader_or_create_group(self, el: T): try: return self.element_to_leader[el] except KeyError: @@ -1994,7 +1997,7 @@ def find_leader_or_create_group(self, el): self.leader_to_group[el] = {el} return el - def union(self, a, b): + def union(self, a: T, b: T): leader_a = self.find_leader_or_create_group(a) leader_b = self.find_leader_or_create_group(b) @@ -2009,7 +2012,7 @@ def union(self, a, b): self.leader_to_group[leader_a].update(self.leader_to_group[leader_b]) del self.leader_to_group[leader_b] - def union_many(self, relation): + def union_many(self, relation: Iterable[tuple[T, T]]): """ :arg relation: an iterable of 2-tuples enumerating the elements of the relation. The relation is assumed to be an equivalence relation @@ -2027,8 +2030,8 @@ def union_many(self, relation): return self -def find_aliasing_equivalence_classes(kernel): - return DisjointSets().union_many( +def find_aliasing_equivalence_classes(kernel: LoopKernel): + return DisjointSets[str]().union_many( (tv.base_storage, tv.name) for tv in kernel.temporary_variables.values() if tv.base_storage is not None) From 27e157ff95675e3ede853dd2d897746909562768 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 31 Oct 2021 16:55:30 -0500 Subject: [PATCH 07/12] Implement loop fusion transformation --- doc/ref_transform.rst | 6 + loopy/__init__.py | 6 + loopy/schedule/tools.py | 2 +- loopy/transform/loop_fusion.py | 959 +++++++++++++++++++++++++++++++++ 4 files changed, 972 insertions(+), 1 deletion(-) create mode 100644 loopy/transform/loop_fusion.py diff --git a/doc/ref_transform.rst b/doc/ref_transform.rst index 9ef012d66..3c209db9e 100644 --- a/doc/ref_transform.rst +++ b/doc/ref_transform.rst @@ -143,4 +143,10 @@ TODO: Matching instruction tags .. automodule:: loopy.match + +Fusing Loops +------------ + +.. automodule:: loopy.transform.loop_fusion + .. vim: tw=75:spell diff --git a/loopy/__init__.py b/loopy/__init__.py index 799253956..de50eb2d3 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -187,6 +187,10 @@ simplify_indices, tag_instructions, ) +from loopy.transform.loop_fusion import ( + get_kennedy_unweighted_fusion_candidates, + rename_inames_in_batch, +) from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call from loopy.transform.padding import ( add_padding, @@ -336,6 +340,7 @@ "get_dot_dependency_graph", "get_global_barrier_order", "get_iname_duplication_options", + "get_kennedy_unweighted_fusion_candidates", "get_mem_access_map", "get_one_linearized_kernel", "get_one_scheduled_kernel", @@ -382,6 +387,7 @@ "rename_callable", "rename_iname", "rename_inames", + "rename_inames_in_batch", "replace_instruction_ids", "save_and_reload_temporaries", "set_argument_order", diff --git a/loopy/schedule/tools.py b/loopy/schedule/tools.py index b51f1b43d..0c1c1e607 100644 --- a/loopy/schedule/tools.py +++ b/loopy/schedule/tools.py @@ -1049,7 +1049,7 @@ def get_partial_loop_nest_tree(kernel: LoopKernel) -> LoopNestTree: def _get_iname_to_tree_node_id_from_partial_loop_nest_tree( tree: LoopNestTree, - ) -> Mapping[str, frozenset[str]]: + ) -> Mapping[InameStr, frozenset[str]]: """ Returns the mapping from the iname to the *tree*'s node that it was a part of. diff --git a/loopy/transform/loop_fusion.py b/loopy/transform/loop_fusion.py new file mode 100644 index 000000000..85aad27f3 --- /dev/null +++ b/loopy/transform/loop_fusion.py @@ -0,0 +1,959 @@ +# pyright: reportAny=warning + +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2021-25 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from dataclasses import dataclass +from typing import TYPE_CHECKING, cast, final + +from constantdict import constantdict +from typing_extensions import override + +from pytools import memoize_on_first_arg + +from loopy.diagnostic import LoopyError +from loopy.kernel import LoopKernel +from loopy.symbolic import ( + ExpansionState, + Reduction, + RuleAwareIdentityMapper, + SubstitutionRuleMappingContext, +) +from loopy.typing import fset_union, not_none + + +if TYPE_CHECKING: + from collections.abc import Callable, Collection, Iterable, Mapping, Set + + from loopy.kernel.instruction import InstructionBase + from loopy.match import RuleStack + from loopy.schedule.tools import LoopNestTree + from loopy.typing import InameStr, InameStrSet + + +__doc__ = """ +.. autofunction:: rename_inames_in_batch +.. autofunction:: get_kennedy_unweighted_fusion_candidates +""" + +_EMPTY_INAME_SET: InameStrSet = frozenset() + + +# {{{ Loop Dependence graph class + builder + +@dataclass(frozen=True, eq=True) +class LoopDependenceGraph: + """ + .. attribute:: successors + + A mapping from iname (``i``) to the collection of inames that can be + scheduled only after the loop corresponding to ``i`` has been exited. + + .. attribute:: predecessors + + A mapping from iname (``i``) to the collection of inames that must have + been exited before entering ``i``. + + .. attribute:: is_infusible + + A mapping from the edges in the loop dependence graph to their + fusibility criterion. An edge in this mapping is represented by a pair + of inames``(iname_i, iname_j)`` such that the edge ``iname_i -> + iname_j`` is present in the graph. + + .. note:: + + Both :attr:`successors` and :attr:`predecessors` are maintained to + reduce the complexity of graph primitive operations (like remove node, + add edge, etc.). + """ + + successors: Mapping[InameStr, Set[InameStr]] + predecessors: Mapping[InameStr, frozenset[InameStr]] + is_infusible: Mapping[tuple[InameStr, InameStr], bool] + + @classmethod + def new(cls, + successors: Mapping[InameStr, Set[InameStr]], + is_infusible: Mapping[tuple[InameStr, InameStr], bool] + ): + predecessors = {node: cast("set[InameStr]", set()) for node in successors} + for node, succs in successors.items(): + for succ in succs: + predecessors[succ].add(node) + + predecessors = { + node: frozenset(preds) for node, preds in predecessors.items() + } + successors = {node: frozenset(succs) for node, succs in successors.items()} + + return LoopDependenceGraph(successors, predecessors, is_infusible) + + def is_empty(self): + """ + Returns *True* only if the loop dependence graph contains no nodes. + """ + return len(self.successors) == 0 + + def get_loops_with_no_predecessors(self): + return { + loop for loop, preds in self.predecessors.items() if len(preds) == 0 + } + + def remove_nodes(self, nodes_to_remove: Set[InameStr]): + """ + Returns a copy of *self* after removing *nodes_to_remove* in the graph. + This routine adds necessary edges after removing *nodes_to_remove* to + conserve the scheduling constraints present in the graph. + """ + # {{{ Step 1. Remove the nodes + + new_successors = { + node: succs + for node, succs in self.successors.items() + if node not in nodes_to_remove + } + new_predecessors = { + node: preds + for node, preds in self.predecessors.items() + if node not in nodes_to_remove + } + + new_is_infusible = { + (from_, to): v + for (from_, to), v in self.is_infusible.items() + if (from_ not in nodes_to_remove and to not in nodes_to_remove) + } + + # }}} + + # {{{ Step 2. Propagate dependencies + + # For every Node 'R' to be removed and every pair (S, P) such that + # 1. there exists an edge 'P' -> 'R' in the original graph, and, + # 2. there exits an edge 'R' -> 'S' in the original graph. + # add the edge 'P' -> 'S' in the new graph. + + for node_to_remove in nodes_to_remove: + for succ in self.successors[node_to_remove] - nodes_to_remove: + new_predecessors[succ] = new_predecessors[succ] - frozenset( + [node_to_remove] + ) + + for pred in self.predecessors[node_to_remove] - nodes_to_remove: + new_successors[pred] = new_successors[pred] - frozenset( + [node_to_remove] + ) + + # }}} + + return LoopDependenceGraph( + new_successors, new_predecessors, new_is_infusible + ) + + +@dataclass +class LoopDependenceGraphBuilder: + """ + A mutable type to act as a helper to instantiate a + :class:`LoopDependenceGraphBuilder`. + """ + + _dag: dict[InameStr, set[InameStr]] + _is_infusible: dict[tuple[InameStr, InameStr], bool] + + @classmethod + def new(cls, candidates: Iterable[InameStr]): + return LoopDependenceGraphBuilder( + {iname: set() for iname in candidates}, {} + ) + + def add_edge(self, from_: str, to: str, is_infusible: bool): + self._dag[from_].add(to) + self._is_infusible[from_, to] = is_infusible or self._is_infusible.get( + (from_, to), False + ) + + def done(self): + """ + Returns the built :class:`LoopDependenceGraph`. + """ + return LoopDependenceGraph.new(self._dag, self._is_infusible) + + +# }}} + + +# {{{ _build_ldg + + +@dataclass(frozen=True, eq=True, repr=True) +class PreLDGNode: + """ + A node in the graph representing the dependencies before building + :class:`LoopDependenceGraph`. + """ + + +@dataclass(frozen=True, eq=True, repr=True) +class CandidateLoop(PreLDGNode): + iname: str + + +@dataclass(frozen=True, eq=True, repr=True) +class NonCandidateLoop(PreLDGNode): + loop_nest: frozenset[str] + + +@dataclass(frozen=True, eq=True, repr=True) +class OuterLoopNestStatement(PreLDGNode): + insn_id: str + + +def _remove_non_candidate_pre_ldg_nodes( + predecessors: Mapping[PreLDGNode, frozenset[PreLDGNode]], + successors: Mapping[PreLDGNode, frozenset[PreLDGNode]], +) -> tuple[ + dict[str, frozenset[str]], + dict[str, frozenset[str]], + set[tuple[str, str]], +]: + """ + Returns ``(new_successors, new_predecessors, infusible_edge)`` where + ``(new_successors, new_predecessors)`` is the graph describing the + dependencies between the *candidates* loops that has been obtained by + removing instances of :class:`NonCandidateLoop` and + :class:`OuterLoopNestStatement` from the graph described by *predecessors*, + *successors*. + + New dependency edges are added in the new graph to preserve the transitive + dependencies that exists in the original graph. + """ + # {{{ input validation + + assert set(predecessors) == set(successors) + assert all(isinstance(val, frozenset) for val in predecessors.values()) + assert all(isinstance(val, frozenset) for val in successors.values()) + + # }}} + + nodes_to_remove = { + node + for node in predecessors + if isinstance(node, (NonCandidateLoop, OuterLoopNestStatement)) + } + new_predecessors = dict(predecessors) + new_successors = dict(successors) + infusible_edges_in_statement_dag: set[tuple[str, str]] = set() + + for node_to_remove in nodes_to_remove: + for pred in new_predecessors[node_to_remove]: + new_successors[pred] = ( + new_successors[pred] - frozenset([node_to_remove]) + ) | new_successors[node_to_remove] + + for succ in new_successors[node_to_remove]: + new_predecessors[succ] = ( + new_predecessors[succ] - frozenset([node_to_remove]) + ) | new_predecessors[node_to_remove] + + for pred in new_predecessors[node_to_remove]: + for succ in new_successors[node_to_remove]: + # now mark the edge from pred -> succ infusible iff both 'pred' and + # 'succ' are *not* in insns_to_remove + if (pred not in nodes_to_remove) and (succ not in nodes_to_remove): + assert isinstance(pred, CandidateLoop) + assert isinstance(succ, CandidateLoop) + infusible_edges_in_statement_dag.add((pred.iname, succ.iname)) + + del new_predecessors[node_to_remove] + del new_successors[node_to_remove] + + assert all(isinstance(pred, CandidateLoop) for pred in new_predecessors) + assert all(isinstance(succ, CandidateLoop) for succ in new_successors) + assert all( + all(isinstance(v, CandidateLoop) for v in vals) + for vals in new_predecessors.values() + ) + assert all( + all(isinstance(v, CandidateLoop) for v in vals) + for vals in new_successors.values() + ) + new_predecessors = cast( + "Mapping[CandidateLoop, frozenset[CandidateLoop]]", + new_predecessors) + new_successors = cast( + "Mapping[CandidateLoop, frozenset[CandidateLoop]]", + new_successors) + + return ( + { + key.iname: frozenset({n.iname for n in value}) # type: ignore[attr-defined] + for key, value in new_predecessors.items() + }, + { + key.iname: frozenset({n.iname for n in value}) # type: ignore[attr-defined] + for key, value in new_successors.items() + }, + infusible_edges_in_statement_dag, + ) + + +@memoize_on_first_arg +def _get_ldg_nodes_from_loopy_insn( + insn: InstructionBase, + candidates: frozenset[str], + non_candidates: frozenset[frozenset[str]], + just_outer_loop_nest: frozenset[str], +) -> tuple[PreLDGNode, ...]: + """ + Helper used in :func:`_build_ldg`. + + :arg just_outer_loop_nest: A :class:`frozenset` of the loop nest that appears + just outer to the *candidates* in the partial loop nest tree. + """ + if (insn.within_inames | insn.reduction_inames()) & candidates: + # => the statement containing + return tuple( + CandidateLoop(candidate) + for candidate in ( + (insn.within_inames | insn.reduction_inames()) & candidates + ) + ) + elif { + loop_nest + for loop_nest in non_candidates + if (loop_nest & insn.within_inames) + }: + (non_candidate,) = { + loop_nest + for loop_nest in non_candidates + if (loop_nest & insn.within_inames) + } + + return (NonCandidateLoop(frozenset(non_candidate)),) + else: + assert (insn.within_inames & just_outer_loop_nest) or ( + insn.within_inames == just_outer_loop_nest + ) + assert insn.id is not None + return (OuterLoopNestStatement(insn.id),) + + +@memoize_on_first_arg +def _compute_isinfusible_via_access_map( + kernel: LoopKernel, + insn_pred: str, + candidate_pred: str, + insn_succ: str, + candidate_succ: str, + outer_inames: frozenset[str], + var: str, +) -> bool: + """ + Returns *True* if the inames *candidate_pred* and *candidate_succ* are fused then + that might lead to a loop carried dependency for *var*. + + Helper used in :func:`_build_ldg`. + """ + import islpy as isl + import pymbolic.primitives as prim + + from loopy.diagnostic import UnableToDetermineAccessRangeError + from loopy.kernel.tools import get_insn_access_map + from loopy.symbolic import isl_set_from_expr + + try: + amap_pred = get_insn_access_map(kernel, insn_pred, var) + amap_succ = get_insn_access_map(kernel, insn_succ, var) + except UnableToDetermineAccessRangeError: + # either predecessors or successors has a non-affine access i.e. + # fallback to the safer option => infusible + return True + + amap_pred = amap_pred.project_out_except( + outer_inames | {candidate_pred}, [isl.dim_type.param, isl.dim_type.in_] + ) + amap_succ = amap_succ.project_out_except( + outer_inames | {candidate_succ}, [isl.dim_type.param, isl.dim_type.in_] + ) + + for outer_iname in sorted(outer_inames): + amap_pred = amap_pred.move_dims( + dst_type=isl.dim_type.param, + dst_pos=amap_pred.dim(isl.dim_type.param), + src_type=isl.dim_type.in_, + src_pos=amap_pred.get_var_dict()[outer_iname][1], + n=1, + ) + amap_succ = amap_succ.move_dims( + dst_type=isl.dim_type.param, + dst_pos=amap_succ.dim(isl.dim_type.param), + src_type=isl.dim_type.in_, + src_pos=amap_succ.get_var_dict()[outer_iname][1], + n=1, + ) + + # since both ranges denote the same variable they must be subscripted with + # the same number of indices. + assert amap_pred.dim(isl.dim_type.out) == amap_succ.dim(isl.dim_type.out) + assert amap_pred.dim(isl.dim_type.in_) == 1 + assert amap_succ.dim(isl.dim_type.in_) == 1 + + if amap_pred == amap_succ: + return False + + ndim = amap_pred.dim(isl.dim_type.out) + + # {{{ set the out dim names as `amap_a_dim0`, `amap_a_dim1`, ... + + for idim in range(ndim): + amap_pred = amap_pred.set_dim_name( + isl.dim_type.out, idim, f"_lpy_amap_a_dim{idim}" + ) + amap_succ = amap_succ.set_dim_name( + isl.dim_type.out, idim, f"_lpy_amap_b_dim{idim}" + ) + + # }}} + + # {{{ amap_pred -> set_pred, amap_succ -> set_succ + + amap_pred = amap_pred.move_dims( + isl.dim_type.in_, + amap_pred.dim(isl.dim_type.in_), + isl.dim_type.out, + 0, + amap_pred.dim(isl.dim_type.out), + ) + + amap_succ = amap_succ.move_dims( + isl.dim_type.in_, + amap_succ.dim(isl.dim_type.in_), + isl.dim_type.out, + 0, + amap_succ.dim(isl.dim_type.out), + ) + set_pred, set_succ = amap_pred.domain(), amap_succ.domain() + set_pred, set_succ = isl.align_two(set_pred, set_succ) + + # }}} + + # {{{ build the bset, both accesses access the same element + + accesses_same_index_set = isl.BasicSet.universe(set_pred.space) + for idim in range(ndim): + cnstrnt = isl.Constraint.eq_from_names( + set_pred.space, + {f"_lpy_amap_a_dim{idim}": 1, f"_lpy_amap_b_dim{idim}": -1}, + ) + accesses_same_index_set = accesses_same_index_set.add_constraint(cnstrnt) + + # }}} + + candidates_not_equal = isl_set_from_expr( + set_pred.space, + prim.Comparison( + prim.Variable(candidate_pred), ">", prim.Variable(candidate_succ) + ), + ) + result = not ( + set_pred & set_succ & accesses_same_index_set & candidates_not_equal + ).is_empty() + + return result + + +def _build_ldg( + kernel: LoopKernel, candidates: frozenset[str], outer_inames: frozenset[str] +): + """ + Returns an instance of :class:`LoopDependenceGraph` needed while fusing + *candidates*. Invoked as a helper function in + :func:`get_kennedy_unweighted_fusion_candidates`. + """ + + from pytools.graph import compute_topological_order + + loop_nest_tree = _get_partial_loop_nest_tree_for_fusion(kernel) + + non_candidate_loop_nests = frozenset( + { + child_loop_nest + for child_loop_nest in loop_nest_tree.children(outer_inames) + if len(child_loop_nest & candidates) == 0 + } + ) + + insns = frozenset(kernel.id_to_insn).intersection(* + (frozenset(kernel.iname_to_insns()[iname]) for iname in outer_inames)) + predecessors: dict[PreLDGNode, set[PreLDGNode]] = {} + successors: dict[PreLDGNode, set[PreLDGNode]] = {} + + for insn in insns: + for successor in _get_ldg_nodes_from_loopy_insn( + kernel.id_to_insn[insn], + candidates, + non_candidate_loop_nests, + outer_inames, + ): + predecessors.setdefault(successor, set()) + successors.setdefault(successor, set()) + for dep in kernel.id_to_insn[insn].depends_on: + if ( + kernel.id_to_insn[dep].within_inames & outer_inames + ) != outer_inames: + # this is not an instruction in 'outer_inames' => bogus dep. + continue + for predecessor in _get_ldg_nodes_from_loopy_insn( + kernel.id_to_insn[dep], + candidates, + non_candidate_loop_nests, + outer_inames, + ): + if predecessor != successor: + predecessors.setdefault(successor, set()).add(predecessor) + successors.setdefault(predecessor, set()).add(successor) + + _, ldg_successors, infusible_edges = _remove_non_candidate_pre_ldg_nodes( + {key: frozenset(value) for key, value in predecessors.items()}, + {key: frozenset(value) for key, value in successors.items()}, + ) + del predecessors, successors + + builder = LoopDependenceGraphBuilder.new(candidates) + + # Interpret the statement DAG as LDG + for pred, succs in ldg_successors.items(): + for succ in succs: + builder.add_edge(pred, succ, (pred, succ) in infusible_edges) + + # {{{ add infusible edges to the LDG depending on memory deps. + + all_candidate_insns: frozenset[str] = fset_union( + kernel.iname_to_insns()[iname] for iname in candidates) + + dep_inducing_vars: frozenset[str] = fset_union( + frozenset(kernel.id_to_insn[insn].assignee_var_names()) + for insn in all_candidate_insns + ) + wmap = kernel.writer_map() + rmap = kernel.reader_map() + + topo_order = { + el: i for i, el in enumerate(compute_topological_order(ldg_successors)) + } + + for var in dep_inducing_vars: + for writer_id in wmap.get(var, frozenset()) & all_candidate_insns: + for access_id in ( + rmap.get(var, frozenset()) | wmap.get(var, frozenset()) + ) & all_candidate_insns: + if writer_id == access_id: + # no need to add self dependence + continue + + (writer_candidate,) = ( + kernel.id_to_insn[writer_id].within_inames & candidates + ) + (access_candidate,) = ( + kernel.id_to_insn[access_id].within_inames & candidates + ) + (pred_candidate, pred), (succ_candidate, succ) = sorted( + [(writer_candidate, writer_id), (access_candidate, access_id)], + key=lambda x: topo_order[x[0]], + ) + + is_infusible = _compute_isinfusible_via_access_map( + kernel, + pred, + pred_candidate, + succ, + succ_candidate, + outer_inames, + var, + ) + + builder.add_edge(pred_candidate, succ_candidate, is_infusible) + + # }}} + + return builder.done() + + +# }}} + + +def _fuse_sequential_loops_with_outer_loops( + kernel: LoopKernel, + candidates: frozenset[str], + outer_inames: frozenset[str], + name_gen: Callable[[str], str], + prefix: str, +): + from collections import deque + + ldg = _build_ldg(kernel, candidates, outer_inames) + + fused_chunks: dict[str, frozenset[str]] = {} + + while not ldg.is_empty(): + + # sorting to have a deterministic order. + # prefer 'deque' over list, as popping elements off the queue would be + # O(1). + + loops_with_no_preds = sorted(ldg.get_loops_with_no_predecessors()) + + queue = deque([loops_with_no_preds[0]]) + for node in loops_with_no_preds[1:]: + queue.append(node) + + loops_to_be_fused: set[str] = set() + non_fusible_loops: set[str] = set() + while queue: + next_loop_in_queue = queue.popleft() + + if next_loop_in_queue in non_fusible_loops: + # had an non-fusible edge with an already schedule loop. + # Sorry 'next_loop_in_queue', until next time :'(. + continue + + if next_loop_in_queue in loops_to_be_fused: + # already fused, no need to fuse again ;) + continue + + if not (ldg.predecessors[next_loop_in_queue] <= loops_to_be_fused): + # this loop still needs some other loops to be scheduled + # before we can reach this. + # Bye bye 'next_loop_in_queue' :'( , see you when all your + # predecessors have been scheduled. + continue + + loops_to_be_fused.add(next_loop_in_queue) + + for succ in ldg.successors[next_loop_in_queue]: + if ldg.is_infusible.get((next_loop_in_queue, succ), False): + non_fusible_loops.add(succ) + else: + queue.append(succ) + + ldg = ldg.remove_nodes(loops_to_be_fused) + fused_chunks[name_gen(prefix)] = frozenset(loops_to_be_fused) + + assert fset_union(fused_chunks.values()) == candidates + assert sum(len(val) for val in fused_chunks.values()) == len(candidates) + + return fused_chunks + + +@final +class ReductionLoopInserter(RuleAwareIdentityMapper[[]]): + """ + Main mapper used by :func:`_add_reduction_loops_in_partial_loop_nest_tree`. + """ + + iname_to_tree_node_id: constantdict[InameStr, frozenset[str]] + + def __init__(self, + rule_mapping_context: SubstitutionRuleMappingContext, + tree: LoopNestTree + ): + super().__init__(rule_mapping_context) + self.tree = tree + from loopy.schedule.tools import ( + _get_iname_to_tree_node_id_from_partial_loop_nest_tree, + ) + + self.iname_to_tree_node_id = constantdict( + _get_iname_to_tree_node_id_from_partial_loop_nest_tree(tree) + ) + + @override + def map_reduction(self, + expr: Reduction, + expn_state: ExpansionState, + outer_redn_inames: frozenset[InameStr] = _EMPTY_INAME_SET + ): + redn_inames = frozenset(expr.inames) + iname_chain = ( + expn_state.instruction.within_inames | outer_redn_inames | redn_inames + ) + not_seen_inames = frozenset( + iname + for iname in iname_chain + if iname not in self.iname_to_tree_node_id + ) + seen_inames = iname_chain - not_seen_inames + + # {{{ verbatim copied from loopy/schedule/tools.py + + from loopy.schedule.tools import _add_inner_loops, separate_loop_nest + + all_nests = {self.iname_to_tree_node_id[iname] for iname in seen_inames} + + self.tree, outer_loop, inner_loop = separate_loop_nest( + self.tree, (all_nests | {frozenset()}), seen_inames + ) + if not_seen_inames: + # make '_not_seen_inames' nest inside the seen ones. + # example: if there is already a loop nesting "i,j,k" + # and the current iname chain is "i,j,l". Only way this is possible + # is if "l" is nested within "i,j"-loops. + self.tree = _add_inner_loops(self.tree, outer_loop, not_seen_inames) + + # {{{ update iname to node id + + for iname in outer_loop: + self.iname_to_tree_node_id = self.iname_to_tree_node_id.set( + iname, outer_loop + ) + + if inner_loop is not None: + for iname in inner_loop: + self.iname_to_tree_node_id = self.iname_to_tree_node_id.set( + iname, inner_loop + ) + + for iname in not_seen_inames: + self.iname_to_tree_node_id = self.iname_to_tree_node_id.set( + iname, not_seen_inames + ) + + # }}} + + # }}} + + assert not (outer_redn_inames & redn_inames) + return super().map_reduction( + expr, expn_state, outer_redn_inames=(outer_redn_inames | redn_inames) + ) + + +def _add_reduction_loops_in_partial_loop_nest_tree( + kernel: LoopKernel, tree: LoopNestTree +) -> LoopNestTree: + """ + Returns a partial loop nest tree with the loop nests corresponding to the + reduction inames added to *tree*. + """ + from loopy.symbolic import SubstitutionRuleMappingContext + + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + reduction_loop_inserter = ReductionLoopInserter(rule_mapping_context, tree) + + def does_insn_have_reduce( + kernel: LoopKernel, + insn: InstructionBase, + stack: RuleStack): + return bool(insn.reduction_inames()) + + reduction_loop_inserter.map_kernel( + kernel, within=does_insn_have_reduce, map_tvs=False, map_args=False + ) + return reduction_loop_inserter.tree + + +@memoize_on_first_arg +def _get_partial_loop_nest_tree_for_fusion(kernel: LoopKernel): + from loopy.schedule.tools import get_partial_loop_nest_tree + + tree = get_partial_loop_nest_tree(kernel) + tree = _add_reduction_loops_in_partial_loop_nest_tree(kernel, tree) + return tree + + +def get_kennedy_unweighted_fusion_candidates( + kernel: LoopKernel, + candidates: frozenset[InameStr], + *, prefix: str = "ifused" + ) -> Mapping[InameStr, Collection[InameStr]]: + """ + Returns the fusion candidates mapping that could be fed to + :func:`rename_inames_in_batch` similar to Ken Kennedy's Unweighted + Loop-Fusion Algorithm. + + .. attribute:: prefix + + Prefix for the fused inames. + + .. note:: + + An error is raised if there exists a pair ``(i, j)`` such that both + ``i`` and ``j`` are present in *candidates*, and the schedluing + constraints in *kernel* enforces ``j`` to be nested within ``i``. + """ + from collections.abc import Collection + + from loopy.kernel.data import ConcurrentTag + from loopy.schedule.tools import ( + _get_iname_to_tree_node_id_from_partial_loop_nest_tree, + ) + + assert not isinstance(candidates, str) + assert isinstance(candidates, Collection) + assert isinstance(kernel, LoopKernel) + + candidates = frozenset(candidates) + vng = kernel.get_var_name_generator() + + # {{{ implementation scope / sanity checks + + # All of the candidates must be either "pure" reduction loops or + # pure-within_inames loops. + # Reason: otherwise _compute_isinfusible_via_access_map might result in + # spurious results. + # One option is to simply perform 'realize_reduction' before implementing + # this algorithm, but that seems like an unnecessary cost to pay. + if any(candidates & insn.reduction_inames() for insn in kernel.instructions): + if any(candidates & insn.within_inames for insn in kernel.instructions): + raise NotImplementedError( + "Some candidates are reduction" + " inames while some of them are not. Such" + " cases are not yet supported." + ) + + if not (candidates <= kernel.all_inames()): + raise ValueError( + "Loop fusion candidates are not part of kernel's inames." + f" Spurious candidates: {candidates - kernel.all_inames()}" + ) + + # }}} + + fused_chunks: dict[InameStr, frozenset[InameStr]] = {} + + # {{{ handle concurrent inames + + # filter out concurrent loops. + all_concurrent_tags: frozenset[ConcurrentTag] = fset_union( + kernel.inames[iname].tags_of_type(ConcurrentTag) for iname in candidates) + + concurrent_tag_to_inames: dict[ConcurrentTag, set[str]] = { + tag: set() for tag in all_concurrent_tags + } + + for iname in candidates: + if kernel.inames[iname].tags_of_type(ConcurrentTag): + # since ConcurrentTag is a UniqueTag there must be exactly one of + # it. + (tag,) = kernel.tags_of_type(ConcurrentTag) + concurrent_tag_to_inames[tag].add(iname) + + for inames in concurrent_tag_to_inames.values(): + fused_chunks[vng(prefix)] = frozenset(inames) + candidates = candidates - inames + + # }}} + + tree = _get_partial_loop_nest_tree_for_fusion(kernel) + + iname_to_tree_node_id = _get_iname_to_tree_node_id_from_partial_loop_nest_tree( + tree + ) + + # {{{ sanitary checks + + nest_tree_id_to_candidate: dict[InameStrSet, InameStr] = {} + + for iname in candidates: + loop_nest_tree_node_id = iname_to_tree_node_id[iname] + if loop_nest_tree_node_id not in nest_tree_id_to_candidate: + nest_tree_id_to_candidate[loop_nest_tree_node_id] = iname + else: + conflict_iname = nest_tree_id_to_candidate[loop_nest_tree_node_id] + raise LoopyError( + f"'{iname}' and '{conflict_iname}' " + "cannot fused be fused as they can be nested " + "within one another." + ) + + for iname in candidates: + outer_loops: frozenset[str] = fset_union( + tree.ancestors(iname_to_tree_node_id[iname])) + if outer_loops & candidates: + raise LoopyError( + f"Cannot fuse '{iname}' with" + f" '{outer_loops & candidates}' as they" + " maybe nesting within one another." + ) + + del nest_tree_id_to_candidate + + # }}} + + # just_outer_loop_nest: mapping from loop nest to the candidates they contain + just_outer_loop_nest: dict[InameStrSet, set[str]] = {} + + for iname in candidates: + assert tree.parent(iname_to_tree_node_id[iname]) is not None + just_outer_loop_nest.setdefault( + not_none(tree.parent(iname_to_tree_node_id[iname])), set() + ).add(iname) + + for outer_inames, inames in just_outer_loop_nest.items(): + fused_chunks.update( + _fuse_sequential_loops_with_outer_loops( + kernel, + frozenset(inames), + frozenset(outer_inames), + vng, + prefix, + ) + ) + + return fused_chunks + + +def rename_inames_in_batch( + kernel: LoopKernel, + batches: Mapping[InameStr, Collection[InameStr]] + ): + """ + Returns a copy of *kernel* with inames renamed according to *batches*. + + :arg kernel: An instance of :class:`loopy.LoopKernel`. + :arg batches: A mapping from ``new_iname`` to a collection of + inames that are to be renamed to ``new_iname``. + """ + from loopy.transform.iname import remove_unused_inames, rename_inames + + for new_iname, candidates in batches.items(): + kernel = cast("LoopKernel", rename_inames( + kernel, candidates, new_iname, + + # type-ignore because remove_newly_unused_inames is added by a + # decorator in a non-type-able way. + remove_newly_unused_inames=False # pyright: ignore[reportCallIssue] + )) + + return remove_unused_inames(kernel, fset_union(batches.values()), + ) + + +# vim: foldmethod=marker From 5a7738fe3793bb185af4ccb40c5b68a319cf054a Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 31 Oct 2021 16:56:41 -0500 Subject: [PATCH 08/12] Test loop fusion implementation --- pyproject.toml | 7 +- test/test_loop_fusion.py | 493 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 497 insertions(+), 3 deletions(-) create mode 100644 test/test_loop_fusion.py diff --git a/pyproject.toml b/pyproject.toml index 6340946c2..fe2b676ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,10 +214,11 @@ exclude = [ [[tool.basedpyright.executionEnvironments]] root = "test" -reportUnknownArgumentType = "hint" -reportUnknownVariableType = "hint" +reportUnknownArgumentType = "none" +reportUnknownVariableType = "none" +reportUnknownMemberType = "hint" +reportUnknownParameterType = "none" reportMissingParameterType = "none" reportAttributeAccessIssue = "none" reportMissingImports = "none" reportArgumentType = "hint" -reportUnknownMemberType = "hint" diff --git a/test/test_loop_fusion.py b/test/test_loop_fusion.py new file mode 100644 index 000000000..100abaefc --- /dev/null +++ b/test/test_loop_fusion.py @@ -0,0 +1,493 @@ +__copyright__ = "Copyright (C) 2021-25 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import logging +import sys + +import numpy as np + +import pyopencl as cl + +import loopy as lp + + +logger = logging.getLogger(__name__) + +import faulthandler + +from pyopencl.tools import ( + pytest_generate_tests_for_pyopencl as pytest_generate_tests, +) + +from loopy.version import ( + LOOPY_USE_LANGUAGE_VERSION_2018_2, # noqa # pyright: ignore[reportUnusedImport] +) + + +__all__ = ["cl", "pytest_generate_tests"] # "cl.create_some_context" + + +faulthandler.enable() + + +def test_loop_fusion_vanilla(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i0, i1, j0, j1]: 0 <= i0, i1, j0, j1 < 10}", + """ + a[i0] = 1 + b[i1, j0] = 2 {id=write_b} + c[j1] = 3 {id=write_c} + """, + ) + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates( + knl["loopy_kernel"], frozenset(["j0", "j1"]) + ) + + knl = knl.with_kernel( + lp.rename_inames_in_batch(knl["loopy_kernel"], fused_chunks) + ) + assert len(ref_knl["loopy_kernel"].all_inames()) == 4 + assert len(knl["loopy_kernel"].all_inames()) == 3 + assert ( + len( + knl["loopy_kernel"].id_to_insn["write_b"].within_inames + & knl["loopy_kernel"].id_to_insn["write_c"].within_inames + ) + == 1 + ) + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_outer_iname_preventing_fusion(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i0, j0, j1]: 0 <= i0, j0, j1 < 10}", + """ + a[i0] = 1 + b[i0, j0] = 2 {id=write_b} + c[j1] = 3 {id=write_c} + """, + ) + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates( + knl["loopy_kernel"], frozenset(["j0", "j1"]) + ) + + knl = knl.with_kernel( + lp.rename_inames_in_batch(knl["loopy_kernel"], fused_chunks) + ) + + assert len(knl["loopy_kernel"].all_inames()) == 3 + assert len(knl["loopy_kernel"].all_inames()) == 3 + assert ( + len( + knl["loopy_kernel"].id_to_insn["write_b"].within_inames + & knl["loopy_kernel"].id_to_insn["write_c"].within_inames + ) + == 0 + ) + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_with_loop_independent_deps(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[j0, j1]: 0 <= j0, j1 < 10}", + """ + a[j0] = 1 + b[j1] = 2 * a[j1] + """, + seq_dependencies=True, + ) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates( + knl["loopy_kernel"], frozenset(["j0", "j1"]) + ) + + knl = knl.with_kernel( + lp.rename_inames_in_batch(knl["loopy_kernel"], fused_chunks) + ) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 1 + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_constrained_by_outer_loop_deps(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[j0, j1]: 0 <= j0, j1 < 10}", + """ + a[j0] = 1 {id=write_a} + b = 2 {id=write_b} + c[j1] = 2 * a[j1] {id=write_c} + """, + seq_dependencies=True, + ) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates( + knl["loopy_kernel"], frozenset(["j0", "j1"]) + ) + + knl = knl.with_kernel( + lp.rename_inames_in_batch(knl["loopy_kernel"], fused_chunks) + ) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 2 + assert ( + len( + knl["loopy_kernel"].id_to_insn["write_a"].within_inames + & knl["loopy_kernel"].id_to_insn["write_c"].within_inames + ) + == 0 + ) + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_with_loop_carried_deps1(ctx_factory: cl.CtxFactory): + + ctx = ctx_factory() + knl = lp.make_kernel( + "{[i0, i1]: 1<=i0, i1<10}", + """ + x[i0] = i0 {id=first_write} + x[i1-1] = i1 ** 2 {id=second_write} + """, + seq_dependencies=True, + ) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates( + knl["loopy_kernel"], frozenset(["i0", "i1"]) + ) + + knl = knl.with_kernel( + lp.rename_inames_in_batch(knl["loopy_kernel"], fused_chunks) + ) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 1 + assert ( + len( + knl["loopy_kernel"].id_to_insn["first_write"].within_inames + & knl["loopy_kernel"].id_to_insn["second_write"].within_inames + ) + == 1 + ) + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_with_loop_carried_deps2(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + knl = lp.make_kernel( + "{[i0, i1]: 1<=i0, i1<10}", + """ + x[i0-1] = i0 {id=first_write} + x[i1] = i1 ** 2 {id=second_write} + """, + seq_dependencies=True, + ) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates( + knl["loopy_kernel"], frozenset(["i0", "i1"]) + ) + + knl = knl.with_kernel( + lp.rename_inames_in_batch(knl["loopy_kernel"], fused_chunks) + ) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 2 + assert ( + len( + knl["loopy_kernel"].id_to_insn["first_write"].within_inames + & knl["loopy_kernel"].id_to_insn["second_write"].within_inames + ) + == 0 + ) + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_with_indirection(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + rng = np.random.default_rng(42) + map_ = rng.permutation(10) + cq = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + "{[i0, i1]: 0<=i0, i1<10}", + """ + x[i0] = i0 {id=first_write} + x[map[i1]] = i1 ** 2 {id=second_write} + """, + seq_dependencies=True, + ) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates( + knl["loopy_kernel"], frozenset(["i0", "i1"]) + ) + + knl = knl.with_kernel( + lp.rename_inames_in_batch(knl["loopy_kernel"], fused_chunks) + ) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 2 + assert ( + len( + knl["loopy_kernel"].id_to_insn["first_write"].within_inames + & knl["loopy_kernel"].id_to_insn["second_write"].within_inames + ) + == 0 + ) + + _, (out1,) = ref_knl(cq, map=map_) + _, (out2,) = knl(cq, map=map_) + np.testing.assert_allclose(out1, out2) + + +def test_loop_fusion_with_induced_dependencies_from_sibling_nests( + ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + t_unit = lp.make_kernel( + "{[i0, j, i1, i2]: 0<=i0, j, i1, i2<10}", + """ + <> tmp0[i0] = i0 + <> tmp1[j] = tmp0[j] + <> tmp2[j] = j + out1[i1] = tmp2[i1] + out2[i2] = 2 * tmp1[i2] + """, + ) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates(knl, frozenset(["i0", "i1"])), + ) + t_unit = t_unit.with_kernel(knl) + + # 'i1', 'i2' should not be fused. If fused that would lead to an + # unshcedulable kernel. Making sure that the kernel 'runs' suffices that + # the transformation was successful. + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_loop_fusion_on_reduction_inames(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + "{[i, j0, j1, j2]: 0<=i, j0, j1, j2<10}", + """ + y0[i] = sum(j0, sum([j1], 2*A[i, j0, j1])) + y1[i] = sum(j0, sum([j2], 3*A[i, j0, j2])) + """, + [lp.GlobalArg("A", dtype=np.float64, shape=lp.auto), ...], + ) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates(knl, frozenset(["j1", "j2"])), + ) + assert ( + knl.id_to_insn["insn"].reduction_inames() + == knl.id_to_insn["insn_0"].reduction_inames() + ) + + t_unit = t_unit.with_kernel(knl) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_loop_fusion_on_reduction_inames_with_depth_mismatch( + ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + "{[i, j0, j1, j2, j3]: 0<=i, j0, j1, j2, j3<10}", + """ + y0[i] = sum(j0, sum([j1], 2*A[i, j0, j1])) + y1[i] = sum(j2, sum([j3], 3*A[i, j3, j2])) + """, + [lp.GlobalArg("A", dtype=np.float64, shape=lp.auto), ...], + ) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates(knl, frozenset(["j1", "j3"])), + ) + + # cannot fuse 'j1', 'j3' because they are not nested within the same outer + # inames. + assert ( + knl.id_to_insn["insn"].reduction_inames() + != knl.id_to_insn["insn_0"].reduction_inames() + ) + + t_unit = t_unit.with_kernel(knl) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_loop_fusion_on_outer_reduction_inames(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + "{[i, j0, j1, j2, j3]: 0<=i, j0, j1, j2, j3<10}", + """ + y0[i] = sum(j0, sum([j1], 2*A[i, j0, j1])) + y1[i] = sum(j2, sum([j3], 3*A[i, j3, j2])) + """, + [lp.GlobalArg("A", dtype=np.float64, shape=lp.auto), ...], + ) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates(knl, frozenset(["j0", "j2"])), + ) + + assert ( + len( + knl.id_to_insn["insn"].reduction_inames() + & knl.id_to_insn["insn_0"].reduction_inames() + ) + == 1 + ) + + t_unit = t_unit.with_kernel(knl) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_loop_fusion_reduction_inames_simple(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + "{[i, j0, j1]: 0<=i, j0, j1<10}", + """ + y0[i] = sum(j0, 2*A[i, j0]) + y1[i] = sum(j1, 3*A[i, j1]) + """, + [lp.GlobalArg("A", dtype=np.float64, shape=lp.auto), ...], + ) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates(knl, frozenset(["j0", "j1"])), + ) + + assert ( + knl.id_to_insn["insn"].reduction_inames() + == knl.id_to_insn["insn_0"].reduction_inames() + ) + + t_unit = t_unit.with_kernel(knl) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_redn_loop_fusion_with_non_candidates_loops_in_nest(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + t_unit = lp.make_kernel( + "{[i, j1, j2, d]: 0<=i, j1, j2, d<10}", + """ + for i + for d + out1[i, d] = sum(j1, 2 * j1*i) + end + out2[i] = sum(j2, 2 * j2) + end + """, + seq_dependencies=True, + ) + ref_t_unit = t_unit + + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates(knl, frozenset(["j1", "j2"])), + ) + + assert not ( + knl.id_to_insn["insn"].reduction_inames() + & knl.id_to_insn["insn_0"].reduction_inames() + ) + + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit.with_kernel(knl)) + + +def test_reduction_loop_fusion_with_multiple_redn_in_same_insn( + ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + t_unit = lp.make_kernel( + "{[j1, j2]: 0<=j1, j2<10}", + """ + out = sum(j1, 2*j1) + sum(j2, 2*j2) + """, + seq_dependencies=True, + ) + ref_t_unit = t_unit + + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates(knl, frozenset(["j1", "j2"])), + ) + + assert len(knl.id_to_insn["insn"].reduction_inames()) == 1 + + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit.with_kernel(knl)) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + + main([__file__]) + +# vim: fdm=marker From 4d868a391dc916f551dca1f6efb130a30fd2862a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 17 Jun 2025 21:25:03 -0500 Subject: [PATCH 09/12] function_interface: don't insert kw keys into arg_id_to_descr --- loopy/kernel/function_interface.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index c43cf494a..05ddb1758 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -924,15 +924,16 @@ def with_descrs(self, for arg in subkernel.args: kw = arg.name if isinstance(arg, ArrayBase): - new_arg_id_to_descr[kw] = ( + new_arg_descriptor = ( ArrayArgDescriptor(shape=arg.shape, dim_tags=arg.dim_tags, address_space=arg.address_space)) else: assert isinstance(arg, ValueArg) - new_arg_id_to_descr[kw] = ValueArgDescriptor() + new_arg_descriptor = ValueArgDescriptor() - new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_id_to_descr[kw] + # FIXME: Should decide what the canonical arg identifiers are + new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_descriptor # }}} From e9987770b01e3f0ce197811a42d39ee0aba909a6 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 24 Jun 2025 13:31:06 -0500 Subject: [PATCH 10/12] KernelArgument is Taggable --- loopy/kernel/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index bceea1c4d..f48cbe406 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -395,7 +395,7 @@ def stringify(cls, val: AddressSpace | type[auto]) -> str: # {{{ arguments -class KernelArgument(ImmutableRecord): +class KernelArgument(ImmutableRecord, Taggable): """Base class for all argument types. .. attribute:: name From f859f1489fa073038d8354602a702332164eccef Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 24 Jun 2025 13:31:28 -0500 Subject: [PATCH 11/12] Sort bpr config --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fe2b676ed..17f2ad465 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,11 +214,11 @@ exclude = [ [[tool.basedpyright.executionEnvironments]] root = "test" +reportArgumentType = "hint" +reportAttributeAccessIssue = "none" +reportMissingImports = "none" +reportMissingParameterType = "none" reportUnknownArgumentType = "none" -reportUnknownVariableType = "none" reportUnknownMemberType = "hint" reportUnknownParameterType = "none" -reportMissingParameterType = "none" -reportAttributeAccessIssue = "none" -reportMissingImports = "none" -reportArgumentType = "hint" +reportUnknownVariableType = "none" From 50bb1c09455b00f5f661268db831afbe7b4a3150 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 17 Jun 2025 21:27:04 -0500 Subject: [PATCH 12/12] Update baseline --- .basedpyright/baseline.json | 4726 ++--------------------------------- 1 file changed, 231 insertions(+), 4495 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 4e056eefd..58f2ec213 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -2359,14 +2359,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 6, - "endColumn": 23, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -6969,110 +6961,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 17, - "endColumn": 67, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 17, - "endColumn": 67, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 27, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 17, - "endColumn": 67, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 17, - "endColumn": 67, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 27, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 67, - "endColumn": 73, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 20, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 44, - "endColumn": 61, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -7209,38 +7097,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 32, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 32, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 35, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 44, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -22145,22 +22001,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 33, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -22201,14 +22041,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -22225,14 +22057,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -24957,30 +24781,6 @@ } ], "./loopy/kernel/creation.py": [ - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 28, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 38, - "endColumn": 39, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -31588,14 +31388,6 @@ "endColumn": 22, "lineCount": 1 } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } } ], "./loopy/kernel/data.py": [ @@ -34689,22 +34481,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 17, - "endColumn": 25, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -40197,22 +39973,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 19, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 47, - "endColumn": 58, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -43437,358 +43197,6 @@ "lineCount": 1 } }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 26, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 21, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 42, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 42, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 20, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 23, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 52, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 52, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 46, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 25, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 23, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 26, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 38, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 38, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 30, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 22, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 22, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 26, - "lineCount": 1 - } - }, { "code": "reportUnnecessaryComparison", "range": { @@ -43797,14 +43205,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 31, - "endColumn": 59, - "lineCount": 1 - } - }, { "code": "reportUnnecessaryComparison", "range": { @@ -43813,54 +43213,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 31, - "endColumn": 56, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 41, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 44, - "endColumn": 72, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 73, - "endColumn": 76, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 24, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 24, - "endColumn": 35, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -43869,78 +43221,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportMissingTypeArgument", - "range": { - "startColumn": 27, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 15, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 15, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 22, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 15, - "endColumn": 56, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 36, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 44, - "endColumn": 55, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -43997,102 +43277,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 27, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 25, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 47, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 25, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 25, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -44213,14 +43397,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 33, - "endColumn": 61, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -49059,14 +48235,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 28, - "endColumn": 62, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -49333,14 +48501,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 50, - "endColumn": 54, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -54777,22 +53937,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 58, - "endColumn": 62, - "lineCount": 1 - } - }, { "code": "reportPossiblyUnboundVariable", "range": { @@ -54801,14 +53945,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 62, - "endColumn": 66, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -54816,22 +53952,6 @@ "endColumn": 45, "lineCount": 1 } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 25, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 25, - "endColumn": 40, - "lineCount": 1 - } } ], "./loopy/schedule/tree.py": [ @@ -66703,62 +65823,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 28, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 55, - "endColumn": 74, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 55, - "endColumn": 74, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 26, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -67071,30 +66135,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 45, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownLambdaType", - "range": { - "startColumn": 60, - "endColumn": 64, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -67106,63 +66146,23 @@ { "code": "reportUnknownParameterType", "range": { - "startColumn": 23, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 23, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 45, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 57, - "endColumn": 63, + "startColumn": 16, + "endColumn": 26, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 57, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 25, - "endColumn": 45, + "startColumn": 16, + "endColumn": 26, "lineCount": 1 } }, { - "code": "reportUnannotatedClassAttribute", + "code": "reportUnknownMemberType", "range": { - "startColumn": 13, + "startColumn": 8, "endColumn": 23, "lineCount": 1 } @@ -67171,135 +66171,39 @@ "code": "reportUnannotatedClassAttribute", "range": { "startColumn": 13, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 14, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 21, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 29, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 29, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 42, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 42, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 20, + "endColumn": 23, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 23, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 23, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 45, - "endColumn": 50, + "startColumn": 17, + "endColumn": 32, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 52, - "endColumn": 58, + "startColumn": 16, + "endColumn": 21, "lineCount": 1 } }, { "code": "reportMissingParameterType", "range": { - "startColumn": 52, - "endColumn": 58, + "startColumn": 16, + "endColumn": 21, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 25, - "endColumn": 45, + "startColumn": 8, + "endColumn": 18, "lineCount": 1 } }, @@ -67319,14 +66223,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -67359,6 +66255,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 19, + "endColumn": 29, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -90051,22 +88955,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 47, - "endColumn": 64, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -93123,22 +92011,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 26, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -94205,22 +93077,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 48, - "endColumn": 62, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -96023,22 +94879,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 17, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 52, - "endColumn": 63, - "lineCount": 2 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -96071,22 +94911,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 17, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 52, - "endColumn": 67, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -96583,22 +95407,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 53, - "endColumn": 68, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -96681,14 +95489,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -98843,54 +97643,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 27, - "endColumn": 59, - "lineCount": 1 - } - }, - { - "code": "reportUnknownLambdaType", - "range": { - "startColumn": 34, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownLambdaType", - "range": { - "startColumn": 42, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownLambdaType", - "range": { - "startColumn": 48, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 48, - "endColumn": 63, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -101117,14 +99869,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -101205,14 +99949,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 53, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -102039,14 +100775,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 31, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -103394,18 +102122,10 @@ } }, { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 13, - "endColumn": 47, + "startColumn": 43, + "endColumn": 63, "lineCount": 1 } }, @@ -103601,6 +102321,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 8, + "endColumn": 24, + "lineCount": 1 + } + }, { "code": "reportUnknownArgumentType", "range": { @@ -103737,14 +102465,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 23, - "endColumn": 34, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -104353,22 +103073,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 28, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -105337,22 +104041,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 28, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -106633,22 +105321,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 29, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -106977,22 +105649,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 37, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -107513,22 +106169,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 29, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -109161,14 +107801,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -109177,14 +107809,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -109201,86 +107825,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 32, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 40, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 40, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 47, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 47, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 47, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -110811,22 +109355,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 46, - "endColumn": 71, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -111511,14 +110039,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -111879,14 +110399,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -112289,198 +110801,6 @@ } ], "./loopy/transform/parameter.py": [ - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 19, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 19, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 27, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 33, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 40, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 40, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 20, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 43, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 45, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 45, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 54, - "endColumn": 58, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 66, - "endColumn": 75, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 31, - "lineCount": 2 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 28, - "lineCount": 4 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 53, - "endColumn": 73, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 31, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 47, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 33, - "endColumn": 46, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -112513,22 +110833,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 19, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -112537,14 +110841,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 28, - "endColumn": 31, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -112561,171 +110857,11 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 22, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 33, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 31, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 12, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 34, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 34, - "endColumn": 65, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 38, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 40, - "endColumn": 58, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 40, - "endColumn": 58, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 19, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 19, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 29, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 29, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { "startColumn": 46, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 53, - "endColumn": 59, + "endColumn": 54, "lineCount": 1 } } @@ -115219,14 +113355,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -126347,94 +124475,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 22, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 22, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 30, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 30, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 37, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 37, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 27, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -126715,22 +124755,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 46, - "endColumn": 63, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -127677,22 +125701,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -127917,22 +125925,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 53, - "endColumn": 77, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -128327,22 +126319,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 35, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -129335,14 +127311,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 31, - "endColumn": 45, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -129699,94 +127667,6 @@ } ], "./test/gnuma_loopy_transforms.py": [ - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 39, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 48, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 37, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -129794,65 +127674,9 @@ "endColumn": 36, "lineCount": 1 } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 31, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 30, - "lineCount": 1 - } } ], "./test/library_for_test.py": [ - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -129869,30 +127693,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 42, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 19, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -129909,22 +127709,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 43, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -129941,22 +127725,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -129973,22 +127741,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 42, - "endColumn": 51, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130005,14 +127757,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 19, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -130029,22 +127773,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 43, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -130061,22 +127789,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 13, - "endColumn": 16, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -130085,22 +127797,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 19, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -130146,133 +127842,13 @@ { "code": "reportUnknownMemberType", "range": { - "startColumn": 12, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 8, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnusedFunction", - "range": { - "startColumn": 8, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, + "startColumn": 12, + "endColumn": 24, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnusedFunction", "range": { "startColumn": 8, "endColumn": 17, @@ -130280,26 +127856,26 @@ } }, { - "code": "reportUnknownParameterType", + "code": "reportUnusedFunction", "range": { - "startColumn": 18, - "endColumn": 21, + "startColumn": 8, + "endColumn": 17, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 8, - "endColumn": 17, + "startColumn": 14, + "endColumn": 29, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 18, - "endColumn": 21, + "startColumn": 14, + "endColumn": 29, "lineCount": 1 } }, @@ -130314,16 +127890,16 @@ { "code": "reportUnknownMemberType", "range": { - "startColumn": 16, - "endColumn": 28, + "startColumn": 12, + "endColumn": 24, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 10, - "endColumn": 27, + "startColumn": 12, + "endColumn": 24, "lineCount": 1 } }, @@ -130338,11 +127914,27 @@ { "code": "reportUnknownMemberType", "range": { - "startColumn": 10, + "startColumn": 8, "endColumn": 27, "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 16, + "endColumn": 28, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 12, + "endColumn": 24, + "lineCount": 1 + } + }, { "code": "reportAny", "range": { @@ -130399,14 +127991,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130415,14 +127999,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130447,14 +128023,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 29, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130575,14 +128143,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnusedFunction", "range": { @@ -130591,14 +128151,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130607,14 +128159,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130623,14 +128167,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130663,22 +128199,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 13, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 14, - "endColumn": 15, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130817,6 +128337,14 @@ "lineCount": 1 } }, + { + "code": "reportAny", + "range": { + "startColumn": 17, + "endColumn": 21, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -130825,6 +128353,14 @@ "lineCount": 1 } }, + { + "code": "reportAny", + "range": { + "startColumn": 21, + "endColumn": 25, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -130841,6 +128377,14 @@ "lineCount": 1 } }, + { + "code": "reportAny", + "range": { + "startColumn": 17, + "endColumn": 21, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -130849,6 +128393,14 @@ "lineCount": 1 } }, + { + "code": "reportAny", + "range": { + "startColumn": 21, + "endColumn": 25, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -130937,46 +128489,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 14, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 15, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 35, - "endColumn": 45, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -130993,14 +128505,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 31, - "lineCount": 1 - } - }, { "code": "reportUnknownLambdaType", "range": { @@ -131009,14 +128513,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 23, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131059,14 +128555,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 56, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131075,14 +128563,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 63, - "endColumn": 69, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131091,14 +128571,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 63, - "endColumn": 69, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131163,14 +128635,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 77, - "endColumn": 83, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131315,22 +128779,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131339,14 +128787,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 55, - "endColumn": 61, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131459,14 +128899,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 58, - "endColumn": 64, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131475,14 +128907,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 68, - "endColumn": 74, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131691,14 +129115,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 76, - "endColumn": 82, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131787,14 +129203,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 62, - "endColumn": 68, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131819,14 +129227,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 41, - "endColumn": 47, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131843,14 +129243,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 63, - "endColumn": 69, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131859,22 +129251,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 54, - "endColumn": 59, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 61, - "endColumn": 67, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131971,14 +129347,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 57, - "endColumn": 63, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -131987,30 +129355,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 24, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 70, - "endColumn": 76, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132019,14 +129363,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 61, - "endColumn": 67, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -132173,62 +129509,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 36, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 24, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132237,14 +129517,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 30, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132253,22 +129525,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 35, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132277,38 +129533,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 31, - "endColumn": 34, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132317,22 +129541,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 30, - "endColumn": 33, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132389,30 +129597,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 25, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132447,14 +129631,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132657,14 +129833,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 33, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132803,14 +129971,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 63, - "endColumn": 67, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132827,14 +129987,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 57, - "endColumn": 61, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132851,14 +130003,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 64, - "endColumn": 68, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132875,14 +130019,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 64, - "endColumn": 68, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -132899,14 +130035,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 66, - "endColumn": 70, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -133045,30 +130173,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 40, - "endColumn": 46, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -133093,14 +130197,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 11, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -133109,54 +130205,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 30, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 38, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 56, - "endColumn": 67, - "lineCount": 1 - } - }, { "code": "reportUnusedVariable", "range": { @@ -133165,46 +130213,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 31, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 39, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 51, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 57, - "endColumn": 63, - "lineCount": 1 - } - }, { "code": "reportUnusedVariable", "range": { @@ -133213,46 +130221,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 39, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 47, - "endColumn": 57, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 59, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 65, - "endColumn": 71, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -133277,38 +130245,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 34, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 24, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -133533,14 +130469,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 53, - "endColumn": 59, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -133573,14 +130501,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 61, - "endColumn": 66, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -133669,14 +130589,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 70, - "endColumn": 76, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -133927,14 +130839,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -134007,14 +130911,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 57, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -134095,14 +130991,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -134399,14 +131287,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 25, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -134423,22 +131303,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 17, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 35, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -134703,14 +131567,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -134735,14 +131591,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -134759,14 +131607,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -134783,14 +131623,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135137,14 +131969,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135481,14 +132305,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 57, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135505,14 +132321,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135681,14 +132489,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 29, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135705,14 +132505,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 29, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135745,14 +132537,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 29, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135833,14 +132617,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 49, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135865,14 +132641,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 49, - "endColumn": 54, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135881,14 +132649,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135905,14 +132665,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 26, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135945,14 +132697,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -135961,22 +132705,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 19, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 34, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136009,30 +132737,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 60, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136057,14 +132761,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 60, - "endColumn": 67, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136537,14 +133233,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136561,14 +133249,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136625,22 +133305,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136657,22 +133321,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 58, - "endColumn": 65, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136689,14 +133337,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136705,22 +133345,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 54, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 65, - "endColumn": 74, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136737,14 +133361,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136777,14 +133393,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -136969,22 +133577,6 @@ "lineCount": 2 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -137001,38 +133593,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 30, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 35, - "endColumn": 60, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -137233,22 +133793,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 47, - "endColumn": 59, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 61, - "endColumn": 75, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -137329,22 +133873,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 17, - "endColumn": 34, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -137481,38 +134009,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 41, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 52, - "endColumn": 60, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 30, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 42, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -137577,14 +134073,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 55, - "endColumn": 63, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -137891,14 +134379,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 20, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -137907,14 +134387,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138091,14 +134563,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 37, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -138115,22 +134579,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 47, - "endColumn": 58, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138195,30 +134643,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 31, - "endColumn": 36, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138235,22 +134659,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 30, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138293,14 +134701,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 17, - "lineCount": 1 - } - }, { "code": "reportUnusedFunction", "range": { @@ -138309,14 +134709,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 18, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnusedFunction", "range": { @@ -138325,14 +134717,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 23, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138349,14 +134733,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 23, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138407,30 +134783,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 56, - "endColumn": 68, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 70, - "endColumn": 72, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 74, - "endColumn": 83, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138471,14 +134823,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138849,70 +135193,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 62, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 17, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 17, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 17, - "endColumn": 20, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138921,14 +135201,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 63, - "endColumn": 67, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -138977,14 +135249,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 66, - "endColumn": 70, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139073,22 +135337,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 55, - "endColumn": 62, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 64, - "endColumn": 69, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -139243,30 +135491,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 53, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 56, - "endColumn": 62, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139283,30 +135507,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 36, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139355,30 +135555,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 49, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 56, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139395,22 +135571,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 57, - "endColumn": 58, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139419,14 +135579,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139451,22 +135603,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 68, - "endColumn": 83, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139483,38 +135619,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 53, - "endColumn": 58, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 57, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 59, - "endColumn": 64, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 49, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139539,30 +135643,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 65, - "endColumn": 68, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -139579,14 +135659,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 30, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -139610,14 +135682,6 @@ "endColumn": 24, "lineCount": 1 } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } } ], "./test/test_sem_reagan.py": [ @@ -139669,70 +135733,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -139753,121 +135753,113 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 14, - "endColumn": 41, + "endColumn": 27, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 8, + "startColumn": 14, "endColumn": 27, "lineCount": 1 } - } - ], - "./test/test_split_iname_slabs.py": [ - { - "code": "reportUnusedImport", - "range": { - "startColumn": 42, - "endColumn": 63, - "lineCount": 1 - } }, { - "code": "reportUnusedImport", + "code": "reportUnknownMemberType", "range": { - "startColumn": 26, - "endColumn": 59, + "startColumn": 14, + "endColumn": 27, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 24, - "endColumn": 36, + "startColumn": 14, + "endColumn": 27, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 4, - "endColumn": 9, + "startColumn": 14, + "endColumn": 29, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 10, - "endColumn": 17, + "startColumn": 14, + "endColumn": 29, "lineCount": 1 } }, { - "code": "reportAny", + "code": "reportUnknownMemberType", "range": { - "startColumn": 27, - "endColumn": 28, + "startColumn": 14, + "endColumn": 41, "lineCount": 1 } }, { - "code": "reportAny", + "code": "reportUnknownMemberType", "range": { - "startColumn": 30, - "endColumn": 48, + "startColumn": 8, + "endColumn": 27, "lineCount": 1 } - }, + } + ], + "./test/test_split_iname_slabs.py": [ { - "code": "reportUnknownParameterType", + "code": "reportUnusedImport", "range": { - "startColumn": 4, - "endColumn": 14, + "startColumn": 42, + "endColumn": 63, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnusedImport", "range": { - "startColumn": 15, - "endColumn": 22, + "startColumn": 26, + "endColumn": 59, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 26, - "endColumn": 39, + "startColumn": 24, + "endColumn": 36, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportAny", "range": { - "startColumn": 49, - "endColumn": 56, + "startColumn": 27, + "endColumn": 28, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportAny", "range": { - "startColumn": 58, - "endColumn": 63, + "startColumn": 30, + "endColumn": 48, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 65, - "endColumn": 75, + "startColumn": 26, + "endColumn": 39, "lineCount": 1 } }, @@ -141505,22 +137497,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 23, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -141927,34 +137903,10 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 14, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, "endColumn": 29, "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 31, - "endColumn": 33, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -141979,14 +137931,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 47, - "endColumn": 49, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -142163,134 +138107,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 49, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 9, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 23, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 43, - "endColumn": 57, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 11, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 71, - "endColumn": 77, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 9, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 31, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 49, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 57, - "endColumn": 62, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 9, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -142323,6 +138139,62 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 11, + "endColumn": 30, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 11, + "endColumn": 30, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 11, + "endColumn": 30, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 10, + "endColumn": 24, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 9, + "endColumn": 21, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 16, + "endColumn": 26, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 31, + "endColumn": 38, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -142340,10 +138212,42 @@ } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 27, - "endColumn": 32, + "startColumn": 10, + "endColumn": 24, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 23, + "endColumn": 28, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 43, + "endColumn": 57, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 9, + "endColumn": 21, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 9, + "endColumn": 21, "lineCount": 1 } }, @@ -142379,14 +138283,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 37, - "lineCount": 1 - } - }, { "code": "reportMissingTypeStubs", "range": { @@ -142419,14 +138315,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -142523,14 +138411,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 68, - "endColumn": 81, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -142741,22 +138621,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -142765,14 +138629,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 31, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -142781,14 +138637,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 31, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -142797,14 +138645,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 49, - "endColumn": 63, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -142813,22 +138653,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 31, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -143229,14 +139053,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -143285,14 +139101,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -143421,30 +139229,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 38, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 47, - "endColumn": 54, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -143765,14 +139549,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 27, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -144329,30 +140105,6 @@ } ], "./test/testlib.py": [ - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -144369,22 +140121,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 42, - "endColumn": 57, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": {