diff --git a/src/__tests__/compiler.test.ts b/src/__tests__/compiler.test.ts index b0f1e78..83db1f8 100644 --- a/src/__tests__/compiler.test.ts +++ b/src/__tests__/compiler.test.ts @@ -56,6 +56,7 @@ describe("E2E Compiler Pipeline", () => { 82, 3, 42, + 2, // IntersectionType tests ]); }); diff --git a/src/__tests__/fixtures/e2e-file.ts b/src/__tests__/fixtures/e2e-file.ts index b3fb2b8..4b7b978 100644 --- a/src/__tests__/fixtures/e2e-file.ts +++ b/src/__tests__/fixtures/e2e-file.ts @@ -209,6 +209,18 @@ pub fn test17() match(x) Some: x.value None: -1 + +obj Animal { age: i32 } +obj Insect extends Animal { age: i32, legs: i32 } +obj Mammal extends Animal { age: i32, legs: i32 } + +fn get_legs(a: Animal & { legs: i32 }) -> i32 + a.legs + +// Test intersection types +pub fn test18() -> i32 + let human = Mammal { age: 10, legs: 2 } + get_legs(human) `; export const tcoText = ` diff --git a/src/assembler.ts b/src/assembler.ts index 98494a3..1776408 100644 --- a/src/assembler.ts +++ b/src/assembler.ts @@ -11,6 +11,7 @@ import { DsArrayType, voidBaseObject, UnionType, + IntersectionType, } from "./syntax-objects/types.js"; import { Variable } from "./syntax-objects/variable.js"; import { Block } from "./syntax-objects/block.js"; @@ -110,6 +111,16 @@ const compileType = (opts: CompileExprOpts) => { return opts.mod.nop(); } + if (type.isUnionType()) { + buildUnionType(opts, type); + return opts.mod.nop(); + } + + if (type.isIntersectionType()) { + buildIntersectionType(opts, type); + return opts.mod.nop(); + } + return opts.mod.nop(); }; @@ -462,6 +473,7 @@ export const mapBinaryenType = ( if (type.isObjectType()) return buildObjectType(opts, type); if (type.isUnionType()) return buildUnionType(opts, type); if (type.isDsArrayType()) return buildDsArrayType(opts, type); + if (type.isIntersectionType()) return buildIntersectionType(opts, type); throw new Error(`Unsupported type ${type}`); }; @@ -488,6 +500,20 @@ const buildUnionType = (opts: MapBinTypeOpts, union: UnionType): TypeRef => { return typeRef; }; +const buildIntersectionType = ( + opts: MapBinTypeOpts, + inter: IntersectionType +): TypeRef => { + if (inter.hasAttribute("binaryenType")) { + return inter.getAttribute("binaryenType") as TypeRef; + } + + const typeRef = mapBinaryenType(opts, inter.nominalType!); + mapBinaryenType(opts, inter.structuralType!); + inter.setAttribute("binaryenType", typeRef); + return typeRef; +}; + // Marks the start of the fields in an object after RTT info fields const OBJECT_FIELDS_OFFSET = 2; @@ -565,9 +591,9 @@ const compileObjMemberAccess = (opts: CompileExprOpts) => { const obj = expr.exprArgAt(0); const member = expr.identifierArgAt(1); const objValue = compileExpression({ ...opts, expr: obj }); - const type = getExprType(obj) as ObjectType; + const type = getExprType(obj) as ObjectType | IntersectionType; - if (type.getAttribute("isStructural")) { + if (type.getAttribute("isStructural") || type.isIntersectionType()) { return opts.fieldLookupHelpers.getFieldValueByAccessor(opts); } diff --git a/src/assembler/field-lookup-helpers.ts b/src/assembler/field-lookup-helpers.ts index 32f3fe9..17deb90 100644 --- a/src/assembler/field-lookup-helpers.ts +++ b/src/assembler/field-lookup-helpers.ts @@ -13,7 +13,11 @@ import { callRef, refCast, } from "../lib/binaryen-gc/index.js"; -import { ObjectType, voidBaseObject } from "../syntax-objects/types.js"; +import { + IntersectionType, + ObjectType, + voidBaseObject, +} from "../syntax-objects/types.js"; import { murmurHash3 } from "../lib/murmur-hash.js"; import { compileExpression, @@ -145,9 +149,19 @@ export const initFieldLookupHelpers = (mod: binaryen.Module) => { const { expr, mod } = opts; const obj = expr.exprArgAt(0); const member = expr.identifierArgAt(1); - const objType = getExprType(obj) as ObjectType; + const objType = getExprType(obj) as ObjectType | IntersectionType; + + const field = objType.isIntersectionType() + ? objType.nominalType?.getField(member) ?? + objType.structuralType?.getField(member) + : objType.getField(member); + + if (!field) { + throw new Error( + `Field ${member.value} not found on object ${objType.id}` + ); + } - const field = objType.getField(member)!; const lookupTable = structGetFieldValue({ mod, fieldType: lookupTableType, diff --git a/src/semantics/check-types.ts b/src/semantics/check-types.ts index cceab69..23afaf9 100644 --- a/src/semantics/check-types.ts +++ b/src/semantics/check-types.ts @@ -17,10 +17,11 @@ import { TypeAlias, ObjectLiteral, UnionType, + IntersectionType, } from "../syntax-objects/index.js"; import { Match } from "../syntax-objects/match.js"; import { getExprType } from "./resolution/get-expr-type.js"; -import { typesAreEquivalent } from "./resolution/index.js"; +import { typesAreCompatible } from "./resolution/index.js"; export const checkTypes = (expr: Expr | undefined): Expr => { if (!expr) return nop(); @@ -37,6 +38,7 @@ export const checkTypes = (expr: Expr | undefined): Expr => { if (expr.isObjectLiteral()) return checkObjectLiteralType(expr); if (expr.isUnionType()) return checkUnionType(expr); if (expr.isMatch()) return checkMatch(expr); + if (expr.isIntersectionType()) return checkIntersectionType(expr); return expr; }; @@ -85,7 +87,7 @@ const checkObjectInit = (call: Call): Call => { checkTypes(literal); // Check to ensure literal structure is compatible with nominal structure - if (!typesAreEquivalent(literal.type, call.type, { structuralOnly: true })) { + if (!typesAreCompatible(literal.type, call.type, { structuralOnly: true })) { throw new Error(`Object literal type does not match expected type`); } @@ -109,7 +111,7 @@ export const checkAssign = (call: Call) => { const initType = getExprType(call.argAt(1)); - if (!typesAreEquivalent(variable.type, initType)) { + if (!typesAreCompatible(variable.type, initType)) { throw new Error(`${id} cannot be assigned to ${initType}`); } @@ -134,7 +136,7 @@ const checkIdentifier = (id: Identifier) => { export const checkIf = (call: Call) => { const cond = checkTypes(call.argAt(0)); const condType = getExprType(cond); - if (!condType || !typesAreEquivalent(condType, bool)) { + if (!condType || !typesAreCompatible(condType, bool)) { throw new Error( `If conditions must resolve to a boolean at ${cond.location}` ); @@ -153,7 +155,7 @@ export const checkIf = (call: Call) => { const elseType = getExprType(elseExpr); // Until unions are supported, throw an error when types don't match - if (!typesAreEquivalent(thenType, elseType)) { + if (!typesAreCompatible(thenType, elseType)) { throw new Error("If condition clauses do not return same type"); } @@ -214,7 +216,7 @@ const checkFnTypes = (fn: Fn): Fn => { if ( inferredReturnType && - !typesAreEquivalent(inferredReturnType, fn.returnType) + !typesAreCompatible(inferredReturnType, fn.returnType) ) { throw new Error( `Fn, ${fn.name}, return value type (${inferredReturnType?.name}) is not compatible with annotated return type (${fn.returnType?.name}) at ${fn.location}` @@ -269,7 +271,7 @@ const checkVarTypes = (variable: Variable): Variable => { if ( variable.annotatedType && - !typesAreEquivalent(variable.inferredType, variable.annotatedType) + !typesAreCompatible(variable.inferredType, variable.annotatedType) ) { throw new Error( `${variable.name} of type ${variable.type} is not assignable to ${variable.inferredType}` @@ -316,7 +318,7 @@ export function assertValidExtension( const validExtension = parent.fields.every((field) => { const match = child.fields.find((f) => f.name === field.name); - return match && typesAreEquivalent(field.type, match.type); + return match && typesAreCompatible(field.type, match.type); }); if (!validExtension) { @@ -375,6 +377,23 @@ const checkMatch = (match: Match) => { return checkObjectMatch(match); }; +const checkIntersectionType = (inter: IntersectionType) => { + checkTypeExpr(inter.nominalTypeExpr.value); + checkTypeExpr(inter.structuralTypeExpr.value); + + if (!inter.nominalType || !inter.structuralType) { + throw new Error(`Unable to resolve intersection type ${inter.location}`); + } + + if (!inter.structuralType.getAttribute("isStructural")) { + throw new Error( + `Structural type must be a structural type ${inter.structuralTypeExpr.value.location}` + ); + } + + return inter; +}; + const checkUnionMatch = (match: Match) => { const union = match.baseType as UnionType; @@ -391,7 +410,7 @@ const checkUnionMatch = (match: Match) => { ); } - if (!typesAreEquivalent(mCase.expr.type, match.type)) { + if (!typesAreCompatible(mCase.expr.type, match.type)) { throw new Error( `All cases must return the same type for now ${mCase.expr.location}` ); @@ -400,7 +419,7 @@ const checkUnionMatch = (match: Match) => { union.types.forEach((type) => { if ( - !match.cases.some((mCase) => typesAreEquivalent(mCase.matchType, type)) + !match.cases.some((mCase) => typesAreCompatible(mCase.matchType, type)) ) { throw new Error( `Match does not handle all possibilities of union ${match.location}` @@ -436,7 +455,7 @@ const checkObjectMatch = (match: Match) => { ); } - if (!typesAreEquivalent(mCase.expr.type, match.type)) { + if (!typesAreCompatible(mCase.expr.type, match.type)) { throw new Error( `All cases must return the same type for now ${mCase.expr.location}` ); diff --git a/src/semantics/init-entities.ts b/src/semantics/init-entities.ts index 247bf2d..cb13d92 100644 --- a/src/semantics/init-entities.ts +++ b/src/semantics/init-entities.ts @@ -14,6 +14,7 @@ import { DsArrayType, nop, UnionType, + IntersectionType, } from "../syntax-objects/index.js"; import { Match, MatchCase } from "../syntax-objects/match.js"; import { SemanticProcessor } from "./types.js"; @@ -67,6 +68,10 @@ export const initEntities: SemanticProcessor = (expr) => { return initPipedUnionType(expr); } + if (expr.calls("&")) { + return initIntersection(expr); + } + return initCall(expr); }; @@ -199,6 +204,22 @@ const initPipedUnionType = (union: List) => { }); }; +const initIntersection = (intersection: List): IntersectionType => { + const nominalObjectExpr = initTypeExprEntities(intersection.at(1)); + const structuralObjectExpr = initTypeExprEntities(intersection.at(2)); + + if (!nominalObjectExpr || !structuralObjectExpr) { + throw new Error("Invalid intersection type"); + } + + return new IntersectionType({ + ...intersection.metadata, + name: intersection.syntaxId.toString(), + nominalObjectExpr, + structuralObjectExpr, + }); +}; + const initVar = (varDef: List): Variable => { const isMutable = varDef.calls("define_mut"); const identifierExpr = varDef.at(1); @@ -332,6 +353,10 @@ const initTypeExprEntities = (type?: Expr): Expr | undefined => { return initPipedUnionType(type); } + if (type.calls("&")) { + return initIntersection(type); + } + return initCall(type); }; diff --git a/src/semantics/resolution/get-call-fn.ts b/src/semantics/resolution/get-call-fn.ts index 92b3f21..33438f4 100644 --- a/src/semantics/resolution/get-call-fn.ts +++ b/src/semantics/resolution/get-call-fn.ts @@ -1,6 +1,6 @@ import { Call, Expr, Fn } from "../../syntax-objects/index.js"; import { getExprType } from "./get-expr-type.js"; -import { typesAreEquivalent } from "./types-are-equivalent.js"; +import { typesAreCompatible } from "./types-are-compatible.js"; import { resolveFnTypes } from "./resolve-fn-type.js"; export const getCallFn = (call: Call): Fn | undefined => { @@ -80,7 +80,7 @@ const typeArgsMatch = (call: Call, candidate: Fn): boolean => ? candidate.appliedTypeArgs.every((t, i) => { const argType = getExprType(call.typeArgs?.at(i)); const appliedType = getExprType(t); - return typesAreEquivalent(argType, appliedType, { + return typesAreCompatible(argType, appliedType, { exactNominalMatch: true, }); }) @@ -94,7 +94,7 @@ const parametersMatch = (candidate: Fn, call: Call) => if (!argType) return false; const argLabel = getExprLabel(arg); const labelsMatch = p.label === argLabel; - return typesAreEquivalent(argType, p.type!) && labelsMatch; + return typesAreCompatible(argType, p.type!) && labelsMatch; }); const getExprLabel = (expr?: Expr): string | undefined => { diff --git a/src/semantics/resolution/index.ts b/src/semantics/resolution/index.ts index e05e5d5..07faf44 100644 --- a/src/semantics/resolution/index.ts +++ b/src/semantics/resolution/index.ts @@ -1,3 +1,3 @@ export { resolveTypes } from "./resolve-types.js"; -export { typesAreEquivalent } from "./types-are-equivalent.js"; +export { typesAreCompatible } from "./types-are-compatible.js"; export { resolveModulePath } from "./resolve-use.js"; diff --git a/src/semantics/resolution/resolve-call-types.ts b/src/semantics/resolution/resolve-call-types.ts index c0edd4a..92e2f1c 100644 --- a/src/semantics/resolution/resolve-call-types.ts +++ b/src/semantics/resolution/resolve-call-types.ts @@ -79,16 +79,35 @@ const getMemberAccessCall = (call: Call): Call | undefined => { const a1 = call.argAt(0); if (!a1) return; const a1Type = getExprType(a1); - if (!a1Type || !a1Type.isObjectType() || !a1Type.hasField(call.fnName)) { - return; + + if (a1Type && a1Type.isObjectType() && a1Type.hasField(call.fnName)) { + return new Call({ + ...call.metadata, + fnName: Identifier.from("member-access"), + args: new List({ value: [a1, call.fnName] }), + type: a1Type.getField(call.fnName)?.type, + }); + } + + if ( + a1Type && + a1Type.isIntersectionType() && + (a1Type.nominalType?.hasField(call.fnName) || + a1Type.structuralType?.hasField(call.fnName)) + ) { + const field = + a1Type.nominalType?.getField(call.fnName) ?? + a1Type.structuralType?.getField(call.fnName); + + return new Call({ + ...call.metadata, + fnName: Identifier.from("member-access"), + args: new List({ value: [a1, call.fnName] }), + type: field?.type, + }); } - return new Call({ - ...call.metadata, - fnName: Identifier.from("member-access"), - args: new List({ value: [a1, call.fnName] }), - type: a1Type.getField(call.fnName)?.type, - }); + return undefined; }; export const resolveIf = (call: Call) => { diff --git a/src/semantics/resolution/resolve-intersection.ts b/src/semantics/resolution/resolve-intersection.ts new file mode 100644 index 0000000..2f43181 --- /dev/null +++ b/src/semantics/resolution/resolve-intersection.ts @@ -0,0 +1,21 @@ +import { IntersectionType } from "../../syntax-objects/types.js"; +import { getExprType } from "./get-expr-type.js"; +import { resolveTypes } from "./resolve-types.js"; + +export const resolveIntersection = ( + inter: IntersectionType +): IntersectionType => { + inter.nominalTypeExpr.value = resolveTypes(inter.nominalTypeExpr.value); + inter.structuralTypeExpr.value = resolveTypes(inter.structuralTypeExpr.value); + + const nominalType = getExprType(inter.nominalTypeExpr.value); + const structuralType = getExprType(inter.structuralTypeExpr.value); + + // TODO Error if not correct type + inter.nominalType = nominalType?.isObjectType() ? nominalType : undefined; + inter.structuralType = structuralType?.isObjectType() + ? structuralType + : undefined; + + return inter; +}; diff --git a/src/semantics/resolution/resolve-match.ts b/src/semantics/resolution/resolve-match.ts index e5d0190..b4ec9ac 100644 --- a/src/semantics/resolution/resolve-match.ts +++ b/src/semantics/resolution/resolve-match.ts @@ -1,6 +1,7 @@ import { Block } from "../../syntax-objects/block.js"; import { Call, + IntersectionType, ObjectType, Parameter, Type, @@ -79,13 +80,13 @@ const resolveMatchReturnType = (match: Match): Type | undefined => { return firstType; } - let type: ObjectType | UnionType = firstType; + let type: ObjectType | IntersectionType | UnionType = firstType; for (const mCase of cases.slice(1)) { if (mCase.id === type.id) { continue; } - if (type.isObjectType() && mCase.isObjectType()) { + if (isObjectOrIntersection(mCase) && isObjectOrIntersection(type)) { const union = new UnionType({ name: `Union#match#(${match.syntaxId}`, }); @@ -94,7 +95,7 @@ const resolveMatchReturnType = (match: Match): Type | undefined => { continue; } - if (mCase.isObjectType() && type.isUnionType()) { + if (isObjectOrIntersection(mCase) && type.isUnionType()) { type.types.push(mCase); continue; } @@ -104,3 +105,9 @@ const resolveMatchReturnType = (match: Match): Type | undefined => { return type; }; + +const isObjectOrIntersection = ( + type: Type +): type is ObjectType | IntersectionType => { + return type.isObjectType() || type.isIntersectionType(); +}; diff --git a/src/semantics/resolution/resolve-object-type.ts b/src/semantics/resolution/resolve-object-type.ts index a90faea..1b86a84 100644 --- a/src/semantics/resolution/resolve-object-type.ts +++ b/src/semantics/resolution/resolve-object-type.ts @@ -9,7 +9,7 @@ import { import { getExprType } from "./get-expr-type.js"; import { implIsCompatible, resolveImpl } from "./resolve-impl.js"; import { resolveTypes } from "./resolve-types.js"; -import { typesAreEquivalent } from "./types-are-equivalent.js"; +import { typesAreCompatible } from "./types-are-compatible.js"; export const resolveObjectTypeTypes = ( obj: ObjectType, @@ -95,7 +95,7 @@ const typeArgsMatch = (call: Call, candidate: ObjectType): boolean => ? candidate.appliedTypeArgs.every((t, i) => { const argType = getExprType(call.typeArgs?.at(i)); const appliedType = getExprType(t); - return typesAreEquivalent(argType, appliedType, { + return typesAreCompatible(argType, appliedType, { exactNominalMatch: true, }); }) diff --git a/src/semantics/resolution/resolve-types.ts b/src/semantics/resolution/resolve-types.ts index cb31b1e..5d6c0ef 100644 --- a/src/semantics/resolution/resolve-types.ts +++ b/src/semantics/resolution/resolve-types.ts @@ -15,6 +15,7 @@ import { getExprType } from "./get-expr-type.js"; import { resolveCallTypes } from "./resolve-call-types.js"; import { resolveFnTypes } from "./resolve-fn-type.js"; import { resolveImpl } from "./resolve-impl.js"; +import { resolveIntersection } from "./resolve-intersection.js"; import { resolveMatch } from "./resolve-match.js"; import { resolveObjectTypeTypes } from "./resolve-object-type.js"; import { resolveUnion } from "./resolve-union.js"; @@ -43,6 +44,7 @@ export const resolveTypes = (expr: Expr | undefined): Expr => { if (expr.isMatch()) return resolveMatch(expr); if (expr.isImpl()) return resolveImpl(expr); if (expr.isUnionType()) return resolveUnion(expr); + if (expr.isIntersectionType()) return resolveIntersection(expr); return expr; }; diff --git a/src/semantics/resolution/types-are-equivalent.ts b/src/semantics/resolution/types-are-compatible.ts similarity index 70% rename from src/semantics/resolution/types-are-equivalent.ts rename to src/semantics/resolution/types-are-compatible.ts index 8e0bf42..5d91b2c 100644 --- a/src/semantics/resolution/types-are-equivalent.ts +++ b/src/semantics/resolution/types-are-compatible.ts @@ -1,6 +1,6 @@ import { Type } from "../../syntax-objects/index.js"; -export const typesAreEquivalent = ( +export const typesAreCompatible = ( /** A is the argument type, the type of the value being passed as b */ a?: Type, /** B is the parameter type, what a should be equivalent to */ @@ -26,7 +26,7 @@ export const typesAreEquivalent = ( if (structural) { return b.fields.every((field) => { const match = a.fields.find((f) => f.name === field.name); - return match && typesAreEquivalent(field.type, match.type); + return match && typesAreCompatible(field.type, match.type); }); } @@ -34,17 +34,22 @@ export const typesAreEquivalent = ( } if (a.isObjectType() && b.isUnionType()) { - return b.types.some((type) => typesAreEquivalent(a, type, opts)); + return b.types.some((type) => typesAreCompatible(a, type, opts)); } if (a.isUnionType() && b.isUnionType()) { return a.types.every((aType) => - b.types.some((bType) => typesAreEquivalent(aType, bType, opts)) + b.types.some((bType) => typesAreCompatible(aType, bType, opts)) ); } + if (a.isObjectType() && b.isIntersectionType()) { + if (!b.nominalType || !b.structuralType) return false; + return a.extends(b.nominalType) && typesAreCompatible(a, b.structuralType); + } + if (a.isDsArrayType() && b.isDsArrayType()) { - return typesAreEquivalent(a.elemType, b.elemType); + return typesAreCompatible(a.elemType, b.elemType); } return false; diff --git a/src/syntax-objects/syntax.ts b/src/syntax-objects/syntax.ts index 9ffcee2..e12ddc1 100644 --- a/src/syntax-objects/syntax.ts +++ b/src/syntax-objects/syntax.ts @@ -24,6 +24,7 @@ import type { TypeAlias, DsArrayType, UnionType, + IntersectionType, } from "./types.js"; import type { Variable } from "./variable.js"; import type { Whitespace } from "./whitespace.js"; @@ -203,6 +204,10 @@ export abstract class Syntax { return this.isType() && this.kindOfType === "union"; } + isIntersectionType(): this is IntersectionType { + return this.isType() && this.kindOfType === "intersection"; + } + isDsArrayType(): this is DsArrayType { return this.isType() && this.kindOfType === "ds-array"; } diff --git a/src/syntax-objects/types.ts b/src/syntax-objects/types.ts index 58a512b..b17834b 100644 --- a/src/syntax-objects/types.ts +++ b/src/syntax-objects/types.ts @@ -7,6 +7,7 @@ import { LexicalContext } from "./lib/lexical-context.js"; import { Implementation } from "./implementation.js"; import { ScopedEntity } from "./scoped-entity.js"; import { ChildList } from "./lib/child-list.js"; +import { Child } from "./lib/child.js"; export type Type = | PrimitiveType @@ -88,7 +89,7 @@ export class PrimitiveType extends BaseType { export class UnionType extends BaseType { readonly kindOfType = "union"; childTypeExprs: ChildList; - types: ObjectType[] = []; + types: (ObjectType | IntersectionType)[] = []; constructor(opts: NamedEntityOpts & { childTypeExprs?: Expr[] }) { super(opts); @@ -109,22 +110,39 @@ export class UnionType extends BaseType { export class IntersectionType extends BaseType { readonly kindOfType = "intersection"; - value: Type[]; + nominalTypeExpr: Child; + structuralTypeExpr: Child; + nominalType?: ObjectType; + structuralType?: ObjectType; - constructor(opts: NamedEntityOpts & { value: Type[] }) { + constructor( + opts: NamedEntityOpts & { + nominalObjectExpr: Expr; + structuralObjectExpr: Expr; + } + ) { super(opts); - this.value = opts.value; + this.nominalTypeExpr = new Child(opts.nominalObjectExpr, this); + this.structuralTypeExpr = new Child(opts.structuralObjectExpr, this); } clone(parent?: Expr): IntersectionType { return new IntersectionType({ ...super.getCloneOpts(parent), - value: this.value, + nominalObjectExpr: this.nominalTypeExpr.clone(), + structuralObjectExpr: this.structuralTypeExpr.clone(), }); } toJSON(): TypeJSON { - return ["type", ["intersection", ...this.value]]; + return [ + "type", + [ + "intersection", + this.nominalTypeExpr.value, + this.structuralTypeExpr.value, + ], + ]; } }