From 9233385ac604e6f3ebf53f8791db77b9d01144a9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 10 Oct 2022 23:44:10 -0500 Subject: [PATCH] add WithTag --- doc/tutorial.rst | 4 +-- loopy/statistics.py | 22 ++++++++++++---- loopy/symbolic.py | 46 ++++++++++++++++++++++++++++++++ test/test_statistics.py | 58 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 7 deletions(-) diff --git a/doc/tutorial.rst b/doc/tutorial.rst index 8e65e4591..d93001faa 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -1566,7 +1566,7 @@ information provided. Now we will count the operations: >>> op_map = lp.get_op_map(knl, subgroup_size=32) >>> print(op_map) - Op(np:dtype('float32'), add, subgroup, "stats_knl"): ... + Op(np:dtype('float32'), add, subgroup, "stats_knl", None): ... Each line of output will look roughly like:: @@ -1628,7 +1628,7 @@ together into keys containing only the specified fields: >>> op_map_dtype = op_map.group_by('dtype') >>> print(op_map_dtype) - Op(np:dtype('float32'), None, None): ... + Op(np:dtype('float32'), None, None, None): ... >>> f32op_count = op_map_dtype[lp.Op(dtype=np.float32) ... ].eval_with_dict(param_dict) diff --git a/loopy/statistics.py b/loopy/statistics.py index bdcdb0878..31f48e431 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -636,10 +636,14 @@ class Op(ImmutableRecord): A :class:`str` representing the kernel name where the operation occurred. + .. attribute:: tags + + A :class:`frozenset` of tags to the operation. + """ def __init__(self, dtype=None, name=None, count_granularity=None, - kernel_name=None): + kernel_name=None, tags=None): if count_granularity not in CountGranularity.ALL+[None]: raise ValueError("Op.__init__: count_granularity '%s' is " "not allowed. count_granularity options: %s" @@ -651,15 +655,17 @@ def __init__(self, dtype=None, name=None, count_granularity=None, super().__init__(dtype=dtype, name=name, count_granularity=count_granularity, - kernel_name=kernel_name) + kernel_name=kernel_name, + tags=tags) def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness if self.kernel_name is not None: return (f"Op({self.dtype}, {self.name}, {self.count_granularity}," - f' "{self.kernel_name}")') + f' "{self.kernel_name}", {self.tags})') else: - return f"Op({self.dtype}, {self.name}, {self.count_granularity})" + return f"Op({self.dtype}, {self.name}, " + \ + f"{self.count_granularity}, {self.tags})" # }}} @@ -724,7 +730,7 @@ class MemAccess(ImmutableRecord): work-group executes on a single compute unit with all work-items within the work-group sharing local memory. A sub-group is an implementation-dependent grouping of work-items within a work-group, - analagous to an NVIDIA CUDA warp. + analogous to an NVIDIA CUDA warp. .. attribute:: kernel_name @@ -922,6 +928,12 @@ def map_constant(self, expr): map_tagged_variable = map_constant map_variable = map_constant + def map_with_tag(self, expr): + opmap = self.rec(expr.expr) + for op in opmap.count_map: + op.tags = expr.tags + return opmap + def map_call(self, expr): from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index b6bd1d009..5343d8f75 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -114,6 +114,10 @@ # {{{ mappers with support for loopy-specific primitives class IdentityMapperMixin: + def map_with_tag(self, expr, *args, **kwargs): + new_expr = self.rec(expr.expr, *args, **kwargs) + return WithTag(expr.tags, new_expr) + def map_literal(self, expr, *args, **kwargs): return expr @@ -207,6 +211,12 @@ def map_common_subexpression_uncached(self, expr): class WalkMapperMixin: + def map_with_tag(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.expr, *args, **kwargs) + def map_literal(self, expr, *args, **kwargs): self.visit(expr, *args, **kwargs) @@ -273,6 +283,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase): class CombineMapper(CombineMapperBase): + def map_with_tag(self, expr, *args, **kwargs): + return self.rec(expr.expr, *args, **kwargs) + def map_reduction(self, expr, *args, **kwargs): return self.rec(expr.expr, *args, **kwargs) @@ -298,6 +311,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase, class StringifyMapper(StringifyMapperBase): + def map_with_tag(self, expr, *args): + from pymbolic.mapper.stringifier import PREC_NONE + return f"WithTag({expr.tags}, {self.rec(expr.expr, PREC_NONE)}" + def map_literal(self, expr, *args): return expr.s @@ -440,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs): def map_loopy_function_identifier(self, expr, *args, **kwargs): return set() + def map_with_tag(self, expr, *args, **kwargs): + deps = self.rec(expr.expr, *args, **kwargs) + return deps + def map_sub_array_ref(self, expr, *args, **kwargs): deps = self.rec(expr.subscript, *args, **kwargs) return deps - set(expr.swept_inames) @@ -712,6 +733,31 @@ def copy(self, *, name=None, tags=None): mapper_method = intern("map_tagged_variable") +class WithTag(LoopyExpressionBase): + """ + Represents a frozenset of tags attached to an :attr:`expr`. + """ + + init_arg_names = ("tags", "expr") + + def __init__(self, tags, expr): + self.tags = tags + self.expr = expr + + def __getinitargs__(self): + return (self.tags, self.expr) + + def get_hash(self): + return hash((self.__class__, self.tags, self.expr)) + + def is_equal(self, other): + return (other.__class__ == self.__class__ + and other.tags == self.tags + and other.expr == self.expr) + + mapper_method = intern("map_with_tag") + + class Reduction(LoopyExpressionBase): """ Represents a reduction operation on :attr:`expr` across :attr:`inames`. diff --git a/test/test_statistics.py b/test/test_statistics.py index 4218067fa..504bfbf98 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1531,6 +1531,64 @@ def test_no_loop_ops(): assert f64_mul == 1 +from pytools.tag import Tag + + +class MyCostTag1(Tag): + pass + + +class MyCostTag2(Tag): + pass + + +class MyCostTagSum(Tag): + pass + + +def test_op_with_tag(): + from loopy.symbolic import WithTag + from pymbolic.primitives import Subscript, Variable, Sum + + n = 500 + + knl = lp.make_kernel( + "{[i]: 0<=i 1: exec(sys.argv[1])