Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Voi-99] Very basic generic function support #35

Merged
merged 11 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/__tests__/compiler.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
e2eVoidText,
gcVoidText,
genericsText,
goodTypeInferenceText,
tcoText,
} from "./fixtures/e2e-file.js";
Expand Down Expand Up @@ -36,23 +37,34 @@ describe("E2E Compiler Pipeline", () => {
const test4 = getWasmFn("test4", instance);
const test5 = getWasmFn("test5", instance);
const test6 = getWasmFn("test6", instance);
const test7 = getWasmFn("test7", instance);
assert(test1, "Test1 exists");
assert(test2, "Test2 exists");
assert(test3, "Test3 exists");
assert(test4, "Test4 exists");
assert(test5, "Test3 exists");
assert(test6, "Test4 exists");
assert(test5, "Test5 exists");
assert(test6, "Test6 exists");
assert(test7, "Test7 exists");
t.expect(test1(), "test 1 returns correct value").toEqual(13);
t.expect(test2(), "test 2 returns correct value").toEqual(1);
t.expect(test3(), "test 3 returns correct value").toEqual(2);
t.expect(test4(), "test 4 returns correct value").toEqual(52);
t.expect(test5(), "test 5 returns correct value").toEqual(21);
t.expect(test6(), "test 6 returns correct value").toEqual(-1);
t.expect(test5(), "test 5 returns correct value").toEqual(52);
t.expect(test6(), "test 6 returns correct value").toEqual(21);
t.expect(test7(), "test 7 returns correct value").toEqual(-1);
});

test("Compiler can do tco", async (t) => {
const spy = vi.spyOn(rCallUtil, "returnCall");
await compile(tcoText);
t.expect(spy).toHaveBeenCalledTimes(1);
});

test("Generic fn compilation", async (t) => {
const mod = await compile(genericsText);
const instance = getWasmInstance(mod);
const main = getWasmFn("main", instance);
assert(main, "Main exists");
t.expect(main(), "main 1 returns correct value").toEqual(143);
});
});
24 changes: 22 additions & 2 deletions src/__tests__/fixtures/e2e-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,18 @@ pub fn test4()
let vec = Point { x: 52, y: 2, z: 21 }
vec.get_x()

// Test match type guard (Point case), should return 21
// Test match type guard (Pointy case), should return 52
pub fn test5()
let vec = Pointy { x: 52, y: 2, z: 21 }
get_num_from_vec_sub_obj(vec)

// Test match type guard (Point case), should return 21
pub fn test6()
let vec = Point { x: 52, y: 2, z: 21 }
get_num_from_vec_sub_obj(vec)

// Test match type guard (else case), should return -1
pub fn test6()
pub fn test7()
let vec = Bitly { x: 52, y: 2, z: 21 }
get_num_from_vec_sub_obj(vec)
`;
Expand Down Expand Up @@ -114,3 +119,18 @@ fn fib_alias(n: i32, a: i64, b: i64) -> i64
pub fn main() -> i64
fib(10, 0i64, 1i64)
`;

export const genericsText = `
use std::all

type DSArrayI32 = DSArray<i32>

pub fn main()
let arr2 = ds_array_init<f64>(10)
arr2.set<f64>(0, 1.5)
arr2.get<f64>(0)

let arr: DSArrayI32 = ds_array_init<i32>(10)
arr.set<i32>(9, 143)
arr.get<i32>(9)
`;
70 changes: 54 additions & 16 deletions src/assembler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import { Expr } from "./syntax-objects/expr.js";
import { Fn } from "./syntax-objects/fn.js";
import { Identifier } from "./syntax-objects/identifier.js";
import { Int } from "./syntax-objects/int.js";
import { Type, Primitive, ObjectType } from "./syntax-objects/types.js";
import {
Type,
Primitive,
ObjectType,
DSArrayType,
} from "./syntax-objects/types.js";
import { Variable } from "./syntax-objects/variable.js";
import { Block } from "./syntax-objects/block.js";
import { Declaration } from "./syntax-objects/declaration.js";
Expand All @@ -17,7 +22,8 @@ import {
refCast,
structGetFieldValue,
} from "./lib/binaryen-gc/index.js";
import { HeapTypeRef } from "./lib/binaryen-gc/types.js";
import * as gc from "./lib/binaryen-gc/index.js";
import { TypeRef } from "./lib/binaryen-gc/types.js";
import { getExprType } from "./semantics/resolution/get-expr-type.js";
import { Match, MatchCase } from "./syntax-objects/match.js";
import { initExtensionHelpers } from "./assembler/extension-helpers.js";
Expand Down Expand Up @@ -264,14 +270,30 @@ const compileAssign = (opts: CompileExprOpts<Call>): number => {
const compileBnrCall = (opts: CompileExprOpts<Call>): number => {
const { expr } = opts;
const funcId = expr.labeledArgAt(0) as Identifier;
const argTypes = expr.labeledArgAt(1) as Call;
const namespace = argTypes.identifierArgAt(0).value;
const args = expr.labeledArgAt(3) as Call;
const func = (opts.mod as any)[namespace][funcId.value];
const namespace = (expr.labeledArgAt(1) as Identifier).value;
const args = expr.labeledArgAt(2) as Call;

const func =
namespace === "gc"
? (...args: unknown[]) => (gc as any)[funcId.value](opts.mod, ...args)
: (opts.mod as any)[namespace][funcId.value];

return func(
...(args.argArrayMap((expr: Expr) =>
compileExpression({ ...opts, expr })
) ?? [])
...(args.argArrayMap((expr: Expr) => {
if (expr?.isCall() && expr.calls("BnrType")) {
const type = getExprType(expr.typeArgs?.at(0));
if (!type) return opts.mod.nop();
return mapBinaryenType(opts, type);
}

if (expr?.isCall() && expr.calls("BnrConst")) {
const arg = expr.argAt(0);
if (!arg) return opts.mod.nop();
if ("value" in arg) return arg.value;
}

return compileExpression({ ...opts, expr });
}) ?? [])
);
};

Expand All @@ -287,6 +309,17 @@ const compileVariable = (opts: CompileExprOpts<Variable>): number => {

const compileFunction = (opts: CompileExprOpts<Fn>): number => {
const { expr: fn, mod } = opts;
if (fn.genericInstances) {
fn.genericInstances.forEach((instance) =>
compileFunction({ ...opts, expr: instance })
);
return mod.nop();
}

if (fn.typeParameters) {
return mod.nop();
}

const parameterTypes = getFunctionParameterTypes(opts, fn);
const returnType = mapBinaryenType(opts, fn.getReturnType());

Expand Down Expand Up @@ -378,20 +411,25 @@ const mapBinaryenType = (opts: CompileExprOpts, type: Type): binaryen.Type => {
if (isPrimitiveId(type, "i64")) return binaryen.i64;
if (isPrimitiveId(type, "f64")) return binaryen.f64;
if (isPrimitiveId(type, "void")) return binaryen.none;
if (type.isObjectType()) {
return type.binaryenType ? type.binaryenType : buildObjectType(opts, type);
}
if (type.isObjectType()) return buildObjectType(opts, type);
if (type.isDSArrayType()) return buildDSArrayType(opts, type);
throw new Error(`Unsupported type ${type}`);
};

const isPrimitiveId = (type: Type, id: Primitive) =>
type.isPrimitiveType() && type.name.value === id;

const buildDSArrayType = (opts: CompileExprOpts, type: DSArrayType) => {
if (type.binaryenType) return type.binaryenType;
const mod = opts.mod;
const elemType = mapBinaryenType(opts, type.elemType!);
type.binaryenType = gc.defineArrayType(mod, elemType, true, type.id);
return type.binaryenType;
};

/** TODO: Skip building types for object literals that are part of an initializer of an obj */
const buildObjectType = (
opts: CompileExprOpts,
obj: ObjectType
): HeapTypeRef => {
const buildObjectType = (opts: CompileExprOpts, obj: ObjectType): TypeRef => {
if (obj.binaryenType) return obj.binaryenType;
const mod = opts.mod;

const binaryenType = defineStructType(mod, {
Expand Down
8 changes: 8 additions & 0 deletions src/lib/binaryen-gc/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ export const binaryenTypeToHeapType = (type: Type): HeapTypeRef => {
return bin._BinaryenTypeGetHeapType(type);
};

// So we can use the from compileBnrCall
export const modBinaryenTypeToHeapType = (
_mod: binaryen.Module,
type: Type
): HeapTypeRef => {
return bin._BinaryenTypeGetHeapType(type);
};

export const refCast = (
mod: binaryen.Module,
ref: ExpressionRef,
Expand Down
7 changes: 5 additions & 2 deletions src/lib/grammar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { Expr } from "../syntax-objects/expr.js";
import { Identifier } from "../syntax-objects/identifier.js";

export const isTerminator = (char: string) =>
isWhitespace(char) || isBracket(char) || isQuote(char) || isOpChar(char);
isWhitespace(char) ||
isBracket(char) ||
isQuote(char) ||
isOpChar(char) ||
char === ",";

export const isQuote = newTest(["'", '"', "`"]);

Expand All @@ -19,7 +23,6 @@ export const isOpChar = newTest([
":",
"?",
".",
",",
";",
"<",
">",
Expand Down
6 changes: 6 additions & 0 deletions src/parser/lexer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ export const lexer = (chars: CharStream): Token => {
break;
}

if (!token.hasChars && char === ",") {
token.addChar(chars.consumeChar());
break;
}

if (!token.hasChars && isOpChar(char)) {
consumeOperator(chars, token);
break;
Expand Down Expand Up @@ -62,6 +67,7 @@ export const lexer = (chars: CharStream): Token => {

const consumeOperator = (chars: CharStream, token: Token) => {
while (isOpChar(chars.next)) {
if (token.value === ">" && chars.next === ">") break; // Ugly hack to support generics, means >> is not a valid operator
token.addChar(chars.consumeChar());
}
};
Expand Down
13 changes: 12 additions & 1 deletion src/semantics/check-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ const checkCallTypes = (call: Call): Call | ObjectLiteral => {
if (call.calls("=")) return checkAssign(call);
if (call.calls("member-access")) return call; // TODO
if (call.fn?.isObjectType()) return checkObjectInit(call);
call.args = call.args.map(checkTypes);

if (!call.fn) {
throw new Error(`Could not resolve fn ${call.fnName} at ${call.location}`);
Expand All @@ -63,6 +62,8 @@ const checkCallTypes = (call: Call): Call | ObjectLiteral => {
);
}

call.args = call.args.map(checkTypes);

return call;
};

Expand Down Expand Up @@ -188,6 +189,16 @@ const checkUse = (use: Use) => {
};

const checkFnTypes = (fn: Fn): Fn => {
if (fn.genericInstances) {
fn.genericInstances.forEach(checkFnTypes);
return fn;
}

// If the function has type parameters and not genericInstances, it isn't in use and wont be compiled.
if (fn.typeParameters) {
return fn;
}

checkParameters(fn.parameters);
checkTypes(fn.body);

Expand Down
Loading