Skip to content

Commit

Permalink
[Voi-94] Support TCO (#33)
Browse files Browse the repository at this point in the history
* Support TCO

* Add test

* Set test workflow timeout

* Less fragile test
  • Loading branch information
drew-y authored Sep 4, 2024
1 parent 95514b3 commit 233e093
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on: [pull_request]
jobs:
test:
runs-on: ubuntu-latest

timeout-minutes: 5
steps:
- uses: actions/checkout@v4
- name: Use Node.js
Expand Down
11 changes: 9 additions & 2 deletions src/__tests__/compiler.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { e2eVoidText, gcVoidText } from "./fixtures/e2e-file.js";
import { e2eVoidText, gcVoidText, tcoText } from "./fixtures/e2e-file.js";
import { compile } from "../compiler.js";
import { describe, test } from "vitest";
import { describe, expect, test, vi } from "vitest";
import assert from "node:assert";
import { getWasmFn, getWasmInstance } from "../lib/wasm.js";
import * as rCallUtil from "../assembler/return-call.js";

describe("E2E Compiler Pipeline", () => {
test("Compiler can compile and run a basic void program", async (t) => {
Expand Down Expand Up @@ -35,4 +36,10 @@ describe("E2E Compiler Pipeline", () => {
t.expect(test5(), "test 5 returns correct value").toEqual(21);
t.expect(test6(), "test 6 returns correct value").toEqual(-1);
});

test("Compiler can do tco", async (t) => {
const spy = vi.spyOn(rCallUtil, "returnCall");
await compile(tcoText);
t.expect(spy).toHaveBeenCalledTimes(1);
});
});
11 changes: 11 additions & 0 deletions src/__tests__/fixtures/e2e-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,14 @@ pub fn test6()
let vec = Bitly { x: 52, y: 2, z: 21 }
get_num_from_vec_sub_obj(vec)
`;

export const tcoText = `
use std::all
// Tail call fib
pub fn fib(n: i32, a: i32, b: i32) -> i32
if n == 0 then:
a
else:
fib(n - 1, b, a + b)
`;
45 changes: 32 additions & 13 deletions src/assembler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { HeapTypeRef } from "./lib/binaryen-gc/types.js";
import { getExprType } from "./semantics/resolution/get-expr-type.js";
import { Match, MatchCase } from "./syntax-objects/match.js";
import { initExtensionHelpers } from "./assembler/extension-helpers.js";
import { returnCall } from "./assembler/return-call.js";

export const assemble = (ast: Expr) => {
const mod = new binaryen.Module();
Expand All @@ -36,25 +37,28 @@ interface CompileExprOpts<T = Expr> {
expr: T;
mod: binaryen.Module;
extensionHelpers: ReturnType<typeof initExtensionHelpers>;
isReturnExpr?: boolean;
}

const compileExpression = (opts: CompileExprOpts): number => {
const { expr, mod } = opts;
if (expr.isCall()) return compileCall({ ...opts, expr });
const { expr, mod, isReturnExpr } = opts;
opts.isReturnExpr = false;
// These can take isReturnExpr
if (expr.isCall()) return compileCall({ ...opts, expr, isReturnExpr });
if (expr.isBlock()) return compileBlock({ ...opts, expr, isReturnExpr });
if (expr.isMatch()) return compileMatch({ ...opts, expr, isReturnExpr });
if (expr.isInt()) return mod.i32.const(expr.value);
if (expr.isFloat()) return mod.f32.const(expr.value);
if (expr.isIdentifier()) return compileIdentifier({ ...opts, expr });
if (expr.isFn()) return compileFunction({ ...opts, expr });
if (expr.isVariable()) return compileVariable({ ...opts, expr });
if (expr.isBlock()) return compileBlock({ ...opts, expr });
if (expr.isDeclaration()) return compileDeclaration({ ...opts, expr });
if (expr.isModule()) return compileModule({ ...opts, expr });
if (expr.isObjectLiteral()) return compileObjectLiteral({ ...opts, expr });
if (expr.isType()) return compileType({ ...opts, expr });
if (expr.isUse()) return mod.nop();
if (expr.isMacro()) return mod.nop();
if (expr.isMacroVariable()) return mod.nop();
if (expr.isMatch()) return compileMatch({ ...opts, expr });

if (expr.isBool()) {
return expr.value ? mod.i32.const(1) : mod.i32.const(0);
Expand Down Expand Up @@ -86,7 +90,13 @@ const compileModule = (opts: CompileExprOpts<VoidModule>) => {
const compileBlock = (opts: CompileExprOpts<Block>) => {
return opts.mod.block(
null,
opts.expr.body.toArray().map((expr) => compileExpression({ ...opts, expr }))
opts.expr.body.toArray().map((expr, index, array) => {
if (index === array.length - 1) {
return compileExpression({ ...opts, expr, isReturnExpr: true });
}

return compileExpression({ ...opts, expr, isReturnExpr: false });
})
);
};

Expand Down Expand Up @@ -146,7 +156,7 @@ const compileIdentifier = (opts: CompileExprOpts<Identifier>) => {
};

const compileCall = (opts: CompileExprOpts<Call>): number => {
const { expr, mod } = opts;
const { expr, mod, isReturnExpr } = opts;
if (expr.calls("quote")) return (expr.argAt(0) as Int).value; // TODO: This is an ugly hack to get constants that the compiler needs to know at compile time for ex bnr calls;
if (expr.calls("=")) return compileAssign(opts);
if (expr.calls("if")) return compileIf(opts);
Expand All @@ -167,13 +177,16 @@ const compileCall = (opts: CompileExprOpts<Call>): number => {

const args = expr.args
.toArray()
.map((expr) => compileExpression({ ...opts, expr }));
.map((expr) => compileExpression({ ...opts, expr, isReturnExpr: false }));

return mod.call(
expr.fn!.id,
args,
mapBinaryenType(opts, expr.fn!.returnType!)
);
const id = expr.fn!.id;
const returnType = mapBinaryenType(opts, expr.fn!.returnType!);

if (isReturnExpr && id === expr.parentFn?.id) {
return returnCall(mod, id, args, returnType);
}

return mod.call(id, args, returnType);
};

const compileObjectInit = (opts: CompileExprOpts<Call>) => {
Expand Down Expand Up @@ -250,7 +263,13 @@ const compileFunction = (opts: CompileExprOpts<Fn>): number => {
const { expr: fn, mod } = opts;
const parameterTypes = getFunctionParameterTypes(opts, fn);
const returnType = mapBinaryenType(opts, fn.getReturnType());
const body = compileExpression({ ...opts, expr: fn.body! });

const body = compileExpression({
...opts,
expr: fn.body!,
isReturnExpr: true,
});

const variableTypes = getFunctionVarTypes(opts, fn); // TODO: Vars should probably be registered with the function type rather than body (for consistency).

mod.addFunction(fn.id, parameterTypes, returnType, variableTypes, body);
Expand Down
11 changes: 11 additions & 0 deletions src/assembler/return-call.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import binaryen from "binaryen";
import { ExpressionRef, TypeRef } from "../lib/binaryen-gc/types.js";

export const returnCall = (
mod: binaryen.Module,
fnId: string,
args: ExpressionRef[],
returnType: TypeRef
) => {
return mod.return_call(fnId, args, returnType);
};

0 comments on commit 233e093

Please sign in to comment.