From 233e093719a83a7535c9a03edc15c224fcdf83ac Mon Sep 17 00:00:00 2001 From: Drew Youngwerth Date: Wed, 4 Sep 2024 09:29:23 -0700 Subject: [PATCH] [Voi-94] Support TCO (#33) * Support TCO * Add test * Set test workflow timeout * Less fragile test --- .github/workflows/pr.yml | 2 +- src/__tests__/compiler.test.ts | 11 ++++++-- src/__tests__/fixtures/e2e-file.ts | 11 ++++++++ src/assembler.ts | 45 +++++++++++++++++++++--------- src/assembler/return-call.ts | 11 ++++++++ 5 files changed, 64 insertions(+), 16 deletions(-) create mode 100644 src/assembler/return-call.ts diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 037620df..789cd742 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -5,7 +5,7 @@ on: [pull_request] jobs: test: runs-on: ubuntu-latest - + timeout-minutes: 5 steps: - uses: actions/checkout@v4 - name: Use Node.js diff --git a/src/__tests__/compiler.test.ts b/src/__tests__/compiler.test.ts index dd600ec3..9cd0dbfd 100644 --- a/src/__tests__/compiler.test.ts +++ b/src/__tests__/compiler.test.ts @@ -1,8 +1,9 @@ -import { e2eVoidText, gcVoidText } from "./fixtures/e2e-file.js"; +import { e2eVoidText, gcVoidText, tcoText } from "./fixtures/e2e-file.js"; import { compile } from "../compiler.js"; -import { describe, test } from "vitest"; +import { describe, expect, test, vi } from "vitest"; import assert from "node:assert"; import { getWasmFn, getWasmInstance } from "../lib/wasm.js"; +import * as rCallUtil from "../assembler/return-call.js"; describe("E2E Compiler Pipeline", () => { test("Compiler can compile and run a basic void program", async (t) => { @@ -35,4 +36,10 @@ describe("E2E Compiler Pipeline", () => { t.expect(test5(), "test 5 returns correct value").toEqual(21); t.expect(test6(), "test 6 returns correct value").toEqual(-1); }); + + test("Compiler can do tco", async (t) => { + const spy = vi.spyOn(rCallUtil, "returnCall"); + await compile(tcoText); + t.expect(spy).toHaveBeenCalledTimes(1); + }); }); diff --git a/src/__tests__/fixtures/e2e-file.ts b/src/__tests__/fixtures/e2e-file.ts index 2fc0a517..d8ee53a0 100644 --- a/src/__tests__/fixtures/e2e-file.ts +++ b/src/__tests__/fixtures/e2e-file.ts @@ -86,3 +86,14 @@ pub fn test6() let vec = Bitly { x: 52, y: 2, z: 21 } get_num_from_vec_sub_obj(vec) `; + +export const tcoText = ` +use std::all + +// Tail call fib +pub fn fib(n: i32, a: i32, b: i32) -> i32 + if n == 0 then: + a + else: + fib(n - 1, b, a + b) +`; diff --git a/src/assembler.ts b/src/assembler.ts index 2cc57630..91b530c2 100644 --- a/src/assembler.ts +++ b/src/assembler.ts @@ -21,6 +21,7 @@ import { HeapTypeRef } from "./lib/binaryen-gc/types.js"; import { getExprType } from "./semantics/resolution/get-expr-type.js"; import { Match, MatchCase } from "./syntax-objects/match.js"; import { initExtensionHelpers } from "./assembler/extension-helpers.js"; +import { returnCall } from "./assembler/return-call.js"; export const assemble = (ast: Expr) => { const mod = new binaryen.Module(); @@ -36,17 +37,21 @@ interface CompileExprOpts { expr: T; mod: binaryen.Module; extensionHelpers: ReturnType; + isReturnExpr?: boolean; } const compileExpression = (opts: CompileExprOpts): number => { - const { expr, mod } = opts; - if (expr.isCall()) return compileCall({ ...opts, expr }); + const { expr, mod, isReturnExpr } = opts; + opts.isReturnExpr = false; + // These can take isReturnExpr + if (expr.isCall()) return compileCall({ ...opts, expr, isReturnExpr }); + if (expr.isBlock()) return compileBlock({ ...opts, expr, isReturnExpr }); + if (expr.isMatch()) return compileMatch({ ...opts, expr, isReturnExpr }); if (expr.isInt()) return mod.i32.const(expr.value); if (expr.isFloat()) return mod.f32.const(expr.value); if (expr.isIdentifier()) return compileIdentifier({ ...opts, expr }); if (expr.isFn()) return compileFunction({ ...opts, expr }); if (expr.isVariable()) return compileVariable({ ...opts, expr }); - if (expr.isBlock()) return compileBlock({ ...opts, expr }); if (expr.isDeclaration()) return compileDeclaration({ ...opts, expr }); if (expr.isModule()) return compileModule({ ...opts, expr }); if (expr.isObjectLiteral()) return compileObjectLiteral({ ...opts, expr }); @@ -54,7 +59,6 @@ const compileExpression = (opts: CompileExprOpts): number => { if (expr.isUse()) return mod.nop(); if (expr.isMacro()) return mod.nop(); if (expr.isMacroVariable()) return mod.nop(); - if (expr.isMatch()) return compileMatch({ ...opts, expr }); if (expr.isBool()) { return expr.value ? mod.i32.const(1) : mod.i32.const(0); @@ -86,7 +90,13 @@ const compileModule = (opts: CompileExprOpts) => { const compileBlock = (opts: CompileExprOpts) => { return opts.mod.block( null, - opts.expr.body.toArray().map((expr) => compileExpression({ ...opts, expr })) + opts.expr.body.toArray().map((expr, index, array) => { + if (index === array.length - 1) { + return compileExpression({ ...opts, expr, isReturnExpr: true }); + } + + return compileExpression({ ...opts, expr, isReturnExpr: false }); + }) ); }; @@ -146,7 +156,7 @@ const compileIdentifier = (opts: CompileExprOpts) => { }; const compileCall = (opts: CompileExprOpts): number => { - const { expr, mod } = opts; + const { expr, mod, isReturnExpr } = opts; if (expr.calls("quote")) return (expr.argAt(0) as Int).value; // TODO: This is an ugly hack to get constants that the compiler needs to know at compile time for ex bnr calls; if (expr.calls("=")) return compileAssign(opts); if (expr.calls("if")) return compileIf(opts); @@ -167,13 +177,16 @@ const compileCall = (opts: CompileExprOpts): number => { const args = expr.args .toArray() - .map((expr) => compileExpression({ ...opts, expr })); + .map((expr) => compileExpression({ ...opts, expr, isReturnExpr: false })); - return mod.call( - expr.fn!.id, - args, - mapBinaryenType(opts, expr.fn!.returnType!) - ); + const id = expr.fn!.id; + const returnType = mapBinaryenType(opts, expr.fn!.returnType!); + + if (isReturnExpr && id === expr.parentFn?.id) { + return returnCall(mod, id, args, returnType); + } + + return mod.call(id, args, returnType); }; const compileObjectInit = (opts: CompileExprOpts) => { @@ -250,7 +263,13 @@ const compileFunction = (opts: CompileExprOpts): number => { const { expr: fn, mod } = opts; const parameterTypes = getFunctionParameterTypes(opts, fn); const returnType = mapBinaryenType(opts, fn.getReturnType()); - const body = compileExpression({ ...opts, expr: fn.body! }); + + const body = compileExpression({ + ...opts, + expr: fn.body!, + isReturnExpr: true, + }); + const variableTypes = getFunctionVarTypes(opts, fn); // TODO: Vars should probably be registered with the function type rather than body (for consistency). mod.addFunction(fn.id, parameterTypes, returnType, variableTypes, body); diff --git a/src/assembler/return-call.ts b/src/assembler/return-call.ts new file mode 100644 index 00000000..820402a7 --- /dev/null +++ b/src/assembler/return-call.ts @@ -0,0 +1,11 @@ +import binaryen from "binaryen"; +import { ExpressionRef, TypeRef } from "../lib/binaryen-gc/types.js"; + +export const returnCall = ( + mod: binaryen.Module, + fnId: string, + args: ExpressionRef[], + returnType: TypeRef +) => { + return mod.return_call(fnId, args, returnType); +};