@@ -235,28 +235,31 @@ extends Importer:
235235 val lt = term(lhs)
236236 val rt = term(rhs)
237237
238- // Check module parameters
238+ // Check if module arguments match module parameters
239239 val args = rt match
240240 case Term .Tup (fields) => S (fields)
241241 case _ => N
242- val argsModFlags = args
243- .map(_.map(_.flags.mod))
244- val paramsModFlags = lt.symbol match
245- case S (sym : BlockMemberSymbol ) => sym.defn match
246- case S (defn : TermDefinition ) => defn.params.lift(0 )
247- .map(_.params.map(_.flags.mod))
248- case _ => argsModFlags.map(_.map(_ => false ))
249- case _ => argsModFlags.map(_.map(_ => false ))
242+ val params = lt.symbol
243+ .filter(_.isInstanceOf [BlockMemberSymbol ])
244+ .flatMap(_.asInstanceOf [BlockMemberSymbol ].trmTree)
245+ .filter(_.isInstanceOf [TermDef ])
246+ .flatMap(_.asInstanceOf [TermDef ].paramLists.headOption)
250247 for
251- (argLists, amfs, pmfs) <- (args lazyZip argsModFlags lazyZip paramsModFlags)
252- (a, amf, pmf) <- (argLists lazyZip amfs lazyZip pmfs)
253- if amf && ! pmf
248+ (args, params) <- (args zip params)
249+ (arg, param) <- (args zip params.fields)
254250 do
255- log(s " ${a.value}" )
256- raise(ErrorReport (
257- msg " Module values can only be passed to module parameters. " -> a.toLoc
258- :: Nil ,
259- ))
251+ val argMod = arg.flags.mod
252+ val paramMod = param match
253+ case Tree .TypeDef (Mod , _, _, _) => true
254+ case _ => false
255+ if argMod && ! paramMod then
256+ raise :
257+ ErrorReport :
258+ msg " Module values can only be passed to module parameters. " -> arg.toLoc :: Nil
259+ if ! argMod && paramMod then
260+ raise :
261+ ErrorReport :
262+ msg " Module parameters can only receive module values. " -> arg.toLoc :: Nil
260263
261264 Term .App (lt, rt)(tree, sym)
262265 case Sel (pre, nme) =>
@@ -354,7 +357,7 @@ extends Importer:
354357 case _ =>
355358 val t = term(tree)
356359 t.symbol.flatMap(_.asMod) match
357- case S (_) => Fld (FldFlags .module , t, N )
360+ case S (_) => Fld (FldFlags .empty.copy(mod = true ) , t, N )
358361 case N => Fld (FldFlags .empty, t, N )
359362
360363 def unit : Term .Lit = Term .Lit (UnitLit (true ))
@@ -652,24 +655,10 @@ extends Importer:
652655 case id : Ident =>
653656 Param (FldFlags .empty, fieldOrVarSym(ParamBind , id), N ) :: Nil
654657 case InfixApp (lhs : Ident , Keyword .`:`, rhs) =>
655- // return S(moduleParam) if t represents a module parameter
656- // return N otherwise
657- def moduleParam (t : Term ): Opt [Param ] = t match
658- case s : hkmc2.semantics.Term .Sel => s.symbol
659- .flatMap(_.asMod)
660- .map(_ => Param (FldFlags .module, fieldOrVarSym(ParamBind , lhs), S (s)))
661- case hkmc2.semantics.Term .TyApp (s : hkmc2.semantics.Term .Sel , _) => s.symbol
662- .flatMap(_.asMod)
663- .map(_ => Param (FldFlags .module, fieldOrVarSym(ParamBind , lhs), S (s)))
664- case _ => N
665-
666- val t = term(rhs)
667- moduleParam(t) match
668- case S (p) =>
669- p :: Nil
670- case N =>
671- Param (FldFlags .empty, fieldOrVarSym(ParamBind , lhs), S (term(rhs))) :: Nil
658+ Param (FldFlags .empty, fieldOrVarSym(ParamBind , lhs), S (term(rhs))) :: Nil
672659 case App (Ident (" ," ), list) => params(list)._1
660+ case TypeDef (Mod , inner, _, _) => param(inner)
661+ .map(p => p.copy(flags = p.flags.copy(mod = true )))
673662 case TermDef (ImmutVal , inner, _) => param(inner)
674663
675664 def params (t : Tree ): Ctxl [(Ls [Param ], Ctx )] = t match
0 commit comments