Skip to content

Commit

Permalink
Fixes for iron derivation for GreaterEqual, LessEqual, In (#3790)
Browse files Browse the repository at this point in the history
Co-authored-by: Pascal Mengelt <[email protected]>
  • Loading branch information
kciesielski and pme123 authored May 28, 2024
1 parent 5e86a8e commit 1033160
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,40 +130,79 @@ trait TapirCodecIron extends DescriptionWitness with LowPriorityValidatorForPred

}

inline given validatorForOr[N, Predicates](using mirror: UnionTypeMirror[Predicates]): ValidatorForPredicate[N, Predicates] =
new ValidatorForPredicate[N, Predicates] {

val unionConstraint = new Constraint.UnionConstraint[N, Predicates]
val validatorsForPredicates: List[ValidatorForPredicate[N, Any]] = summonValidators[N, mirror.ElementTypes]
private inline def strictValues[N, V <: Tuple]: List[N] = {
inline erasedValue[V] match
case _: EmptyTuple => Nil
case _: (StrictEqual[t] *: ts) =>
inline erasedValue[t] match
case e: N => e :: strictValues[N, ts]
case _ => Nil
case _ => Nil
}

override def validator: Validator[N] = Validator.any(validatorsForPredicates.map(_.validator): _*)
inline given validatorForOr[N, Predicates](using
mirror: UnionTypeMirror[Predicates]
): ValidatorForPredicate[N, Predicates] =
val strictEqualsValues = strictValues[N, mirror.ElementTypes]
if (strictEqualsValues.length == mirror.size)
// All elements of union type were StrictEqual[_], so it's simply an enumeration
ValidatorForPredicate.fromPrimitiveValidator(Validator.enumeration[N](strictEqualsValues))
else
new ValidatorForPredicate[N, Predicates] {

val unionConstraint = new Constraint.UnionConstraint[N, Predicates]
val validatorsForPredicates: List[ValidatorForPredicate[N, Any]] =
if strictEqualsValues.isEmpty then summonValidators[N, mirror.ElementTypes]
else
// There were some strict equals at the beginning of union type - putting them into a Validator.enumeration and attaching the rest of the validators as a normal list
ValidatorForPredicate
.fromPrimitiveValidator(Validator.enumeration[N](strictEqualsValues)) :: summonValidators[N, mirror.ElementTypes].drop(
strictEqualsValues.length
)

override def makeErrors(value: N, errorMessage: String): List[ValidationError[_]] =
if (!unionConstraint.test(value))
List(
ValidationError[N](
Validator.Custom(_ =>
ValidationResult.Invalid(unionConstraint.message) // at this point the validator is already failed anyway
),
value
override def validator: Validator[N] = Validator.any(validatorsForPredicates.map(_.validator): _*)

override def makeErrors(value: N, errorMessage: String): List[ValidationError[_]] =
if (!unionConstraint.test(value))
List(
ValidationError[N](
Validator.Custom(_ =>
ValidationResult.Invalid(unionConstraint.message) // at this point the validator is already failed anyway
),
value
)
)
)
else Nil
else Nil

}
}

inline given validatorForDescribedAnd[N, P](using
id: IsDescription[P],
mirror: IntersectionTypeMirror[id.Predicate]
): ValidatorForPredicate[N, P] =
validatorForAnd[N, id.Predicate].asInstanceOf[ValidatorForPredicate[N, P]]

inline given validatorForDescribedOr[N, P](using
inline given validatorForDescribedOr[N, P, Num](using
id: IsDescription[P],
mirror: UnionTypeMirror[id.Predicate]
mirror: UnionTypeMirror[id.Predicate],
notGe: NotGiven[P =:= GreaterEqual[Num]],
notLe: NotGiven[P =:= LessEqual[Num]]
): ValidatorForPredicate[N, P] =
validatorForOr[N, id.Predicate].asInstanceOf[ValidatorForPredicate[N, P]]

inline given validatorForDescribedOrGe[N: Numeric, P, Num <: N](using
id: IsDescription[P],
isGe: P =:= GreaterEqual[Num],
singleton: ValueOf[Num]
): ValidatorForPredicate[N, P] =
validatorForGreaterEqual[N, Num].asInstanceOf[ValidatorForPredicate[N, P]]

inline given validatorForDescribedOrLe[N: Numeric, P, Num <: N](using
id: IsDescription[P],
isLe: P =:= LessEqual[Num],
singleton: ValueOf[Num]
): ValidatorForPredicate[N, P] =
validatorForLessEqual[N, Num].asInstanceOf[ValidatorForPredicate[N, P]]
inline given validatorForDescribedPrimitive[N, P](using
id: IsDescription[P],
notUnion: NotGiven[UnionTypeMirror[id.Predicate]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import scala.quoted.*

trait UnionTypeMirror[A] {

type ElementTypes <: Tuple
type ElementTypes <: NonEmptyTuple
// Number of elements in the union
def size: Int
}

// Building a class is more convenient to instantiate using macros
class UnionTypeMirrorImpl[A, T <: Tuple] extends UnionTypeMirror[A] {
class UnionTypeMirrorImpl[A, T <: NonEmptyTuple](val size: Int) extends UnionTypeMirror[A] {

override type ElementTypes = T
}
Expand All @@ -31,15 +33,20 @@ object UnionTypeMirror {
def concatTypes(left: TypeRepr, right: TypeRepr): TypeRepr =
AppliedType(tplConcatType, List(left, right))

def rec(tpe: TypeRepr): TypeRepr =
def rec(tpe: TypeRepr): (Int, TypeRepr) =
tpe.dealias match {
case OrType(left, right) => concatTypes(rec(left), rec(right))
case t => prependTypes(t, TypeRepr.of[EmptyTuple])
case OrType(left, right) =>
val (c1, rec1) = rec(left)
val (c2, rec2) = rec(right)
(c1 + c2, concatTypes(rec1, rec2))
case t => (1, prependTypes(t, TypeRepr.of[EmptyTuple]))
}
val tupled =
val (size, tupled) =
TypeRepr.of[A].dealias match {
case or: OrType => rec(or).asType.asInstanceOf[Type[Elems]]
case tpe => report.errorAndAbort(s"${tpe.show} is not a union type")
case or: OrType =>
val (s, r) = rec(or)
(s, r.asType.asInstanceOf[Type[Elems]])
case tpe => report.errorAndAbort(s"${tpe.show} is not a union type")
}

type Elems
Expand All @@ -65,7 +72,7 @@ object UnionTypeMirror {
TypeTree.of[Elems]
)
),
Nil
List(Literal(IntConstant((size))))
).asExprOf[UnionTypeMirror[A]]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,7 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers {
type LimitedInt = Int :| IntConstraint

summon[Schema[LimitedInt]].validator should matchPattern {
case Validator.Mapped(
Validator.All(
List(
Validator.Any(List(Validator.Min(1, true), Validator.Enumeration(List(1), _, _))),
Validator.Any(List(Validator.Max(3, true), Validator.Enumeration(List(3), _, _)))
)
),
_
) =>
case Validator.Mapped(Validator.All(List(Validator.Min(1, false), Validator.Max(3, false))), _) =>
}
}

Expand All @@ -159,15 +151,7 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers {
type LimitedInt = Int :| IntConstraint

summon[Schema[LimitedInt]].validator should matchPattern {
case Validator.Mapped(
Validator.All(
List(
Validator.Min(1, true),
Validator.Any(List(Validator.Max(3, true), Validator.Enumeration(List(3), _, _)))
)
),
_
) =>
case Validator.Mapped(Validator.All(List(Validator.Min(1, true), Validator.Max(3, false))), _) =>
}
}

Expand All @@ -176,15 +160,7 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers {
type LimitedInt = Int :| IntConstraint

summon[Schema[LimitedInt]].validator should matchPattern {
case Validator.Mapped(
Validator.All(
List(
Validator.Any(List(Validator.Min(1, true), Validator.Enumeration(List(1), _, _))),
Validator.Max(3, true)
)
),
_
) =>
case Validator.Mapped(Validator.All(List(Validator.Min(1, false), Validator.Max(3, true))), _) =>
}
}

Expand All @@ -196,6 +172,14 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers {
case Validator.Mapped(Validator.All(List(Validator.Min(1, true), Validator.Max(3, true))), _) =>
}
}
"Generated validator for intersection of constraints" should "use tapir Validator.min(1, false) and Validator.max(3, false)" in {
type IntConstraint = GreaterEqual[1] & LessEqual[3]
type LimitedInt = Int :| IntConstraint

summon[Schema[LimitedInt]].validator should matchPattern {
case Validator.Mapped(Validator.All(List(Validator.Min(1, false), Validator.Max(3, false))), _) =>
}
}

"Generated validator for union of constraints" should "use tapir Validator.min and Validator.max" in {
type IntConstraint = Less[1] | Greater[3]
Expand All @@ -206,6 +190,50 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers {
}
}

"Generated validator for union of constraints" should "use tapir Validator.min and strict equality (enumeration)" in {
type IntConstraint = StrictEqual[3] | Greater[5]
type LimitedInt = Int :| IntConstraint

summon[Schema[LimitedInt]].validator should matchPattern {
case Validator.Mapped(Validator.Any(List(Validator.Enumeration(List(3), _, _), Validator.Min(5, true))), _) =>
}
}

"Generated validator for union of constraints" should "put muiltiple StrictEquality into a single enum and follow with the rest of constrains" in {
type IntConstraint = StrictEqual[3] | StrictEqual[4] | StrictEqual[13] | GreaterEqual[23]
type LimitedInt = Int :| IntConstraint

summon[Schema[LimitedInt]].validator should matchPattern {
case Validator.Mapped(Validator.Any(List(Validator.Enumeration(List(3, 4, 13), _, _), Validator.Min(23, false))), _) =>
}
}

"Generated validator for union of constraints" should "use tapir Validator.enumeration" in {
type IntConstraint = In[
(
110354433,
110354454,
122483323
)
]
type LimitedInt = Int :| IntConstraint

summon[Schema[LimitedInt]].validator should matchPattern {
case Validator.Mapped(
Validator.Enumeration(
List(
110354433,
110354454,
122483323
),
_,
_
),
_
) =>
}
}

"Generated validator for described union" should "use tapir Validator.min and Validator.max" in {
type IntConstraint = (Less[1] | Greater[3]) DescribedAs ("Should be included in less than 1 or more than 3")
type LimitedInt = Int :| IntConstraint
Expand All @@ -214,6 +242,25 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers {
case Validator.Mapped(Validator.Any(List(Validator.Max(1, true), Validator.Min(3, true))), _) =>
}
}

"Generated validator for described union" should "work with strings" in {
type StrConstraint = (Match["[a-c]*"] | Match["[x-z]*"]) DescribedAs ("Some description")
type LimitedStr = String :| StrConstraint

val identifierCodec = implicitly[PlainCodec[LimitedStr]]
identifierCodec.decode("aac") shouldBe DecodeResult.Value("aac")
identifierCodec.decode("yzx") shouldBe DecodeResult.Value("yzx")
identifierCodec.decode("aax") shouldBe a[DecodeResult.InvalidValue]
}

"Generated validator for described single constraint" should "use tapir Validator.max" in {
type IntConstraint = (Less[1]) DescribedAs ("Should be included in less than 1 or more than 3")
type LimitedInt = Int :| IntConstraint

summon[Schema[LimitedInt]].validator should matchPattern {
case Validator.Mapped(Validator.Max(1, true), _) =>
}
}

"Instances for opaque refined type" should "be correctly derived" in:
summon[Schema[RefinedInt]]
Expand Down

0 comments on commit 1033160

Please sign in to comment.