From 4d9e4584cd27c8f9a283d8c6a1f502dd31a23787 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 11 Nov 2024 17:03:25 -0600 Subject: [PATCH] fix (?) remaining mypy errors, somewhat sketchy in parts --- loopy/statistics.py | 137 ++++++++++++++++++++-------------- loopy/transform/precompute.py | 2 +- 2 files changed, 80 insertions(+), 59 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 9c14891cd..a7355d60a 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -31,7 +31,18 @@ from dataclasses import dataclass, replace from enum import Enum, auto as enum_auto from functools import cached_property, partial -from typing import Any, Callable, Generic, Iterable, Mapping, Optional, TypeVar, Union +from typing import ( + Any, + Callable, + Generic, + Iterable, + Mapping, + Sequence, + Type, + TypeVar, + Union, + cast, +) from immutabledict import immutabledict @@ -39,6 +50,7 @@ import pymbolic.primitives as p from islpy import PwQPolynomial, dim_type from pymbolic.mapper import CombineMapper +from pymbolic.typing import ArithmeticExpressionT from pytools import memoize_method from pytools.tag import Tag @@ -46,8 +58,8 @@ from loopy.diagnostic import LoopyError, warn_with_kernel from loopy.kernel import LoopKernel from loopy.kernel.array import ArrayBase -from loopy.kernel.data import AddressSpace, InameImplementationTag, MultiAssignmentBase -from loopy.kernel.function_interface import CallableKernel, InKernelCallable +from loopy.kernel.data import AddressSpace, MultiAssignmentBase +from loopy.kernel.function_interface import CallableKernel from loopy.kernel.instruction import InstructionBase from loopy.symbolic import ( CoefficientCollector, @@ -56,9 +68,9 @@ TaggedExpression, flatten, ) -from loopy.translation_unit import TranslationUnit +from loopy.translation_unit import ConcreteCallablesTable, TranslationUnit from loopy.types import LoopyType -from loopy.typing import Expression +from loopy.typing import Expression, ExpressionT, auto __doc__ = """ @@ -277,7 +289,7 @@ def __len__(self) -> int: return len(self.count_map) def get(self, - key: Countable, default: Optional[CountT] = None) -> Optional[CountT]: + key: Countable, default: CountT | None = None) -> CountT | None: return self.count_map.get(key, default) def items(self): @@ -290,7 +302,7 @@ def values(self): return self.count_map.values() def copy( - self, count_map: Optional[dict[Countable, CountT]] = None + self, count_map: dict[Countable, CountT] | None = None ) -> ToCountMap[CountT]: if count_map is None: count_map = self.count_map @@ -686,8 +698,8 @@ class Op: .. attribute:: count_granularity - A :class:`str` that specifies whether this operation should be counted - once per *work-item*, *sub-group*, or *work-group*. The granularities + A :class:`CountGranularity` that specifies whether this operation should be + counted once per *work-item*, *sub-group*, or *work-group*. The granularities allowed can be found in :class:`CountGranularity`, and may be accessed, e.g., as ``CountGranularity.WORKITEM``. A work-item is a single instance of computation executing on a single processor (think "thread"), a @@ -816,16 +828,16 @@ class MemAccess: A :class:`frozenset` of tags to the operation. """ - address_space: Optional[AddressSpace] = None - dtype: Optional[LoopyType] = None - lid_strides: Optional[Mapping[int, Expression]] = None - gid_strides: Optional[Mapping[int, Expression]] = None - read_write: Optional[AccessDirection] = None - variable: Optional[str] = None + address_space: AddressSpace | Type[auto] | None = None + dtype: LoopyType | None = None + lid_strides: Mapping[int, Expression] | None = None + gid_strides: Mapping[int, Expression] | None = None + read_write: AccessDirection | None = None + variable: str | None = None variable_tags: frozenset[Tag] = frozenset() - count_granularity: Optional[CountGranularity] = None - kernel_name: Optional[str] = None + count_granularity: CountGranularity | None = None + kernel_name: str | None = None tags: frozenset[Tag] = frozenset() def __post_init__(self): @@ -927,8 +939,8 @@ class Sync: A :class:`frozenset` of tags attached to the synchronization. """ - sync_kind: Optional[SynchronizationKind] = None - kernel_name: Optional[str] = None + sync_kind: SynchronizationKind | None = None + kernel_name: str | None = None tags: frozenset[Tag] = frozenset() def __post_init__(self): @@ -1044,7 +1056,7 @@ def map_reduction( % (type(self).__name__, type(expr).__name__)) def __call__( - self, expr, tags: Optional[frozenset[Tag]] = None + self, expr, tags: frozenset[Tag] | None = None ) -> ToCountPolynomialMap: if tags is None: tags = frozenset() @@ -1111,7 +1123,8 @@ def map_sum(self, expr: p.Sum, tags: frozenset[Tag]) -> ToCountPolynomialMap: ) + sum(self.rec(child, tags) for child in expr.children) def map_product( - self, expr: p.Product, tags: frozenset[Tag]) -> ToCountPolynomialMap: + self, expr: p.Product, tags: frozenset[Tag]) \ + -> ToCountMap[GuardedPwQPolynomial]: from pymbolic.primitives import is_zero assert expr.children return sum(self.new_poly_map({Op(dtype=self.type_inf(expr), @@ -1122,7 +1135,7 @@ def map_product( kernel_name=self.knl.name): self.one}) + self.rec(child, tags) for child in expr.children - if not is_zero(child + 1)) + \ + if not is_zero(cast(ArithmeticExpressionT, child) + 1)) + \ self.new_poly_map({Op(dtype=self.type_inf(expr), op_type=OpType.MUL, tags=tags, @@ -1247,8 +1260,8 @@ def map_floor_div(self, expr): # {{{ _get_lid_and_gid_strides def _get_lid_and_gid_strides( - knl: LoopKernel, array: ArrayBase, index: tuple[Expression, ...] - ) -> tuple[Mapping[int, Expression], Mapping[int, Expression]]: + knl: LoopKernel, array: ArrayBase, index: tuple[ExpressionT, ...] + ) -> tuple[Mapping[int, ExpressionT], Mapping[int, ExpressionT]]: # find all local and global index tags and corresponding inames from loopy.symbolic import get_dependencies my_inames = get_dependencies(index) & knl.all_inames() @@ -1285,18 +1298,20 @@ def _get_lid_and_gid_strides( from loopy.symbolic import simplify_using_aff def get_iname_strides( - tag_to_iname_dict: Mapping[InameImplementationTag, str] - ) -> Mapping[InameImplementationTag, Expression]: + tag_to_iname_dict: Mapping[int, str] + ) -> Mapping[int, Expression]: tag_to_stride_dict = {} + from loopy.kernel.array import ArrayDimImplementationTag + if array.dim_tags is None: assert len(index) <= 1 - dim_tags = (None,) * len(index) + dim_tags: Sequence[ArrayDimImplementationTag | None] = (None,) * len(index) else: dim_tags = array.dim_tags for tag in tag_to_iname_dict: - total_iname_stride = 0 + total_iname_stride: Any = 0 # find total stride of this iname for each axis for idx, axis_tag in zip(index, dim_tags): # collect index coefficients @@ -1305,7 +1320,7 @@ def get_iname_strides( [tag_to_iname_dict[tag]])( simplify_using_aff(knl, idx)) except ExpressionNotAffineError: - total_iname_stride = None + total_iname_stride = 0 break # check if idx contains this iname @@ -1321,7 +1336,7 @@ def get_iname_strides( axis_tag_stride = axis_tag.stride if axis_tag_stride is lp.auto: - total_iname_stride = None + total_iname_stride = 0 break elif axis_tag is None: @@ -1363,7 +1378,7 @@ def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap: def count_var_access(self, dtype: LoopyType, name: str, - index: Optional[tuple[Expression, ...]], + index: ExpressionT | None, tags: frozenset[Tag], var_tags: frozenset[Tag] = frozenset() ) -> ToCountPolynomialMap: @@ -1457,10 +1472,12 @@ def map_variable( def map_subscript( self, expr: p.Subscript, tags: frozenset[Tag]) -> ToCountPolynomialMap: try: - var_tags = expr.aggregate.tags + var_tags = expr.aggregate.tags # type: ignore[union-attr] except AttributeError: var_tags = frozenset() + assert hasattr(expr.aggregate, "name") + return (self.count_var_access(self.type_inf(expr), expr.aggregate.name, expr.index, tags, var_tags) @@ -1713,7 +1730,7 @@ def count_inames_domain( def count_insn_runs( knl: LoopKernel, - callables_table: Mapping[str, InKernelCallable], + callables_table: ConcreteCallablesTable, insn: InstructionBase, count_redundant_work: bool, disregard_local_axes: bool = False) -> GuardedPwQPolynomial: @@ -1738,11 +1755,11 @@ def count_insn_runs( def _get_insn_count( knl: LoopKernel, - callables_table: Mapping[str, InKernelCallable], - insn_id: str, - subgroup_size: Optional[int], + callables_table: ConcreteCallablesTable, + insn_id: str | None, + subgroup_size: int | None, count_redundant_work: bool, - count_granularity: CountGranularity = CountGranularity.WORKITEM + count_granularity: CountGranularity | None = CountGranularity.WORKITEM ) -> GuardedPwQPolynomial: insn = knl.id_to_insn[insn_id] @@ -1813,10 +1830,10 @@ def _get_insn_count( def _get_op_map_for_single_kernel( knl: LoopKernel, - callables_table: Mapping[str, InKernelCallable], + callables_table: ConcreteCallablesTable, count_redundant_work: bool, count_within_subscripts: bool, - subgroup_size: int, within) -> ToCountPolynomialMap: + subgroup_size: int | None, within) -> ToCountMap[GuardedPwQPolynomial]: subgroup_size = _process_subgroup_size(knl, subgroup_size) @@ -1828,7 +1845,7 @@ def _get_op_map_for_single_kernel( op_counter = ExpressionOpCounter(knl, callables_table, kernel_rec, count_within_subscripts) - op_map = op_counter._new_zero_map() + op_map: ToCountMap[GuardedPwQPolynomial] = op_counter._new_zero_map() from loopy.kernel.instruction import ( Assignment, @@ -1843,6 +1860,7 @@ def _get_op_map_for_single_kernel( if isinstance(insn, (CallInstruction, Assignment)): ops = op_counter(insn.assignees) + op_counter(insn.expression) for key, val in ops.count_map.items(): + key = cast(Op, key) count = _get_insn_count(knl, callables_table, insn.id, subgroup_size, count_redundant_work, key.count_granularity) @@ -1861,9 +1879,9 @@ def _get_op_map_for_single_kernel( def get_op_map( t_unit: TranslationUnit, *, count_redundant_work: bool = False, count_within_subscripts: bool = True, - subgroup_size: Optional[int] = None, - entrypoint: Optional[str] = None, - within: Any = None): + subgroup_size: int | None = None, + entrypoint: str | None = None, + within: Any = None) -> ToCountMap[GuardedPwQPolynomial]: """Count the number of operations in a loopy kernel. @@ -1955,7 +1973,7 @@ def get_op_map( # {{{ subgroup size finding -def _find_subgroup_size_for_knl(knl): +def _find_subgroup_size_for_knl(knl: LoopKernel) -> int | None: from loopy.target.pyopencl import PyOpenCLTarget if isinstance(knl.target, PyOpenCLTarget) and knl.target.device is not None: from pyopencl.characterize import get_simd_group_size @@ -2013,9 +2031,9 @@ def _process_subgroup_size(knl, subgroup_size_requested): def _get_mem_access_map_for_single_kernel( knl: LoopKernel, - callables_table: Mapping[str, InKernelCallable], - count_redundant_work: bool, subgroup_size: Optional[int], - within: Any) -> ToCountPolynomialMap: + callables_table: ConcreteCallablesTable, + count_redundant_work: bool, subgroup_size: int | None, + within: Any) -> ToCountMap[GuardedPwQPolynomial]: subgroup_size = _process_subgroup_size(knl, subgroup_size) @@ -2025,7 +2043,7 @@ def _get_mem_access_map_for_single_kernel( subgroup_size=subgroup_size) access_counter = MemAccessCounter(knl, callables_table, kernel_rec) - access_map = access_counter._new_zero_map() + access_map: ToCountMap[GuardedPwQPolynomial] = access_counter._new_zero_map() from loopy.kernel.instruction import ( Assignment, @@ -2047,6 +2065,7 @@ def _get_mem_access_map_for_single_kernel( ).with_set_attributes(read_write=AccessDirection.WRITE) for key, val in insn_access_map.count_map.items(): + key = cast(MemAccess, key) count = _get_insn_count(knl, callables_table, insn.id, subgroup_size, count_redundant_work, key.count_granularity) @@ -2065,9 +2084,9 @@ def _get_mem_access_map_for_single_kernel( def get_mem_access_map( t_unit: TranslationUnit, *, count_redundant_work: bool = False, - subgroup_size: Optional[int] = None, - entrypoint: Optional[str] = None, - within: Any = None) -> ToCountPolynomialMap: + subgroup_size: int | None = None, + entrypoint: str | None = None, + within: Any = None) -> ToCountMap[GuardedPwQPolynomial]: """Count the number of memory accesses in a loopy kernel. :arg knl: A :class:`loopy.LoopKernel` whose memory accesses are to be @@ -2184,8 +2203,8 @@ def get_mem_access_map( def _get_synchronization_map_for_single_kernel( knl: LoopKernel, - callables_table: Mapping[str, InKernelCallable], - subgroup_size: Optional[int] = None): + callables_table: ConcreteCallablesTable, + subgroup_size: int | None = None) -> ToCountMap[GuardedPwQPolynomial]: knl = lp.get_one_linearized_kernel(knl, callables_table) @@ -2203,10 +2222,12 @@ def _get_synchronization_map_for_single_kernel( subgroup_size=subgroup_size) sync_counter = CounterBase(knl, callables_table, kernel_rec) - sync_map = sync_counter._new_zero_map() + sync_map: ToCountMap[GuardedPwQPolynomial] = sync_counter._new_zero_map() iname_list = [] + assert knl.linearization is not None + for sched_item in knl.linearization: if isinstance(sched_item, EnterLoop): if sched_item.iname: # (if not empty) @@ -2246,8 +2267,8 @@ def _get_synchronization_map_for_single_kernel( def get_synchronization_map( t_unit: TranslationUnit, *, - subgroup_size: Optional[int] = None, - entrypoint: Optional[str] = None) -> ToCountPolynomialMap: + subgroup_size: int | None = None, + entrypoint: str | None = None) -> ToCountMap[GuardedPwQPolynomial]: """Count the number of synchronization events each work-item encounters in a loopy kernel. @@ -2337,7 +2358,7 @@ def _gather_access_footprints_for_single_kernel( def gather_access_footprints( t_unit: TranslationUnit, *, ignore_uncountable: bool = False, - entrypoint: Optional[str] = None) -> Mapping[MemAccess, isl.Set]: + entrypoint: str | None = None) -> Mapping[MemAccess, isl.Set]: """Return a dictionary mapping ``(var_name, direction)`` to :class:`islpy.Set` instances capturing which indices of each the array *var_name* are read/written (where *direction* is either ``read`` or @@ -2409,7 +2430,7 @@ def gather_access_footprint_bytes( # FIXME: Only supporting a single kernel for now kernel = t_unit.default_entrypoint - result = {} + result: dict[Countable, GuardedPwQPolynomial] = {} for ma, var_fp in fp.items(): assert ma.variable var_descr = kernel.get_var_descriptor(ma.variable) diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index c2cd0a5ca..7a2c726b8 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -155,7 +155,7 @@ def storage_axis_exprs(storage_axis_sources, args) -> Sequence[ExpressionT]: # {{{ gather rule invocations class RuleInvocationGatherer(RuleAwareIdentityMapper): - def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within): + def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within) -> None: super().__init__(rule_mapping_context) from loopy.symbolic import SubstitutionRuleExpander