Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run function lifting and named form refs #32

Merged
merged 5 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ lazy val guinep = projectMatrix
.in(file("guinep"))
.settings(commonSettings)
.settings(
name := "GUInep"
name := "GUInep",
libraryDependencies ++= Seq(
"com.softwaremill.quicklens" %%% "quicklens" % "1.9.7"
)
)
.jvmPlatform(scalaVersions = List(scala3))

Expand Down
104 changes: 89 additions & 15 deletions guinep/src/main/scala/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package guinep

import guinep.model.*
import scala.quoted.*
import scala.collection.mutable
import com.softwaremill.quicklens.*

private[guinep] object macros {
inline def funInfos(inline fs: Any): Seq[Fun] =
Expand Down Expand Up @@ -33,7 +35,13 @@ private[guinep] object macros {
extension (t: Term)
private def select(s: Term): Term = Select(t, s.symbol)
private def select(s: String): Term =
t.select(t.tpe.typeSymbol.methodMember(s).head)
t.select(
t.tpe
.typeSymbol
.methodMember(s)
.headOption.
getOrElse(report.errorAndAbort(s"PANIC: No member $s in term ${t.show} with type ${t.tpe.show}"))
)

extension (s: Symbol)
private def prettyName: String =
Expand Down Expand Up @@ -93,7 +101,28 @@ private[guinep] object macros {
val isEnumCaseNonClassDef = typeSymbol.flags.is(Flags.Enum) && typeSymbol.flags.is(Flags.Case) && !typeSymbol.isClassDef
isModule || isEnumCaseNonClassDef

private def functionFormElementFromTree(paramName: String, paramType: TypeRepr): FormElement = paramType match {
private case class FormConstrContext(constructedTpes: mutable.Map[String, Option[FormElement]], referencedTpes: mutable.Set[String])
private def formConstrCtx(using FormConstrContext) = summon[FormConstrContext]

extension (tpe: TypeRepr)
private def namedRef: String = tpe match
case ntpe: NamedType => ntpe.typeSymbol.fullName
case AppliedType(tpe, args) => s"${tpe.namedRef}[${args.map(_.namedRef).mkString(", ")}]"
case AnnotatedType(tpe, _) => tpe.namedRef
case _ => tpe.show

private def functionFormElementFromTreeWithCaching(paramName: String, paramTpe: TypeRepr)(using FormConstrContext): FormElement =
formConstrCtx.constructedTpes.get(paramTpe.namedRef) match
case Some(_) =>
formConstrCtx.referencedTpes.add(paramTpe.namedRef)
FormElement.NamedRef(paramName, paramTpe.namedRef)
case _ =>
formConstrCtx.constructedTpes.update(paramTpe.namedRef, None)
val formElement = functionFormElementFromTree(paramName, paramTpe)
formConstrCtx.constructedTpes.update(paramTpe.namedRef, Some(formElement.modify(_.name).setTo("value")))
formElement

private def functionFormElementFromTree(paramName: String, paramType: TypeRepr)(using FormConstrContext): FormElement = paramType match {
case ntpe: NamedType if ntpe.name == "String" => FormElement.TextInput(paramName)
case ntpe: NamedType if ntpe.name == "Int" => FormElement.NumberInput(paramName)
case ntpe: NamedType if ntpe.name == "Boolean" => FormElement.CheckboxInput(paramName)
Expand All @@ -104,7 +133,7 @@ private[guinep] object macros {
FormElement.FieldSet(
paramName,
fields.map { valdef =>
functionFormElementFromTree(
functionFormElementFromTreeWithCaching(
valdef.name,
valdef.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs).stripAnnots
)
Expand All @@ -113,17 +142,29 @@ private[guinep] object macros {
case ntpe if isSumTpe(ntpe) =>
val classSymbol = ntpe.typeSymbol
val childrenAppliedTpes = classSymbol.children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)).map(_.stripAnnots)
val childrenFormElements = childrenAppliedTpes.map(t => functionFormElementFromTree("value", t))
val childrenFormElements = childrenAppliedTpes.map(t => functionFormElementFromTreeWithCaching("value", t))
val options = classSymbol.children.map(_.prettyName).zip(childrenFormElements)
FormElement.Dropdown(paramName, options)
case _ =>
unsupportedFunctionParamType(paramType)
}

private def functionFormElementsImpl(f: Expr[Any]): Expr[Seq[FormElement]] =
Expr.ofSeq(
functionParams(f).map { case ValDef(name, tpt, _) => functionFormElementFromTree(name, tpt.tpe) } .map(Expr(_))
)
private def formImpl(f: Expr[Any]): Expr[Form] =
given FormConstrContext = FormConstrContext(mutable.Map.empty, mutable.Set.empty)
val inputs = functionParams(f)
.map {
case ValDef(name, tpt, _) =>
functionFormElementFromTreeWithCaching(name, tpt.tpe)
}
val usedFormDecls =
formConstrCtx.constructedTpes
.toList.filter( (ref, formElement) => formConstrCtx.referencedTpes.contains(ref) )
.collect {
case (ref, Some(formElement)) => ref -> formElement
}
.toMap
val form = Form(inputs, usedFormDecls)
Expr(form)

private def appliedChild(childSym: Symbol, parentSym: Symbol, parentArgs: List[TypeRepr]): TypeRepr = childSym.tree match {
case classDef @ ClassDef(_, _, parents, _, _) =>
Expand All @@ -149,7 +190,35 @@ private[guinep] object macros {
childSym.typeRef
}

private def constructArg(paramTpe: TypeRepr, param: Term): Term = {
private case class ConstrEntry(definition: Option[Statement], ref: Term)
private case class ConstrContext(constrMap: mutable.Map[String, ConstrEntry])
private def constrCtx(using ConstrContext) = summon[ConstrContext]

private def constructArgWithCaching(paramTpe: TypeRepr, param: Term)(using ConstrContext): Term =
constrCtx.constrMap.get(paramTpe.namedRef) match
case Some(ConstrEntry(_, ref)) =>
ref.appliedTo(param)
case None =>
val ConstrEntry(_, ref) = constructFunction(paramTpe)
ref.appliedTo(param)

private def constructFunction(paramTpe: TypeRepr)(using ConstrContext): ConstrEntry =
val defdefSymbol =
Symbol.newMethod(
Symbol.spliceOwner,
s"constrFunFor${paramTpe.namedRef}",
MethodType(List("inputs"))(_ => List(TypeRepr.of[Any]), _ => paramTpe)
)
constrCtx.constrMap.update(paramTpe.namedRef, ConstrEntry(None, Ref(defdefSymbol)))
val defdef = DefDef(defdefSymbol, {
case List(List(param: Term)) =>
Some(constructArg(paramTpe, param))
})
val constrEntry = ConstrEntry(Some(defdef), Ref(defdefSymbol))
val newMap = constrCtx.constrMap.update(paramTpe.namedRef, constrEntry)
constrEntry

private def constructArg(paramTpe: TypeRepr, param: Term)(using ConstrContext): Term = {
paramTpe match {
case ntpe: NamedType if ntpe.name == "String" => param.select("asInstanceOf").appliedToType(ntpe)
case ntpe: NamedType if ntpe.name == "Int" => param.select("asInstanceOf").appliedToType(ntpe)
Expand All @@ -166,7 +235,7 @@ private[guinep] object macros {
val args = fields.collect { case field: ValDef =>
val fieldName = field.name
val fieldValue = paramValue.select("apply").appliedTo(Literal(StringConstant(fieldName)))
constructArg(
constructArgWithCaching(
field.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs),
fieldValue
)
Expand All @@ -186,7 +255,7 @@ private[guinep] object macros {
val childName = Literal(StringConstant(child.prettyName))
If(
paramName.select("equals").appliedTo(childName),
constructArg(childAppliedTpe, paramValue),
constructArgWithCaching(childAppliedTpe, paramValue),
acc
)
}
Expand All @@ -199,14 +268,15 @@ private[guinep] object macros {
private def functionRunImpl(f: Expr[Any]): Expr[List[Any] => String] = f.asTerm match {
case l@Lambda(params, _) =>
/* (params: List[Any]) => l.apply(constructArg(params(0)), constructArg(params(1)), ...) */
Lambda(
given ConstrContext = ConstrContext(mutable.Map.empty)
val resLambda = Lambda(
Symbol.spliceOwner,
MethodType(List("inputs"))(_ => List(TypeRepr.of[List[Any]]), _ => TypeRepr.of[String]),
{ case (sym, List(params: Term)) =>
val args = functionParams(f).zipWithIndex.map { case (valdef, i) =>
val paramTpe = valdef.tpt.tpe
val param = params.select("apply").appliedTo(Literal(IntConstant(i)))
constructArg(paramTpe, param)
constructArgWithCaching(paramTpe, param)
}.toList
val aply = l.select("apply")
val res =
Expand All @@ -216,6 +286,10 @@ private[guinep] object macros {
aply.appliedToArgs(args)
res.select("toString").appliedToNone
}
)
Block(
constrCtx.constrMap.toList.flatMap(_._2.definition),
resLambda
).asExprOf[List[Any] => String]
case i@Ident(_) =>
Lambda(
Expand Down Expand Up @@ -249,9 +323,9 @@ private[guinep] object macros {

def funInfoImpl(f: Expr[Any]): Expr[Fun] = {
val name = functionNameImpl(f)
val params = functionFormElementsImpl(f)
val form = formImpl(f)
val run = functionRunImpl(f)
'{ Fun($name, $params, $run) }
'{ Fun($name, $form, $run) }
}
}
}
53 changes: 50 additions & 3 deletions guinep/src/main/scala/model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,26 @@ package guinep
import scala.quoted.*

private[guinep] object model {
case class Fun(name: String, inputs: Seq[FormElement], run: List[Any] => String)
case class Fun(name: String, form: Form, run: List[Any] => String)

case class Form(inputs: Seq[FormElement], namedFormElements: Map[String, FormElement]) {
def formElementsJSONRepr =
val elems = this.inputs.map(_.toJSONRepr).mkString(",")
s"[$elems]"
def namedFormElementsJSONRepr: String =
val entries = this.namedFormElements.toList.map { (name, formElement) =>
s""""$name": ${formElement.toJSONRepr}"""
}
.mkString(",")
s"{$entries}"
}
object Form:
given ToExpr[Form] with
def apply(form: Form)(using Quotes): Expr[Form] = form match
case Form(inputs, namedFormElements) =>
'{ Form(${Expr(inputs)}, ${Expr(namedFormElements)}) }

enum FormElement(val name: String):
case FieldSet(override val name: String, elements: List[FormElement]) extends FormElement(name)
case TextInput(override val name: String) extends FormElement(name)
case NumberInput(override val name: String) extends FormElement(name)
case CheckboxInput(override val name: String) extends FormElement(name)
Expand All @@ -15,6 +31,20 @@ private[guinep] object model {
case DateInput(override val name: String) extends FormElement(name)
case EmailInput(override val name: String) extends FormElement(name)
case PasswordInput(override val name: String) extends FormElement(name)
case FieldSet(override val name: String, elements: List[FormElement]) extends FormElement(name)
case NamedRef(override val name: String, ref: String) extends FormElement(name)

def constrOrd: Int = this match
case TextInput(_) => 0
case NumberInput(_) => 1
case CheckboxInput(_) => 2
case Dropdown(_, _) => 3
case TextArea(_, _, _) => 4
case DateInput(_) => 5
case EmailInput(_) => 6
case PasswordInput(_) => 7
case FieldSet(_, _) => 8
case NamedRef(_, _) => 9

def toJSONRepr: String = this match
case FormElement.FieldSet(name, elements) =>
Expand All @@ -26,7 +56,8 @@ private[guinep] object model {
case FormElement.CheckboxInput(name) =>
s"""{ "name": '$name', "type": 'checkbox' }"""
case FormElement.Dropdown(name, options) =>
s"""{ "name": '$name', "type": 'dropdown', "options": [${options.map { case (k, v) => s"""{"name": "$k", "value": ${v.toJSONRepr}}""" }.mkString(",")}] }"""
// TODO(kπ) this sortBy isn't 100% sure to be working (the only requirement is for the first constructor to not be recursive; this is a graph problem, sorta)
s"""{ "name": '$name', "type": 'dropdown', "options": [${options.sortBy(_._2).map { case (k, v) => s"""{"name": "$k", "value": ${v.toJSONRepr}}""" }.mkString(",")}] }"""
case FormElement.TextArea(name, rows, cols) =>
s"""{ "name": '$name', "type": 'textarea', "rows": ${rows.getOrElse("")}, "cols": ${cols.getOrElse("")} }"""
case FormElement.DateInput(name) =>
Expand All @@ -35,6 +66,8 @@ private[guinep] object model {
s"""{ "name": '$name', "type": 'email' }"""
case FormElement.PasswordInput(name) =>
s"""{ "name": '$name', "type": 'password' }"""
case FormElement.NamedRef(name, ref) =>
s"""{ "name": '$name', "ref": '$ref', "type": 'namedref' }"""

object FormElement:
given ToExpr[FormElement] with
Expand All @@ -57,4 +90,18 @@ private[guinep] object model {
'{ FormElement.EmailInput(${Expr(name)}) }
case FormElement.PasswordInput(name) =>
'{ FormElement.PasswordInput(${Expr(name)}) }
case FormElement.NamedRef(name, ref) =>
'{ FormElement.NamedRef(${Expr(name)}, ${Expr(ref)}) }

given Ordering[FormElement] = new Ordering[FormElement] {
def compare(x: FormElement, y: FormElement): Int =
if x.constrOrd < y.constrOrd then -1
else if x.constrOrd > y.constrOrd then 1
else (x, y) match
case (FormElement.FieldSet(_, elems1), FormElement.FieldSet(_, elems2)) =>
elems1.size - elems2.size
case (FormElement.Dropdown(_, opts1), FormElement.Dropdown(_, opts2)) =>
opts1.size - opts2.size
case _ => 0
}
}
25 changes: 23 additions & 2 deletions testcases/src/main/scala/main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,28 @@ def printsWeirdGADT(g: WeirdGADT[String]): String = g match
case SomeValue(value) => s"SomeValue($value)"
case SomeOtherValue(value, value2) => s"SomeOtherValue($value, $value2)"

// This loops forever
def concatAll(elems: List[String]): String =
elems.mkString

enum IntTree:
case Leaf
case Node(left: IntTree, value: Int, right: IntTree)

def isInTree(elem: Int, tree: IntTree): Boolean = tree match
case IntTree.Leaf => false
case IntTree.Node(left, value, right) =>
value == elem || isInTree(elem, left) || isInTree(elem, right)

// Can't be handled right now
extension (elem: Int)
def isInTreeExt(tree: IntTree): Boolean = tree match
case IntTree.Leaf => false
case IntTree.Node(left, value, right) =>
value == elem || elem.isInTreeExt(left) || elem.isInTreeExt(right)

def addManyParamLists(a: Int)(b: Int): Int =
a + b

@main
def run: Unit =
guinep.web(
Expand All @@ -86,6 +104,9 @@ def run: Unit =
nameWithPossiblePrefix1,
roll20,
roll6(),
concatAll,
isInTree,
// isInTreeExt
// addManyParamLists
// printsWeirdGADT
// concatAll
)
Loading
Loading