@@ -519,19 +519,54 @@ extends Importer:
519519 case ((pss, ctx), ps) =>
520520 val (qs, newCtx) = params(ps)(using ctx)
521521 (pss :+ ParamList (ParamListFlags .empty, qs), newCtx)
522+ // * Elaborate signature
523+ val st = td.signature.orElse(newSignatureTrees.get(id.name))
524+ val s = st.map(term(_)(using newCtx))
522525 val b = rhs.map(term(_)(using newCtx))
523526 val r = FlowSymbol (s " ‹result of ${sym}› " , nextUid)
524- val tdf = TermDefinition (owner, k, sym, pss,
525- td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r,
526- TermDefFlags (isModMember))
527+ val tdf = TermDefinition (owner, k, sym, pss, s, b, r,
528+ TermDefFlags .empty.copy(isModMember = isModMember))
527529 sym.defn = S (tdf)
530+
531+ // the return type of the function
532+ val result = td.head match
533+ case InfixApp (_, Keyword .`:`, rhs) => S (term(rhs)(using newCtx))
534+ case _ => N
535+
536+ // indicates if the function really returns a module
537+ val em = b.fold(false )(ModuleChecker .evalsToModule)
538+ // indicates if the function marks its result as "module"
539+ val mm = st match
540+ case Some (TypeDef (Mod , _, N , N )) => true
541+ case _ => false
542+
543+ // checks rules regarding module methods
544+ s match
545+ case N if em => raise :
546+ ErrorReport :
547+ msg " Function returning module values must have explicit return types. " ->
548+ td.head.toLoc :: Nil
549+ case S (t) if em && ModuleChecker .isTypeParam(t) => raise :
550+ ErrorReport :
551+ msg " Function returning module values must have concrete return types. " ->
552+ td.head.toLoc :: Nil
553+ case S (_) if em && ! mm => raise :
554+ ErrorReport :
555+ msg " The return type of functions returning module values must be prefixed with module keyword. " ->
556+ td.head.toLoc :: Nil
557+ case S (_) if mm && ! isModMember => raise :
558+ ErrorReport :
559+ msg " Only module methods may return module values. " ->
560+ td.head.toLoc :: Nil
561+ case _ => ()
562+
528563 tdf
529564 go(sts, tdf :: acc)
530565 case L (d) =>
531566 raise(d)
532567 go(sts, acc)
533568 case (td @ TypeDef (k, head, extension, body)) :: sts =>
534- assert((k is Als ) || (k is Cls ) || (k is Mod ), k)
569+ assert((k is Als ) || (k is Cls ) || (k is Mod ) || (k is Obj ) , k)
535570 val nme = td.name match
536571 case R (id) => id
537572 case L (d) =>
@@ -583,7 +618,7 @@ extends Importer:
583618 semantics.TypeDef (alsSym, tps, extension.map(term), N )
584619 alsSym.defn = S (d)
585620 d
586- case Mod =>
621+ case k : ( Mod . type | Obj . type ) =>
587622 val clsSym = td.symbol.asInstanceOf [ModuleSymbol ] // TODO: improve `asInstanceOf`
588623 val owner = ctx.outer
589624 newCtx.nest(S (clsSym)).givenIn:
@@ -594,7 +629,7 @@ extends Importer:
594629 // case S(t) => block(t :: Nil)
595630 case S (t) => ???
596631 case N => (new Term .Blk (Nil , Term .Lit (UnitLit (true ))), ctx)
597- ModuleDef (owner, clsSym, tps, ps, ObjBody (bod))
632+ ModuleDef (owner, clsSym, tps, ps, k, ObjBody (bod))
598633 clsSym.defn = S (cd)
599634 cd
600635 case Cls =>
@@ -618,7 +653,6 @@ extends Importer:
618653 // TODO: pass abstract to `go`
619654 go(body :: sts, acc)
620655 case Modified (Keyword .`declare`, absLoc, body) :: sts =>
621- ???
622656 // TODO: pass declare to `go`
623657 go(body :: sts, acc)
624658 case (result : Tree ) :: Nil =>
@@ -723,6 +757,23 @@ extends Importer:
723757 .filter(_.isInstanceOf [VarSymbol ])
724758 .flatMap(_.asInstanceOf [VarSymbol ].decl)
725759 .fold(false )(_.isInstanceOf [TyParam ])
760+
761+ /** Checks if a term evaluates to a module value. */
762+ def evalsToModule (t : Term ): Bool =
763+ def isModule (t : Tree ): Bool = t match
764+ case TypeDef (Mod , _, _, _) => true
765+ case _ => false
766+ def returnsModule (t : TermDef ): Bool = t.signature match
767+ case S (TypeDef (Mod , _, N , N )) => true
768+ case _ => false
769+ t match
770+ case Term .Blk (_, res) => evalsToModule(res)
771+ case Term .App (lhs, rhs) => lhs.symbol match
772+ case S (sym : BlockMemberSymbol ) => sym.trmTree.fold(false )(returnsModule)
773+ case _ => false
774+ case t => t.symbol match
775+ case S (sym : BlockMemberSymbol ) => sym.modTree.fold(false )(isModule)
776+ case _ => false
726777
727778 class VarianceTraverser (var changed : Bool = true ) extends Traverser :
728779 override def traverseType (pol : Pol )(trm : Term ): Unit = trm match
0 commit comments