diff --git a/cps/defers.nim b/cps/defers.nim index b2b1aed9..87475c17 100644 --- a/cps/defers.nim +++ b/cps/defers.nim @@ -1,96 +1,106 @@ -import cps/normalizedast +import cps/[normalizedast, rewrites] import std/macros except newStmtList template isNotNil*(x: untyped): bool = not(isNil(x)) -func hasDefer*(n: NormNode): bool = - ## Return whether there is a `defer` within the given node - ## that might cause it to be rewritten. - case n.kind - of nnkDefer: - true - of nnkStmtList, nnkStmtListExpr: +proc findTree(n: NormNode, traverseKinds: set[NimNodeKind], + cond: proc(n: NormNode): bool): NormNode = + ## Find the first node in the AST tree `n` satisfying `cond`. + ## + ## :traverseKinds: + ## The AST node kinds to traverse. + if cond(n): + result = n + elif n.kind in traverseKinds: for child in n.items: - if child.hasDefer: - return true - false - else: - false + result = findTree(child, traverseKinds, cond) + if result.isNotNil: + break -proc rewriteDefer*(n: NormNode): NormNode = - ## Rewrite the AST of `n` so that all `defer` nodes are - ## transformed into try-finally +proc splitStmtList(n, splitNode: NormNode): seq[NormNode] = + ## Split the StmtList `n` at `splitNode` to up to two splits. + ## + ## The same StmtList hierarchy will be shared on both splits. + if n == splitNode: + discard "The node to be split upon should not be in result" - # TODO: This could be made simpler + elif n.kind in {nnkStmtList, nnkStmtListExpr}: + result.add: copyNimNode(n) + for idx, child in n.pairs: + template listTail(): seq[NormNode] = + ## The remaining nodes in this list, excluding the current node + n[idx + 1 .. ^1] - proc splitDefer(n: NormNode): tuple[b, d, a: NormNode] = - ## Cut the AST into three parts: - ## - b: all nodes before the defer that could affect `n` - ## - d: the defer node itself - ## - a: nodes that follow and are affected by the defer - ## - ## If there are no defers in the AST, all nodes are left as-is in - ## `b`. - case n.kind - of nnkDefer: - # it's just a defer node; return it as such - result = (nil, n, nil) - of nnkStmtList, nnkStmtListExpr: - var d, b, a: NormNode - # Make a copy of our node to the part before defer - b = copyNimNode n - # The rest of the split stays in a new node of the same kind - a = NormNode newNimNode(n.kind, n) + let childSplits = splitStmtList(n[idx], splitNode) + if childSplits.len > 0: + # Merge the first split + result[0].add: childSplits[0] - # Look for the defer in the child nodes - for idx, child in n.pairs: - if child.hasDefer: - var xb, xa: NormNode - (xb, d, xa) = splitDefer child - if xb.isNotNil: - b.add xb - # Add nodes coming after the defer to the list of affected nodes - if xa.isNotNil: - a.add xa - if idx < n.len - 1: - a.add n[idx + 1 .. ^1] - # We are done here + # The inner StmtList has two splits + if childSplits.len > 1: + # Construct the other split + result.add: copyNimNode(n) + # Add the inner split + result[^1].add childSplits[1] + # Add the remaining nodes of this list + result[^1].add listTail() + # Done break - # If there's no defer in the child node, add as-is - b.add child - result = (b, d, a) - else: - # there's no defer, so yield the input as unaffected - result = (n, nil, nil) + else: + # There are no splits, thus this is the split node + # + # Construct the other split with the remaining nodes in this list + result.add: + copyNimNode(n).add: + listTail() + # Done + break + + else: + # If it's not a StmtList, just return as is + result.add n + +proc rewriteDefer*(n: NormNode): NormNode = + ## Rewrite the AST of `n` so that all `defer` nodes are + ## transformed into try-finally + proc rewriter(n: NormNode): NormNode = + let deferNode = + findTree(n, {nnkStmtList, nnkStmtListExpr}) do (n: NormNode) -> bool: + n.kind == nnkDefer - if n.hasDefer and n.kind != nnkDefer: - let (before, deferNode, affected) = splitDefer n - result = before + if deferNode.isNotNil: + let + splits = splitStmtList(n, deferNode) + # Construct a finally node with lineinfo of the defer node + finallyNode = newNimNode(nnkFinally, deferNode).add: + # Use the defer body as the finally body + deferNode.last - # Construct the try-finally statement - let tryStmt = newNimNode(nnkTryStmt) + if splits.len > 0: + # Add the first split, or "nodes before defer" + result = splits[0] - if affected.isNotNil: - # Wrap the affected body with the try statement - tryStmt.add affected - else: - # If this doesn't exist, use an empty StmtList - tryStmt.add newStmtList() - # Convert the defer node into a finally node - tryStmt.add: - newNimNode(nnkFinally, deferNode).add: - deferNode[0] + # Construct a try-finally with the remainder and add it to the end + result.add: + # Create a new try statement with the lineinfo of the second split + newNimNode(nnkTryStmt, splits[1]).add( + # Put the second split, or "nodes after defer", as the try body + splits[1], + finallyNode + ) + else: + # There are no splits, thus this is a defer without a container + # + # Construct a naked try-finally for it. + result = NormNode: + newNimNode(nnkTryStmt, deferNode).add( + # Use an empty statement list for the body + newNimNode(nnkStmtList, deferNode), + finallyNode + ) - result.add tryStmt - # Run the transform on the result to cover any - # nodes nested within this node - result = rewriteDefer result + # Also rewrite the result to eliminate all defers in it + result = rewriteDefer(result) - else: - # This node doesn't have any `defer` that will cause it to be rewritten - result = copyNimNode n - # Process its children instead - for child in n.items: - result.add: - rewriteDefer child + result = filter(n, rewriter)