Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: refactor DPIA traversals #107

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
59 changes: 51 additions & 8 deletions macros/src/main/scala/shine/macros/Primitive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,62 @@ object Primitive {
class Impl(val c: blackbox.Context) {
import c.universe._

def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees)
def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees)
def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees)
def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees)

def primitive(transform : ClassDef => ClassDef)(annottees: Seq[c.Expr[Any]]): c.Expr[Any] = {
annottees.map(_.tree) match {
case (cdef: ClassDef) :: Nil =>
c.Expr(transform(cdef))
case (cdef: ClassDef) :: (md: ModuleDef) :: Nil =>
c.Expr(q"{${transform(cdef)}; $md}")
case (cdef: ClassDef) :: Nil => c.Expr(transform(cdef))
case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => c.Expr(q"{${transform(cdef)}; $md}")
case _ => c.abort(c.enclosingPosition, "expected a class definition")
}
}

def makeLowerCaseName(s: String): String =
s"${Character.toLowerCase(s.charAt(0))}${s.substring(1)}"

def makeTraverseCall(v : Tree, name : TermName) : Tree => Option[Tree] = {
case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) |
Ident(TypeName("BasicType")) => Some(fq"${name} <- $v.datatype($name)")
case Ident(TypeName("Data")) => Some(fq"${name} <- $v.data($name)")
case Ident(TypeName("Nat")) => Some(fq"${name} <- $v.nat($name)")
case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.nat($name)")
case Ident(TypeName("NatToNat")) => Some(fq"${name} <- $v.natToNat($name)")
case Ident(TypeName("NatToData")) => Some(fq"${name} <- $v.natToData($name)")
case Ident(TypeName("AccessType")) => Some(fq"${name} <- $v.accessType($name)")
case Ident(TypeName("AddressSpace")) => Some(fq"${name} <- $v.addressSpace($name)")
// Phrase[ExpType]
case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => Some(fq"${name} <- $v.phrase($name)")
// Vector[Phrase[ExpType]]
case AppliedTypeTree((Ident(TypeName("Vector")),
List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverseV($name.map($v.phrase(_)))")
case AppliedTypeTree((Ident(TypeName("Seq")),
List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverse($name.map($v.phrase(_)))")
case _ => None
}

def makeTraverse(name: TypeName, additionalParams: List[ValDef], params: List[ValDef]): Tree = {
val v = q"v"
val paramNames = params.map { case ValDef(_, name, _, _) => q"$name" }
val additionalParamNames = additionalParams.map { case ValDef(_, name, _, _) => q"$name" }
val forLoopBindings : List[Tree] = params.flatMap {
case ValDef(_, name, tpt, _) => makeTraverseCall(v, name)(tpt)
}
val construct = if (additionalParamNames.isEmpty) q"new $name(..$paramNames)"
else q"new $name(..$additionalParamNames)(..$paramNames)"
val forloop = if (forLoopBindings.isEmpty) q"monad.return_($construct)"
else q"for (..${forLoopBindings}) yield $construct"

q"""
override def traverse[M[+_]]($v: shine.DPIA.Phrases.traverse.Traversal[M]): M[$name] = {
import util.monad._
implicit val monad: Monad[M] = implicitly($v.monad)
$forloop
}
"""
}

def makeVisitAndRebuild(name: TypeName,
additionalParams: List[ValDef],
params: List[ValDef]): Tree = {
Expand Down Expand Up @@ -151,7 +190,7 @@ object Primitive {
body: List[Tree],
parents: List[Tree])

def primitivesFromClassDef: ClassDef => ClassInfo = {
def getClassInfo: ClassDef => ClassInfo = {
case q"case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " =>
ClassInfo(
name.asInstanceOf[c.TypeName],
Expand Down Expand Up @@ -187,12 +226,16 @@ object Primitive {
}

def makePrimitiveClass : ClassInfo => ClassDef = { case ClassInfo(name, additionalParams, params, body, parents) =>
val traverseMissing =
body.collectFirst({ case DefDef(_, TermName("traverse"), _, _, _, _) => ()}).isEmpty
val visitAndRebuildMissing =
body.collectFirst({ case DefDef(_, TermName("visitAndRebuild"), _, _, _, _) => ()}).isEmpty
val xmlPrinterMissing =
body.collectFirst({ case DefDef(_, TermName("xmlPrinter"), _, _, _, _) => ()}).isEmpty

val generated = q"""
${if (traverseMissing) makeTraverse(name, additionalParams, params) else q""}

${if (visitAndRebuildMissing)
makeVisitAndRebuild(name, additionalParams, params)
else q""}
Expand Down
48 changes: 15 additions & 33 deletions src/main/scala/rise/core/traverse.scala
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
package rise.core

import scala.language.implicitConversions
import arithexpr.arithmetic.NamedVar
import util.monad
import rise.core.semantics._
import rise.core.types._
import scala.language.implicitConversions

object traverse {
trait Monad[M[_]] {
def return_[T] : T => M[T]
def bind[T,S] : M[T] => (T => M[S]) => M[S]
def traverse[A] : Seq[M[A]] => M[Seq[A]] =
_.foldRight(return_(Nil : Seq[A]))({case (mx, mxs) =>
bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))})
}

implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: Monad[M]) = new {
def map[B](f: A => B): M[B] = tc.bind(m)(a => tc.return_(f(a)) )
def flatMap[B](f: A => M[B]): M[B] = tc.bind(m)(f)
}
// Reexport util.monad.*
type Monad[M[+_]] = monad.Monad[M]
type Pure[+T] = monad.Pure[T]
implicit def monadicSyntax[M[+_], A](m: M[A])(implicit tc: monad.Monad[M]) = monad.monadicSyntax(m)(tc)
val PureMonad = monad.PureMonad
val OptionMonad = monad.OptionMonad

sealed trait VarType
case object Binding extends VarType
case object Reference extends VarType

trait Traversal[M[_]] {
trait Traversal[M[+_]] {
protected[this] implicit def monad : Monad[M]
def return_[T] : T => M[T] = monad.return_
def bind[T,S] : M[T] => (T => M[S]) => M[S] = monad.bind
Expand Down Expand Up @@ -55,7 +49,7 @@ object traverse {
def matrixLayout : MatrixLayout => M[MatrixLayout] = return_
def fragmentKind : FragmentKind => M[FragmentKind] = return_
def datatype : DataType => M[DataType] = {
case i: DataTypeIdentifier => return_(i.asInstanceOf[DataType])
case i: DataTypeIdentifier => return_(i)
case NatType => return_(NatType : DataType)
case s : ScalarType => return_(s : DataType)
case ArrayType(n, d) =>
Expand Down Expand Up @@ -86,14 +80,14 @@ object traverse {
}

def natToNat : NatToNat => M[NatToNat] = {
case i : NatToNatIdentifier => return_(i.asInstanceOf[NatToNat])
case i : NatToNatIdentifier => return_(i)
case NatToNatLambda(x, e) =>
for { x1 <- typeIdentifierDispatch(Binding)(x); e1 <- natDispatch(Reference)(e) }
yield NatToNatLambda(x1, e1)
}

def natToData : NatToData => M[NatToData] = {
case i : NatToDataIdentifier => return_(i.asInstanceOf[NatToData])
case i : NatToDataIdentifier => return_(i)
case NatToDataLambda(x, e) =>
for { x1 <- typeIdentifierDispatch(Binding)(x); e1 <- datatype(e) }
yield NatToDataLambda(x1, e1)
Expand Down Expand Up @@ -192,27 +186,15 @@ object traverse {
}
}

trait ExprTraversal[M[_]] extends Traversal[M] {
trait ExprTraversal[M[+_]] extends Traversal[M] {
override def `type`[T <: Type] : T => M[T] = return_
}

case class Pure[T](unwrap : T)
implicit object PureMonad extends Monad[Pure] {
override def return_[T] : T => Pure[T] = t => Pure(t)
override def bind[T,S] : Pure[T] => (T => Pure[S]) => Pure[S] =
v => f => v match { case Pure(v) => f(v) }
}

implicit object OptionMonad extends Monad[Option] {
def return_[T]: T => Option[T] = Some(_)
def bind[T, S]: Option[T] => (T => Option[S]) => Option[S] = v => v.flatMap
}

trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad }
trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure]

def apply(e : Expr, f : PureTraversal) : Expr = f.expr(e).unwrap
def apply[M[_]](e : Expr, f : Traversal[M]) : M[Expr] = f.expr(e)
def apply[M[+_]](e : Expr, f : Traversal[M]) : M[Expr] = f.expr(e)
def apply[T <: Type](t : T, f : PureTraversal) : T = f.`type`(t).unwrap
def apply[T <: Type, M[_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e)
def apply[T <: Type, M[+_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e)
}
55 changes: 30 additions & 25 deletions src/main/scala/shine/DPIA/Phrases/Phrase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import shine.DPIA.Semantics.OperationalSemantics.{IndexData, NatData}
import shine.DPIA.Types._
import shine.DPIA.Types.TypeCheck._
import shine.DPIA._
import shine.DPIA.Phrases.traverse._
import shine.DPIA.primitives.functional.NatAsIndex

sealed trait Phrase[T <: PhraseType] {
Expand Down Expand Up @@ -46,6 +47,7 @@ final case class DepLambda[K <: Kind, T <: PhraseType](x: K#I, body: Phrase[T])
extends Phrase[K `()->:` T] {
override val t: DepFunType[K, T] = DepFunType[K, T](x, body.t)
override def toString: String = s"Λ(${x.name} : ${kn.get}). $body"
val kindName : KindName[K] = implicitly(kn)
}

object DepLambda {
Expand Down Expand Up @@ -137,7 +139,7 @@ object Phrase {
`for`: Phrase[T1],
in: Phrase[T2]): Phrase[T2] = {
var substCounter = 0
object Visitor extends VisitAndRebuild.Visitor {
object Visitor extends PureTraversal {
def renaming[X <: PhraseType](p: Phrase[X]): Phrase[X] = {
case class Renaming(idMap: Map[String, String]) extends VisitAndRebuild.Visitor {
override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match {
Expand Down Expand Up @@ -167,33 +169,33 @@ object Phrase {
}
VisitAndRebuild(p, Renaming(Map()))
}
override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = {
p match {
case `for` =>
val newPh = if (substCounter == 0) ph else renaming(ph)
substCounter += 1
Stop(newPh.asInstanceOf[Phrase[T]])
case Natural(n) =>
val v = NatIdentifier(`for` match {
case Identifier(name, _) => name
case _ => throw new Exception("This should never happen")
})

ph.t match {
case ExpType(NatType, _) =>
Stop(Natural(Nat.substitute(
Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]])
case ExpType(IndexType(_), _) =>
Stop(Natural(Nat.substitute(
Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]])
case _ => Continue(p, this)
}
case _ => Continue(p, this)
}

// override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = {
override def phrase[T <: PhraseType]: Phrase[T] => Pure[Phrase[T]] = p => p match {
case `for` =>
val newPh = if (substCounter == 0) ph else renaming(ph)
substCounter += 1
return_(newPh.asInstanceOf[Phrase[T]])
case Natural(n) =>
val v = NatIdentifier(`for` match {
case Identifier(name, _) => name
case _ => throw new Exception("This should never happen")
})

ph.t match {
case ExpType(NatType, _) =>
return_(Natural(Nat.substitute(
Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]])
case ExpType(IndexType(_), _) =>
return_(Natural(Nat.substitute(
Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]])
case _ => super.phrase(p)
}
case _ => super.phrase(p)
}
}

VisitAndRebuild(in, Visitor)
Visitor.phrase(in).unwrap
}

def substitute[T2 <: PhraseType](substitutionMap: Map[Phrase[_], Phrase[_]],
Expand Down Expand Up @@ -367,6 +369,9 @@ sealed trait Primitive[T <: PhraseType] extends Phrase[T] {
def xmlPrinter: xml.Elem =
throw new Exception("xmlPrinter should be implemented by a macro")

def traverse[M[+_]](f: Traversal[M]): M[Phrase[T]] =
throw new Exception("traverse should be implemented by a macro")

def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[T] =
throw new Exception("visitAndRebuild should be implemented by a macro")
}
Expand Down
Loading