Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[eve]: Preserve annex in custom visitors #1874

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions src/gt4py/eve/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
return None


def _preserve_annex(
node: concepts.Node, new_node: concepts.Node, preserved_annex_attrs: tuple[str, ...]
) -> None:
# access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter
old_annex = node.annex
new_annex_dict = new_node.annex.__dict__
for key in preserved_annex_attrs:
if (value := getattr(old_annex, key, NOTHING)) is not NOTHING:
# Note: The annex value of the new node might not be equal
# (in the sense that an equality comparison returns false),
# but in the context of the pass, they are equivalent.
# Therefore, we don't assert equality here.
new_annex_dict[key] = value


class NodeTranslator(NodeVisitor):
"""Special `NodeVisitor` to translate nodes and trees.

Expand Down Expand Up @@ -158,13 +173,8 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
if (new_child := self.visit(child, **kwargs)) is not NOTHING
}
)
if self.PRESERVED_ANNEX_ATTRS and (old_annex := getattr(node, "__node_annex__", None)):
# access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter
new_annex_dict = new_node.annex.__dict__
for key in self.PRESERVED_ANNEX_ATTRS:
if (value := getattr(old_annex, key, NOTHING)) is not NOTHING:
assert key not in new_annex_dict
new_annex_dict[key] = value
if self.PRESERVED_ANNEX_ATTRS and getattr(node, "__node_annex__", None):
_preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS)

return new_node

Expand All @@ -189,3 +199,16 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
)

return copy.deepcopy(node, memo=memo)

def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
new_node = super().visit(node, **kwargs)

if (
isinstance(node, concepts.Node)
and isinstance(new_node, concepts.Node)
and self.PRESERVED_ANNEX_ATTRS
and getattr(node, "__node_annex__", None)
):
_preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS)

return new_node
8 changes: 6 additions & 2 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
domain_utils,
ir_makers as im,
)
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms import constant_folding, trace_shifts
from gt4py.next.utils import flatten_nested_tuple, tree_map


Expand Down Expand Up @@ -436,8 +436,12 @@ def _infer_stmt(
**kwargs: Unpack[InferenceOptions],
):
if isinstance(stmt, itir.SetAt):
# constant fold once otherwise constant folding after domain inference might create (syntactic) differences
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change really belong to this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially checked that the new annex attribute is equal to the old one. That would be a desirable property of the IR and this change here is a step in that direction. This also motivated the comment above

# Note: The annex value of the new node might not be equal
# (in the sense that an equality comparison returns false),
# but in the context of the pass, they are equivalent.
# Therefore, we don't assert equality here.

However there are annex values that are equivalent, but not equal e.g. the domain annex with a value of
(domain, domain) (tuple of domains) is equivalent to domain. I therefore gave up on the assert, but kept this change here. Does that make sense?

# between the domain stored in IR and in the annex
domain = constant_folding.ConstantFolding.apply(stmt.domain)

transformed_call, _ = infer_expr(
stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), **kwargs
stmt.expr, domain_utils.SymbolicDomain.from_expr(domain), **kwargs
)

return itir.SetAt(
Expand Down
17 changes: 17 additions & 0 deletions tests/eve_tests/unit_tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,20 @@ class SampleTranslator(eve.NodeTranslator):
assert translated_node.annex.foo == 1
assert translated_node.annex.bar is None
assert not hasattr(translated_node.annex, "baz")


def test_annex_preservation_translated_node(compound_node: eve.Node):
compound_node.annex.foo = 1
compound_node.annex.baz = 2

class SampleTranslator(eve.NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("foo",)

def visit_Node(self, node: eve.Node):
# just return an empty node, we care about the annex only anyway
return eve.Node()

translated_node = SampleTranslator().visit(compound_node)

assert translated_node.annex.foo == 1
assert not hasattr(translated_node.annex, "baz")