From 23e649eabe96dd6ab9c4b64f0e35a6ccc759886f Mon Sep 17 00:00:00 2001 From: Axel Hecht Date: Wed, 13 Feb 2019 13:13:45 +0100 Subject: [PATCH] Add Transformer API for in-place manipulation of AST (#99). Fixes #97 * Add perf test for something actually modifying the ast * Add Transformer API for in-place manipulation of AST, deprecate .traverse() * Align APIs of Visitor and Transformer --- fluent.syntax/fluent/syntax/ast.py | 71 ++++++++++++++++--- fluent.syntax/tests/syntax/test_equals.py | 6 ++ fluent.syntax/tests/syntax/test_visitor.py | 81 +++++++++++++++++++++- 3 files changed, 148 insertions(+), 10 deletions(-) diff --git a/fluent.syntax/fluent/syntax/ast.py b/fluent.syntax/fluent/syntax/ast.py index 60670c79..d31eeedf 100644 --- a/fluent.syntax/fluent/syntax/ast.py +++ b/fluent.syntax/fluent/syntax/ast.py @@ -13,14 +13,13 @@ class Visitor(object): The boolean value of the returned value determines if the visitor descends into the children of the given AST node. ''' - def visit(self, value): - if isinstance(value, BaseNode): - self.visit_node(value) - if isinstance(value, list): - for node in value: - self.visit(node) - - def visit_node(self, node): + def visit(self, node): + if isinstance(node, list): + for child in node: + self.visit(child) + return + if not isinstance(node, BaseNode): + return nodename = type(node).__name__ visit = getattr(self, 'visit_{}'.format(nodename), self.generic_visit) should_descend = visit(node) @@ -33,6 +32,41 @@ def generic_visit(self, node): return True +class Transformer(Visitor): + '''In-place AST Transformer pattern. + + Subclass this to create an in-place modified variant + of the given AST. + If you need to keep the original AST around, pass + a `node.clone()` to the transformer. + ''' + def visit(self, node): + if not isinstance(node, BaseNode): + return node + + nodename = type(node).__name__ + visit = getattr(self, 'visit_{}'.format(nodename), self.generic_visit) + return visit(node) + + def generic_visit(self, node): + for propname, propvalue in vars(node).items(): + if isinstance(propvalue, list): + new_vals = [] + for child in propvalue: + new_val = self.visit(child) + if new_val is not None: + new_vals.append(new_val) + # in-place manipulation + propvalue[:] = new_vals + elif isinstance(propvalue, BaseNode): + new_val = self.visit(propvalue) + if new_val is None: + delattr(node, propname) + else: + setattr(node, propname, new_val) + return node + + def to_json(value, fn=None): if isinstance(value, BaseNode): return value.to_json(fn) @@ -79,7 +113,9 @@ class BaseNode(object): """ def traverse(self, fun): - """Postorder-traverse this node and apply `fun` to all child nodes. + """DEPRECATED. Please use Visitor or Transformer. + + Postorder-traverse this node and apply `fun` to all child nodes. Traverse this node depth-first applying `fun` to subnodes and leaves. Children are processed before parents (postorder traversal). @@ -103,6 +139,23 @@ def visit(value): return fun(node) + def clone(self): + """Create a deep clone of the current node.""" + def visit(value): + """Clone node and its descendants.""" + if isinstance(value, BaseNode): + return value.clone() + if isinstance(value, list): + return [visit(child) for child in value] + if isinstance(value, tuple): + return tuple(visit(child) for child in value) + return value + + # Use all attributes found on the node as kwargs to the constructor. + return self.__class__( + **{name: visit(value) for name, value in vars(self).items()} + ) + def equals(self, other, ignored_fields=['span']): """Compare two nodes. diff --git a/fluent.syntax/tests/syntax/test_equals.py b/fluent.syntax/tests/syntax/test_equals.py index 823915f7..f70546ab 100644 --- a/fluent.syntax/tests/syntax/test_equals.py +++ b/fluent.syntax/tests/syntax/test_equals.py @@ -26,6 +26,7 @@ def test_same_simple_message(self): self.assertTrue(message1.equals(message1)) self.assertTrue(message1.equals(message1.traverse(identity))) + self.assertTrue(message1.equals(message1.clone())) def test_same_selector_message(self): message1 = self.parse_ftl_entry("""\ @@ -41,6 +42,7 @@ def test_same_selector_message(self): self.assertTrue(message1.equals(message1)) self.assertTrue(message1.equals(message1.traverse(identity))) + self.assertTrue(message1.equals(message1.clone())) def test_same_complex_placeable_message(self): message1 = self.parse_ftl_entry("""\ @@ -49,6 +51,7 @@ def test_same_complex_placeable_message(self): self.assertTrue(message1.equals(message1)) self.assertTrue(message1.equals(message1.traverse(identity))) + self.assertTrue(message1.equals(message1.clone())) def test_same_message_with_attribute(self): message1 = self.parse_ftl_entry("""\ @@ -58,6 +61,7 @@ def test_same_message_with_attribute(self): self.assertTrue(message1.equals(message1)) self.assertTrue(message1.equals(message1.traverse(identity))) + self.assertTrue(message1.equals(message1.clone())) def test_same_message_with_attributes(self): message1 = self.parse_ftl_entry("""\ @@ -68,6 +72,7 @@ def test_same_message_with_attributes(self): self.assertTrue(message1.equals(message1)) self.assertTrue(message1.equals(message1.traverse(identity))) + self.assertTrue(message1.equals(message1.clone())) def test_same_junk(self): message1 = self.parse_ftl_entry("""\ @@ -76,6 +81,7 @@ def test_same_junk(self): self.assertTrue(message1.equals(message1)) self.assertTrue(message1.equals(message1.traverse(identity))) + self.assertTrue(message1.equals(message1.clone())) class TestOrderEquals(unittest.TestCase): diff --git a/fluent.syntax/tests/syntax/test_visitor.py b/fluent.syntax/tests/syntax/test_visitor.py index c9c313dc..285b67c0 100644 --- a/fluent.syntax/tests/syntax/test_visitor.py +++ b/fluent.syntax/tests/syntax/test_visitor.py @@ -50,6 +50,30 @@ def test_resource(self): ) +class TestTransformer(unittest.TestCase): + def test(self): + resource = FluentParser().parse(dedent_ftl('''\ + one = Message + two = Messages + three = Has a + .an = Message string in the Attribute + ''')) + prior_res_id = id(resource) + prior_msg_id = id(resource.body[1].value) + backup = resource.clone() + transformed = ReplaceTransformer('Message', 'Term').visit(resource) + self.assertEqual(prior_res_id, id(transformed)) + self.assertEqual( + prior_msg_id, + id(transformed.body[1].value) + ) + self.assertFalse(transformed.equals(backup)) + self.assertEqual( + transformed.body[1].value.elements[0].value, + 'Terms' + ) + + class WordCounter(object): def __init__(self): self.word_count = 0 @@ -70,6 +94,34 @@ def visit_TextElement(self, node): return False +class ReplaceText(object): + def __init__(self, before, after): + self.before = before + self.after = after + + def __call__(self, node): + """Perform find and replace on text values only""" + if type(node) == ast.TextElement: + node.value = node.value.replace(self.before, self.after) + return node + + +class ReplaceTransformer(ast.Transformer): + def __init__(self, before, after): + self.before = before + self.after = after + + def generic_visit(self, node): + if isinstance(node, (ast.Span, ast.Annotation)): + return node + return super(ReplaceTransformer, self).generic_visit(node) + + def visit_TextElement(self, node): + """Perform find and replace on text values only""" + node.value = node.value.replace(self.before, self.after) + return node + + class TestPerf(unittest.TestCase): def setUp(self): parser = FluentParser() @@ -89,6 +141,27 @@ def test_visitor(self): counter.visit(self.resource) self.assertEqual(counter.word_count, 277) + def test_edit_traverse(self): + edited = self.resource.traverse(ReplaceText('Tab', 'Reiter')) + self.assertEqual( + edited.body[4].attributes[0].value.elements[0].value, + 'New Reiter' + ) + + def test_edit_transform(self): + edited = ReplaceTransformer('Tab', 'Reiter').visit(self.resource) + self.assertEqual( + edited.body[4].attributes[0].value.elements[0].value, + 'New Reiter' + ) + + def test_edit_cloned(self): + edited = ReplaceTransformer('Tab', 'Reiter').visit(self.resource.clone()) + self.assertEqual( + edited.body[4].attributes[0].value.elements[0].value, + 'New Reiter' + ) + def gather_stats(method, repeat=10, number=50): t = timeit.Timer( @@ -107,7 +180,13 @@ def gather_stats(method, repeat=10, number=50): if __name__=='__main__': - for m in ('traverse', 'visitor'): + for m in ( + 'traverse', + 'visitor', + 'edit_traverse', + 'edit_transform', + 'edit_cloned', + ): results = gather_stats(m) try: import statistics