Skip to content

Commit

Permalink
[Voi-99] Very basic generic function support (#35)
Browse files Browse the repository at this point in the history
* BnrCall gc support

* Fix parsing bug: commas are not operators.

* Potential basic generic type resolution

* Semi working

* Fix typechecking bug

* Detect duplicate var defs

* Use type args to find a good fn match

* Forgot to add these changes to last commit

* IT WORKS!

* Add unit test for generics

* Cleanup
  • Loading branch information
drew-y authored Sep 7, 2024
1 parent 0463b01 commit 2166975
Show file tree
Hide file tree
Showing 31 changed files with 450 additions and 96 deletions.
20 changes: 16 additions & 4 deletions src/__tests__/compiler.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
e2eVoidText,
gcVoidText,
genericsText,
goodTypeInferenceText,
tcoText,
} from "./fixtures/e2e-file.js";
Expand Down Expand Up @@ -36,23 +37,34 @@ describe("E2E Compiler Pipeline", () => {
const test4 = getWasmFn("test4", instance);
const test5 = getWasmFn("test5", instance);
const test6 = getWasmFn("test6", instance);
const test7 = getWasmFn("test7", instance);
assert(test1, "Test1 exists");
assert(test2, "Test2 exists");
assert(test3, "Test3 exists");
assert(test4, "Test4 exists");
assert(test5, "Test3 exists");
assert(test6, "Test4 exists");
assert(test5, "Test5 exists");
assert(test6, "Test6 exists");
assert(test7, "Test7 exists");
t.expect(test1(), "test 1 returns correct value").toEqual(13);
t.expect(test2(), "test 2 returns correct value").toEqual(1);
t.expect(test3(), "test 3 returns correct value").toEqual(2);
t.expect(test4(), "test 4 returns correct value").toEqual(52);
t.expect(test5(), "test 5 returns correct value").toEqual(21);
t.expect(test6(), "test 6 returns correct value").toEqual(-1);
t.expect(test5(), "test 5 returns correct value").toEqual(52);
t.expect(test6(), "test 6 returns correct value").toEqual(21);
t.expect(test7(), "test 7 returns correct value").toEqual(-1);
});

test("Compiler can do tco", async (t) => {
const spy = vi.spyOn(rCallUtil, "returnCall");
await compile(tcoText);
t.expect(spy).toHaveBeenCalledTimes(1);
});

test("Generic fn compilation", async (t) => {
const mod = await compile(genericsText);
const instance = getWasmInstance(mod);
const main = getWasmFn("main", instance);
assert(main, "Main exists");
t.expect(main(), "main 1 returns correct value").toEqual(143);
});
});
24 changes: 22 additions & 2 deletions src/__tests__/fixtures/e2e-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,18 @@ pub fn test4()
let vec = Point { x: 52, y: 2, z: 21 }
vec.get_x()
// Test match type guard (Point case), should return 21
// Test match type guard (Pointy case), should return 52
pub fn test5()
let vec = Pointy { x: 52, y: 2, z: 21 }
get_num_from_vec_sub_obj(vec)
// Test match type guard (Point case), should return 21
pub fn test6()
let vec = Point { x: 52, y: 2, z: 21 }
get_num_from_vec_sub_obj(vec)
// Test match type guard (else case), should return -1
pub fn test6()
pub fn test7()
let vec = Bitly { x: 52, y: 2, z: 21 }
get_num_from_vec_sub_obj(vec)
`;
Expand Down Expand Up @@ -114,3 +119,18 @@ fn fib_alias(n: i32, a: i64, b: i64) -> i64
pub fn main() -> i64
fib(10, 0i64, 1i64)
`;

export const genericsText = `
use std::all
type DSArrayI32 = DSArray<i32>
pub fn main()
let arr2 = ds_array_init<f64>(10)
arr2.set<f64>(0, 1.5)
arr2.get<f64>(0)
let arr: DSArrayI32 = ds_array_init<i32>(10)
arr.set<i32>(9, 143)
arr.get<i32>(9)
`;
70 changes: 54 additions & 16 deletions src/assembler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import { Expr } from "./syntax-objects/expr.js";
import { Fn } from "./syntax-objects/fn.js";
import { Identifier } from "./syntax-objects/identifier.js";
import { Int } from "./syntax-objects/int.js";
import { Type, Primitive, ObjectType } from "./syntax-objects/types.js";
import {
Type,
Primitive,
ObjectType,
DSArrayType,
} from "./syntax-objects/types.js";
import { Variable } from "./syntax-objects/variable.js";
import { Block } from "./syntax-objects/block.js";
import { Declaration } from "./syntax-objects/declaration.js";
Expand All @@ -17,7 +22,8 @@ import {
refCast,
structGetFieldValue,
} from "./lib/binaryen-gc/index.js";
import { HeapTypeRef } from "./lib/binaryen-gc/types.js";
import * as gc from "./lib/binaryen-gc/index.js";
import { TypeRef } from "./lib/binaryen-gc/types.js";
import { getExprType } from "./semantics/resolution/get-expr-type.js";
import { Match, MatchCase } from "./syntax-objects/match.js";
import { initExtensionHelpers } from "./assembler/extension-helpers.js";
Expand Down Expand Up @@ -264,14 +270,30 @@ const compileAssign = (opts: CompileExprOpts<Call>): number => {
const compileBnrCall = (opts: CompileExprOpts<Call>): number => {
const { expr } = opts;
const funcId = expr.labeledArgAt(0) as Identifier;
const argTypes = expr.labeledArgAt(1) as Call;
const namespace = argTypes.identifierArgAt(0).value;
const args = expr.labeledArgAt(3) as Call;
const func = (opts.mod as any)[namespace][funcId.value];
const namespace = (expr.labeledArgAt(1) as Identifier).value;
const args = expr.labeledArgAt(2) as Call;

const func =
namespace === "gc"
? (...args: unknown[]) => (gc as any)[funcId.value](opts.mod, ...args)
: (opts.mod as any)[namespace][funcId.value];

return func(
...(args.argArrayMap((expr: Expr) =>
compileExpression({ ...opts, expr })
) ?? [])
...(args.argArrayMap((expr: Expr) => {
if (expr?.isCall() && expr.calls("BnrType")) {
const type = getExprType(expr.typeArgs?.at(0));
if (!type) return opts.mod.nop();
return mapBinaryenType(opts, type);
}

if (expr?.isCall() && expr.calls("BnrConst")) {
const arg = expr.argAt(0);
if (!arg) return opts.mod.nop();
if ("value" in arg) return arg.value;
}

return compileExpression({ ...opts, expr });
}) ?? [])
);
};

Expand All @@ -287,6 +309,17 @@ const compileVariable = (opts: CompileExprOpts<Variable>): number => {

const compileFunction = (opts: CompileExprOpts<Fn>): number => {
const { expr: fn, mod } = opts;
if (fn.genericInstances) {
fn.genericInstances.forEach((instance) =>
compileFunction({ ...opts, expr: instance })
);
return mod.nop();
}

if (fn.typeParameters) {
return mod.nop();
}

const parameterTypes = getFunctionParameterTypes(opts, fn);
const returnType = mapBinaryenType(opts, fn.getReturnType());

Expand Down Expand Up @@ -378,20 +411,25 @@ const mapBinaryenType = (opts: CompileExprOpts, type: Type): binaryen.Type => {
if (isPrimitiveId(type, "i64")) return binaryen.i64;
if (isPrimitiveId(type, "f64")) return binaryen.f64;
if (isPrimitiveId(type, "void")) return binaryen.none;
if (type.isObjectType()) {
return type.binaryenType ? type.binaryenType : buildObjectType(opts, type);
}
if (type.isObjectType()) return buildObjectType(opts, type);
if (type.isDSArrayType()) return buildDSArrayType(opts, type);
throw new Error(`Unsupported type ${type}`);
};

const isPrimitiveId = (type: Type, id: Primitive) =>
type.isPrimitiveType() && type.name.value === id;

const buildDSArrayType = (opts: CompileExprOpts, type: DSArrayType) => {
if (type.binaryenType) return type.binaryenType;
const mod = opts.mod;
const elemType = mapBinaryenType(opts, type.elemType!);
type.binaryenType = gc.defineArrayType(mod, elemType, true, type.id);
return type.binaryenType;
};

/** TODO: Skip building types for object literals that are part of an initializer of an obj */
const buildObjectType = (
opts: CompileExprOpts,
obj: ObjectType
): HeapTypeRef => {
const buildObjectType = (opts: CompileExprOpts, obj: ObjectType): TypeRef => {
if (obj.binaryenType) return obj.binaryenType;
const mod = opts.mod;

const binaryenType = defineStructType(mod, {
Expand Down
8 changes: 8 additions & 0 deletions src/lib/binaryen-gc/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ export const binaryenTypeToHeapType = (type: Type): HeapTypeRef => {
return bin._BinaryenTypeGetHeapType(type);
};

// So we can use the from compileBnrCall
export const modBinaryenTypeToHeapType = (
_mod: binaryen.Module,
type: Type
): HeapTypeRef => {
return bin._BinaryenTypeGetHeapType(type);
};

export const refCast = (
mod: binaryen.Module,
ref: ExpressionRef,
Expand Down
7 changes: 5 additions & 2 deletions src/lib/grammar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { Expr } from "../syntax-objects/expr.js";
import { Identifier } from "../syntax-objects/identifier.js";

export const isTerminator = (char: string) =>
isWhitespace(char) || isBracket(char) || isQuote(char) || isOpChar(char);
isWhitespace(char) ||
isBracket(char) ||
isQuote(char) ||
isOpChar(char) ||
char === ",";

export const isQuote = newTest(["'", '"', "`"]);

Expand All @@ -19,7 +23,6 @@ export const isOpChar = newTest([
":",
"?",
".",
",",
";",
"<",
">",
Expand Down
6 changes: 6 additions & 0 deletions src/parser/lexer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ export const lexer = (chars: CharStream): Token => {
break;
}

if (!token.hasChars && char === ",") {
token.addChar(chars.consumeChar());
break;
}

if (!token.hasChars && isOpChar(char)) {
consumeOperator(chars, token);
break;
Expand Down Expand Up @@ -62,6 +67,7 @@ export const lexer = (chars: CharStream): Token => {

const consumeOperator = (chars: CharStream, token: Token) => {
while (isOpChar(chars.next)) {
if (token.value === ">" && chars.next === ">") break; // Ugly hack to support generics, means >> is not a valid operator
token.addChar(chars.consumeChar());
}
};
Expand Down
13 changes: 12 additions & 1 deletion src/semantics/check-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ const checkCallTypes = (call: Call): Call | ObjectLiteral => {
if (call.calls("=")) return checkAssign(call);
if (call.calls("member-access")) return call; // TODO
if (call.fn?.isObjectType()) return checkObjectInit(call);
call.args = call.args.map(checkTypes);

if (!call.fn) {
throw new Error(`Could not resolve fn ${call.fnName} at ${call.location}`);
Expand All @@ -63,6 +62,8 @@ const checkCallTypes = (call: Call): Call | ObjectLiteral => {
);
}

call.args = call.args.map(checkTypes);

return call;
};

Expand Down Expand Up @@ -188,6 +189,16 @@ const checkUse = (use: Use) => {
};

const checkFnTypes = (fn: Fn): Fn => {
if (fn.genericInstances) {
fn.genericInstances.forEach(checkFnTypes);
return fn;
}

// If the function has type parameters and not genericInstances, it isn't in use and wont be compiled.
if (fn.typeParameters) {
return fn;
}

checkParameters(fn.parameters);
checkTypes(fn.body);

Expand Down
Loading

0 comments on commit 2166975

Please sign in to comment.