diff --git a/doc/tutorial.rst b/doc/tutorial.rst index 4aeb42428..1dd43e7e2 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -1574,7 +1574,7 @@ information provided. Now we will count the operations: >>> op_map = lp.get_op_map(knl, subgroup_size=32) >>> print(op_map) - Op(np:dtype('float32'), add, subgroup, "stats_knl"): ... + Op(np:dtype('float32'), OpType.ADD, CountGranularity.SUBGROUP, "stats_knl", frozenset()): ... Each line of output will look roughly like:: @@ -1599,13 +1599,13 @@ One way to evaluate these polynomials is with :meth:`islpy.PwQPolynomial.eval_wi .. doctest:: >>> param_dict = {'n': 256, 'm': 256, 'l': 8} - >>> from loopy.statistics import CountGranularity as CG - >>> f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) - >>> f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) - >>> f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) - >>> f64add = op_map[lp.Op(np.float64, 'add', CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) - >>> f64mul = op_map[lp.Op(np.float64, 'mul', CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) - >>> i32add = op_map[lp.Op(np.int32, 'add', CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) + >>> from loopy.statistics import CountGranularity as CG, OpType, AddressSpace, AccessDirection, SynchronizationKind + >>> f32add = op_map[lp.Op(np.float32, OpType.ADD, CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) + >>> f32div = op_map[lp.Op(np.float32, OpType.DIV, CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) + >>> f32mul = op_map[lp.Op(np.float32, OpType.MUL, CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) + >>> f64add = op_map[lp.Op(np.float64, OpType.ADD, CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) + >>> f64mul = op_map[lp.Op(np.float64, OpType.MUL, CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) + >>> i32add = op_map[lp.Op(np.int32, OpType.ADD, CG.SUBGROUP, "stats_knl")].eval_with_dict(param_dict) >>> print("%i\n%i\n%i\n%i\n%i\n%i" % ... (f32add, f32div, f32mul, f64add, f64mul, i32add)) 524288 @@ -1636,7 +1636,7 @@ together into keys containing only the specified fields: >>> op_map_dtype = op_map.group_by('dtype') >>> print(op_map_dtype) - Op(np:dtype('float32'), None, None): ... + Op(np:dtype('float32'), None, None, frozenset()): ... >>> f32op_count = op_map_dtype[lp.Op(dtype=np.float32) ... ].eval_with_dict(param_dict) >>> print(f32op_count) @@ -1661,7 +1661,7 @@ we'll continue using the kernel from the previous example: >>> mem_map = lp.get_mem_access_map(knl, subgroup_size=32) >>> print(mem_map) - MemAccess(global, np:dtype('float32'), {}, {}, load, a, None, subgroup, 'stats_knl'): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float32'), {}, {}, AccessDirection.READ, a, frozenset(), CountGranularity.SUBGROUP, 'stats_knl', frozenset()): ... Each line of output will look roughly like:: @@ -1703,17 +1703,17 @@ We can evaluate these polynomials using :meth:`islpy.PwQPolynomial.eval_with_dic .. doctest:: - >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, {}, {}, 'load', 'g', - ... variable_tags=None, count_granularity=CG.SUBGROUP, kernel_name="stats_knl") + >>> f64ld_g = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, {}, {}, AccessDirection.READ, 'g', + ... variable_tags=frozenset(), count_granularity=CG.SUBGROUP, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, {}, {}, 'store', 'e', - ... variable_tags=None, count_granularity=CG.SUBGROUP, kernel_name="stats_knl") + >>> f64st_e = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, {}, {}, AccessDirection.WRITE, 'e', + ... variable_tags=frozenset(), count_granularity=CG.SUBGROUP, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, {}, {}, 'load', 'a', - ... variable_tags=None, count_granularity=CG.SUBGROUP, kernel_name="stats_knl") + >>> f32ld_a = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, {}, {}, AccessDirection.READ, 'a', + ... variable_tags=frozenset(), count_granularity=CG.SUBGROUP, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, {}, {}, 'store', 'c', - ... variable_tags=None, count_granularity=CG.SUBGROUP, kernel_name="stats_knl") + >>> f32st_c = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, {}, {}, AccessDirection.WRITE, 'c', + ... variable_tags=frozenset(), count_granularity=CG.SUBGROUP, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) >>> print("f32 ld a: %i\nf32 st c: %i\nf64 ld g: %i\nf64 st e: %i" % ... (f32ld_a, f32st_c, f64ld_g, f64st_e)) @@ -1731,15 +1731,15 @@ using :func:`loopy.ToCountMap.to_bytes` and :func:`loopy.ToCountMap.group_by`: >>> bytes_map = mem_map.to_bytes() >>> print(bytes_map) - MemAccess(global, np:dtype('float32'), {}, {}, load, a, None, subgroup, 'stats_knl'): ... - >>> global_ld_st_bytes = bytes_map.filter_by(mtype=['global'] - ... ).group_by('direction') + MemAccess(AddressSpace.GLOBAL, np:dtype('float32'), {}, {}, AccessDirection.READ, a, frozenset(), CountGranularity.SUBGROUP, 'stats_knl', frozenset()): ... + >>> global_ld_st_bytes = bytes_map.filter_by(address_space=[AddressSpace.GLOBAL] + ... ).group_by('read_write') >>> print(global_ld_st_bytes) - MemAccess(None, None, None, None, load, None, None, None, None): ... - MemAccess(None, None, None, None, store, None, None, None, None): ... - >>> loaded = global_ld_st_bytes[lp.MemAccess(direction='load') + MemAccess(None, None, None, None, AccessDirection.READ, None, frozenset(), None, None, frozenset()): ... + MemAccess(None, None, None, None, AccessDirection.WRITE, None, frozenset(), None, None, frozenset()): ... + >>> loaded = global_ld_st_bytes[lp.MemAccess(read_write=AccessDirection.READ) ... ].eval_with_dict(param_dict) - >>> stored = global_ld_st_bytes[lp.MemAccess(direction='store') + >>> stored = global_ld_st_bytes[lp.MemAccess(read_write=AccessDirection.WRITE) ... ].eval_with_dict(param_dict) >>> print("bytes loaded: %s\nbytes stored: %s" % (loaded, stored)) bytes loaded: 7340032 @@ -1772,12 +1772,12 @@ this time. ... outer_tag="l.1", inner_tag="l.0") >>> mem_map = lp.get_mem_access_map(knl_consec, subgroup_size=32) >>> print(mem_map) - MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, load, a, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, load, b, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, store, c, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, load, g, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, load, h, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, store, e, None, workitem, 'stats_knl'): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float32'), {0: 1, 1: 128}, {}, AccessDirection.READ, a, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float32'), {0: 1, 1: 128}, {}, AccessDirection.READ, b, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float32'), {0: 1, 1: 128}, {}, AccessDirection.WRITE, c, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float64'), {0: 1, 1: 128}, {}, AccessDirection.READ, g, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float64'), {0: 1, 1: 128}, {}, AccessDirection.READ, h, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float64'), {0: 1, 1: 128}, {}, AccessDirection.WRITE, e, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... With this parallelization, consecutive work-items will access consecutive array elements in memory. The polynomials are a bit more complicated now due to the @@ -1785,17 +1785,16 @@ parallelization, but when we evaluate them, we see that the total number of array accesses has not changed: .. doctest:: - - >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, {0: 1, 1: 128}, {}, 'load', 'g', + >>> f64ld_g = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.float64), {0: 1, 1: 128}, {}, AccessDirection.READ, 'g', ... variable_tags=None, count_granularity=CG.WORKITEM, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, {0: 1, 1: 128}, {}, 'store', 'e', + >>> f64st_e = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, {0: 1, 1: 128}, {}, AccessDirection.WRITE, 'e', ... variable_tags=None, count_granularity=CG.WORKITEM, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, {0: 1, 1: 128}, {}, 'load', 'a', + >>> f32ld_a = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, {0: 1, 1: 128}, {}, AccessDirection.READ, 'a', ... variable_tags=None, count_granularity=CG.WORKITEM, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, {0: 1, 1: 128}, {}, 'store', 'c', + >>> f32st_c = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, {0: 1, 1: 128}, {}, AccessDirection.WRITE, 'c', ... variable_tags=None, count_granularity=CG.WORKITEM, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) >>> print("f32 ld a: %i\nf32 st c: %i\nf64 ld g: %i\nf64 st e: %i" % @@ -1816,12 +1815,12 @@ we'll switch the inner and outer tags in our parallelization of the kernel: ... outer_tag="l.0", inner_tag="l.1") >>> mem_map = lp.get_mem_access_map(knl_nonconsec, subgroup_size=32) >>> print(mem_map) - MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, load, a, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, load, b, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, store, c, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, load, g, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, load, h, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, store, e, None, workitem, 'stats_knl'): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float32'), {0: 128, 1: 1}, {}, AccessDirection.READ, a, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float32'), {0: 128, 1: 1}, {}, AccessDirection.READ, b, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float32'), {0: 128, 1: 1}, {}, AccessDirection.WRITE, c, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float64'), {0: 128, 1: 1}, {}, AccessDirection.READ, g, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float64'), {0: 128, 1: 1}, {}, AccessDirection.READ, h, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... + MemAccess(AddressSpace.GLOBAL, np:dtype('float64'), {0: 128, 1: 1}, {}, AccessDirection.WRITE, e, frozenset(), CountGranularity.WORKITEM, 'stats_knl', frozenset()): ... With this parallelization, consecutive work-items will access *nonconsecutive* array elements in memory. The total number of array accesses still has not @@ -1829,16 +1828,16 @@ changed: .. doctest:: - >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, {0: 128, 1: 1}, {}, 'load', 'g', + >>> f64ld_g = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, {0: 128, 1: 1}, {}, AccessDirection.READ, 'g', ... variable_tags=None, count_granularity=CG.WORKITEM, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, {0: 128, 1: 1}, {}, 'store', 'e', + >>> f64st_e = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, {0: 128, 1: 1}, {}, AccessDirection.WRITE, 'e', ... variable_tags=None, count_granularity=CG.WORKITEM, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, {0: 128, 1: 1}, {}, 'load', 'a', + >>> f32ld_a = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, {0: 128, 1: 1}, {}, AccessDirection.READ, 'a', ... variable_tags=None, count_granularity=CG.WORKITEM, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) - >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, {0: 128, 1: 1}, {}, 'store', 'c', + >>> f32st_c = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, {0: 128, 1: 1}, {}, AccessDirection.WRITE, 'c', ... variable_tags=None, count_granularity=CG.WORKITEM, kernel_name="stats_knl") ... ].eval_with_dict(param_dict) >>> print("f32 ld a: %i\nf32 st c: %i\nf64 ld g: %i\nf64 st e: %i" % @@ -1873,13 +1872,13 @@ kernel from the previous example: >>> sync_map = lp.get_synchronization_map(knl) >>> print(sync_map) - Sync(kernel_launch, stats_knl): [l, m, n] -> { 1 } + Sync(SynchronizationKind.KERNEL_LAUNCH, 'stats_knl', frozenset()): [l, m, n] -> { 1 } We can evaluate this polynomial using :meth:`islpy.PwQPolynomial.eval_with_dict`: .. doctest:: - >>> launch_count = sync_map[lp.Sync("kernel_launch", "stats_knl")].eval_with_dict(param_dict) + >>> launch_count = sync_map[lp.Sync(sync_kind=SynchronizationKind.KERNEL_LAUNCH, kernel_name="stats_knl")].eval_with_dict(param_dict) >>> print("Kernel launch count: %s" % launch_count) Kernel launch count: 1 @@ -1930,8 +1929,8 @@ count the barriers using :func:`loopy.get_synchronization_map`: >>> sync_map = lp.get_synchronization_map(knl) >>> print(sync_map) - Sync(barrier_local, loopy_kernel): { 1000 } - Sync(kernel_launch, loopy_kernel): { 1 } + Sync(SynchronizationKind.BARRIER_LOCAL, 'loopy_kernel', frozenset()): { 1000 } + Sync(SynchronizationKind.KERNEL_LAUNCH, 'loopy_kernel', frozenset()): { 1 } Based on the kernel code printed above, we would expect each work-item to encounter 50x10x2 barriers, which matches the result from diff --git a/loopy/__init__.py b/loopy/__init__.py index 07f06a021..ef9868e6f 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -87,10 +87,13 @@ linearize, ) from loopy.statistics import ( + AccessDirection, CountGranularity, MemAccess, Op, + OpType, Sync, + SynchronizationKind, ToCountMap, ToCountPolynomialMap, gather_access_footprint_bytes, @@ -99,7 +102,13 @@ get_op_map, get_synchronization_map, ) -from loopy.symbolic import LinearSubscript, Reduction, TaggedVariable, TypeCast +from loopy.symbolic import ( + LinearSubscript, + Reduction, + TaggedExpression, + TaggedVariable, + TypeCast, +) from loopy.target import ASTBuilderBase, TargetBase from loopy.target.c import ( CFamilyTarget, @@ -214,6 +223,7 @@ "MOST_RECENT_LANGUAGE_VERSION", "VERSION", "ASTBuilderBase", + "AccessDirection", "AddressSpace", "ArrayArg", "Assignment", @@ -257,6 +267,7 @@ "NoOpInstruction", "NumpyType", "Op", + "OpType", "OpenCLTarget", "Optional", "Options", @@ -267,6 +278,8 @@ "ScalarCallable", "SubstitutionRule", "Sync", + "SynchronizationKind", + "TaggedExpression", "TaggedVariable", "TargetBase", "TemporaryVariable", diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 5d1de0e5d..668e6a07d 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -27,7 +27,7 @@ from dataclasses import dataclass, replace -from enum import IntEnum +from enum import Enum, IntEnum from sys import intern from typing import ( Any, @@ -377,6 +377,10 @@ def stringify(cls, val: Union["AddressSpace", Type[auto]]) -> str: else: raise ValueError("unexpected value of AddressSpace") + # IntEnum.__repr__ only prints the integer values + __repr__ = Enum.__repr__ + __str__ = Enum.__str__ + # }}} diff --git a/loopy/statistics.py b/loopy/statistics.py index 99b163f80..5284dda2a 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2015 James Stevens Copyright (C) 2018 Kaushik Kulkarni @@ -25,19 +28,49 @@ THE SOFTWARE. """ +from dataclasses import dataclass, replace +from enum import Enum, auto as enum_auto from functools import cached_property, partial +from typing import ( + Any, + Callable, + Generic, + Iterable, + Mapping, + Sequence, + Type, + TypeVar, + Union, + cast, +) + +from immutabledict import immutabledict import islpy as isl -from islpy import dim_type +import pymbolic.primitives as p +from islpy import PwQPolynomial, dim_type from pymbolic.mapper import CombineMapper -from pytools import ImmutableRecord, memoize_method +from pymbolic.typing import ArithmeticExpressionT +from pytools import memoize_method +from pytools.tag import Tag import loopy as lp from loopy.diagnostic import LoopyError, warn_with_kernel -from loopy.kernel.data import AddressSpace, MultiAssignmentBase, TemporaryVariable +from loopy.kernel import LoopKernel +from loopy.kernel.array import ArrayBase +from loopy.kernel.data import AddressSpace, MultiAssignmentBase from loopy.kernel.function_interface import CallableKernel -from loopy.symbolic import CoefficientCollector, flatten -from loopy.translation_unit import TranslationUnit +from loopy.kernel.instruction import InstructionBase +from loopy.symbolic import ( + CoefficientCollector, + Reduction, + SubArrayRef, + TaggedExpression, + flatten, +) +from loopy.translation_unit import ConcreteCallablesTable, TranslationUnit +from loopy.types import LoopyType +from loopy.typing import Expression, ExpressionT, auto __doc__ = """ @@ -47,8 +80,12 @@ .. autoclass:: ToCountMap .. autoclass:: ToCountPolynomialMap .. autoclass:: CountGranularity +.. autoclass:: OpType .. autoclass:: Op +.. autoclass:: AccessDirection .. autoclass:: MemAccess +.. autoclass:: SynchronizationKind +.. autoclass:: Sync .. autofunction:: get_op_map .. autofunction:: get_mem_access_map @@ -61,6 +98,10 @@ .. autoclass:: GuardedPwQPolynomial +.. class:: CountT + + An arithmetic type that can be used in :class:`ToCountMap`. + .. currentmodule:: loopy """ @@ -78,28 +119,29 @@ # - Test for the subkernel functionality need to be written -def get_kernel_parameter_space(kernel): +def get_kernel_parameter_space(kernel: LoopKernel) -> isl.Space: return isl.Space.create_from_names(kernel.isl_context, set=[], params=sorted(kernel.outer_params())).params() -def get_kernel_zero_pwqpolynomial(kernel): +def get_kernel_zero_pwqpolynomial(kernel: LoopKernel) -> PwQPolynomial: space = get_kernel_parameter_space(kernel) space = space.insert_dims(dim_type.out, 0, 1) - return isl.PwQPolynomial.zero(space) + return PwQPolynomial.zero(space) # {{{ GuardedPwQPolynomial -def _get_param_tuple(obj): +def _get_param_tuple(obj) -> tuple[str, ...]: return tuple( obj.get_dim_name(dim_type.param, i) for i in range(obj.dim(dim_type.param))) class GuardedPwQPolynomial: - def __init__(self, pwqpolynomial, valid_domain): - assert isinstance(pwqpolynomial, isl.PwQPolynomial) + def __init__(self, + pwqpolynomial: PwQPolynomial, valid_domain: isl.Set) -> None: + assert isinstance(pwqpolynomial, PwQPolynomial) self.pwqpolynomial = pwqpolynomial self.valid_domain = valid_domain @@ -165,7 +207,11 @@ def __repr__(self): # {{{ ToCountMap -class ToCountMap: +Countable = Union["Op", "MemAccess", "Sync"] +CountT = TypeVar("CountT", int, GuardedPwQPolynomial) + + +class ToCountMap(Generic[CountT]): """A map from work descriptors like :class:`Op` and :class:`MemAccess` to any arithmetic type. @@ -189,7 +235,9 @@ class ToCountMap: """ - def __init__(self, count_map=None): + count_map: dict[Countable, CountT] + + def __init__(self, count_map: dict[Countable, CountT] | None = None) -> None: if count_map is None: count_map = {} @@ -198,13 +246,13 @@ def __init__(self, count_map=None): def _zero(self): return 0 - def __add__(self, other): + def __add__(self, other: ToCountMap[CountT]) -> ToCountMap[CountT]: result = self.count_map.copy() for k, v in other.count_map.items(): result[k] = self.count_map.get(k, 0) + v return self.copy(count_map=result) - def __radd__(self, other): + def __radd__(self, other: Union[int, ToCountMap[CountT]]) -> ToCountMap[CountT]: if other != 0: raise ValueError("ToCountMap: Attempted to add ToCountMap " "to {} {}. ToCountMap may only be added to " @@ -213,7 +261,7 @@ def __radd__(self, other): return self - def __mul__(self, other): + def __mul__(self, other: GuardedPwQPolynomial) -> ToCountMap[CountT]: if isinstance(other, GuardedPwQPolynomial): return self.copy({ index: other*value @@ -225,22 +273,23 @@ def __mul__(self, other): __rmul__ = __mul__ - def __getitem__(self, index): + def __getitem__(self, index: Countable) -> CountT: return self.count_map[index] - def __repr__(self): + def __repr__(self) -> str: return repr(self.count_map) - def __str__(self): + def __str__(self) -> str: return "\n".join( f"{k}: {v}" for k, v in sorted(self.count_map.items(), key=lambda k: str(k))) - def __len__(self): + def __len__(self) -> int: return len(self.count_map) - def get(self, key, default=None): + def get(self, + key: Countable, default: CountT | None = None) -> CountT | None: return self.count_map.get(key, default) def items(self): @@ -252,18 +301,20 @@ def keys(self): def values(self): return self.count_map.values() - def copy(self, count_map=None): + def copy( + self, count_map: dict[Countable, CountT] | None = None + ) -> ToCountMap[CountT]: if count_map is None: count_map = self.count_map return type(self)(count_map=count_map) - def with_set_attributes(self, **kwargs): + def with_set_attributes(self, **kwargs) -> ToCountMap: return self.copy(count_map={ - key.copy(**kwargs): val + replace(key, **kwargs): val for key, val in self.count_map.items()}) - def filter_by(self, **kwargs): + def filter_by(self, **kwargs) -> ToCountMap: """Remove items without specified key fields. :arg kwargs: Keyword arguments matching fields in the keys of the @@ -308,7 +359,8 @@ class _Sentinel: return self.copy(count_map=new_count_map) - def filter_by_func(self, func): + def filter_by_func( + self, func: Callable[[Countable], bool]) -> ToCountMap[CountT]: """Keep items that pass a test. :arg func: A function that takes a map key a parameter and returns a @@ -341,7 +393,7 @@ def filter_func(key): return self.copy(count_map=new_count_map) - def group_by(self, *args): + def group_by(self, *args) -> ToCountMap[CountT]: """Group map items together, distinguishing by only the key fields passed in args. @@ -387,7 +439,7 @@ def group_by(self, *args): """ - new_count_map = {} + new_count_map: dict[Countable, CountT] = {} # make sure all item keys have same type if self.count_map: @@ -408,7 +460,7 @@ def group_by(self, *args): return self.copy(count_map=new_count_map) - def to_bytes(self): + def to_bytes(self) -> ToCountMap[CountT]: """Convert counts to bytes using data type in map key. :return: A :class:`ToCountMap` mapping each original key to an @@ -442,11 +494,11 @@ def to_bytes(self): new_count_map = {} for key, val in self.count_map.items(): - new_count_map[key] = int(key.dtype.itemsize) * val + new_count_map[key] = int(key.dtype.itemsize) * val # type: ignore[union-attr] # noqa: E501 return self.copy(new_count_map) - def sum(self): + def sum(self) -> CountT: """:return: A sum of the values of the dictionary.""" total = self._zero() @@ -461,14 +513,18 @@ def sum(self): # {{{ ToCountPolynomialMap -class ToCountPolynomialMap(ToCountMap): +class ToCountPolynomialMap(ToCountMap[GuardedPwQPolynomial]): """Maps any type of key to a :class:`islpy.PwQPolynomial` or a :class:`~loopy.statistics.GuardedPwQPolynomial`. .. automethod:: eval_and_sum """ - def __init__(self, space, count_map=None): + def __init__( + self, + space: isl.Space, + count_map: dict[Countable, GuardedPwQPolynomial] + ) -> None: if not isinstance(space, isl.Space): raise TypeError( "first argument to ToCountPolynomialMap must be " @@ -491,7 +547,7 @@ def __init__(self, space, count_map=None): super().__init__(count_map) - def _zero(self): + def _zero(self) -> isl.PwQPolynomial: space = self.space.insert_dims(dim_type.out, 0, 1) return isl.PwQPolynomial.zero(space) @@ -504,7 +560,7 @@ def copy(self, count_map=None, space=None): return type(self)(space, count_map) - def eval_and_sum(self, params=None): + def eval_and_sum(self, params: Mapping[str, int] | None = None) -> int: """Add all counts and evaluate with provided parameter dict *params* :return: An :class:`int` containing the sum of all counts @@ -550,9 +606,12 @@ def subst_into_guarded_pwqpolynomial(new_space, guarded_poly, subst_dict): return GuardedPwQPolynomial(poly, valid_domain) -def subst_into_to_count_map(space, tcm, subst_dict): +def subst_into_to_count_map( + space: isl.Space, + tcm: ToCountPolynomialMap, + subst_dict: Mapping[str, PwQPolynomial]) -> ToCountPolynomialMap: from loopy.isl_helpers import subst_into_pwqpolynomial - new_count_map = {} + new_count_map: dict[Countable, GuardedPwQPolynomial] = {} for key, value in tcm.count_map.items(): if isinstance(value, GuardedPwQPolynomial): new_count_map[key] = subst_into_guarded_pwqpolynomial( @@ -574,38 +633,58 @@ def subst_into_to_count_map(space, tcm, subst_dict): # {{{ CountGranularity -class CountGranularity: - """Strings specifying whether an operation should be counted once per +class CountGranularity(Enum): + """Specify whether an operation should be counted once per *work-item*, *sub-group*, or *work-group*. .. attribute:: WORKITEM - A :class:`str` that specifies that an operation should be counted - once per *work-item*. + Specifies that an operation should be counted once per *work-item*. .. attribute:: SUBGROUP - A :class:`str` that specifies that an operation should be counted - once per *sub-group*. + Specifies that an operation should be counted once per *sub-group*. .. attribute:: WORKGROUP - A :class:`str` that specifies that an operation should be counted - once per *work-group*. + Specifies that an operation should be counted once per *work-group*. """ - WORKITEM = "workitem" - SUBGROUP = "subgroup" - WORKGROUP = "workgroup" - ALL = [WORKITEM, SUBGROUP, WORKGROUP] + WORKITEM = 0 + SUBGROUP = 1 + WORKGROUP = 2 # }}} # {{{ Op descriptor -class Op(ImmutableRecord): +class OpType(Enum): + """ + Specify the type of an (arithmetic) operation. + + .. attribute:: ADD + .. attribute:: MUL + .. attribute:: DIV + .. attribute:: POW + .. attribute:: SHIFT + .. attribute:: BITWISE + .. attribute:: MAXMIN + .. attribute:: SPECIAL_FUNC + """ + ADD = enum_auto() + MUL = enum_auto() + DIV = enum_auto() + POW = enum_auto() + SHIFT = enum_auto() + BITWISE = enum_auto() + MAXMIN = enum_auto() + SPECIAL_FUNC = enum_auto() + + +@dataclass(frozen=True, eq=True) +class Op: """A descriptor for a type of arithmetic operation. .. attribute:: dtype @@ -613,15 +692,14 @@ class Op(ImmutableRecord): A :class:`loopy.types.LoopyType` or :class:`numpy.dtype` that specifies the data type operated on. - .. attribute:: name + .. attribute:: op_type - A :class:`str` that specifies the kind of arithmetic operation as - *add*, *mul*, *div*, *pow*, *shift*, *bw* (bitwise), etc. + A :class:`OpType`. .. attribute:: count_granularity - A :class:`str` that specifies whether this operation should be counted - once per *work-item*, *sub-group*, or *work-group*. The granularities + A :class:`CountGranularity` that specifies whether this operation should be + counted once per *work-item*, *sub-group*, or *work-group*. The granularities allowed can be found in :class:`CountGranularity`, and may be accessed, e.g., as ``CountGranularity.WORKITEM``. A work-item is a single instance of computation executing on a single processor (think "thread"), a @@ -635,43 +713,60 @@ class Op(ImmutableRecord): A :class:`str` representing the kernel name where the operation occurred. - """ + .. attribute:: tags - def __init__(self, dtype=None, name=None, count_granularity=None, - kernel_name=None): - if count_granularity not in CountGranularity.ALL+[None]: - raise ValueError("Op.__init__: count_granularity '%s' is " - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) + A :class:`frozenset` of tags to the operation. - if dtype is not None: + """ + dtype: LoopyType | None = None + op_type: OpType | None = None + count_granularity: CountGranularity | None = None + kernel_name: str | None = None + tags: frozenset[Tag] = frozenset() + + def __post_init__(self): + if self.dtype is not None: from loopy.types import to_loopy_type - dtype = to_loopy_type(dtype) + object.__setattr__(self, "dtype", to_loopy_type(self.dtype)) - super().__init__(dtype=dtype, name=name, - count_granularity=count_granularity, - kernel_name=kernel_name) + assert isinstance(self.op_type, (OpType, type(None))) + + if not isinstance(self.count_granularity, (CountGranularity, type(None))): + raise ValueError( + f"unexpected count_granularity: '{self.count_granularity}'") def __repr__(self): - # Record.__repr__ overridden for consistent ordering and conciseness if self.kernel_name is not None: - return (f"Op({self.dtype}, {self.name}, {self.count_granularity}," - f' "{self.kernel_name}")') + return (f"Op({self.dtype}, {self.op_type}, {self.count_granularity}," + f' "{self.kernel_name}", {self.tags})') else: - return f"Op({self.dtype}, {self.name}, {self.count_granularity})" + return f"Op({self.dtype}, {self.op_type}, " + \ + f"{self.count_granularity}, {self.tags})" # }}} # {{{ MemAccess descriptor -class MemAccess(ImmutableRecord): +class AccessDirection(Enum): + """ + Specify the direction of a memory access. + + .. attribute:: READ + .. attribute:: WRITE + """ + READ = 0 + WRITE = 1 + + +@dataclass(frozen=True, eq=True) +class MemAccess: """A descriptor for a type of memory access. - .. attribute:: mtype + .. attribute:: address_space - A :class:`str` that specifies the memory type accessed as **global** - or **local** + A :class:`AddressSpace` that specifies the memory type accessed as **global** + or **local**. .. attribute:: dtype @@ -696,10 +791,9 @@ class MemAccess(ImmutableRecord): specifies global strides for each global id in the memory access index. global ids not found will not be present in ``gid_strides.keys()``. - .. attribute:: direction + .. attribute:: read_write - A :class:`str` that specifies the direction of memory access as - **load** or **store**. + An :class:`AccessDirection` or *None*. .. attribute:: variable @@ -714,8 +808,8 @@ class MemAccess(ImmutableRecord): .. attribute:: count_granularity - A :class:`str` that specifies whether this operation should be counted - once per *work-item*, *sub-group*, or *work-group*. The granularities + A :class:`CountGranularity` that specifies whether this operation should be + counted once per *work-item*, *sub-group*, or *work-group*. The granularities allowed can be found in :class:`CountGranularity`, and may be accessed, e.g., as ``CountGranularity.WORKITEM``. A work-item is a single instance of computation executing on a single processor (think "thread"), a @@ -728,75 +822,135 @@ class MemAccess(ImmutableRecord): .. attribute:: kernel_name A :class:`str` representing the kernel name where the operation occurred. + + .. attribute:: tags + + A :class:`frozenset` of tags to the operation. """ - def __init__(self, mtype=None, dtype=None, lid_strides=None, gid_strides=None, - direction=None, variable=None, - *, variable_tags=None, - count_granularity=None, kernel_name=None): + address_space: AddressSpace | Type[auto] | None = None + dtype: LoopyType | None = None + lid_strides: Mapping[int, Expression] | None = None + gid_strides: Mapping[int, Expression] | None = None + read_write: AccessDirection | None = None + variable: str | None = None - if count_granularity not in CountGranularity.ALL+[None]: - raise ValueError("Op.__init__: count_granularity '%s' is " - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) + variable_tags: frozenset[Tag] = frozenset() + count_granularity: CountGranularity | None = None + kernel_name: str | None = None + tags: frozenset[Tag] = frozenset() - if variable_tags is None: - variable_tags = frozenset() + def __post_init__(self): + assert isinstance(self.address_space, (AddressSpace, type(None))) - if dtype is not None: + if self.dtype is not None: from loopy.types import to_loopy_type - dtype = to_loopy_type(dtype) + object.__setattr__(self, "dtype", to_loopy_type(self.dtype)) - ImmutableRecord.__init__(self, mtype=mtype, dtype=dtype, - lid_strides=lid_strides, - gid_strides=gid_strides, direction=direction, - variable=variable, variable_tags=variable_tags, - count_granularity=count_granularity, - kernel_name=kernel_name) + if isinstance(self.lid_strides, dict): + object.__setattr__(self, "lid_strides", immutabledict(self.lid_strides)) + + if isinstance(self.gid_strides, dict): + object.__setattr__(self, "gid_strides", immutabledict(self.gid_strides)) + + if self.variable_tags is None: + object.__setattr__(self, "variable_tags", frozenset()) + + if not isinstance(self.count_granularity, (CountGranularity, type(None))): + raise ValueError( + f"unexpected count_granularity: '{self.count_granularity}'") - def __hash__(self): - # dicts in gid_strides and lid_strides aren't natively hashable - return hash(repr(self)) + @property + def mtype(self) -> str: + from warnings import warn + warn("MemAccess.mtype is deprecated and will stop working in 2024. " + "Use MemAccess.address_space instead.", + DeprecationWarning, stacklevel=2) + + if self.address_space == AddressSpace.GLOBAL: + return "global" + elif self.address_space == AddressSpace.LOCAL: + return "local" + else: + raise ValueError(f"unexpected address_space: '{self.address_space}'") + + @property + def direction(self) -> str: + from warnings import warn + warn("MemAccess.access_direction is deprecated " + "and will stop working in 2024. " + "Use MemAccess.read_write instead.", + DeprecationWarning, stacklevel=2) + + if self.read_write == AccessDirection.READ: + return "read" + elif self.read_write == AccessDirection.WRITE: + return "write" + else: + raise ValueError(f"unexpected read_write: '{self.read_write}'") def __repr__(self): - # Record.__repr__ overridden for consistent ordering and conciseness - return "MemAccess({}, {}, {}, {}, {}, {}, {}, {}, {})".format( - self.mtype, + # dataclasses.__repr__ overridden for consistent ordering and conciseness + return "MemAccess({}, {}, {}, {}, {}, {}, {}, {}, {}, {})".format( + self.address_space, self.dtype, None if self.lid_strides is None else dict( sorted(self.lid_strides.items())), None if self.gid_strides is None else dict( sorted(self.gid_strides.items())), - self.direction, + self.read_write, self.variable, - "None" if not self.variable_tags else str(self.variable_tags), + str(self.variable_tags), self.count_granularity, - repr(self.kernel_name)) + repr(self.kernel_name), + self.tags) # }}} # {{{ Sync descriptor -class Sync(ImmutableRecord): +class SynchronizationKind(Enum): + """Specify the kind of synchronization. + + .. attribute:: BARRIER_GLOBAL + .. attribute:: BARRIER_LOCAL + .. attribute:: KERNEL_LAUNCH + """ + + BARRIER_GLOBAL = 0 + BARRIER_LOCAL = 1 + KERNEL_LAUNCH = 2 + + +@dataclass(frozen=True, eq=True) +class Sync: """A descriptor for a type of synchronization. - .. attribute:: kind + .. attribute:: sync_kind - A string describing the synchronization kind, e.g. ``"barrier_global"`` or - ``"barrier_local"`` or ``"kernel_launch"``. + A :class:`SynchronizationKind` or *None*. .. attribute:: kernel_name A :class:`str` representing the kernel name where the operation occurred. + + .. attribute:: tags + + A :class:`frozenset` of tags attached to the synchronization. """ + sync_kind: SynchronizationKind | None = None + kernel_name: str | None = None + tags: frozenset[Tag] = frozenset() - def __init__(self, kind=None, kernel_name=None): - super().__init__(kind=kind, kernel_name=kernel_name) + def __post_init__(self): + if not isinstance(self.sync_kind, (SynchronizationKind, type(None))): + raise ValueError(f"unexpected sync_kind: '{self.sync_kind}'") def __repr__(self): - # Record.__repr__ overridden for consistent ordering and conciseness - return f"Sync({self.kind}, {self.kernel_name})" + # Overridden for conciseness + return "Sync({}, {}, {})".format( + self.sync_kind, repr(self.kernel_name), self.tags) # }}} @@ -804,7 +958,7 @@ def __repr__(self): # {{{ CounterBase class CounterBase(CombineMapper): - def __init__(self, knl, callables_table, kernel_rec): + def __init__(self, knl: LoopKernel, callables_table, kernel_rec) -> None: self.knl = knl self.callables_table = callables_table self.kernel_rec = kernel_rec @@ -815,22 +969,29 @@ def __init__(self, knl, callables_table, kernel_rec): self.one = self.zero + 1 @cached_property - def param_space(self): + def param_space(self) -> isl.Space: return get_kernel_parameter_space(self.knl) - def new_poly_map(self, count_map): + def new_poly_map(self, count_map) -> ToCountPolynomialMap: return ToCountPolynomialMap(self.param_space, count_map) - def new_zero_poly_map(self): + def _new_zero_map(self) -> ToCountPolynomialMap: return self.new_poly_map({}) - def combine(self, values): - return sum(values) + def combine(self, values: Iterable[ToCountPolynomialMap]) -> ToCountPolynomialMap: + return sum(values, self._new_zero_map()) + + def map_tagged_expression( + self, expr: TaggedExpression, tags: frozenset[Tag] + ) -> ToCountPolynomialMap: + return self.rec(expr.expr, expr.tags) - def map_constant(self, expr): - return self.new_zero_poly_map() + def map_constant( + self, expr: object, tags: frozenset[Tag] + ) -> ToCountPolynomialMap: + return self._new_zero_map() - def map_call(self, expr): + def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] @@ -853,41 +1014,34 @@ def map_call(self, expr): return subst_into_to_count_map( self.param_space, sub_result, subst_dict) \ - + self.rec(expr.parameters) + + self.rec(expr.parameters, tags) else: raise NotImplementedError() - def map_call_with_kwargs(self, expr): + def map_call_with_kwargs( + self, expr: p.CallWithKwargs, tags: frozenset[Tag] + ) -> ToCountPolynomialMap: # See https://github.com/inducer/loopy/pull/323 raise NotImplementedError - def map_sum(self, expr): - if expr.children: - return sum(self.rec(child) for child in expr.children) - else: - return self.new_zero_poly_map() + def map_comparison( + self, expr: p.Comparison, tags: frozenset[Tag] + ) -> ToCountPolynomialMap: + return self.rec(expr.left, tags) + self.rec(expr.right, tags) - map_product = map_sum - - def map_comparison(self, expr): - return self.rec(expr.left)+self.rec(expr.right) - - def map_if(self, expr): + def map_if( + self, expr: p.If, tags: frozenset[Tag] + ) -> ToCountPolynomialMap: warn_with_kernel(self.knl, "summing_if_branches", "%s counting sum of if-expression branches." % type(self).__name__) - return self.rec(expr.condition) + self.rec(expr.then) \ - + self.rec(expr.else_) + return self.rec(expr.condition, tags) + self.rec(expr.then, tags) \ + + self.rec(expr.else_, tags) - def map_if_positive(self, expr): - warn_with_kernel(self.knl, "summing_if_branches", - "%s counting sum of if-expression branches." - % type(self).__name__) - return self.rec(expr.criterion) + self.rec(expr.then) \ - + self.rec(expr.else_) - - def map_common_subexpression(self, expr): + def map_common_subexpression( + self, expr: p.CommonSubexpression, tags: frozenset[Tag] + ) -> ToCountPolynomialMap: raise RuntimeError("%s encountered %s--not supposed to happen" % (type(self).__name__, type(expr).__name__)) @@ -895,36 +1049,40 @@ def map_common_subexpression(self, expr): map_derivative = map_common_subexpression map_slice = map_common_subexpression - def map_reduction(self, expr): + def map_reduction( + self, expr: Reduction, tags: frozenset[Tag]) -> ToCountPolynomialMap: # preprocessing should have removed these raise RuntimeError("%s encountered %s--not supposed to happen" % (type(self).__name__, type(expr).__name__)) + def __call__( + self, expr, tags: frozenset[Tag] | None = None + ) -> ToCountPolynomialMap: + if tags is None: + tags = frozenset() + return self.rec(expr, tags=tags) + # }}} # {{{ ExpressionOpCounter class ExpressionOpCounter(CounterBase): - def __init__(self, knl, callables_table, kernel_rec, - count_within_subscripts=True): - super().__init__( - knl, callables_table, kernel_rec) + def __init__(self, knl: LoopKernel, callables_table, kernel_rec, + count_within_subscripts: bool = True): + super().__init__(knl, callables_table, kernel_rec) self.count_within_subscripts = count_within_subscripts arithmetic_count_granularity = CountGranularity.SUBGROUP - def combine(self, values): - return sum(values) - - def map_constant(self, expr): - return self.new_zero_poly_map() + def map_constant(self, expr: Any, tags: frozenset[Tag]) -> ToCountPolynomialMap: + return self._new_zero_map() map_tagged_variable = map_constant map_variable = map_constant map_nan = map_constant - def map_call(self, expr): + def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] @@ -933,137 +1091,153 @@ def map_call(self, expr): if not isinstance(clbl, CallableKernel): return self.new_poly_map( {Op(dtype=self.type_inf(expr), - name="func:"+clbl.name, + op_type=OpType.SPECIAL_FUNC, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one} - ) + self.rec(expr.parameters) + ) + self.rec(expr.parameters, tags) else: - return super().map_call(expr) + return super().map_call(expr, tags) - def map_subscript(self, expr): + def map_subscript( + self, expr: p.Subscript, tags: frozenset[Tag]) -> ToCountPolynomialMap: if self.count_within_subscripts: - return self.rec(expr.index) + return self.rec(expr.index, tags) else: - return self.new_zero_poly_map() + return self._new_zero_map() - def map_sub_array_ref(self, expr): + def map_sub_array_ref( + self, expr: SubArrayRef, tags: frozenset[Tag]) -> ToCountPolynomialMap: # generates an array view, considered free - return self.new_zero_poly_map() + return self._new_zero_map() - def map_sum(self, expr): + def map_sum(self, expr: p.Sum, tags: frozenset[Tag]) -> ToCountPolynomialMap: assert expr.children return self.new_poly_map( {Op(dtype=self.type_inf(expr), - name="add", + op_type=OpType.ADD, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.zero + (len(expr.children)-1)} - ) + sum(self.rec(child) for child in expr.children) + ) + sum(self.rec(child, tags) for child in expr.children) - def map_product(self, expr): + def map_product( + self, expr: p.Product, tags: frozenset[Tag]) \ + -> ToCountMap[GuardedPwQPolynomial]: from pymbolic.primitives import is_zero assert expr.children return sum(self.new_poly_map({Op(dtype=self.type_inf(expr), - name="mul", + op_type=OpType.MUL, + tags=tags, count_granularity=( self.arithmetic_count_granularity), kernel_name=self.knl.name): self.one}) - + self.rec(child) + + self.rec(child, tags) for child in expr.children - if not is_zero(child + 1)) + \ + if not is_zero(cast(ArithmeticExpressionT, child) + 1)) + \ self.new_poly_map({Op(dtype=self.type_inf(expr), - name="mul", + op_type=OpType.MUL, + tags=tags, count_granularity=( self.arithmetic_count_granularity), kernel_name=self.knl.name): -self.one}) - def map_quotient(self, expr, *args): + def map_quotient( + self, expr: p.QuotientBase, tags: frozenset[Tag] + ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="div", + op_type=OpType.DIV, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one}) \ - + self.rec(expr.numerator) \ - + self.rec(expr.denominator) + + self.rec(expr.numerator, tags) \ + + self.rec(expr.denominator, tags) map_floor_div = map_quotient map_remainder = map_quotient - def map_power(self, expr): + def map_power(self, expr: p.Power, tags: frozenset[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="pow", + op_type=OpType.POW, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one}) \ - + self.rec(expr.base) \ - + self.rec(expr.exponent) + + self.rec(expr.base, tags) \ + + self.rec(expr.exponent, tags) - def map_left_shift(self, expr): + def map_left_shift( + self, expr: Union[p.LeftShift, p.RightShift], tags: frozenset[Tag] + ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="shift", + op_type=OpType.SHIFT, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one}) \ - + self.rec(expr.shiftee) \ - + self.rec(expr.shift) + + self.rec(expr.shiftee, tags) \ + + self.rec(expr.shift, tags) map_right_shift = map_left_shift - def map_bitwise_not(self, expr): + def map_bitwise_not( + self, expr: p.BitwiseNot, tags: frozenset[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="bw", + op_type=OpType.BITWISE, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one}) \ - + self.rec(expr.child) + + self.rec(expr.child, tags) - def map_bitwise_or(self, expr): + def map_bitwise_or( + self, expr: Union[p.BitwiseOr, p.BitwiseAnd, p.BitwiseXor], + tags: frozenset[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="bw", + op_type=OpType.BITWISE, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.zero + (len(expr.children)-1)}) \ - + sum(self.rec(child) for child in expr.children) + + sum(self.rec(child, tags) for child in expr.children) map_bitwise_xor = map_bitwise_or map_bitwise_and = map_bitwise_or - def map_if(self, expr): + def map_if(self, expr: p.If, tags: frozenset[Tag]) -> ToCountPolynomialMap: warn_with_kernel(self.knl, "summing_if_branches_ops", "ExpressionOpCounter counting ops as sum of " "if-statement branches.") - return self.rec(expr.condition) + self.rec(expr.then) \ - + self.rec(expr.else_) - - def map_if_positive(self, expr): - warn_with_kernel(self.knl, "summing_ifpos_branches_ops", - "ExpressionOpCounter counting ops as sum of " - "if_pos-statement branches.") - return self.rec(expr.criterion) + self.rec(expr.then) \ - + self.rec(expr.else_) + return self.rec(expr.condition, tags) + self.rec(expr.then, tags) \ + + self.rec(expr.else_, tags) - def map_min(self, expr): + def map_min( + self, expr: Union[p. Min, p.Max], tags: frozenset[Tag] + ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="maxmin", + op_type=OpType.MAXMIN, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): len(expr.children)-1}) \ - + sum(self.rec(child) for child in expr.children) + + sum(self.rec(child, tags) for child in expr.children) map_max = map_min - def map_common_subexpression(self, expr): + def map_common_subexpression(self, expr, tags): raise NotImplementedError("ExpressionOpCounter encountered " "common_subexpression, " "map_common_subexpression not implemented.") - def map_substitution(self, expr): + def map_substitution(self, expr, tags): raise NotImplementedError("ExpressionOpCounter encountered " "substitution, " "map_substitution not implemented.") - def map_derivative(self, expr): + def map_derivative(self, expr, tags): raise NotImplementedError("ExpressionOpCounter encountered " "derivative, " "map_derivative not implemented.") - def map_slice(self, expr): + def map_slice(self, expr, tags): raise NotImplementedError("ExpressionOpCounter encountered slice, " "map_slice not implemented.") @@ -1085,7 +1259,9 @@ def map_floor_div(self, expr): # {{{ _get_lid_and_gid_strides -def _get_lid_and_gid_strides(knl, array, index): +def _get_lid_and_gid_strides( + knl: LoopKernel, array: ArrayBase, index: tuple[ExpressionT, ...] + ) -> tuple[Mapping[int, ExpressionT], Mapping[int, ExpressionT]]: # find all local and global index tags and corresponding inames from loopy.symbolic import get_dependencies my_inames = get_dependencies(index) & knl.all_inames() @@ -1121,17 +1297,21 @@ def _get_lid_and_gid_strides(knl, array, index): from loopy.kernel.array import FixedStrideArrayDimTag from loopy.symbolic import simplify_using_aff - def get_iname_strides(tag_to_iname_dict): + def get_iname_strides( + tag_to_iname_dict: Mapping[int, str] + ) -> Mapping[int, Expression]: tag_to_stride_dict = {} + from loopy.kernel.array import ArrayDimImplementationTag + if array.dim_tags is None: assert len(index) <= 1 - dim_tags = (None,) * len(index) + dim_tags: Sequence[ArrayDimImplementationTag | None] = (None,) * len(index) else: dim_tags = array.dim_tags for tag in tag_to_iname_dict: - total_iname_stride = 0 + total_iname_stride: Any = 0 # find total stride of this iname for each axis for idx, axis_tag in zip(index, dim_tags): # collect index coefficients @@ -1140,7 +1320,7 @@ def get_iname_strides(tag_to_iname_dict): [tag_to_iname_dict[tag]])( simplify_using_aff(knl, idx)) except ExpressionNotAffineError: - total_iname_stride = None + total_iname_stride = 0 break # check if idx contains this iname @@ -1156,7 +1336,7 @@ def get_iname_strides(tag_to_iname_dict): axis_tag_stride = axis_tag.stride if axis_tag_stride is lp.auto: - total_iname_stride = None + total_iname_stride = 0 break elif axis_tag is None: @@ -1176,180 +1356,155 @@ def get_iname_strides(tag_to_iname_dict): # }}} -# {{{ MemAccessCounterBase +# {{{ MemAccessCounter -class MemAccessCounterBase(CounterBase): - def map_sub_array_ref(self, expr): +class MemAccessCounter(CounterBase): + def map_sub_array_ref( + self, expr: SubArrayRef, tags: frozenset[Tag]) -> ToCountPolynomialMap: # generates an array view, considered free - return self.new_zero_poly_map() + return self._new_zero_map() - def map_call(self, expr): + def map_call(self, expr: p.Call, tags: frozenset[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] from loopy.kernel.function_interface import CallableKernel if not isinstance(clbl, CallableKernel): - return self.rec(expr.parameters) + return self.rec(expr.parameters, tags) else: - return super().map_call(expr) - -# }}} - - -# {{{ LocalMemAccessCounter - -class LocalMemAccessCounter(MemAccessCounterBase): - local_mem_count_granularity = CountGranularity.SUBGROUP - - def count_var_access(self, dtype, name, index): - count_map = {} - if name in self.knl.temporary_variables: - array = self.knl.temporary_variables[name] - if isinstance(array, TemporaryVariable) and ( + return super().map_call(expr, tags) + + def count_var_access(self, + dtype: LoopyType, + name: str, + index: ExpressionT | None, + tags: frozenset[Tag], + var_tags: frozenset[Tag] = frozenset() + ) -> ToCountPolynomialMap: + from loopy.kernel.data import TemporaryVariable + array = self.knl.get_var_descriptor(name) + + if isinstance(array, TemporaryVariable) and ( array.address_space == AddressSpace.LOCAL): - if index is None: - # no subscript - count_map[MemAccess( - mtype="local", - dtype=dtype, - count_granularity=self.local_mem_count_granularity, - kernel_name=self.knl.name)] = self.one - return self.new_poly_map(count_map) - - array = self.knl.temporary_variables[name] - - # could be tuple or scalar index - index_tuple = index - if not isinstance(index_tuple, tuple): - index_tuple = (index_tuple,) - - lid_strides, gid_strides = _get_lid_and_gid_strides( - self.knl, array, index_tuple) - - count_map[MemAccess( - mtype="local", - dtype=dtype, - lid_strides=dict(sorted(lid_strides.items())), - gid_strides=dict(sorted(gid_strides.items())), - variable=name, - count_granularity=self.local_mem_count_granularity, - kernel_name=self.knl.name)] = self.one - - return self.new_poly_map(count_map) - - def map_variable(self, expr): - return self.count_var_access( - self.type_inf(expr), expr.name, None) - - map_tagged_variable = map_variable - - def map_subscript(self, expr): - return (self.count_var_access(self.type_inf(expr), - expr.aggregate.name, - expr.index) - + self.rec(expr.index)) - -# }}} - + # local memory access + local_mem_count_granularity = CountGranularity.SUBGROUP -# {{{ GlobalMemAccessCounter - -class GlobalMemAccessCounter(MemAccessCounterBase): - def map_variable(self, expr): - name = expr.name - - if name in self.knl.arg_dict: - array = self.knl.arg_dict[name] - else: - # this is a temporary variable - # FIXME temporary variable could have global address space - return self.new_zero_poly_map() - - if not isinstance(array, lp.ArrayArg): - # this array is not in global memory - return self.new_zero_poly_map() - - return self.new_poly_map({MemAccess(mtype="global", - dtype=self.type_inf(expr), lid_strides={}, - gid_strides={}, variable=name, + if index is None: + return self.new_poly_map({MemAccess( + address_space=AddressSpace.LOCAL, + dtype=dtype, + tags=tags, + count_granularity=local_mem_count_granularity, + kernel_name=self.knl.name): self.one}) + + # could be tuple or scalar index + index_tuple = index + if not isinstance(index_tuple, tuple): + index_tuple = (index_tuple,) + + lid_strides, gid_strides = _get_lid_and_gid_strides( + self.knl, array, index_tuple) + + return self.new_poly_map({MemAccess( + address_space=array.address_space, + dtype=dtype, + tags=tags, + lid_strides=immutabledict(lid_strides), + gid_strides=immutabledict(gid_strides), + variable=name, + count_granularity=local_mem_count_granularity, + kernel_name=self.knl.name): self.one}) + + elif (isinstance(array, TemporaryVariable) and ( + array.address_space == AddressSpace.GLOBAL)) or ( + isinstance(array, lp.ArrayArg)): + if index is None: + return self.new_poly_map({MemAccess( + address_space=AddressSpace.GLOBAL, + dtype=dtype, + lid_strides=immutabledict({}), + gid_strides=immutabledict({}), + variable=name, + tags=tags, count_granularity=CountGranularity.WORKITEM, - kernel_name=self.knl.name): self.one} - ) + self.rec(expr.index) - - def map_subscript(self, expr): - name = expr.aggregate.name - try: - var_tags = expr.aggregate.tags - except AttributeError: - var_tags = frozenset() - - is_global_temp = False - if name in self.knl.arg_dict: - array = self.knl.arg_dict[name] - elif name in self.knl.temporary_variables: - # This a temporary, but might have global address space - from loopy.kernel.data import AddressSpace - array = self.knl.temporary_variables[name] - if array.address_space != AddressSpace.GLOBAL: - # This temporary does not have global address space - return self.rec(expr.index) - # This temporary has global address space - is_global_temp = True - else: - # This temporary does not have global address space - return self.rec(expr.index) - - if (not is_global_temp) and not isinstance(array, lp.ArrayArg): - # This array is not in global memory - return self.rec(expr.index) - - index_tuple = expr.index # could be tuple or scalar index - if not isinstance(index_tuple, tuple): - index_tuple = (index_tuple,) + kernel_name=self.knl.name): self.one}) - lid_strides, gid_strides = _get_lid_and_gid_strides( - self.knl, array, index_tuple) + # could be tuple or scalar index + index_tuple = index + if not isinstance(index_tuple, tuple): + index_tuple = (index_tuple,) - global_access_count_granularity = CountGranularity.SUBGROUP - - # Account for broadcasts once per subgroup - count_granularity = CountGranularity.WORKITEM if ( + lid_strides, gid_strides = _get_lid_and_gid_strides( + self.knl, array, index_tuple) + # Account for broadcasts once per subgroup + count_granularity = CountGranularity.WORKITEM if ( # if the stride in lid.0 is known 0 in lid_strides and # it is nonzero lid_strides[0] != 0 - ) else global_access_count_granularity + ) else CountGranularity.SUBGROUP - return self.new_poly_map({MemAccess( - mtype="global", - dtype=self.type_inf(expr), - lid_strides=dict(sorted(lid_strides.items())), - gid_strides=dict(sorted(gid_strides.items())), + return self.new_poly_map({MemAccess( + address_space=AddressSpace.GLOBAL, + dtype=dtype, + lid_strides=immutabledict(lid_strides), + gid_strides=immutabledict(gid_strides), variable=name, + tags=tags, variable_tags=var_tags, count_granularity=count_granularity, kernel_name=self.knl.name, ): self.one} - ) + self.rec(expr.index_tuple) + ) + else: + return self._new_zero_map() + + def map_variable( + self, expr: p.Variable, tags: frozenset[Tag] + ) -> ToCountPolynomialMap: + return self.count_var_access( + self.type_inf(expr), expr.name, None, tags) + + map_tagged_variable = map_variable + + def map_subscript( + self, expr: p.Subscript, tags: frozenset[Tag]) -> ToCountPolynomialMap: + try: + var_tags = expr.aggregate.tags # type: ignore[union-attr] + except AttributeError: + var_tags = frozenset() + + assert hasattr(expr.aggregate, "name") + + return (self.count_var_access(self.type_inf(expr), + expr.aggregate.name, + expr.index, tags, var_tags) + + self.rec(expr.index, tags)) # }}} # {{{ AccessFootprintGatherer +FootprintsT = dict[str, isl.Set] + + class AccessFootprintGatherer(CombineMapper): - def __init__(self, kernel, domain, ignore_uncountable=False): + def __init__(self, + kernel: LoopKernel, + domain: isl.BasicSet, + ignore_uncountable: bool = False) -> None: self.kernel = kernel self.domain = domain self.ignore_uncountable = ignore_uncountable @staticmethod - def combine(values): + def combine(values: Iterable[FootprintsT]) -> FootprintsT: assert values - def merge_dicts(a, b): + def merge_dicts(a: FootprintsT, b: FootprintsT) -> FootprintsT: result = a.copy() for var_name, footprint in b.items(): @@ -1363,13 +1518,13 @@ def merge_dicts(a, b): from functools import reduce return reduce(merge_dicts, values) - def map_constant(self, expr): + def map_constant(self, expr: p.Any) -> FootprintsT: return {} - def map_variable(self, expr): + def map_variable(self, expr: p.Variable) -> FootprintsT: return {} - def map_subscript(self, expr): + def map_subscript(self, expr: p.Subscript) -> FootprintsT: subscript = expr.index if not isinstance(subscript, tuple): @@ -1406,13 +1561,15 @@ def map_subscript(self, expr): # {{{ count -def add_assumptions_guard(kernel, pwqpolynomial): +def add_assumptions_guard( + kernel: LoopKernel, pwqpolynomial: isl.PwQPolynomial + ) -> GuardedPwQPolynomial: return GuardedPwQPolynomial( pwqpolynomial, kernel.assumptions.align_params(pwqpolynomial.space)) -def count(kernel, set, space=None): +def count(kernel, set: isl.Set, space=None) -> GuardedPwQPolynomial: if isinstance(kernel, TranslationUnit): kernel_names = [i for i, clbl in kernel.callables_table.items() if isinstance(clbl, CallableKernel)] @@ -1518,7 +1675,9 @@ def count(kernel, set, space=None): return add_assumptions_guard(kernel, total_count) -def get_unused_hw_axes_factor(knl, callables_table, insn, disregard_local_axes): +def get_unused_hw_axes_factor( + knl: LoopKernel, callables_table, insn: InstructionBase, + disregard_local_axes: bool) -> GuardedPwQPolynomial: # FIXME: Multi-kernel support gsize, lsize = knl.get_grid_size_upper_bounds(callables_table) @@ -1557,7 +1716,8 @@ def mult_grid_factor(used_axes, sizes): return add_assumptions_guard(knl, result) -def count_inames_domain(knl, inames): +def count_inames_domain( + knl: LoopKernel, inames: frozenset[str]) -> GuardedPwQPolynomial: space = get_kernel_parameter_space(knl) if not inames: return add_assumptions_guard(knl, @@ -1568,8 +1728,12 @@ def count_inames_domain(knl, inames): return count(knl, domain, space=space) -def count_insn_runs(knl, callables_table, insn, count_redundant_work, - disregard_local_axes=False): +def count_insn_runs( + knl: LoopKernel, + callables_table: ConcreteCallablesTable, + insn: InstructionBase, + count_redundant_work: bool, + disregard_local_axes: bool = False) -> GuardedPwQPolynomial: insn_inames = insn.within_inames @@ -1589,8 +1753,14 @@ def count_insn_runs(knl, callables_table, insn, count_redundant_work, return c -def _get_insn_count(knl, callables_table, insn_id, subgroup_size, - count_redundant_work, count_granularity=CountGranularity.WORKITEM): +def _get_insn_count( + knl: LoopKernel, + callables_table: ConcreteCallablesTable, + insn_id: str | None, + subgroup_size: int | None, + count_redundant_work: bool, + count_granularity: CountGranularity | None = CountGranularity.WORKITEM + ) -> GuardedPwQPolynomial: insn = knl.id_to_insn[insn_id] if count_granularity is None: @@ -1650,18 +1820,20 @@ def _get_insn_count(knl, callables_table, insn_id, subgroup_size, else: # this should not happen since this is enforced in Op/MemAccess - raise ValueError("get_insn_count: count_granularity '%s' is" - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) + raise ValueError("get_insn_count: count_granularity " + f"'{count_granularity}' is not allowed.") # }}} # {{{ get_op_map -def _get_op_map_for_single_kernel(knl, callables_table, - count_redundant_work, - count_within_subscripts, subgroup_size, within): +def _get_op_map_for_single_kernel( + knl: LoopKernel, + callables_table: ConcreteCallablesTable, + count_redundant_work: bool, + count_within_subscripts: bool, + subgroup_size: int | None, within) -> ToCountMap[GuardedPwQPolynomial]: subgroup_size = _process_subgroup_size(knl, subgroup_size) @@ -1673,7 +1845,7 @@ def _get_op_map_for_single_kernel(knl, callables_table, op_counter = ExpressionOpCounter(knl, callables_table, kernel_rec, count_within_subscripts) - op_map = op_counter.new_zero_poly_map() + op_map: ToCountMap[GuardedPwQPolynomial] = op_counter._new_zero_map() from loopy.kernel.instruction import ( Assignment, @@ -1688,6 +1860,7 @@ def _get_op_map_for_single_kernel(knl, callables_table, if isinstance(insn, (CallInstruction, Assignment)): ops = op_counter(insn.assignees) + op_counter(insn.expression) for key, val in ops.count_map.items(): + key = cast(Op, key) count = _get_insn_count(knl, callables_table, insn.id, subgroup_size, count_redundant_work, key.count_granularity) @@ -1703,9 +1876,12 @@ def _get_op_map_for_single_kernel(knl, callables_table, return op_map -def get_op_map(program, count_redundant_work=False, - count_within_subscripts=True, subgroup_size=None, - entrypoint=None, within=None): +def get_op_map( + t_unit: TranslationUnit, *, count_redundant_work: bool = False, + count_within_subscripts: bool = True, + subgroup_size: int | None = None, + entrypoint: str | None = None, + within: Any = None) -> ToCountMap[GuardedPwQPolynomial]: """Count the number of operations in a loopy kernel. @@ -1765,25 +1941,28 @@ def get_op_map(program, count_redundant_work=False, """ if entrypoint is None: - if len(program.entrypoints) > 1: + if len(t_unit.entrypoints) > 1: raise LoopyError("Must provide entrypoint") - entrypoint = list(program.entrypoints)[0] + entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints from loopy.preprocess import infer_unknown_types, preprocess_program - program = preprocess_program(program) + t_unit = preprocess_program(t_unit) from loopy.match import parse_match within = parse_match(within) # Ordering restriction: preprocess might insert arguments to # make strides valid. Those also need to go through type inference. - program = infer_unknown_types(program, expect_completion=True) + t_unit = infer_unknown_types(t_unit, expect_completion=True) + + kernel = t_unit[entrypoint] + assert isinstance(kernel, LoopKernel) return _get_op_map_for_single_kernel( - program[entrypoint], program.callables_table, + kernel, t_unit.callables_table, count_redundant_work=count_redundant_work, count_within_subscripts=count_within_subscripts, subgroup_size=subgroup_size, @@ -1792,9 +1971,9 @@ def get_op_map(program, count_redundant_work=False, # }}} -# {{{ subgoup size finding +# {{{ subgroup size finding -def _find_subgroup_size_for_knl(knl): +def _find_subgroup_size_for_knl(knl: LoopKernel) -> int | None: from loopy.target.pyopencl import PyOpenCLTarget if isinstance(knl.target, PyOpenCLTarget) and knl.target.device is not None: from pyopencl.characterize import get_simd_group_size @@ -1850,8 +2029,11 @@ def _process_subgroup_size(knl, subgroup_size_requested): # {{{ get_mem_access_map -def _get_mem_access_map_for_single_kernel(knl, callables_table, - count_redundant_work, subgroup_size, within): +def _get_mem_access_map_for_single_kernel( + knl: LoopKernel, + callables_table: ConcreteCallablesTable, + count_redundant_work: bool, subgroup_size: int | None, + within: Any) -> ToCountMap[GuardedPwQPolynomial]: subgroup_size = _process_subgroup_size(knl, subgroup_size) @@ -1860,11 +2042,8 @@ def _get_mem_access_map_for_single_kernel(knl, callables_table, count_redundant_work=count_redundant_work, subgroup_size=subgroup_size) - access_counter_g = GlobalMemAccessCounter( - knl, callables_table, kernel_rec) - access_counter_l = LocalMemAccessCounter( - knl, callables_table, kernel_rec) - access_map = access_counter_g.new_zero_poly_map() + access_counter = MemAccessCounter(knl, callables_table, kernel_rec) + access_map: ToCountMap[GuardedPwQPolynomial] = access_counter._new_zero_map() from loopy.kernel.instruction import ( Assignment, @@ -1878,16 +2057,15 @@ def _get_mem_access_map_for_single_kernel(knl, callables_table, if within(knl, insn): if isinstance(insn, (CallInstruction, Assignment)): insn_access_map = ( - access_counter_g(insn.expression) - + access_counter_l(insn.expression) - ).with_set_attributes(direction="load") + access_counter(insn.expression) + ).with_set_attributes(read_write=AccessDirection.READ) for assignee in insn.assignees: insn_access_map = insn_access_map + ( - access_counter_g(assignee) - + access_counter_l(assignee) - ).with_set_attributes(direction="store") + access_counter(assignee) + ).with_set_attributes(read_write=AccessDirection.WRITE) for key, val in insn_access_map.count_map.items(): + key = cast(MemAccess, key) count = _get_insn_count(knl, callables_table, insn.id, subgroup_size, count_redundant_work, key.count_granularity) @@ -1904,9 +2082,11 @@ def _get_mem_access_map_for_single_kernel(knl, callables_table, return access_map -def get_mem_access_map(program, count_redundant_work=False, - subgroup_size=None, entrypoint=None, - within=None): +def get_mem_access_map( + t_unit: TranslationUnit, *, count_redundant_work: bool = False, + subgroup_size: int | None = None, + entrypoint: str | None = None, + within: Any = None) -> ToCountMap[GuardedPwQPolynomial]: """Count the number of memory accesses in a loopy kernel. :arg knl: A :class:`loopy.LoopKernel` whose memory accesses are to be @@ -1992,26 +2172,26 @@ def get_mem_access_map(program, count_redundant_work=False, """ if entrypoint is None: - if len(program.entrypoints) > 1: + if len(t_unit.entrypoints) > 1: raise LoopyError("Must provide entrypoint") - entrypoint = list(program.entrypoints)[0] + entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints from loopy.preprocess import infer_unknown_types, preprocess_program - program = preprocess_program(program) + t_unit = preprocess_program(t_unit) from loopy.match import parse_match within = parse_match(within) # Ordering restriction: preprocess might insert arguments to # make strides valid. Those also need to go through type inference. - program = infer_unknown_types(program, expect_completion=True) + t_unit = infer_unknown_types(t_unit, expect_completion=True) return _get_mem_access_map_for_single_kernel( - program[entrypoint], program.callables_table, + t_unit[entrypoint], t_unit.callables_table, count_redundant_work=count_redundant_work, subgroup_size=subgroup_size, within=within) @@ -2021,8 +2201,10 @@ def get_mem_access_map(program, count_redundant_work=False, # {{{ get_synchronization_map -def _get_synchronization_map_for_single_kernel(knl, callables_table, - subgroup_size=None): +def _get_synchronization_map_for_single_kernel( + knl: LoopKernel, + callables_table: ConcreteCallablesTable, + subgroup_size: int | None = None) -> ToCountMap[GuardedPwQPolynomial]: knl = lp.get_one_linearized_kernel(knl, callables_table) @@ -2040,10 +2222,12 @@ def _get_synchronization_map_for_single_kernel(knl, callables_table, subgroup_size=subgroup_size) sync_counter = CounterBase(knl, callables_table, kernel_rec) - sync_map = sync_counter.new_zero_poly_map() + sync_map: ToCountMap[GuardedPwQPolynomial] = sync_counter._new_zero_map() iname_list = [] + assert knl.linearization is not None + for sched_item in knl.linearization: if isinstance(sched_item, EnterLoop): if sched_item.iname: # (if not empty) @@ -2053,9 +2237,14 @@ def _get_synchronization_map_for_single_kernel(knl, callables_table, iname_list.pop() elif isinstance(sched_item, Barrier): + if sched_item.synchronization_kind == "local": + sync_kind = SynchronizationKind.BARRIER_LOCAL + else: + sync_kind = SynchronizationKind.BARRIER_GLOBAL + sync_map = sync_map + ToCountMap( {Sync( - "barrier_%s" % sched_item.synchronization_kind, + sync_kind, knl.name): count_inames_domain(knl, frozenset(iname_list))}) elif isinstance(sched_item, RunInstruction): @@ -2063,7 +2252,7 @@ def _get_synchronization_map_for_single_kernel(knl, callables_table, elif isinstance(sched_item, CallKernel): sync_map = sync_map + ToCountMap( - {Sync("kernel_launch", knl.name): + {Sync(SynchronizationKind.KERNEL_LAUNCH, knl.name): count_inames_domain(knl, frozenset(iname_list))}) elif isinstance(sched_item, ReturnFromKernel): @@ -2076,7 +2265,10 @@ def _get_synchronization_map_for_single_kernel(knl, callables_table, return sync_map -def get_synchronization_map(program, subgroup_size=None, entrypoint=None): +def get_synchronization_map( + t_unit: TranslationUnit, *, + subgroup_size: int | None = None, + entrypoint: str | None = None) -> ToCountMap[GuardedPwQPolynomial]: """Count the number of synchronization events each work-item encounters in a loopy kernel. @@ -2113,21 +2305,21 @@ def get_synchronization_map(program, subgroup_size=None, entrypoint=None): """ if entrypoint is None: - if len(program.entrypoints) > 1: + if len(t_unit.entrypoints) > 1: raise LoopyError("Must provide entrypoint") - entrypoint = list(program.entrypoints)[0] + entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints from loopy.preprocess import infer_unknown_types, preprocess_program - program = preprocess_program(program) + t_unit = preprocess_program(t_unit) # Ordering restriction: preprocess might insert arguments to # make strides valid. Those also need to go through type inference. - program = infer_unknown_types(program, expect_completion=True) + t_unit = infer_unknown_types(t_unit, expect_completion=True) return _get_synchronization_map_for_single_kernel( - program[entrypoint], program.callables_table, + t_unit[entrypoint], t_unit.callables_table, subgroup_size=subgroup_size) # }}} @@ -2135,7 +2327,9 @@ def get_synchronization_map(program, subgroup_size=None, entrypoint=None): # {{{ gather_access_footprints -def _gather_access_footprints_for_single_kernel(kernel, ignore_uncountable): +def _gather_access_footprints_for_single_kernel( + kernel: LoopKernel, ignore_uncountable: bool + ) -> tuple[FootprintsT, FootprintsT]: write_footprints = [] read_footprints = [] @@ -2157,10 +2351,14 @@ def _gather_access_footprints_for_single_kernel(kernel, ignore_uncountable): write_footprints.append(afg(insn.assignees)) read_footprints.append(afg(insn.expression)) - return write_footprints, read_footprints + return ( + AccessFootprintGatherer.combine(write_footprints), + AccessFootprintGatherer.combine(read_footprints)) -def gather_access_footprints(program, ignore_uncountable=False, entrypoint=None): +def gather_access_footprints( + t_unit: TranslationUnit, *, ignore_uncountable: bool = False, + entrypoint: str | None = None) -> Mapping[MemAccess, isl.Set]: """Return a dictionary mapping ``(var_name, direction)`` to :class:`islpy.Set` instances capturing which indices of each the array *var_name* are read/written (where *direction* is either ``read`` or @@ -2172,48 +2370,48 @@ def gather_access_footprints(program, ignore_uncountable=False, entrypoint=None) """ if entrypoint is None: - if len(program.entrypoints) > 1: + if len(t_unit.entrypoints) > 1: raise LoopyError("Must provide entrypoint") - entrypoint = list(program.entrypoints)[0] + entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints - # FIMXE: works only for one callable kernel till now. + # FIXME: works only for one callable kernel till now. if len([in_knl_callable for in_knl_callable in - program.callables_table.values() if isinstance(in_knl_callable, + t_unit.callables_table.values() if isinstance(in_knl_callable, CallableKernel)]) != 1: - raise NotImplementedError("Currently only supported for program with " - "only one CallableKernel.") + raise NotImplementedError("Currently only supported for " + "translation unit with only one CallableKernel.") from loopy.preprocess import infer_unknown_types, preprocess_program - program = preprocess_program(program) + t_unit = preprocess_program(t_unit) # Ordering restriction: preprocess might insert arguments to # make strides valid. Those also need to go through type inference. - program = infer_unknown_types(program, expect_completion=True) - - write_footprints = [] - read_footprints = [] + t_unit = infer_unknown_types(t_unit, expect_completion=True) + kernel = t_unit[entrypoint] + assert isinstance(kernel, LoopKernel) write_footprints, read_footprints = _gather_access_footprints_for_single_kernel( - program[entrypoint], ignore_uncountable) - - write_footprints = AccessFootprintGatherer.combine(write_footprints) - read_footprints = AccessFootprintGatherer.combine(read_footprints) + kernel, ignore_uncountable) result = {} for vname, footprint in write_footprints.items(): - result[(vname, "write")] = footprint + result[MemAccess(variable=vname, read_write=AccessDirection.WRITE)] \ + = footprint for vname, footprint in read_footprints.items(): - result[(vname, "read")] = footprint + result[MemAccess(variable=vname, read_write=AccessDirection.READ)] \ + = footprint return result -def gather_access_footprint_bytes(program, ignore_uncountable=False): +def gather_access_footprint_bytes( + t_unit: TranslationUnit, *, ignore_uncountable: bool = False + ) -> ToCountPolynomialMap: """Return a dictionary mapping ``(var_name, direction)`` to :class:`islpy.PwQPolynomial` instances capturing the number of bytes are read/written (where *direction* is either ``read`` or ``write`` on array @@ -2224,30 +2422,25 @@ def gather_access_footprint_bytes(program, ignore_uncountable=False): nonlinear indices) """ - from loopy.preprocess import infer_unknown_types, preprocess_program - kernel = infer_unknown_types(program, expect_completion=True) + from loopy.preprocess import infer_unknown_types + t_unit = infer_unknown_types(t_unit, expect_completion=True) - from loopy.kernel import KernelState - if kernel.state < KernelState.PREPROCESSED: - kernel = preprocess_program(program) + fp = gather_access_footprints(t_unit, ignore_uncountable=ignore_uncountable) - result = {} - fp = gather_access_footprints(kernel, - ignore_uncountable=ignore_uncountable) + # FIXME: Only supporting a single kernel for now + kernel = t_unit.default_entrypoint - for key, var_fp in fp.items(): - vname, direction = key - - var_descr = kernel.get_var_descriptor(vname) + result: dict[Countable, GuardedPwQPolynomial] = {} + for ma, var_fp in fp.items(): + assert ma.variable + var_descr = kernel.get_var_descriptor(ma.variable) + assert var_descr.dtype bytes_transferred = ( int(var_descr.dtype.numpy_dtype.itemsize) * count(kernel, var_fp)) - if key in result: - result[key] += bytes_transferred - else: - result[key] = bytes_transferred + result[ma] = add_assumptions_guard(kernel, bytes_transferred) - return result + return ToCountPolynomialMap(get_kernel_parameter_space(kernel), result) # }}} diff --git a/loopy/symbolic.py b/loopy/symbolic.py index ad502e1a5..c595e8392 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -101,6 +101,8 @@ .. autoclass:: TypeCast .. autoclass:: TaggedVariable +.. autoclass:: TaggedExpression + .. autoclass:: Reduction .. autoclass:: LinearSubscript @@ -130,6 +132,10 @@ # {{{ mappers with support for loopy-specific primitives class IdentityMapperMixin(Mapper[ExpressionT, P]): + def map_tagged_expression(self, expr: TaggedExpression, *args, **kwargs): + new_expr = self.rec(expr.expr, *args, **kwargs) + return TaggedExpression(expr.tags, new_expr) + def map_literal(self, expr: Literal, *args, **kwargs): return expr @@ -233,6 +239,12 @@ def map_common_subexpression_uncached(self, expr): class WalkMapperMixin: + def map_tagged_expression(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.expr, *args, **kwargs) + def map_literal(self, expr, *args, **kwargs): self.visit(expr, *args, **kwargs) @@ -299,6 +311,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase): class CombineMapper(CombineMapperBase): + def map_tagged_expression(self, expr, *args, **kwargs): + return self.rec(expr.expr, *args, **kwargs) + def map_reduction(self, expr, *args, **kwargs): return self.rec(expr.expr, *args, **kwargs) @@ -329,6 +344,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase, class StringifyMapper(StringifyMapperBase[[]]): + def map_tagged_expression(self, expr, *args): + from pymbolic.mapper.stringifier import PREC_NONE + return f"TaggedExpression({expr.tags}, {self.rec(expr.expr, PREC_NONE)}" + def map_literal(self, expr, *args): return expr.s @@ -438,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs): def map_loopy_function_identifier(self, expr, *args, **kwargs): return set() + def map_tagged_expression(self, expr, *args, **kwargs): + deps = self.rec(expr.expr, *args, **kwargs) + return deps + def map_sub_array_ref(self, expr, *args, **kwargs): deps = self.rec(expr.subscript, *args, **kwargs) return deps - set(expr.swept_inames) @@ -680,6 +703,27 @@ def copy(self, *, name=None, tags=None): return TaggedVariable(name, tags) +@p.expr_dataclass() +class TaggedExpression(LoopyExpressionBase): + """ + Represents a frozenset of tags attached to an :attr:`expr`. + + .. attribute:: tags + + A :class:`frozenset` of subclasses of :class:`pytools.tag.Tag` used to + provide metadata on this expression. + + .. attribute:: expr + + An expression to which :attr:`tags` are attached. + """ + + init_arg_names = ("tags", "expr") + + tags: frozenset[Tag] + expr: ExpressionT + + @p.expr_dataclass(init=False) class Reduction(LoopyExpressionBase): """ diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 4afdfcef7..ed68bb36e 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -456,7 +456,7 @@ def __call__(self, *args, **kwargs): return pex(*args, **kwargs) - def __str__(self): + def __str__(self) -> str: # FIXME: do a topological sort by the call graph return "\n".join( diff --git a/loopy/types.py b/loopy/types.py index b43026bdb..223b59cc5 100644 --- a/loopy/types.py +++ b/loopy/types.py @@ -95,6 +95,7 @@ def __init__(self, dtype: np.dtype): if dtype == object: # noqa: E721 raise TypeError("loopy does not directly support object arrays") + # Normalize due to https://stackoverflow.com/questions/35293672/why-do-these-dtypes-compare-equal-but-hash-different self.dtype = np.dtype(dtype) def __hash__(self) -> int: diff --git a/test/test_statistics.py b/test/test_statistics.py index 6665f6c76..0547a77f1 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -31,7 +31,13 @@ from pytools import div_ceil import loopy as lp -from loopy.statistics import CountGranularity as CG +from loopy.statistics import ( + AccessDirection, + AddressSpace, + CountGranularity as CG, + OpType, + SynchronizationKind, +) from loopy.types import to_loopy_type from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa @@ -64,15 +70,15 @@ def test_op_counter_basic(): m = 256 ell = 128 params = {"n": n, "m": m, "ell": ell} - f32add = op_map[lp.Op(np.float32, "add", CG.SUBGROUP, "basic")].eval_with_dict( + f32add = op_map[lp.Op(np.float32, OpType.ADD, CG.SUBGROUP, "basic")].eval_with_dict( params) - f32mul = op_map[lp.Op(np.float32, "mul", CG.SUBGROUP, "basic")].eval_with_dict( + f32mul = op_map[lp.Op(np.float32, OpType.MUL, CG.SUBGROUP, "basic")].eval_with_dict( params) - f32div = op_map[lp.Op(np.float32, "div", CG.SUBGROUP, "basic")].eval_with_dict( + f32div = op_map[lp.Op(np.float32, OpType.DIV, CG.SUBGROUP, "basic")].eval_with_dict( params) - f64mul = op_map[lp.Op(np.dtype(np.float64), "mul", CG.SUBGROUP, "basic") + f64mul = op_map[lp.Op(np.dtype(np.float64), OpType.MUL, CG.SUBGROUP, "basic") ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), "add", CG.SUBGROUP, "basic") + i32add = op_map[lp.Op(np.dtype(np.int32), OpType.ADD, CG.SUBGROUP, "basic") ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert f32add == f32mul == f32div == n*m*ell*n_subgroups @@ -99,9 +105,9 @@ def test_op_counter_reduction(): m = 256 ell = 128 params = {"n": n, "m": m, "ell": ell} - f32add = op_map[lp.Op(np.float32, "add", CG.SUBGROUP, + f32add = op_map[lp.Op(np.float32, OpType.ADD, CG.SUBGROUP, "matmul_serial")].eval_with_dict(params) - f32mul = op_map[lp.Op(np.dtype(np.float32), "mul", CG.SUBGROUP, + f32mul = op_map[lp.Op(np.dtype(np.float32), OpType.MUL, CG.SUBGROUP, "matmul_serial")].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert f32add == f32mul == n*m*ell*n_subgroups @@ -135,13 +141,13 @@ def test_op_counter_logic(): m = 256 ell = 128 params = {"n": n, "m": m, "ell": ell} - f32mul = op_map[lp.Op(np.float32, "mul", CG.SUBGROUP, "logic")].eval_with_dict( + f32mul = op_map[lp.Op(np.float32, OpType.MUL, CG.SUBGROUP, "logic")].eval_with_dict( params) - f64add = op_map[lp.Op(np.float64, "add", CG.SUBGROUP, "logic")].eval_with_dict( + f64add = op_map[lp.Op(np.float64, OpType.ADD, CG.SUBGROUP, "logic")].eval_with_dict( params) - f64div = op_map[lp.Op(np.dtype(np.float64), "div", CG.SUBGROUP, "logic") + f64div = op_map[lp.Op(np.dtype(np.float64), OpType.DIV, CG.SUBGROUP, "logic") ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), "add", CG.SUBGROUP, "logic") + i32add = op_map[lp.Op(np.dtype(np.int32), OpType.ADD, CG.SUBGROUP, "logic") ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert f32mul == n*m*n_subgroups @@ -175,27 +181,28 @@ def test_op_counter_special_ops(): m = 256 ell = 128 params = {"n": n, "m": m, "ell": ell} - f32mul = op_map[lp.Op(np.float32, "mul", CG.SUBGROUP, + f32mul = op_map[lp.Op(np.float32, OpType.MUL, CG.SUBGROUP, "special_ops")].eval_with_dict(params) - f32div = op_map[lp.Op(np.float32, "div", CG.SUBGROUP, + f32div = op_map[lp.Op(np.float32, OpType.DIV, CG.SUBGROUP, "special_ops")].eval_with_dict(params) - f32add = op_map[lp.Op(np.float32, "add", CG.SUBGROUP, + f32add = op_map[lp.Op(np.float32, OpType.ADD, CG.SUBGROUP, "special_ops")].eval_with_dict(params) - f64pow = op_map[lp.Op(np.float64, "pow", CG.SUBGROUP, + f64pow = op_map[lp.Op(np.float64, OpType.POW, CG.SUBGROUP, "special_ops")].eval_with_dict(params) - f64add = op_map[lp.Op(np.dtype(np.float64), "add", CG.SUBGROUP, "special_ops") + f64add = op_map[lp.Op(np.dtype(np.float64), OpType.ADD, CG.SUBGROUP, "special_ops") ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), "add", CG.SUBGROUP, "special_ops") + i32add = op_map[lp.Op(np.dtype(np.int32), OpType.ADD, CG.SUBGROUP, "special_ops") ].eval_with_dict(params) - f64rsq = op_map[lp.Op(np.dtype(np.float64), "func:rsqrt", CG.SUBGROUP, + # FIXME: OpType.SPECIAL_FUNC doesn't make much sense here + f64rsq = op_map[lp.Op(np.dtype(np.float64), OpType.SPECIAL_FUNC, CG.SUBGROUP, "special_ops")].eval_with_dict(params) - f64sin = op_map[lp.Op(np.dtype(np.float64), "func:sin", CG.SUBGROUP, + f64sin = op_map[lp.Op(np.dtype(np.float64), OpType.SPECIAL_FUNC, CG.SUBGROUP, "special_ops")].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert f32div == 2*n*m*ell*n_subgroups assert f32mul == f32add == n*m*ell*n_subgroups assert f64add == 3*n*m*n_subgroups - assert f64pow == i32add == f64rsq == f64sin == n*m*n_subgroups + assert f64pow == i32add == f64rsq/2 == f64sin/2 == n*m*n_subgroups def test_op_counter_bitwise(): @@ -227,22 +234,22 @@ def test_op_counter_bitwise(): params = {"n": n, "m": m, "ell": ell} print(op_map) i32add = op_map[ - lp.Op(np.int32, "add", CG.SUBGROUP, "bitwise") + lp.Op(np.int32, OpType.ADD, CG.SUBGROUP, "bitwise") ].eval_with_dict(params) i32bw = op_map[ - lp.Op(np.int32, "bw", CG.SUBGROUP, "bitwise") + lp.Op(np.int32, OpType.BITWISE, CG.SUBGROUP, "bitwise") ].eval_with_dict(params) i64bw = op_map[ - lp.Op(np.dtype(np.int64), "bw", CG.SUBGROUP, "bitwise") + lp.Op(np.dtype(np.int64), OpType.BITWISE, CG.SUBGROUP, "bitwise") ].eval_with_dict(params) i64mul = op_map[ - lp.Op(np.dtype(np.int64), "mul", CG.SUBGROUP, "bitwise") + lp.Op(np.dtype(np.int64), OpType.MUL, CG.SUBGROUP, "bitwise") ].eval_with_dict(params) i64add = op_map[ - lp.Op(np.dtype(np.int64), "add", CG.SUBGROUP, "bitwise") + lp.Op(np.dtype(np.int64), OpType.ADD, CG.SUBGROUP, "bitwise") ].eval_with_dict(params) i64shift = op_map[ - lp.Op(np.dtype(np.int64), "shift", CG.SUBGROUP, "bitwise") + lp.Op(np.dtype(np.int64), OpType.SHIFT, CG.SUBGROUP, "bitwise") ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert i32add == n*m*ell*n_subgroups @@ -277,7 +284,7 @@ def test_op_counter_triangular_domain(): knl, subgroup_size=SGS, count_redundant_work=True - )[lp.Op(np.float64, "mul", CG.SUBGROUP, "bitwise")] + )[lp.Op(np.float64, OpType.MUL, CG.SUBGROUP, "bitwise")] value_dict = dict(m=13, n=200) flops = op_map.eval_with_dict(value_dict) @@ -320,27 +327,27 @@ def test_mem_access_counter_basic(): subgroups_per_group = div_ceil(group_size, SGS) n_subgroups = n_workgroups*subgroups_per_group - f32l = mem_map[lp.MemAccess("global", np.float32, + f32l = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={}, gid_strides={}, - direction="load", variable="a", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.SUBGROUP, kernel_name="basic") ].eval_with_dict(params) - f32l += mem_map[lp.MemAccess("global", np.float32, + f32l += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={}, gid_strides={}, - direction="load", variable="b", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.SUBGROUP, kernel_name="basic") ].eval_with_dict(params) - f64l = mem_map[lp.MemAccess("global", np.float64, + f64l = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={}, gid_strides={}, - direction="load", variable="g", + read_write=AccessDirection.READ, variable="g", count_granularity=CG.SUBGROUP, kernel_name="basic") ].eval_with_dict(params) - f64l += mem_map[lp.MemAccess("global", np.float64, + f64l += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={}, gid_strides={}, - direction="load", variable="h", + read_write=AccessDirection.READ, variable="h", count_granularity=CG.SUBGROUP, kernel_name="basic") ].eval_with_dict(params) @@ -349,15 +356,15 @@ def test_mem_access_counter_basic(): assert f32l == (3*n*m*ell)*n_subgroups assert f64l == (2*n*m)*n_subgroups - f32s = mem_map[lp.MemAccess("global", np.dtype(np.float32), + f32s = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={}, gid_strides={}, - direction="store", variable="c", + read_write=AccessDirection.WRITE, variable="c", count_granularity=CG.SUBGROUP, kernel_name="basic") ].eval_with_dict(params) - f64s = mem_map[lp.MemAccess("global", np.dtype(np.float64), + f64s = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.float64), lid_strides={}, gid_strides={}, - direction="store", variable="e", + read_write=AccessDirection.WRITE, variable="e", count_granularity=CG.SUBGROUP, kernel_name="basic") ].eval_with_dict(params) @@ -390,15 +397,15 @@ def test_mem_access_counter_reduction(): subgroups_per_group = div_ceil(group_size, SGS) n_subgroups = n_workgroups*subgroups_per_group - f32l = mem_map[lp.MemAccess("global", np.float32, + f32l = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={}, gid_strides={}, - direction="load", variable="a", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.SUBGROUP, kernel_name="matmul") ].eval_with_dict(params) - f32l += mem_map[lp.MemAccess("global", np.float32, + f32l += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={}, gid_strides={}, - direction="load", variable="b", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.SUBGROUP, kernel_name="matmul") ].eval_with_dict(params) @@ -406,9 +413,9 @@ def test_mem_access_counter_reduction(): # uniform: (count-per-sub-group)*n_subgroups assert f32l == (2*n*m*ell)*n_subgroups - f32s = mem_map[lp.MemAccess("global", np.dtype(np.float32), + f32s = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={}, gid_strides={}, - direction="store", variable="c", + read_write=AccessDirection.WRITE, variable="c", count_granularity=CG.SUBGROUP, kernel_name="matmul") ].eval_with_dict(params) @@ -416,9 +423,11 @@ def test_mem_access_counter_reduction(): # uniform: (count-per-sub-group)*n_subgroups assert f32s == (n*ell)*n_subgroups - ld_bytes = mem_map.filter_by(mtype=["global"], direction=["load"] + ld_bytes = mem_map.filter_by(address_space=[AddressSpace.GLOBAL], + read_write=[AccessDirection.READ] ).to_bytes().eval_and_sum(params) - st_bytes = mem_map.filter_by(mtype=["global"], direction=["store"] + st_bytes = mem_map.filter_by(address_space=[AddressSpace.GLOBAL], + read_write=[AccessDirection.WRITE] ).to_bytes().eval_and_sum(params) assert ld_bytes == 4*f32l assert st_bytes == 4*f32s @@ -451,16 +460,16 @@ def test_mem_access_counter_logic(): subgroups_per_group = div_ceil(group_size, SGS) n_subgroups = n_workgroups*subgroups_per_group - reduced_map = mem_map.group_by("mtype", "dtype", "direction") + reduced_map = mem_map.group_by("address_space", "dtype", "read_write") - f32_g_l = reduced_map[lp.MemAccess("global", to_loopy_type(np.float32), - direction="load") + f32_g_l = reduced_map[lp.MemAccess(AddressSpace.GLOBAL, to_loopy_type(np.float32), + read_write=AccessDirection.READ) ].eval_with_dict(params) - f64_g_l = reduced_map[lp.MemAccess("global", to_loopy_type(np.float64), - direction="load") + f64_g_l = reduced_map[lp.MemAccess(AddressSpace.GLOBAL, to_loopy_type(np.float64), + read_write=AccessDirection.READ) ].eval_with_dict(params) - f64_g_s = reduced_map[lp.MemAccess("global", to_loopy_type(np.float64), - direction="store") + f64_g_s = reduced_map[lp.MemAccess(AddressSpace.GLOBAL, to_loopy_type(np.float64), + read_write=AccessDirection.WRITE) ].eval_with_dict(params) # uniform: (count-per-sub-group)*n_subgroups @@ -496,27 +505,27 @@ def test_mem_access_counter_special_ops(): subgroups_per_group = div_ceil(group_size, SGS) n_subgroups = n_workgroups*subgroups_per_group - f32 = mem_map[lp.MemAccess("global", np.float32, + f32 = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={}, gid_strides={}, - direction="load", variable="a", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.SUBGROUP, kernel_name="special_ops") ].eval_with_dict(params) - f32 += mem_map[lp.MemAccess("global", np.float32, + f32 += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={}, gid_strides={}, - direction="load", variable="b", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.SUBGROUP, kernel_name="special_ops") ].eval_with_dict(params) - f64 = mem_map[lp.MemAccess("global", np.dtype(np.float64), + f64 = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.float64), lid_strides={}, gid_strides={}, - direction="load", variable="g", + read_write=AccessDirection.READ, variable="g", count_granularity=CG.SUBGROUP, kernel_name="special_ops") ].eval_with_dict(params) - f64 += mem_map[lp.MemAccess("global", np.dtype(np.float64), + f64 += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.float64), lid_strides={}, gid_strides={}, - direction="load", variable="h", + read_write=AccessDirection.READ, variable="h", count_granularity=CG.SUBGROUP, kernel_name="special_ops") ].eval_with_dict(params) @@ -525,15 +534,15 @@ def test_mem_access_counter_special_ops(): assert f32 == (2*n*m*ell)*n_subgroups assert f64 == (2*n*m)*n_subgroups - f32 = mem_map[lp.MemAccess("global", np.float32, + f32 = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={}, gid_strides={}, - direction="store", variable="c", + read_write=AccessDirection.WRITE, variable="c", count_granularity=CG.SUBGROUP, kernel_name="special_ops") ].eval_with_dict(params) - f64 = mem_map[lp.MemAccess("global", np.float64, + f64 = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={}, gid_strides={}, - direction="store", variable="e", + read_write=AccessDirection.WRITE, variable="e", count_granularity=CG.SUBGROUP, kernel_name="special_ops") ].eval_with_dict(params) @@ -542,8 +551,9 @@ def test_mem_access_counter_special_ops(): assert f32 == (n*m*ell)*n_subgroups assert f64 == (n*m)*n_subgroups - filtered_map = mem_map.filter_by(direction=["load"], variable=["a", "g"], - count_granularity=CG.SUBGROUP) + filtered_map = mem_map.filter_by(read_write=[AccessDirection.READ], + variable=["a", "g"], + count_granularity=[CG.SUBGROUP]) tot = filtered_map.eval_and_sum(params) # uniform: (count-per-sub-group)*n_subgroups @@ -579,27 +589,27 @@ def test_mem_access_counter_bitwise(): subgroups_per_group = div_ceil(group_size, SGS) n_subgroups = n_workgroups*subgroups_per_group - i32 = mem_map[lp.MemAccess("global", np.int32, + i32 = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.int32, lid_strides={}, gid_strides={}, - direction="load", variable="a", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.SUBGROUP, kernel_name="bitwise") ].eval_with_dict(params) - i32 += mem_map[lp.MemAccess("global", np.int32, + i32 += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.int32, lid_strides={}, gid_strides={}, - direction="load", variable="b", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.SUBGROUP, kernel_name="bitwise") ].eval_with_dict(params) - i32 += mem_map[lp.MemAccess("global", np.int32, + i32 += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.int32, lid_strides={}, gid_strides={}, - direction="load", variable="g", + read_write=AccessDirection.READ, variable="g", count_granularity=CG.SUBGROUP, kernel_name="bitwise") ].eval_with_dict(params) - i32 += mem_map[lp.MemAccess("global", np.dtype(np.int32), + i32 += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.int32), lid_strides={}, gid_strides={}, - direction="load", variable="h", + read_write=AccessDirection.READ, variable="h", count_granularity=CG.SUBGROUP, kernel_name="bitwise") ].eval_with_dict(params) @@ -607,15 +617,15 @@ def test_mem_access_counter_bitwise(): # uniform: (count-per-sub-group)*n_subgroups assert i32 == (4*n*m+2*n*m*ell)*n_subgroups - i32 = mem_map[lp.MemAccess("global", np.int32, + i32 = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.int32, lid_strides={}, gid_strides={}, - direction="store", variable="c", + read_write=AccessDirection.WRITE, variable="c", count_granularity=CG.SUBGROUP, kernel_name="bitwise") ].eval_with_dict(params) - i32 += mem_map[lp.MemAccess("global", np.int32, + i32 += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.int32, lid_strides={}, gid_strides={}, - direction="store", variable="e", + read_write=AccessDirection.WRITE, variable="e", count_granularity=CG.SUBGROUP, kernel_name="bitwise") ].eval_with_dict(params) @@ -656,36 +666,36 @@ def test_mem_access_counter_mixed(): mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, subgroup_size=SGS) - f64uniform = mem_map[lp.MemAccess("global", np.float64, + f64uniform = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={}, gid_strides={}, - direction="load", variable="g", + read_write=AccessDirection.READ, variable="g", count_granularity=CG.SUBGROUP, kernel_name="mixed") ].eval_with_dict(params) - f64uniform += mem_map[lp.MemAccess("global", np.float64, + f64uniform += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={}, gid_strides={}, - direction="load", variable="h", + read_write=AccessDirection.READ, variable="h", count_granularity=CG.SUBGROUP, kernel_name="mixed") ].eval_with_dict(params) - f32uniform = mem_map[lp.MemAccess("global", np.float32, + f32uniform = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={}, gid_strides={}, - direction="load", variable="x", + read_write=AccessDirection.READ, variable="x", count_granularity=CG.SUBGROUP, kernel_name="mixed") ].eval_with_dict(params) - f32nonconsec = mem_map[lp.MemAccess("global", np.dtype(np.float32), + f32nonconsec = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={0: Variable("m")}, gid_strides={0: Variable("m")*group_size_0}, - direction="load", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.WORKITEM, kernel_name="mixed") ].eval_with_dict(params) - f32nonconsec += mem_map[lp.MemAccess("global", np.dtype(np.float32), + f32nonconsec += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={0: Variable("m")}, gid_strides={0: Variable("m")*group_size_0}, - direction="load", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.WORKITEM, kernel_name="mixed") @@ -712,16 +722,16 @@ def test_mem_access_counter_mixed(): else: assert f32nonconsec == 3*n*m*ell - f64uniform = mem_map[lp.MemAccess("global", np.float64, + f64uniform = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={}, gid_strides={}, - direction="store", variable="e", + read_write=AccessDirection.WRITE, variable="e", count_granularity=CG.SUBGROUP, kernel_name="mixed") ].eval_with_dict(params) - f32nonconsec = mem_map[lp.MemAccess("global", np.float32, + f32nonconsec = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={0: Variable("m")}, gid_strides={0: Variable("m")*group_size_0}, - direction="store", + read_write=AccessDirection.WRITE, variable="c", count_granularity=CG.WORKITEM, kernel_name="mixed") @@ -762,54 +772,54 @@ def test_mem_access_counter_nonconsec(): m = 256 ell = 128 params = {"n": n, "m": m, "ell": ell} - f64nonconsec = mem_map[lp.MemAccess("global", np.float64, + f64nonconsec = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={0: Variable("m")}, gid_strides={0: Variable("m")*lsize0}, - direction="load", + read_write=AccessDirection.READ, variable="g", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) - f64nonconsec += mem_map[lp.MemAccess("global", np.float64, + f64nonconsec += mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={0: Variable("m")}, gid_strides={0: Variable("m")*lsize0}, - direction="load", + read_write=AccessDirection.READ, variable="h", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) f32nonconsec = mem_map[lp.MemAccess( - "global", np.dtype(np.float32), + AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={0: Variable("m")*Variable("ell")}, gid_strides={0: Variable("m")*Variable("ell")*lsize0}, - direction="load", variable="a", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) f32nonconsec += mem_map[lp.MemAccess( - "global", np.dtype(np.float32), + AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={0: Variable("m")*Variable("ell")}, gid_strides={0: Variable("m")*Variable("ell")*lsize0}, - direction="load", variable="b", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) assert f64nonconsec == 2*n*m assert f32nonconsec == 3*n*m*ell - f64nonconsec = mem_map[lp.MemAccess("global", np.float64, + f64nonconsec = mem_map[lp.MemAccess(AddressSpace.GLOBAL, np.float64, lid_strides={0: Variable("m")}, gid_strides={0: Variable("m")*lsize0}, - direction="store", + read_write=AccessDirection.WRITE, variable="e", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) f32nonconsec = mem_map[lp.MemAccess( - "global", np.float32, + AddressSpace.GLOBAL, np.float32, lid_strides={0: Variable("m")*Variable("ell")}, gid_strides={0: Variable("m")*Variable("ell")*lsize0}, - direction="store", variable="c", + read_write=AccessDirection.WRITE, variable="c", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) @@ -819,39 +829,39 @@ def test_mem_access_counter_nonconsec(): mem_map64 = lp.get_mem_access_map(knl, count_redundant_work=True, subgroup_size=64) f64nonconsec = mem_map64[lp.MemAccess( - "global", + AddressSpace.GLOBAL, np.float64, lid_strides={0: Variable("m")}, gid_strides={0: Variable("m")*lsize0}, - direction="load", variable="g", + read_write=AccessDirection.READ, variable="g", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) f64nonconsec += mem_map64[lp.MemAccess( - "global", + AddressSpace.GLOBAL, np.float64, lid_strides={0: Variable("m")}, gid_strides={0: Variable("m")*lsize0}, - direction="load", variable="h", + read_write=AccessDirection.READ, variable="h", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) f32nonconsec = mem_map64[lp.MemAccess( - "global", + AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={0: Variable("m")*Variable("ell")}, gid_strides={0: Variable("m")*Variable("ell")*lsize0}, - direction="load", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.WORKITEM, kernel_name="non_consec") ].eval_with_dict(params) f32nonconsec += mem_map64[lp.MemAccess( - "global", + AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={0: Variable("m")*Variable("ell")}, gid_strides={0: Variable("m")*Variable("ell")*lsize0}, - direction="load", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.WORKITEM, kernel_name="non_consec") @@ -883,32 +893,32 @@ def test_mem_access_counter_consec(): params = {"n": n, "m": m, "ell": ell} f64consec = mem_map[lp.MemAccess( - "global", np.float64, + AddressSpace.GLOBAL, np.float64, lid_strides={0: 1}, gid_strides={0: Variable("m")}, - direction="load", variable="g", + read_write=AccessDirection.READ, variable="g", count_granularity=CG.WORKITEM, kernel_name="consec") ].eval_with_dict(params) f64consec += mem_map[lp.MemAccess( - "global", np.float64, + AddressSpace.GLOBAL, np.float64, lid_strides={0: 1}, gid_strides={0: Variable("m")}, - direction="load", variable="h", + read_write=AccessDirection.READ, variable="h", count_granularity=CG.WORKITEM, kernel_name="consec") ].eval_with_dict(params) f32consec = mem_map[lp.MemAccess( - "global", np.float32, + AddressSpace.GLOBAL, np.float32, lid_strides={0: 1}, gid_strides={0: Variable("m")*Variable("ell"), 1: Variable("m")}, - direction="load", variable="a", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.WORKITEM, kernel_name="consec") ].eval_with_dict(params) f32consec += mem_map[lp.MemAccess( - "global", np.dtype(np.float32), + AddressSpace.GLOBAL, np.dtype(np.float32), lid_strides={0: 1}, gid_strides={0: Variable("m")*Variable("ell"), 1: Variable("m")}, - direction="load", variable="b", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.WORKITEM, kernel_name="consec") ].eval_with_dict(params) @@ -916,17 +926,17 @@ def test_mem_access_counter_consec(): assert f32consec == 3*n*m*ell f64consec = mem_map[lp.MemAccess( - "global", np.float64, + AddressSpace.GLOBAL, np.float64, lid_strides={0: 1}, gid_strides={0: Variable("m")}, - direction="store", variable="e", + read_write=AccessDirection.WRITE, variable="e", count_granularity=CG.WORKITEM, kernel_name="consec") ].eval_with_dict(params) f32consec = mem_map[lp.MemAccess( - "global", np.float32, + AddressSpace.GLOBAL, np.float32, lid_strides={0: 1}, gid_strides={0: Variable("m")*Variable("ell"), 1: Variable("m")}, - direction="store", variable="c", + read_write=AccessDirection.WRITE, variable="c", count_granularity=CG.WORKITEM, kernel_name="consec") ].eval_with_dict(params) @@ -958,7 +968,7 @@ def test_mem_access_counter_global_temps(): # Count global accesses global_accesses = mem_map.filter_by( - mtype=["global"]).sum().eval_with_dict(params) + address_space=[AddressSpace.GLOBAL]).sum().eval_with_dict(params) assert global_accesses == n*m @@ -1008,7 +1018,8 @@ def test_barrier_counter_nobarriers(): ell = 128 params = {"n": n, "m": m, "ell": ell} assert len(sync_map) == 1 - assert sync_map.filter_by(kind="kernel_launch").eval_and_sum(params) == 1 + assert sync_map.filter_by( + sync_kind=[SynchronizationKind.KERNEL_LAUNCH]).eval_and_sum(params) == 1 def test_barrier_counter_barriers(): @@ -1028,12 +1039,13 @@ def test_barrier_counter_barriers(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.int32)) knl = lp.split_iname(knl, "k", 128, inner_tag="l.0") sync_map = lp.get_synchronization_map(knl) - print(sync_map) + print(f"{sync_map=}") n = 512 m = 256 ell = 128 params = {"n": n, "m": m, "ell": ell} - barrier_count = sync_map.filter_by(kind="barrier_local").eval_and_sum(params) + barrier_count = sync_map.filter_by( + sync_kind=[SynchronizationKind.BARRIER_LOCAL]).eval_and_sum(params) assert barrier_count == 50*10*2 @@ -1048,7 +1060,8 @@ def test_barrier_count_single(): knl = lp.tag_inames(knl, {"i": "l.0"}) sync_map = lp.get_synchronization_map(knl) print(sync_map) - barrier_count = sync_map.filter_by(kind="barrier_local").eval_and_sum() + barrier_count = sync_map.filter_by( + sync_kind=[SynchronizationKind.BARRIER_LOCAL]).eval_and_sum() assert barrier_count == 1 @@ -1078,21 +1091,23 @@ def test_all_counters_parallel_matmul(): sync_map = lp.get_synchronization_map(knl) assert len(sync_map) == 2 - assert sync_map.filter_by(kind="kernel_launch").eval_and_sum(params) == 1 - assert sync_map.filter_by(kind="barrier_local").eval_and_sum(params) == 2*m/bsize + assert sync_map.filter_by( + sync_kind=[SynchronizationKind.KERNEL_LAUNCH]).eval_and_sum(params) == 1 + assert sync_map.filter_by( + sync_kind=[SynchronizationKind.BARRIER_LOCAL]).eval_and_sum(params) == 2*m/bsize op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) f32mul = op_map[ - lp.Op(np.float32, "mul", CG.SUBGROUP, "matmul") + lp.Op(np.float32, OpType.MUL, CG.SUBGROUP, "matmul") ].eval_with_dict(params) f32add = op_map[ - lp.Op(np.float32, "add", CG.SUBGROUP, "matmul") + lp.Op(np.float32, OpType.ADD, CG.SUBGROUP, "matmul") ].eval_with_dict(params) i32ops = op_map[ - lp.Op(np.int32, "add", CG.SUBGROUP, "matmul") + lp.Op(np.int32, OpType.ADD, CG.SUBGROUP, "matmul") ].eval_with_dict(params) i32ops += op_map[ - lp.Op(np.dtype(np.int32), "mul", CG.SUBGROUP, "matmul") + lp.Op(np.dtype(np.int32), OpType.MUL, CG.SUBGROUP, "matmul") ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups @@ -1101,17 +1116,17 @@ def test_all_counters_parallel_matmul(): mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True, subgroup_size=SGS) - f32s1lb = mem_access_map[lp.MemAccess("global", np.float32, + f32s1lb = mem_access_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={0: 1, 1: Variable("ell")}, gid_strides={1: bsize}, - direction="load", variable="b", + read_write=AccessDirection.READ, variable="b", count_granularity=CG.WORKITEM, kernel_name="matmul") ].eval_with_dict(params) - f32s1la = mem_access_map[lp.MemAccess("global", np.float32, + f32s1la = mem_access_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={0: 1, 1: Variable("m")}, gid_strides={0: Variable("m")*bsize}, - direction="load", + read_write=AccessDirection.READ, variable="a", count_granularity=CG.WORKITEM, kernel_name="matmul") ].eval_with_dict(params) @@ -1119,10 +1134,10 @@ def test_all_counters_parallel_matmul(): assert f32s1lb == n*m*ell/bsize assert f32s1la == n*m*ell/bsize - f32coal = mem_access_map[lp.MemAccess("global", np.float32, + f32coal = mem_access_map[lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={0: 1, 1: Variable("ell")}, gid_strides={0: Variable("ell")*bsize, 1: bsize}, - direction="store", variable="c", + read_write=AccessDirection.WRITE, variable="c", count_granularity=CG.WORKITEM, kernel_name="matmul") ].eval_with_dict(params) @@ -1131,23 +1146,23 @@ def test_all_counters_parallel_matmul(): local_mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=SGS).filter_by(mtype=["local"]) + subgroup_size=SGS).filter_by(address_space=[AddressSpace.LOCAL]) - local_mem_l = local_mem_map.filter_by(direction=["load"] + local_mem_l = local_mem_map.filter_by(read_write=[AccessDirection.READ] ).eval_and_sum(params) # (count-per-sub-group)*n_subgroups assert local_mem_l == m*2*n_subgroups - local_mem_l_a = local_mem_map[lp.MemAccess("local", np.dtype(np.float32), - direction="load", + local_mem_l_a = local_mem_map[lp.MemAccess(AddressSpace.LOCAL, np.dtype(np.float32), + read_write=AccessDirection.READ, lid_strides={1: 16}, gid_strides={}, variable="a_fetch", count_granularity=CG.SUBGROUP, kernel_name="matmul") ].eval_with_dict(params) - local_mem_l_b = local_mem_map[lp.MemAccess("local", np.dtype(np.float32), - direction="load", + local_mem_l_b = local_mem_map[lp.MemAccess(AddressSpace.LOCAL, np.dtype(np.float32), + read_write=AccessDirection.READ, lid_strides={0: 1}, gid_strides={}, variable="b_fetch", @@ -1158,7 +1173,7 @@ def test_all_counters_parallel_matmul(): # (count-per-sub-group)*n_subgroups assert local_mem_l_a == local_mem_l_b == m*n_subgroups - local_mem_s = local_mem_map.filter_by(direction=["store"] + local_mem_s = local_mem_map.filter_by(read_write=[AccessDirection.WRITE] ).eval_and_sum(params) # (count-per-sub-group)*n_subgroups @@ -1236,19 +1251,19 @@ def test_mem_access_tagged_variables(): subgroup_size=SGS) f32s1lb = mem_access_map[ - lp.MemAccess("global", np.float32, + lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={0: 1}, gid_strides={1: bsize}, - direction="load", variable="b", + read_write=AccessDirection.READ, variable="b", variable_tags=frozenset([lp.LegacyStringInstructionTag("mmbload")]), count_granularity=CG.WORKITEM, kernel_name="matmul") ].eval_with_dict(params) f32s1la = mem_access_map[ - lp.MemAccess("global", np.float32, + lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={1: Variable("m")}, gid_strides={0: Variable("m")*bsize}, - direction="load", + read_write=AccessDirection.READ, variable="a", variable_tags=frozenset([lp.LegacyStringInstructionTag("mmaload")]), count_granularity=CG.SUBGROUP, @@ -1261,10 +1276,10 @@ def test_mem_access_tagged_variables(): assert f32s1la == m*n_subgroups f32coal = mem_access_map[ - lp.MemAccess("global", np.float32, + lp.MemAccess(AddressSpace.GLOBAL, np.float32, lid_strides={0: 1, 1: Variable("ell")}, gid_strides={0: Variable("ell")*bsize, 1: bsize}, - direction="store", variable="c", + read_write=AccessDirection.WRITE, variable="c", variable_tags=frozenset([lp.LegacyStringInstructionTag("mmresult")]), count_granularity=CG.WORKITEM, kernel_name="matmul") @@ -1333,24 +1348,27 @@ def test_summations_and_filters(): mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, subgroup_size=SGS) - loads_a = mem_map.filter_by(direction=["load"], variable=["a"], + loads_a = mem_map.filter_by(read_write=[AccessDirection.READ], variable=["a"], count_granularity=[CG.SUBGROUP] ).eval_and_sum(params) # uniform: (count-per-sub-group)*n_subgroups assert loads_a == (2*n*m*ell)*n_subgroups - global_stores = mem_map.filter_by(mtype=["global"], direction=["store"], + global_stores = mem_map.filter_by(address_space=[AddressSpace.GLOBAL], + read_write=[AccessDirection.WRITE], count_granularity=[CG.SUBGROUP] ).eval_and_sum(params) # uniform: (count-per-sub-group)*n_subgroups assert global_stores == (n*m*ell + n*m)*n_subgroups - ld_bytes = mem_map.filter_by(mtype=["global"], direction=["load"], + ld_bytes = mem_map.filter_by(address_space=[AddressSpace.GLOBAL], + read_write=[AccessDirection.READ], count_granularity=[CG.SUBGROUP] ).to_bytes().eval_and_sum(params) - st_bytes = mem_map.filter_by(mtype=["global"], direction=["store"], + st_bytes = mem_map.filter_by(address_space=[AddressSpace.GLOBAL], + read_write=[AccessDirection.WRITE], count_granularity=[CG.SUBGROUP] ).to_bytes().eval_and_sum(params) @@ -1359,10 +1377,14 @@ def test_summations_and_filters(): assert st_bytes == (4*n*m*ell + 8*n*m)*n_subgroups # ignore stride and variable names in this map - reduced_map = mem_map.group_by("mtype", "dtype", "direction") - f32lall = reduced_map[lp.MemAccess("global", np.float32, direction="load") + reduced_map = mem_map.group_by("address_space", "dtype", "read_write") + f32lall = reduced_map[lp.MemAccess(address_space=AddressSpace.GLOBAL, + dtype=np.float32, + read_write=AccessDirection.READ) ].eval_with_dict(params) - f64lall = reduced_map[lp.MemAccess("global", np.float64, direction="load") + f64lall = reduced_map[lp.MemAccess(address_space=AddressSpace.GLOBAL, + dtype=np.float64, + read_write=AccessDirection.READ) ].eval_with_dict(params) # uniform: (count-per-sub-group)*n_subgroups @@ -1382,7 +1404,7 @@ def test_summations_and_filters(): assert f64 == n*m assert i32 == n*m*2 - addsub_all = op_map.filter_by(name=["add", "sub"]).eval_and_sum(params) + addsub_all = op_map.filter_by(op_type=[OpType.ADD]).eval_and_sum(params) f32ops_all = op_map.filter_by(dtype=[np.float32]).eval_and_sum(params) assert addsub_all == n*m*ell + n*m*2 assert f32ops_all == n*m*ell*3 @@ -1390,16 +1412,16 @@ def test_summations_and_filters(): non_field = op_map.filter_by(xxx=[np.float32]).eval_and_sum(params) assert non_field == 0 - ops_nodtype = op_map.group_by("name") + ops_nodtype = op_map.group_by("op_type") ops_noname = op_map.group_by("dtype") - mul_all = ops_nodtype[lp.Op(name="mul")].eval_with_dict(params) + mul_all = ops_nodtype[lp.Op(op_type=OpType.MUL)].eval_with_dict(params) f64ops_all = ops_noname[lp.Op(dtype=np.float64)].eval_with_dict(params) assert mul_all == n*m*ell + n*m assert f64ops_all == n*m def func_filter(key): return key.lid_strides == {} and key.dtype == to_loopy_type(np.float64) and \ - key.direction == "load" + key.read_write == AccessDirection.READ f64l = mem_map.filter_by_func(func_filter).eval_and_sum(params) # uniform: (count-per-sub-group)*n_subgroups @@ -1423,7 +1445,7 @@ def test_strided_footprint(): knl = lp.split_iname(knl, "i_inner", bx, outer_tag="unr", inner_tag="l.0") footprints = lp.gather_access_footprints(knl) - x_l_foot = footprints[("x", "read")] + x_l_foot = footprints[lp.MemAccess(variable="x", read_write=AccessDirection.READ)] from loopy.statistics import count num = count(knl, x_l_foot).eval_with_dict(param_dict) @@ -1453,7 +1475,7 @@ def test_stats_on_callable_kernel(): op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, count_within_subscripts=True) - f64_add = op_map.filter_by(name="add").eval_and_sum({}) + f64_add = op_map.filter_by(op_type=[OpType.ADD]).eval_and_sum({}) assert f64_add == 400 @@ -1479,7 +1501,7 @@ def test_stats_on_callable_kernel_within_loop(): op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, count_within_subscripts=True) - f64_add = op_map.filter_by(name="add").eval_and_sum({}) + f64_add = op_map.filter_by(op_type=[OpType.ADD]).eval_and_sum({}) assert f64_add == 8000 @@ -1507,7 +1529,7 @@ def test_callable_kernel_with_substitution(): op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, count_within_subscripts=True) - f64_add = op_map.filter_by(name="add").eval_and_sum({}) + f64_add = op_map.filter_by(op_type=[OpType.ADD]).eval_and_sum({}) assert f64_add == 8000 @@ -1525,12 +1547,85 @@ def test_no_loop_ops(): op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True, count_within_subscripts=True) - f64_add = op_map.filter_by(name="add").eval_and_sum({}) - f64_mul = op_map.filter_by(name="mul").eval_and_sum({}) + f64_add = op_map.filter_by(op_type=[OpType.ADD]).eval_and_sum({}) + f64_mul = op_map.filter_by(op_type=[OpType.MUL]).eval_and_sum({}) assert f64_add == 3 assert f64_mul == 1 +from pytools.tag import Tag + + +class MyCostTag1(Tag): + pass + + +class MyCostTag2(Tag): + pass + + +class MyCostTagSum(Tag): + pass + + +def test_op_taggedexpression(): + from pymbolic.primitives import Subscript, Sum, Variable + + from loopy.symbolic import TaggedExpression + + n = 500 + + knl = lp.make_kernel( + "{[i]: 0<=i