Skip to content

Commit

Permalink
Add Transformer API for in-place manipulation of AST (#99). Fixes #97
Browse files Browse the repository at this point in the history
* Add perf test for something actually modifying the ast
* Add Transformer API for in-place manipulation of AST, deprecate .traverse()
* Align APIs of Visitor and Transformer
  • Loading branch information
Pike authored Feb 13, 2019
1 parent f3f9053 commit 23e649e
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 10 deletions.
71 changes: 62 additions & 9 deletions fluent.syntax/fluent/syntax/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ class Visitor(object):
The boolean value of the returned value determines if the visitor
descends into the children of the given AST node.
'''
def visit(self, value):
if isinstance(value, BaseNode):
self.visit_node(value)
if isinstance(value, list):
for node in value:
self.visit(node)

def visit_node(self, node):
def visit(self, node):
if isinstance(node, list):
for child in node:
self.visit(child)
return
if not isinstance(node, BaseNode):
return
nodename = type(node).__name__
visit = getattr(self, 'visit_{}'.format(nodename), self.generic_visit)
should_descend = visit(node)
Expand All @@ -33,6 +32,41 @@ def generic_visit(self, node):
return True


class Transformer(Visitor):
'''In-place AST Transformer pattern.
Subclass this to create an in-place modified variant
of the given AST.
If you need to keep the original AST around, pass
a `node.clone()` to the transformer.
'''
def visit(self, node):
if not isinstance(node, BaseNode):
return node

nodename = type(node).__name__
visit = getattr(self, 'visit_{}'.format(nodename), self.generic_visit)
return visit(node)

def generic_visit(self, node):
for propname, propvalue in vars(node).items():
if isinstance(propvalue, list):
new_vals = []
for child in propvalue:
new_val = self.visit(child)
if new_val is not None:
new_vals.append(new_val)
# in-place manipulation
propvalue[:] = new_vals
elif isinstance(propvalue, BaseNode):
new_val = self.visit(propvalue)
if new_val is None:
delattr(node, propname)
else:
setattr(node, propname, new_val)
return node


def to_json(value, fn=None):
if isinstance(value, BaseNode):
return value.to_json(fn)
Expand Down Expand Up @@ -79,7 +113,9 @@ class BaseNode(object):
"""

def traverse(self, fun):
"""Postorder-traverse this node and apply `fun` to all child nodes.
"""DEPRECATED. Please use Visitor or Transformer.
Postorder-traverse this node and apply `fun` to all child nodes.
Traverse this node depth-first applying `fun` to subnodes and leaves.
Children are processed before parents (postorder traversal).
Expand All @@ -103,6 +139,23 @@ def visit(value):

return fun(node)

def clone(self):
"""Create a deep clone of the current node."""
def visit(value):
"""Clone node and its descendants."""
if isinstance(value, BaseNode):
return value.clone()
if isinstance(value, list):
return [visit(child) for child in value]
if isinstance(value, tuple):
return tuple(visit(child) for child in value)
return value

# Use all attributes found on the node as kwargs to the constructor.
return self.__class__(
**{name: visit(value) for name, value in vars(self).items()}
)

def equals(self, other, ignored_fields=['span']):
"""Compare two nodes.
Expand Down
6 changes: 6 additions & 0 deletions fluent.syntax/tests/syntax/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_same_simple_message(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_selector_message(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -41,6 +42,7 @@ def test_same_selector_message(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_complex_placeable_message(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -49,6 +51,7 @@ def test_same_complex_placeable_message(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_message_with_attribute(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -58,6 +61,7 @@ def test_same_message_with_attribute(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_message_with_attributes(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -68,6 +72,7 @@ def test_same_message_with_attributes(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))

def test_same_junk(self):
message1 = self.parse_ftl_entry("""\
Expand All @@ -76,6 +81,7 @@ def test_same_junk(self):

self.assertTrue(message1.equals(message1))
self.assertTrue(message1.equals(message1.traverse(identity)))
self.assertTrue(message1.equals(message1.clone()))


class TestOrderEquals(unittest.TestCase):
Expand Down
81 changes: 80 additions & 1 deletion fluent.syntax/tests/syntax/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ def test_resource(self):
)


class TestTransformer(unittest.TestCase):
def test(self):
resource = FluentParser().parse(dedent_ftl('''\
one = Message
two = Messages
three = Has a
.an = Message string in the Attribute
'''))
prior_res_id = id(resource)
prior_msg_id = id(resource.body[1].value)
backup = resource.clone()
transformed = ReplaceTransformer('Message', 'Term').visit(resource)
self.assertEqual(prior_res_id, id(transformed))
self.assertEqual(
prior_msg_id,
id(transformed.body[1].value)
)
self.assertFalse(transformed.equals(backup))
self.assertEqual(
transformed.body[1].value.elements[0].value,
'Terms'
)


class WordCounter(object):
def __init__(self):
self.word_count = 0
Expand All @@ -70,6 +94,34 @@ def visit_TextElement(self, node):
return False


class ReplaceText(object):
def __init__(self, before, after):
self.before = before
self.after = after

def __call__(self, node):
"""Perform find and replace on text values only"""
if type(node) == ast.TextElement:
node.value = node.value.replace(self.before, self.after)
return node


class ReplaceTransformer(ast.Transformer):
def __init__(self, before, after):
self.before = before
self.after = after

def generic_visit(self, node):
if isinstance(node, (ast.Span, ast.Annotation)):
return node
return super(ReplaceTransformer, self).generic_visit(node)

def visit_TextElement(self, node):
"""Perform find and replace on text values only"""
node.value = node.value.replace(self.before, self.after)
return node


class TestPerf(unittest.TestCase):
def setUp(self):
parser = FluentParser()
Expand All @@ -89,6 +141,27 @@ def test_visitor(self):
counter.visit(self.resource)
self.assertEqual(counter.word_count, 277)

def test_edit_traverse(self):
edited = self.resource.traverse(ReplaceText('Tab', 'Reiter'))
self.assertEqual(
edited.body[4].attributes[0].value.elements[0].value,
'New Reiter'
)

def test_edit_transform(self):
edited = ReplaceTransformer('Tab', 'Reiter').visit(self.resource)
self.assertEqual(
edited.body[4].attributes[0].value.elements[0].value,
'New Reiter'
)

def test_edit_cloned(self):
edited = ReplaceTransformer('Tab', 'Reiter').visit(self.resource.clone())
self.assertEqual(
edited.body[4].attributes[0].value.elements[0].value,
'New Reiter'
)


def gather_stats(method, repeat=10, number=50):
t = timeit.Timer(
Expand All @@ -107,7 +180,13 @@ def gather_stats(method, repeat=10, number=50):


if __name__=='__main__':
for m in ('traverse', 'visitor'):
for m in (
'traverse',
'visitor',
'edit_traverse',
'edit_transform',
'edit_cloned',
):
results = gather_stats(m)
try:
import statistics
Expand Down

0 comments on commit 23e649e

Please sign in to comment.