Skip to content

Commit

Permalink
Merge pull request #1 from road21/basic_scala3_derivation_fix_collect…
Browse files Browse the repository at this point in the history
…_defauls

fix collect defaults macro
  • Loading branch information
goshacodes authored Apr 23, 2024
2 parents 9501a3c + f66ee03 commit ce26657
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
52 changes: 33 additions & 19 deletions modules/core/src/main/scala-3/tethys/derivation/Defaults.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,36 @@ private[derivation]
object DefaultsMacro:
import scala.quoted.*

def collect[T: Type](using quotes: Quotes): Expr[Map[Int, Any]] =
import quotes.reflect.*
val typeSymbol = TypeRepr.of[T].typeSymbol

val res = typeSymbol.caseFields.zipWithIndex.flatMap {
case (sym, idx) if sym.flags.is(Flags.HasDefault) =>
val defaultValueMethodSym =
typeSymbol.companionClass
.declaredMethod(s"$$lessinit$$greater$$default$$${idx + 1}")
.headOption
.getOrElse(report.errorAndAbort(s"Error while extracting default value for field '${sym.name}'"))

Some(Expr.ofTuple(Expr(idx) -> Ref(typeSymbol.companionModule).select(defaultValueMethodSym).asExprOf[Any]))
case _ =>
None
}

'{ Map(${ Varargs(res) }: _*) }

def collect[T: Type](using

Quotes
): Expr[Map[Int, Any]] =
import quotes.reflect._

val tpe = TypeRepr.of[T].typeSymbol
val terms = tpe.primaryConstructor.paramSymss.flatten
.filter(_.isValDef)
.zipWithIndex
.flatMap { case (field, idx) =>
val defaultMethodName = s"$$lessinit$$greater$$default$$${idx + 1}"
tpe.companionClass
.declaredMethod(defaultMethodName)
.headOption
.map { defaultMethod =>
val callDefault = {
val base = Ident(tpe.companionModule.termRef).select(defaultMethod)
val tParams = defaultMethod.paramSymss.headOption.filter(_.forall(_.isType))
tParams match
case Some(tParams) => TypeApply(base, tParams.map(TypeTree.ref))
case _ => base
}

defaultMethod.tree match {
case tree: DefDef => tree.rhs.getOrElse(callDefault)
case _ => callDefault
}
}
.map(x => Expr.ofTuple(Expr(idx) -> x.asExprOf[Any]))
}

'{ Map(${ Varargs(terms) }: _*) }
16 changes: 15 additions & 1 deletion modules/core/src/test/scala-3/tethys/DerivationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import tethys.commons.TokenNode.obj
import tethys.commons.{Token, TokenNode}
import tethys.readers.tokens.QueueIterator
import tethys.writers.tokens.SimpleTokenWriter.SimpleTokenWriterOps
import tethys.derivation.Defaults

class DerivationSpec extends AnyFlatSpec with Matchers {
def read[A: JsonReader](nodes: List[TokenNode]): A = {
Expand Down Expand Up @@ -113,5 +114,18 @@ class DerivationSpec extends AnyFlatSpec with Matchers {
}
}

it should "correctly read case classes with default parameters" in {
object Mod {
case class WithOpt(x: Int, y: Option[String] = Some("default")) derives JsonReader
}

read[Mod.WithOpt](obj("x" -> 5)) shouldBe Mod.WithOpt(5)
}

}
it should "correctly read case classes with default parameters and type arguments" in {
case class WithArg[A](x: Int, y: Option[A] = None) derives JsonReader

read[WithArg[Int]](obj("x" -> 5)) shouldBe WithArg[Int](5)
read[WithArg[String]](obj("x" -> 5, "y" -> "lool")) shouldBe WithArg[String](5, Some("lool"))
}
}

0 comments on commit ce26657

Please sign in to comment.