Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TagCountMapper #326

Merged
merged 27 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
68eb741
add TagCountMapper
matthiasdiener May 23, 2022
7949b19
clarify doc
matthiasdiener May 23, 2022
4a578d8
small code cleanups
matthiasdiener May 23, 2022
c6a1781
Merge branch 'main' into count-tags
matthiasdiener May 23, 2022
9a79943
lint fixes
matthiasdiener May 23, 2022
80dbca5
use a CombineMapper
matthiasdiener May 24, 2022
de25c6a
Merge branch 'main' into count-tags
matthiasdiener May 24, 2022
9ca88bc
simplify sum
matthiasdiener May 24, 2022
fbdfd6e
add another test
matthiasdiener May 24, 2022
29281d9
set cache to zero
matthiasdiener May 24, 2022
d01e839
Merge branch 'main' into count-tags
matthiasdiener May 25, 2022
33d7d1e
remove normalize_outputs
matthiasdiener May 25, 2022
68f159e
fix tests
matthiasdiener May 25, 2022
e50ca77
Merge branch 'main' into count-tags
matthiasdiener Jun 1, 2022
dcc2a21
materialize_with_mpms: log number of materialized nodes (#327)
matthiasdiener Jun 6, 2022
0fefda9
Merge branch 'main' into count-tags
matthiasdiener Jun 20, 2022
fb7c6be
add foldmethod
matthiasdiener Jun 20, 2022
33c1611
simplify rec()
matthiasdiener Jun 20, 2022
651539b
Merge branch 'main' into count-tags
matthiasdiener Jun 23, 2022
b0f0b96
Merge branch 'main' into count-tags
matthiasdiener Aug 1, 2022
f8e909a
Merge branch 'main' into count-tags
matthiasdiener Oct 5, 2022
18498dd
Merge branch 'main' into count-tags
matthiasdiener Jan 24, 2025
3d98036
merge errors
matthiasdiener Jan 24, 2025
2fcf93f
Merge branch 'main' into count-tags
matthiasdiener Jan 30, 2025
aad73a0
fix errors
matthiasdiener Jan 30, 2025
9f9f320
fix doc
matthiasdiener Jan 30, 2025
a37a037
Merge branch 'main' into count-tags
matthiasdiener Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
THE SOFTWARE.
"""

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Never

from orderedsets import FrozenOrderedSet
from typing_extensions import Self

import pytools
from loopy.tools import LoopyKeyBuilder
from pymbolic.mapper.optimize import optimize_mapper

Expand All @@ -49,11 +49,11 @@
Stack,
)
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper
from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, Mapper, P


if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Iterable, Mapping

from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.loopy import LoopyCall
Expand All @@ -74,6 +74,9 @@
.. autofunction:: get_num_call_sites

.. autoclass:: DirectPredecessorsGetter

.. autoclass:: TagCountMapper
.. autofunction:: get_num_tags_of_type
"""


Expand Down Expand Up @@ -594,6 +597,54 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int:
# }}}


# {{{ TagCountMapper

class TagCountMapper(CombineMapper[int, Never]):
"""
Returns the number of nodes in a DAG that are tagged with all the tags in *tags*.
"""

def __init__(self, tags: pytools.tag.Tag | Iterable[pytools.tag.Tag]) -> None:
super().__init__()
if isinstance(tags, pytools.tag.Tag):
tags = frozenset((tags,))
elif not isinstance(tags, frozenset):
tags = frozenset(tags)
self._tags = tags

def combine(self, *args: int) -> int:
return sum(args)

def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> int:
key = self._cache.get_key(expr, *args, **kwargs)
try:
return self._cache.retrieve((expr, args, kwargs), key=key)
except KeyError:
s = super().rec(expr, *args, **kwargs)
if isinstance(expr, Array) and self._tags <= expr.tags:
result = 1 + s
else:
result = 0 + s

self._cache.add((expr, args, kwargs),
0,
key=key)
return result


def get_num_tags_of_type(
outputs: Array | DictOfNamedArrays,
tags: pytools.tag.Tag | Iterable[pytools.tag.Tag]) -> int:
"""Returns the number of nodes in DAG *outputs* that are tagged with
all the tags in *tags*."""

tcm = TagCountMapper(tags)

return tcm(outputs)

# }}}
matthiasdiener marked this conversation as resolved.
Show resolved Hide resolved


# {{{ PytatoKeyBuilder

class PytatoKeyBuilder(LoopyKeyBuilder):
Expand Down
12 changes: 10 additions & 2 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,13 +1809,21 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays:
====== ======== =======

"""
from pytato.analysis import get_nusers
from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers
materializer = MPMSMaterializer(get_nusers(expr))
new_data = {}
for name, ary in expr.items():
new_data[name] = materializer(ary.expr).expr

return DictOfNamedArrays(new_data, tags=expr.tags)
res = DictOfNamedArrays(new_data, tags=expr.tags)

from pytato import DEBUG_ENABLED
if DEBUG_ENABLED:
transform_logger.info("materialize_with_mpms: materialized "
f"{get_num_tags_of_type(res, ImplStored())} out of "
f"{get_num_nodes(res)} nodes")

return res

# }}}

Expand Down
37 changes: 37 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,43 @@ def test_adv_indexing_into_zero_long_axes():
# }}}


def test_tagcountmapper():
from testlib import RandomDAGContext, make_random_dag

from pytools.tag import Tag

from pytato.analysis import get_num_nodes, get_num_tags_of_type

class NonExistentTag(Tag):
pass

class ExistentTag(Tag):
pass

seed = 199
axis_len = 3

rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed),
axis_len=axis_len, use_numpy=False)

out = make_random_dag(rdagc_pt).tagged(ExistentTag())

dag = pt.make_dict_of_named_arrays({"out": out})

# get_num_nodes() returns an extra DictOfNamedArrays node
assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)

assert get_num_tags_of_type(dag, NonExistentTag()) == 0
assert get_num_tags_of_type(dag, frozenset((ExistentTag(),))) == 1
assert get_num_tags_of_type(dag,
frozenset((ExistentTag(), NonExistentTag()))) == 0

a = pt.make_data_wrapper(np.arange(27))
dag = a+a+a+a+a+a+a+a

assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)


def test_expand_dims_input_validate():
a = pt.make_placeholder("x", (10, 4))

Expand Down
Loading