Skip to content

Commit

Permalink
rewrite defer transform (#267)
Browse files Browse the repository at this point in the history
Refactor the transform from:

  1. Check if StmtList has `defer`
  2. Recursively collect "before" and "after" as well as the defer
     itself
  3. Generate try-finally

to:

  1. Find `defer` in the StmtList
  2. Split StmtList into two at the location of the found defer
  3. Generate try-finally

This generalizes the helpers of the transform, making them reusable for
other transforms.
  • Loading branch information
alaviss authored Apr 11, 2022
1 parent 8e34538 commit 6cee9ba
Showing 1 changed file with 90 additions and 80 deletions.
170 changes: 90 additions & 80 deletions cps/defers.nim
Original file line number Diff line number Diff line change
@@ -1,96 +1,106 @@
import cps/normalizedast
import cps/[normalizedast, rewrites]
import std/macros except newStmtList

template isNotNil*(x: untyped): bool = not(isNil(x))

func hasDefer*(n: NormNode): bool =
## Return whether there is a `defer` within the given node
## that might cause it to be rewritten.
case n.kind
of nnkDefer:
true
of nnkStmtList, nnkStmtListExpr:
proc findTree(n: NormNode, traverseKinds: set[NimNodeKind],
cond: proc(n: NormNode): bool): NormNode =
## Find the first node in the AST tree `n` satisfying `cond`.
##
## :traverseKinds:
## The AST node kinds to traverse.
if cond(n):
result = n
elif n.kind in traverseKinds:
for child in n.items:
if child.hasDefer:
return true
false
else:
false
result = findTree(child, traverseKinds, cond)
if result.isNotNil:
break

proc rewriteDefer*(n: NormNode): NormNode =
## Rewrite the AST of `n` so that all `defer` nodes are
## transformed into try-finally
proc splitStmtList(n, splitNode: NormNode): seq[NormNode] =
## Split the StmtList `n` at `splitNode` to up to two splits.
##
## The same StmtList hierarchy will be shared on both splits.
if n == splitNode:
discard "The node to be split upon should not be in result"

# TODO: This could be made simpler
elif n.kind in {nnkStmtList, nnkStmtListExpr}:
result.add: copyNimNode(n)
for idx, child in n.pairs:
template listTail(): seq[NormNode] =
## The remaining nodes in this list, excluding the current node
n[idx + 1 .. ^1]

proc splitDefer(n: NormNode): tuple[b, d, a: NormNode] =
## Cut the AST into three parts:
## - b: all nodes before the defer that could affect `n`
## - d: the defer node itself
## - a: nodes that follow and are affected by the defer
##
## If there are no defers in the AST, all nodes are left as-is in
## `b`.
case n.kind
of nnkDefer:
# it's just a defer node; return it as such
result = (nil, n, nil)
of nnkStmtList, nnkStmtListExpr:
var d, b, a: NormNode
# Make a copy of our node to the part before defer
b = copyNimNode n
# The rest of the split stays in a new node of the same kind
a = NormNode newNimNode(n.kind, n)
let childSplits = splitStmtList(n[idx], splitNode)
if childSplits.len > 0:
# Merge the first split
result[0].add: childSplits[0]

# Look for the defer in the child nodes
for idx, child in n.pairs:
if child.hasDefer:
var xb, xa: NormNode
(xb, d, xa) = splitDefer child
if xb.isNotNil:
b.add xb
# Add nodes coming after the defer to the list of affected nodes
if xa.isNotNil:
a.add xa
if idx < n.len - 1:
a.add n[idx + 1 .. ^1]
# We are done here
# The inner StmtList has two splits
if childSplits.len > 1:
# Construct the other split
result.add: copyNimNode(n)
# Add the inner split
result[^1].add childSplits[1]
# Add the remaining nodes of this list
result[^1].add listTail()
# Done
break

# If there's no defer in the child node, add as-is
b.add child
result = (b, d, a)
else:
# there's no defer, so yield the input as unaffected
result = (n, nil, nil)
else:
# There are no splits, thus this is the split node
#
# Construct the other split with the remaining nodes in this list
result.add:
copyNimNode(n).add:
listTail()
# Done
break

else:
# If it's not a StmtList, just return as is
result.add n

proc rewriteDefer*(n: NormNode): NormNode =
## Rewrite the AST of `n` so that all `defer` nodes are
## transformed into try-finally
proc rewriter(n: NormNode): NormNode =
let deferNode =
findTree(n, {nnkStmtList, nnkStmtListExpr}) do (n: NormNode) -> bool:
n.kind == nnkDefer

if n.hasDefer and n.kind != nnkDefer:
let (before, deferNode, affected) = splitDefer n
result = before
if deferNode.isNotNil:
let
splits = splitStmtList(n, deferNode)
# Construct a finally node with lineinfo of the defer node
finallyNode = newNimNode(nnkFinally, deferNode).add:
# Use the defer body as the finally body
deferNode.last

# Construct the try-finally statement
let tryStmt = newNimNode(nnkTryStmt)
if splits.len > 0:
# Add the first split, or "nodes before defer"
result = splits[0]

if affected.isNotNil:
# Wrap the affected body with the try statement
tryStmt.add affected
else:
# If this doesn't exist, use an empty StmtList
tryStmt.add newStmtList()
# Convert the defer node into a finally node
tryStmt.add:
newNimNode(nnkFinally, deferNode).add:
deferNode[0]
# Construct a try-finally with the remainder and add it to the end
result.add:
# Create a new try statement with the lineinfo of the second split
newNimNode(nnkTryStmt, splits[1]).add(
# Put the second split, or "nodes after defer", as the try body
splits[1],
finallyNode
)
else:
# There are no splits, thus this is a defer without a container
#
# Construct a naked try-finally for it.
result = NormNode:
newNimNode(nnkTryStmt, deferNode).add(
# Use an empty statement list for the body
newNimNode(nnkStmtList, deferNode),
finallyNode
)

result.add tryStmt
# Run the transform on the result to cover any
# nodes nested within this node
result = rewriteDefer result
# Also rewrite the result to eliminate all defers in it
result = rewriteDefer(result)

else:
# This node doesn't have any `defer` that will cause it to be rewritten
result = copyNimNode n
# Process its children instead
for child in n.items:
result.add:
rewriteDefer child
result = filter(n, rewriter)

0 comments on commit 6cee9ba

Please sign in to comment.