Skip to content

Commit edbc565

Browse files
Enforce rules on functions returning modules
1 parent 064ba0e commit edbc565

File tree

14 files changed

+118
-22
lines changed

14 files changed

+118
-22
lines changed

hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class JSBuilder extends CodeBuilder:
187187
} + ")""""
188188
}; }"""
189189
} #} # }"
190-
if clsDefn.kind is syntax.Mod then
190+
if (clsDefn.kind is syntax.Mod) || (clsDefn.kind is syntax.Obj) then
191191
val clsTmp = summon[Scope].allocateName(new semantics.TempSymbol(0/*TODO rm this useless param*/, N, sym.nme+"$"+"class"))
192192
clsDefn.owner match
193193
case S(owner) =>

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -519,19 +519,54 @@ extends Importer:
519519
case ((pss, ctx), ps) =>
520520
val (qs, newCtx) = params(ps)(using ctx)
521521
(pss :+ ParamList(ParamListFlags.empty, qs), newCtx)
522+
// * Elaborate signature
523+
val st = td.signature.orElse(newSignatureTrees.get(id.name))
524+
val s = st.map(term(_)(using newCtx))
522525
val b = rhs.map(term(_)(using newCtx))
523526
val r = FlowSymbol(s"‹result of ${sym}", nextUid)
524-
val tdf = TermDefinition(owner, k, sym, pss,
525-
td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r,
526-
TermDefFlags(isModMember))
527+
val tdf = TermDefinition(owner, k, sym, pss, s, b, r,
528+
TermDefFlags.empty.copy(isModMember = isModMember))
527529
sym.defn = S(tdf)
530+
531+
// the return type of the function
532+
val result = td.head match
533+
case InfixApp(_, Keyword.`:`, rhs) => S(term(rhs)(using newCtx))
534+
case _ => N
535+
536+
// indicates if the function really returns a module
537+
val em = b.fold(false)(ModuleChecker.evalsToModule)
538+
// indicates if the function marks its result as "module"
539+
val mm = st match
540+
case Some(TypeDef(Mod, _, N, N)) => true
541+
case _ => false
542+
543+
// checks rules regarding module methods
544+
s match
545+
case N if em => raise:
546+
ErrorReport:
547+
msg"Function returning module values must have explicit return types." ->
548+
td.head.toLoc :: Nil
549+
case S(t) if em && ModuleChecker.isTypeParam(t) => raise:
550+
ErrorReport:
551+
msg"Function returning module values must have concrete return types." ->
552+
td.head.toLoc :: Nil
553+
case S(_) if em && !mm => raise:
554+
ErrorReport:
555+
msg"The return type of functions returning module values must be prefixed with module keyword." ->
556+
td.head.toLoc :: Nil
557+
case S(_) if mm && !isModMember => raise:
558+
ErrorReport:
559+
msg"Only module methods may return module values." ->
560+
td.head.toLoc :: Nil
561+
case _ => ()
562+
528563
tdf
529564
go(sts, tdf :: acc)
530565
case L(d) =>
531566
raise(d)
532567
go(sts, acc)
533568
case (td @ TypeDef(k, head, extension, body)) :: sts =>
534-
assert((k is Als) || (k is Cls) || (k is Mod), k)
569+
assert((k is Als) || (k is Cls) || (k is Mod) || (k is Obj), k)
535570
val nme = td.name match
536571
case R(id) => id
537572
case L(d) =>
@@ -583,7 +618,7 @@ extends Importer:
583618
semantics.TypeDef(alsSym, tps, extension.map(term), N)
584619
alsSym.defn = S(d)
585620
d
586-
case Mod =>
621+
case k: (Mod.type | Obj.type) =>
587622
val clsSym = td.symbol.asInstanceOf[ModuleSymbol] // TODO: improve `asInstanceOf`
588623
val owner = ctx.outer
589624
newCtx.nest(S(clsSym)).givenIn:
@@ -594,7 +629,7 @@ extends Importer:
594629
// case S(t) => block(t :: Nil)
595630
case S(t) => ???
596631
case N => (new Term.Blk(Nil, Term.Lit(UnitLit(true))), ctx)
597-
ModuleDef(owner, clsSym, tps, ps, ObjBody(bod))
632+
ModuleDef(owner, clsSym, tps, ps, k, ObjBody(bod))
598633
clsSym.defn = S(cd)
599634
cd
600635
case Cls =>
@@ -618,7 +653,6 @@ extends Importer:
618653
// TODO: pass abstract to `go`
619654
go(body :: sts, acc)
620655
case Modified(Keyword.`declare`, absLoc, body) :: sts =>
621-
???
622656
// TODO: pass declare to `go`
623657
go(body :: sts, acc)
624658
case (result: Tree) :: Nil =>
@@ -723,6 +757,23 @@ extends Importer:
723757
.filter(_.isInstanceOf[VarSymbol])
724758
.flatMap(_.asInstanceOf[VarSymbol].decl)
725759
.fold(false)(_.isInstanceOf[TyParam])
760+
761+
/** Checks if a term evaluates to a module value. */
762+
def evalsToModule(t: Term): Bool =
763+
def isModule(t: Tree): Bool = t match
764+
case TypeDef(Mod, _, _, _) => true
765+
case _ => false
766+
def returnsModule(t: TermDef): Bool = t.signature match
767+
case S(TypeDef(Mod, _, N, N)) => true
768+
case _ => false
769+
t match
770+
case Term.Blk(_, res) => evalsToModule(res)
771+
case Term.App(lhs, rhs) => lhs.symbol match
772+
case S(sym: BlockMemberSymbol) => sym.trmTree.fold(false)(returnsModule)
773+
case _ => false
774+
case t => t.symbol match
775+
case S(sym: BlockMemberSymbol) => sym.modTree.fold(false)(isModule)
776+
case _ => false
726777

727778
class VarianceTraverser(var changed: Bool = true) extends Traverser:
728779
override def traverseType(pol: Pol)(trm: Term): Unit = trm match

hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class BlockMemberSymbol(val nme: Str, val trees: Ls[Tree]) extends MemberSymbol[
9696
def clsTree: Opt[Tree.TypeDef] = trees.collectFirst:
9797
case t: Tree.TypeDef if t.k is Cls => t
9898
def modTree: Opt[Tree.TypeDef] = trees.collectFirst:
99-
case t: Tree.TypeDef if t.k is Mod => t
99+
case t: Tree.TypeDef if (t.k is Mod) || (t.k is Obj) => t
100100
def alsTree: Opt[Tree.TypeDef] = trees.collectFirst:
101101
case t: Tree.TypeDef if t.k is Als => t
102102
def trmTree: Opt[Tree.TermDef] = trees.collectFirst:

hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,14 @@ sealed abstract class ClassLikeDef extends TypeLikeDef:
236236
val body: ObjBody
237237

238238

239-
case class ModuleDef(owner: Opt[InnerSymbol], sym: ModuleSymbol, tparams: Ls[TyParam], paramsOpt: Opt[Ls[Param]], body: ObjBody) extends ClassLikeDef with Companion:
240-
self =>
241-
val kind: ClsLikeKind = Mod
239+
case class ModuleDef(
240+
owner: Opt[InnerSymbol],
241+
sym: ModuleSymbol,
242+
tparams: Ls[TyParam],
243+
paramsOpt: Opt[Ls[Param]],
244+
kind: ClsLikeKind,
245+
body: ObjBody,
246+
) extends ClassLikeDef with Companion
242247

243248

244249
sealed abstract class ClassDef extends ClassLikeDef:

hkmc2/shared/src/main/scala/hkmc2/syntax/Keyword.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ object Keyword:
9494
val `new` = Keyword("new", N, curPrec) // TODO: check the prec
9595
// val `namespace` = Keyword("namespace", N, N)
9696
val `module` = Keyword("module", N, curPrec)
97+
val `object` = Keyword("object", N, curPrec)
9798
val `open` = Keyword("open", N, curPrec)
9899
val `type` = Keyword("type", N, N)
99100
val `where` = Keyword("where", N, N)

hkmc2/shared/src/main/scala/hkmc2/syntax/ParseRule.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ object ParseRule:
271271
Kw(`class`)(typeDeclBody(Cls)),
272272
Kw(`trait`)(typeDeclBody(Trt)),
273273
Kw(`module`)(typeDeclBody(Mod)),
274+
Kw(`object`)(typeDeclBody(Obj)),
274275
Kw(`open`):
275276
ParseRule("'open' keyword")(
276277
exprOrBlk(ParseRule("'open' declaration")(End(()))){

hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ case object Trt extends TypeDefKind("trait") with ObjDefKind
203203
case object Mxn extends TypeDefKind("mixin")
204204
case object Als extends TypeDefKind("type alias")
205205
case object Mod extends TypeDefKind("module") with ClsLikeKind
206+
case object Obj extends TypeDefKind("object") with ClsLikeKind
206207

207208

208209

@@ -275,7 +276,7 @@ trait TypeDefImpl extends TypeOrTermDef:
275276

276277
lazy val symbol = k match
277278
case Cls => semantics.ClassSymbol(this, name.getOrElse(Ident("<error>")))
278-
case Mod => semantics.ModuleSymbol(this, name.getOrElse(Ident("<error>")))
279+
case Mod | Obj => semantics.ModuleSymbol(this, name.getOrElse(Ident("<error>")))
279280
case Als => semantics.TypeAliasSymbol(name.getOrElse(Ident("<error>")))
280281
case Trt | Mxn => ???
281282

hkmc2/shared/src/test/mlscript-compile/Option.mls

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ open Predef
99
module Option with ...
1010

1111
class Some(value)
12-
module None
12+
object None
1313

1414
fun isDefined(x) = if x is
1515
Some then true

hkmc2/shared/src/test/mlscript/basics/ModuleMethods.mls

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,39 @@ fun f2[T](module m: T)
2020
//│ ║ l.18: fun f2[T](module m: T)
2121
//│ ╙── ^^^^
2222

23+
:e
24+
module N with {
25+
fun f3() = M
26+
}
27+
//│ ╔══[ERROR] Function returning module values must have explicit return types.
28+
//│ ║ l.25: fun f3() = M
29+
//│ ╙── ^^^^
30+
31+
:e
32+
module N with {
33+
fun f4[T](): module T = M
34+
}
35+
//│ ╔══[ERROR] Function returning module values must have explicit return types.
36+
//│ ║ l.33: fun f4[T](): module T = M
37+
//│ ╙── ^^^^^^^^^^^^^^^^^
38+
39+
:e
40+
module N with {
41+
fun f5(): M = M
42+
}
43+
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
44+
//│ ║ l.41: fun f5(): M = M
45+
//│ ╙── ^^^^^^^
46+
47+
:e
48+
fun f7(): module M
49+
//│ ╔══[ERROR] Only module methods may return module values.
50+
//│ ║ l.48: fun f7(): module M
51+
//│ ╙── ^^^^^^^^^^^^^^
52+
2353

2454
fun ok1(module m: M)
55+
56+
module N with {
57+
fun ok2(): module M = M
58+
}

hkmc2/shared/src/test/mlscript/bbml/bbBorrowing.mls

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ letreg of r =>
6161
123
6262
if next(it) > 0 then () => 0 else () => clear(b)
6363
k()
64-
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@102),Lit(BoolLit(true)),Else(Lam(List(),Lit(IntLit(0))))),Else(Lam(List(),App(Sel(Ref(globalThis:block#5),Ident(clear)),Tup(List(Fld(‹›,Ref(b@91),None))))))) (of class hkmc2.semantics.Split$Cons)
64+
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@123),Lit(BoolLit(true)),Else(Lam(List(),Lit(IntLit(0))))),Else(Lam(List(),App(Sel(Ref(globalThis:block#5),Ident(clear)),Tup(List(Fld(‹›,Ref(b@112),None))))))) (of class hkmc2.semantics.Split$Cons)
6565

6666
:e
6767
letreg of r =>

0 commit comments

Comments
 (0)