Skip to content

Commit

Permalink
Implement closures
Browse files Browse the repository at this point in the history
  • Loading branch information
pmatos authored Jun 15, 2023
1 parent e2bb4f5 commit 050ba79
Show file tree
Hide file tree
Showing 13 changed files with 335 additions and 28 deletions.
42 changes: 42 additions & 0 deletions src/ASTRuntime.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "ASTRuntime.h"

#include "AnalysisFreeVars.h"

#include <utility>

using namespace ast;

Closure::Closure(const Lambda &Lbd, const std::vector<Environment> &Envs)
: ClonableNode(ASTNodeKind::AST_Closure),
L(std::unique_ptr<Lambda>(static_cast<Lambda *>(Lbd.clone()))) {

// To create a closure we need to:

// 1. Find the free variables in the lambda.
AnalysisFreeVars AFV;
L->accept(AFV);
auto const &FreeVars = AFV.getResult();

// 2. Find in the current environment, the values of the free variables
// and save them.
for (auto const &Var : FreeVars) {
for (auto const &E : llvm::reverse(Envs)) {
auto const &Val = E.lookup(Var);
if (Val) {
Env.add(Var, std::unique_ptr<ValueNode>(Val->clone()));
break;
}
}
}
}

Closure::Closure(const Closure &Other)
: ClonableNode(ASTNodeKind::AST_Closure),
L(std::unique_ptr<Lambda>(static_cast<Lambda *>(Other.L->clone()))) {
for (auto const &E : Other.Env) {
Env.add(E.first, std::unique_ptr<ValueNode>(E.second->clone()));
}
}

void Closure::dump() const {}
void Closure::write() const {}
134 changes: 134 additions & 0 deletions src/AnalysisFreeVars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#include "AnalysisFreeVars.h"

void AnalysisFreeVars::visit(ast::Identifier const &Id) {
// If the identifier is not in the environment, then it is a free variable.

for (auto const &Var : llvm::reverse(Vars)) {
if (Var.count(Id) == 0) {
Result.insert(Id);
}
}
}

void AnalysisFreeVars::visit(ast::Integer const &Int) {
// Integers do not have free variables.
// Nothing to do.
}

void AnalysisFreeVars::visit(ast::Linklet const &Linklet) {
llvm::errs() << "Free variable analysis only applies to expressions.\n";
}

void AnalysisFreeVars::visit(ast::DefineValues const &DV) {
llvm::errs() << "Free variable analysis only applies to expressions.\n";
}

void AnalysisFreeVars::visit(ast::Values const &V) {
// Need to check for free variable in each expression of the Values
// expression.
for (auto const &Expr : V.getExprs()) {
Expr->accept(*this);
}
}

void AnalysisFreeVars::visit(ast::Void const &Vd) {
// Void expressions have no free variables.
// Nothing to do.
}

void AnalysisFreeVars::visit(ast::Lambda const &L) {
const ast::Formal &F = L.getFormals();
std::set<ast::Identifier> FormalVars;

if (F.getType() == ast::Formal::Type::Identifier) {
auto IF = static_cast<const ast::IdentifierFormal &>(F);
FormalVars.insert(IF.getIdentifier());
} else if (F.getType() == ast::Formal::Type::List) {
auto LF = static_cast<const ast::ListFormal &>(F);
for (auto const &Id : LF.getIds()) {
FormalVars.insert(Id);
}
} else if (F.getType() == ast::Formal::Type::ListRest) {
auto LRF = static_cast<const ast::ListRestFormal &>(F);
for (auto const &Id : LRF.getIds()) {
FormalVars.insert(Id);
}
FormalVars.insert(LRF.getRestFormal());
}

// Save the current environment.
Vars.push_back(FormalVars);

// Check for free variables in the body of the lambda.
L.getBody().accept(*this);

// Restore the environment.
Vars.pop_back();
}

void AnalysisFreeVars::visit(ast::Closure const &L) {
// Closures by definition do not have free variables.
// Nothing to do.
}

void AnalysisFreeVars::visit(ast::Begin const &B) {
// Iterate through all the begin expressions and check for free variables.
for (auto const &Expr : B.getBody()) {
Expr->accept(*this);
}
}

void AnalysisFreeVars::visit(ast::List const &L) {
// Iterate through all the List expressions and check for free variables.
for (auto const &Expr : L.values()) {
Expr->accept(*this);
}
}

void AnalysisFreeVars::visit(ast::Application const &A) {
// Iterate through all the Application expressions and check for free
// variables.
for (auto const &Expr : A.getExprs()) {
Expr->accept(*this);
}
}

void AnalysisFreeVars::visit(ast::SetBang const &SB) {
// Check for free variables on the right hand side expression of SetBang
// expression.
SB.getExpr().accept(*this);
}

void AnalysisFreeVars::visit(ast::IfCond const &If) {
// Check for free variables on the condition expression of IfCond expression.
If.getCond().accept(*this);
// Check for free variables on the consequent expression of IfCond expression.
If.getThen().accept(*this);
// Check for free variables on the alternative expression of IfCond
// expression.
If.getElse().accept(*this);
}

void AnalysisFreeVars::visit(ast::BooleanLiteral const &Bool) {
// Boolean literals have no free variables.
// Nothing to do.
}

void AnalysisFreeVars::visit(ast::LetValues const &LV) {
std::set<ast::Identifier> LVVars;
for (size_t Idx = 0; Idx < LV.bindingCount(); Idx++)
for (auto const &Var : LV.getBindingIds(Idx))
LVVars.insert(Var);

Vars.push_back(LVVars);

for (size_t Idx = 0; Idx < LV.bodyCount(); Idx++)
LV.getBodyExpr(Idx).accept(*this);

Vars.pop_back();
}

void AnalysisFreeVars::visit(ast::RuntimeFunction const &LV) {
// Runtime Functions have no free variables.
// Nothing to do.
}
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ add_llvm_executable(norac
main.cpp
environment.cpp
ast.cpp
ASTRuntime.cpp
AnalysisFreeVars.cpp
idpool.cpp
Parse.cpp
Lex.cpp
Expand Down
36 changes: 36 additions & 0 deletions src/include/ASTRuntime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include "ast.h"
#include "environment.h"

#include <memory>

namespace ast {
//
// This file includes the structures that are used in addition to
// those in ast.h during runtime interpretation.
//
// The simplest example is the Closure.

// A Closure is a runtime manifestation of a Lambda.
class Closure : public ClonableNode<Closure, ValueNode> {
public:
Closure(const Lambda &Lbd, const std::vector<Environment> &Envs);
Closure(const Closure &Other);

static bool classof(const ASTNode *N) {
return N->getKind() == ASTNodeKind::AST_Closure;
}

void dump() const override;
void write() const override;

const Lambda &getLambda() const { return *L; }
const Environment &getEnvironment() const { return Env; }

private:
std::unique_ptr<Lambda> L;
Environment Env;
};

}; // namespace ast
1 change: 1 addition & 0 deletions src/include/ASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ASTVisitor {
virtual void visit(ast::Values const &V) = 0;
virtual void visit(ast::Void const &Vd) = 0;
virtual void visit(ast::Lambda const &L) = 0;
virtual void visit(ast::Closure const &L) = 0;
virtual void visit(ast::Begin const &B) = 0;
virtual void visit(ast::List const &L) = 0;
virtual void visit(ast::Application const &A) = 0;
Expand Down
42 changes: 42 additions & 0 deletions src/include/AnalysisFreeVars.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include "ASTVisitor.h"
#include "ast.h"

#include <llvm/ADT/STLExtras.h>

#include <cassert>
#include <map>
#include <memory>
#include <set>
#include <vector>

// File implementing free variable analysis for expressions.

class AnalysisFreeVars : public ASTVisitor {
public:
virtual void visit(ast::Identifier const &Id) override;
virtual void visit(ast::Integer const &Int) override;
virtual void visit(ast::Linklet const &Linklet) override;
virtual void visit(ast::DefineValues const &DV) override;
virtual void visit(ast::Values const &V) override;
virtual void visit(ast::Void const &Vd) override;
virtual void visit(ast::Lambda const &L) override;
virtual void visit(ast::Closure const &L) override;
virtual void visit(ast::Begin const &B) override;
virtual void visit(ast::List const &L) override;
virtual void visit(ast::Application const &A) override;
virtual void visit(ast::SetBang const &SB) override;
virtual void visit(ast::IfCond const &If) override;
virtual void visit(ast::BooleanLiteral const &Bool) override;
virtual void visit(ast::LetValues const &LV) override;
virtual void visit(ast::RuntimeFunction const &LV) override;

// Get the current saved result.
std::set<ast::Identifier> getResult() const { return Result; };

private:
std::set<ast::Identifier> Result; /// List of free variables.
llvm::SmallVector<std::set<ast::Identifier>>
Vars; /// Environment map for identifiers.
};
30 changes: 24 additions & 6 deletions src/include/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <gmp.h>
#include <iostream>
#include <memory>
#include <ranges>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -42,6 +43,7 @@ class ASTNode {
AST_BooleanLiteral,
AST_Integer,
AST_Lambda,
AST_Closure, // result of evaluating a Lambda expression
AST_List,
AST_Values,
AST_Void,
Expand Down Expand Up @@ -206,6 +208,7 @@ class Linklet : public ClonableNode<Linklet, ASTNode> {
class Application : public ClonableNode<Application, ExprNode> {
public:
Application() : ClonableNode(ASTNodeKind::AST_Application) {}
Application &operator=(Application &&) = delete;
Application(const Application &);
Application(Application &&) = default;
Application &operator=(const Application &) = delete;
Expand All @@ -221,8 +224,12 @@ class Application : public ClonableNode<Application, ExprNode> {
return N->getKind() == ASTNodeKind::AST_Application;
}

const llvm::SmallVector<std::unique_ptr<ExprNode>> &getExprs() const {
return Exprs;
}

private:
std::vector<std::unique_ptr<ExprNode>> Exprs;
llvm::SmallVector<std::unique_ptr<ExprNode>> Exprs;
};

// AST Node representing a begin or begin0 expression.
Expand All @@ -234,7 +241,8 @@ class Begin : public ClonableNode<Begin, ExprNode> {
Begin &operator=(const Begin &B) = delete;
~Begin() = default;

[[nodiscard]] const std::vector<std::unique_ptr<ExprNode>> &getBody() const {
[[nodiscard]] const llvm::SmallVector<std::unique_ptr<ExprNode>> &
getBody() const {
return Body;
}
[[nodiscard]] size_t bodyCount() const { return Body.size(); }
Expand All @@ -249,7 +257,7 @@ class Begin : public ClonableNode<Begin, ExprNode> {
}

private:
std::vector<std::unique_ptr<ExprNode>> Body;
llvm::SmallVector<std::unique_ptr<ExprNode>> Body;
bool Zero = false;
};

Expand Down Expand Up @@ -495,6 +503,8 @@ class Lambda : public ClonableNode<Lambda, ValueNode> {
void dump() const override;
void write() const override;

llvm::SmallVector<Identifier> findFreeVariables() const;

static bool classof(const ASTNode *N) {
return N->getKind() == ASTNodeKind::AST_Lambda;
}
Expand All @@ -507,6 +517,8 @@ class Lambda : public ClonableNode<Lambda, ValueNode> {
class LetValues : public ClonableNode<LetValues, ExprNode> {
public:
LetValues() : ClonableNode(ASTNodeKind::AST_LetValues) {}
LetValues &operator=(const LetValues &) = delete;
LetValues &operator=(LetValues &&) = delete;
LetValues(const LetValues &DV);
LetValues(LetValues &&DV) = default;
~LetValues() = default;
Expand Down Expand Up @@ -575,8 +587,10 @@ class List : public ClonableNode<List, ValueNode> {
return N->getKind() == ASTNodeKind::AST_List;
}

auto const &values() const { return Values; }

private:
std::vector<std::unique_ptr<ast::ValueNode>> Values;
llvm::SmallVector<std::unique_ptr<ast::ValueNode>> Values;
};

class SetBang : public ClonableNode<SetBang, ExprNode> {
Expand Down Expand Up @@ -671,8 +685,12 @@ class RuntimeFunction : public ValueNode {
virtual std::unique_ptr<ast::ValueNode>
operator()(const std::vector<const ast::ValueNode *> &Args) const = 0;

void dump() const override { std::cerr << "#<runtime:" << getName() << ">"; }
void write() const override { std::cout << "#<runtime:" << getName() << ">"; }
void dump() const override {
llvm::errs() << "#<runtime:" << getName() << ">";
}
void write() const override {
llvm::outs() << "#<runtime:" << getName() << ">";
}

static bool classof(const ASTNode *N) {
return N->getKind() == ASTNodeKind::AST_RuntimeFunction;
Expand Down
1 change: 1 addition & 0 deletions src/include/ast_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class DefineValues;
class Values;
class Void;
class Lambda;
class Closure;
class Begin;
class List;
class Application;
Expand Down
4 changes: 4 additions & 0 deletions src/include/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class Environment {
// Lookup an identifier in the environment.
std::unique_ptr<ast::ValueNode> lookup(ast::Identifier const &Id) const;

// Implement range style access to the Env map.
auto begin() const { return Env.begin(); }
auto end() const { return Env.end(); }

private:
// Environment map for identifiers.
std::map<ast::Identifier, std::shared_ptr<ast::ValueNode>> Env;
Expand Down
Loading

0 comments on commit 050ba79

Please sign in to comment.