Skip to content

Commit

Permalink
Merge pull request #303 from dos65/default_fields_fix
Browse files Browse the repository at this point in the history
fix: support default values for case class fields
  • Loading branch information
dos65 authored Jun 27, 2024
2 parents efadf64 + 946f7e1 commit 839ad8d
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ trait CaseClassUtils extends LoggingUtils {
import c.universe._

case class CaseClassDefinition(tpe: Type, fields: List[CaseClassField])
case class CaseClassField(name: String, tpe: Type)
case class CaseClassField(name: String, tpe: Type, defaultValue: Option[Tree])

def caseClassDefinition[A: WeakTypeTag]: CaseClassDefinition = caseClassDefinition(weakTypeOf[A])

def caseClassDefinition(tpe: Type): CaseClassDefinition = {
val ctor = getConstructor(tpe)
CaseClassDefinition(
tpe = tpe,
fields = ctor.paramLists.head.map(constructorParameterToCaseClassField(tpe))
fields = ctor.paramLists.head.zipWithIndex.map{ case (sym, idx) => constructorParameterToCaseClassField(tpe)(idx, sym) }
)
}

Expand All @@ -39,13 +39,21 @@ trait CaseClassUtils extends LoggingUtils {
}
}

private def constructorParameterToCaseClassField(tpe: Type)(param: Symbol): CaseClassField = {
private def constructorParameterToCaseClassField(tpe: Type)(idx: Int, param: Symbol): CaseClassField = {
val possibleRealType = tpe.decls.collectFirst {
case s if s.name == param.name => s.typeSignatureIn(tpe).finalResultType
}

CaseClassField(
name = param.name.decodedName.toString,
tpe = possibleRealType.getOrElse(param.typeSignatureIn(tpe))
tpe = possibleRealType.getOrElse(param.typeSignatureIn(tpe)),
defaultValue =
if (param.asTerm.isParamWithDefault) {
val methodName = TermName(s"apply$$default$$${idx + 1}")
val select = q"${tpe.companion.typeSymbol.asClass.module}.$methodName"
Some(select)
} else
None
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ trait ReaderDerivation
tpe: Type,
jsonName: String,
value: TermName,
isInitialized: TermName) extends ReaderField
isInitialized: TermName,
defaultValue: Option[Tree]) extends ReaderField

private case class ExtractedField(name: String,
tpe: Type,
Expand Down Expand Up @@ -80,7 +81,8 @@ trait ReaderDerivation
tpe = field.tpe,
jsonName = field.name,
value = TermName(c.freshName(field.name + "Value")),
isInitialized = TermName(c.freshName(field.name + "Init"))
isInitialized = TermName(c.freshName(field.name + "Init")),
defaultValue = field.defaultValue
)
})

Expand Down Expand Up @@ -275,17 +277,20 @@ trait ReaderDerivation
}

private def allocateVariables(readerFields: List[ReaderField], typeDefaultValues: List[(Type, TermName)]): List[Tree] = {
val possibleValues: List[(TermName, Type)] = readerFields.flatMap {
val possibleValues: List[(TermName, Type, Option[Tree])] = readerFields.flatMap {
case f: SimpleField =>
List(f.value -> f.tpe)
List((f.value, f.tpe, f.defaultValue))
case f: ExtractedField =>
(f.value, f.tpe) :: f.args.map(arg => arg.value -> arg.field.tpe)
(f.value, f.tpe, None) :: f.args.map(arg => (arg.value, arg.field.tpe, None))
case f: FromExtractedReader =>
(f.value, f.tpe) :: f.args.map(arg => arg.value -> arg.field.tpe)
((f.value, f.tpe, None)) :: f.args.map(arg => (arg.value, arg.field.tpe, None))
}

val (_, values) = possibleValues.foldLeft(List[TermName](), List[Tree]()) {
case ((allocated, trees), (value, tpe)) if !allocated.contains(value) =>
case ((allocated, trees), (value, tpe, Some(defaultTree))) =>
val tree = q"var $value: $tpe = $defaultTree"
(value :: allocated, tree :: trees)
case ((allocated, trees), (value, tpe, defaultTreeOpt)) if !allocated.contains(value) =>
val tree = q"var $value: $tpe = ${typeDefaultValues.find(_._1 =:= tpe).get._2}"
(value :: allocated, tree :: trees)

Expand All @@ -295,14 +300,14 @@ trait ReaderDerivation
val inits = readerFields
.flatMap {
case f: SimpleField =>
List(f.isInitialized)
List((f.isInitialized, f.defaultValue.isDefined))
case f: ExtractedField =>
f.isInitialized :: f.args.map(_.isInitialized)
(f.isInitialized, false) :: f.args.map(a => (a.isInitialized, false))
case f: FromExtractedReader =>
f.isInitialized :: f.args.map(_.isInitialized)
(f.isInitialized, false) :: f.args.map(a => (a.isInitialized, false))
}
.distinct
.map(term => q"var $term: Boolean = false")
.map{ case (term, initialized) => q"var $term: Boolean = $initialized"}

val tempIterators = readerFields.collect {
case f: FromExtractedReader =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ trait ReaderDerivation extends ReaderBuilderCommons {
}
it.nextToken()

$defaultValuesExpr.foreach { case (name, tpeName, defaultValue) =>
readFields.getOrElseUpdate(name, MutableMap(tpeName -> defaultValue))
}

$possiblyNotInitializedExpr.foreach { case (name, tpeName, defaultValue) =>
readFields.getOrElseUpdate(name, MutableMap(tpeName -> defaultValue))
}
Expand Down Expand Up @@ -232,9 +236,6 @@ trait ReaderDerivation extends ReaderBuilderCommons {
Expr.block(res, '{ () })
}

$defaultValuesExpr.foreach { case (name, defaultValue) =>
resultFields.getOrElseUpdate(name, defaultValue)
}

val notReadAfterExtractingFields: Set[String] =
Set.from(${ Varargs(classFields.map(field => Expr(field.name))) }) -- resultFields.keySet
Expand Down Expand Up @@ -377,10 +378,10 @@ trait ReaderDerivation extends ReaderBuilderCommons {
(readersExpr, fieldsWithoutReadersExpr)
}

private def allocateDefaultValuesFromDefinition[T: Type]: Expr[Map[String, Any]] = {
private def allocateDefaultValuesFromDefinition[T: Type]: Expr[List[(String, String, Any)]] = {
val tpe = TypeRepr.of[T]

val res = tpe.typeSymbol.caseFields.flatMap {
val res = tpe.typeSymbol.caseFields.collect {
case sym if sym.flags.is(Flags.HasDefault) =>
val comp = sym.owner.companionClass
val mod = Ref(sym.owner.companionModule)
Expand All @@ -397,11 +398,14 @@ trait ReaderDerivation extends ReaderBuilderCommons {
)

val defaultValueTerm = mod.select(defaultValueMethodSym)
Some(Expr.ofTuple(Expr(sym.name) -> defaultValueTerm.asExprOf[Any]))
case _ => None
val appliedTypes = if tpe.typeArgs.nonEmpty then defaultValueTerm.appliedToTypes(tpe.typeArgs) else defaultValueTerm
Expr.ofTuple(
Expr(sym.name),
Expr(tpe.memberType(sym).getDealiasFullName),
appliedTypes.asExprOf[Any]
)
}

'{ Map(${ Varargs(res) }: _*) }
Expr.ofList(res)
}

private def allocateTypeReadersInfos(readerFields: List[ReaderField]): List[(TypeRepr, Term)] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,23 @@ class SemiautoReaderDerivationTest extends AnyFlatSpec with Matchers {
))
} should have message "Illegal json at '[ROOT]': unexpected field 'not_id_param', expected one of 'some_param', 'id_param', 'simple'"
}

it should "derive reader for class with default params" in {
implicit val reader: JsonReader[DefaultField[Int]] = jsonReader[DefaultField[Int]]

read[DefaultField[Int]](obj(
"value" -> 1,
"default" -> false
)) shouldBe DefaultField[Int](
value = 1,
default = false
)

read[DefaultField[Int]](obj(
"value" -> 1,
)) shouldBe DefaultField[Int](
value = 1,
default = true
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ package object derivation {
case class SeqMaster4(a: Seq[Int])

case class CamelCaseNames(someParam: Int, IDParam: Int, simple: Int)

case class DefaultField[T](value: T, default: Boolean = true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,23 @@ class SemiautoReaderDerivationTest extends AnyFlatSpec with Matchers {
token(ParametrizedEnum.TWO.toString)
) shouldBe ParametrizedEnum.TWO
}

it should "derive reader for class with default params" in {
implicit val reader: JsonReader[DefaultField[Int]] = jsonReader[DefaultField[Int]]

read[DefaultField[Int]](obj(
"value" -> 1,
"default" -> false
)) shouldBe DefaultField[Int](
value = 1,
default = false
)

read[DefaultField[Int]](obj(
"value" -> 1
)) shouldBe DefaultField[Int](
value = 1,
default = true
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ package object derivation {
case ONE extends ParametrizedEnum(1)
case TWO extends ParametrizedEnum(2)
}

case class DefaultField[T](value: T, default: Boolean = true)
}

0 comments on commit 839ad8d

Please sign in to comment.