Skip to content

Commit

Permalink
add TagCountMapper (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Jan 30, 2025
1 parent b73a556 commit 06b4295
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 5 deletions.
57 changes: 54 additions & 3 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
THE SOFTWARE.
"""

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Never

from orderedsets import FrozenOrderedSet
from typing_extensions import Self

import pytools
from loopy.tools import LoopyKeyBuilder
from pymbolic.mapper.optimize import optimize_mapper

Expand All @@ -49,11 +49,11 @@
Stack,
)
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper
from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, Mapper, P


if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Iterable, Mapping

from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.loopy import LoopyCall
Expand All @@ -74,6 +74,9 @@
.. autofunction:: get_num_call_sites
.. autoclass:: DirectPredecessorsGetter
.. autoclass:: TagCountMapper
.. autofunction:: get_num_tags_of_type
"""


Expand Down Expand Up @@ -594,6 +597,54 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int:
# }}}


# {{{ TagCountMapper

class TagCountMapper(CombineMapper[int, Never]):
"""
Returns the number of nodes in a DAG that are tagged with all the tags in *tags*.
"""

def __init__(self, tags: pytools.tag.Tag | Iterable[pytools.tag.Tag]) -> None:
super().__init__()
if isinstance(tags, pytools.tag.Tag):
tags = frozenset((tags,))
elif not isinstance(tags, frozenset):
tags = frozenset(tags)
self._tags = tags

def combine(self, *args: int) -> int:
return sum(args)

def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> int:
key = self._cache.get_key(expr, *args, **kwargs)
try:
return self._cache.retrieve((expr, args, kwargs), key=key)
except KeyError:
s = super().rec(expr, *args, **kwargs)
if isinstance(expr, Array) and self._tags <= expr.tags:
result = 1 + s
else:
result = 0 + s

self._cache.add((expr, args, kwargs),
0,
key=key)
return result


def get_num_tags_of_type(
outputs: Array | DictOfNamedArrays,
tags: pytools.tag.Tag | Iterable[pytools.tag.Tag]) -> int:
"""Returns the number of nodes in DAG *outputs* that are tagged with
all the tags in *tags*."""

tcm = TagCountMapper(tags)

return tcm(outputs)

# }}}


# {{{ PytatoKeyBuilder

class PytatoKeyBuilder(LoopyKeyBuilder):
Expand Down
12 changes: 10 additions & 2 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,13 +1809,21 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays:
====== ======== =======
"""
from pytato.analysis import get_nusers
from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers
materializer = MPMSMaterializer(get_nusers(expr))
new_data = {}
for name, ary in expr.items():
new_data[name] = materializer(ary.expr).expr

return DictOfNamedArrays(new_data, tags=expr.tags)
res = DictOfNamedArrays(new_data, tags=expr.tags)

from pytato import DEBUG_ENABLED
if DEBUG_ENABLED:
transform_logger.info("materialize_with_mpms: materialized "
f"{get_num_tags_of_type(res, ImplStored())} out of "
f"{get_num_nodes(res)} nodes")

return res

# }}}

Expand Down
37 changes: 37 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,43 @@ def test_adv_indexing_into_zero_long_axes():
# }}}


def test_tagcountmapper():
from testlib import RandomDAGContext, make_random_dag

from pytools.tag import Tag

from pytato.analysis import get_num_nodes, get_num_tags_of_type

class NonExistentTag(Tag):
pass

class ExistentTag(Tag):
pass

seed = 199
axis_len = 3

rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed),
axis_len=axis_len, use_numpy=False)

out = make_random_dag(rdagc_pt).tagged(ExistentTag())

dag = pt.make_dict_of_named_arrays({"out": out})

# get_num_nodes() returns an extra DictOfNamedArrays node
assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)

assert get_num_tags_of_type(dag, NonExistentTag()) == 0
assert get_num_tags_of_type(dag, frozenset((ExistentTag(),))) == 1
assert get_num_tags_of_type(dag,
frozenset((ExistentTag(), NonExistentTag()))) == 0

a = pt.make_data_wrapper(np.arange(27))
dag = a+a+a+a+a+a+a+a

assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)


def test_expand_dims_input_validate():
a = pt.make_placeholder("x", (10, 4))

Expand Down

0 comments on commit 06b4295

Please sign in to comment.