diff --git a/integrations/iron/src/main/scala/sttp/iron/codec/iron/TapirCodecIron.scala b/integrations/iron/src/main/scala/sttp/iron/codec/iron/TapirCodecIron.scala index 3c2b026435..bc2748b06b 100644 --- a/integrations/iron/src/main/scala/sttp/iron/codec/iron/TapirCodecIron.scala +++ b/integrations/iron/src/main/scala/sttp/iron/codec/iron/TapirCodecIron.scala @@ -130,27 +130,51 @@ trait TapirCodecIron extends DescriptionWitness with LowPriorityValidatorForPred } - inline given validatorForOr[N, Predicates](using mirror: UnionTypeMirror[Predicates]): ValidatorForPredicate[N, Predicates] = - new ValidatorForPredicate[N, Predicates] { - - val unionConstraint = new Constraint.UnionConstraint[N, Predicates] - val validatorsForPredicates: List[ValidatorForPredicate[N, Any]] = summonValidators[N, mirror.ElementTypes] + private inline def strictValues[N, V <: Tuple]: List[N] = { + inline erasedValue[V] match + case _: EmptyTuple => Nil + case _: (StrictEqual[t] *: ts) => + inline erasedValue[t] match + case e: N => e :: strictValues[N, ts] + case _ => Nil + case _ => Nil + } - override def validator: Validator[N] = Validator.any(validatorsForPredicates.map(_.validator): _*) + inline given validatorForOr[N, Predicates](using + mirror: UnionTypeMirror[Predicates] + ): ValidatorForPredicate[N, Predicates] = + val strictEqualsValues = strictValues[N, mirror.ElementTypes] + if (strictEqualsValues.length == mirror.size) + // All elements of union type were StrictEqual[_], so it's simply an enumeration + ValidatorForPredicate.fromPrimitiveValidator(Validator.enumeration[N](strictEqualsValues)) + else + new ValidatorForPredicate[N, Predicates] { + + val unionConstraint = new Constraint.UnionConstraint[N, Predicates] + val validatorsForPredicates: List[ValidatorForPredicate[N, Any]] = + if strictEqualsValues.isEmpty then summonValidators[N, mirror.ElementTypes] + else + // There were some strict equals at the beginning of union type - putting them into a Validator.enumeration and attaching the rest of the validators as a normal list + ValidatorForPredicate + .fromPrimitiveValidator(Validator.enumeration[N](strictEqualsValues)) :: summonValidators[N, mirror.ElementTypes].drop( + strictEqualsValues.length + ) - override def makeErrors(value: N, errorMessage: String): List[ValidationError[_]] = - if (!unionConstraint.test(value)) - List( - ValidationError[N]( - Validator.Custom(_ => - ValidationResult.Invalid(unionConstraint.message) // at this point the validator is already failed anyway - ), - value + override def validator: Validator[N] = Validator.any(validatorsForPredicates.map(_.validator): _*) + + override def makeErrors(value: N, errorMessage: String): List[ValidationError[_]] = + if (!unionConstraint.test(value)) + List( + ValidationError[N]( + Validator.Custom(_ => + ValidationResult.Invalid(unionConstraint.message) // at this point the validator is already failed anyway + ), + value + ) ) - ) - else Nil + else Nil - } + } inline given validatorForDescribedAnd[N, P](using id: IsDescription[P], @@ -158,12 +182,27 @@ trait TapirCodecIron extends DescriptionWitness with LowPriorityValidatorForPred ): ValidatorForPredicate[N, P] = validatorForAnd[N, id.Predicate].asInstanceOf[ValidatorForPredicate[N, P]] - inline given validatorForDescribedOr[N, P](using + inline given validatorForDescribedOr[N, P, Num](using id: IsDescription[P], - mirror: UnionTypeMirror[id.Predicate] + mirror: UnionTypeMirror[id.Predicate], + notGe: NotGiven[P =:= GreaterEqual[Num]], + notLe: NotGiven[P =:= LessEqual[Num]] ): ValidatorForPredicate[N, P] = validatorForOr[N, id.Predicate].asInstanceOf[ValidatorForPredicate[N, P]] + inline given validatorForDescribedOrGe[N: Numeric, P, Num <: N](using + id: IsDescription[P], + isGe: P =:= GreaterEqual[Num], + singleton: ValueOf[Num] + ): ValidatorForPredicate[N, P] = + validatorForGreaterEqual[N, Num].asInstanceOf[ValidatorForPredicate[N, P]] + + inline given validatorForDescribedOrLe[N: Numeric, P, Num <: N](using + id: IsDescription[P], + isLe: P =:= LessEqual[Num], + singleton: ValueOf[Num] + ): ValidatorForPredicate[N, P] = + validatorForLessEqual[N, Num].asInstanceOf[ValidatorForPredicate[N, P]] inline given validatorForDescribedPrimitive[N, P](using id: IsDescription[P], notUnion: NotGiven[UnionTypeMirror[id.Predicate]], diff --git a/integrations/iron/src/main/scala/sttp/iron/codec/iron/UnionTypeMirror.scala b/integrations/iron/src/main/scala/sttp/iron/codec/iron/UnionTypeMirror.scala index e634a8b5d2..5c4fc78b9d 100644 --- a/integrations/iron/src/main/scala/sttp/iron/codec/iron/UnionTypeMirror.scala +++ b/integrations/iron/src/main/scala/sttp/iron/codec/iron/UnionTypeMirror.scala @@ -6,11 +6,13 @@ import scala.quoted.* trait UnionTypeMirror[A] { - type ElementTypes <: Tuple + type ElementTypes <: NonEmptyTuple + // Number of elements in the union + def size: Int } - + // Building a class is more convenient to instantiate using macros -class UnionTypeMirrorImpl[A, T <: Tuple] extends UnionTypeMirror[A] { +class UnionTypeMirrorImpl[A, T <: NonEmptyTuple](val size: Int) extends UnionTypeMirror[A] { override type ElementTypes = T } @@ -31,15 +33,20 @@ object UnionTypeMirror { def concatTypes(left: TypeRepr, right: TypeRepr): TypeRepr = AppliedType(tplConcatType, List(left, right)) - def rec(tpe: TypeRepr): TypeRepr = + def rec(tpe: TypeRepr): (Int, TypeRepr) = tpe.dealias match { - case OrType(left, right) => concatTypes(rec(left), rec(right)) - case t => prependTypes(t, TypeRepr.of[EmptyTuple]) + case OrType(left, right) => + val (c1, rec1) = rec(left) + val (c2, rec2) = rec(right) + (c1 + c2, concatTypes(rec1, rec2)) + case t => (1, prependTypes(t, TypeRepr.of[EmptyTuple])) } - val tupled = + val (size, tupled) = TypeRepr.of[A].dealias match { - case or: OrType => rec(or).asType.asInstanceOf[Type[Elems]] - case tpe => report.errorAndAbort(s"${tpe.show} is not a union type") + case or: OrType => + val (s, r) = rec(or) + (s, r.asType.asInstanceOf[Type[Elems]]) + case tpe => report.errorAndAbort(s"${tpe.show} is not a union type") } type Elems @@ -65,7 +72,7 @@ object UnionTypeMirror { TypeTree.of[Elems] ) ), - Nil + List(Literal(IntConstant((size)))) ).asExprOf[UnionTypeMirror[A]] } } diff --git a/integrations/iron/src/test/scala-3/sttp/iron/codec/iron/TapirCodecIronTestScala3.scala b/integrations/iron/src/test/scala-3/sttp/iron/codec/iron/TapirCodecIronTestScala3.scala index 417cfe68fc..101b489950 100644 --- a/integrations/iron/src/test/scala-3/sttp/iron/codec/iron/TapirCodecIronTestScala3.scala +++ b/integrations/iron/src/test/scala-3/sttp/iron/codec/iron/TapirCodecIronTestScala3.scala @@ -142,15 +142,7 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers { type LimitedInt = Int :| IntConstraint summon[Schema[LimitedInt]].validator should matchPattern { - case Validator.Mapped( - Validator.All( - List( - Validator.Any(List(Validator.Min(1, true), Validator.Enumeration(List(1), _, _))), - Validator.Any(List(Validator.Max(3, true), Validator.Enumeration(List(3), _, _))) - ) - ), - _ - ) => + case Validator.Mapped(Validator.All(List(Validator.Min(1, false), Validator.Max(3, false))), _) => } } @@ -159,15 +151,7 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers { type LimitedInt = Int :| IntConstraint summon[Schema[LimitedInt]].validator should matchPattern { - case Validator.Mapped( - Validator.All( - List( - Validator.Min(1, true), - Validator.Any(List(Validator.Max(3, true), Validator.Enumeration(List(3), _, _))) - ) - ), - _ - ) => + case Validator.Mapped(Validator.All(List(Validator.Min(1, true), Validator.Max(3, false))), _) => } } @@ -176,15 +160,7 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers { type LimitedInt = Int :| IntConstraint summon[Schema[LimitedInt]].validator should matchPattern { - case Validator.Mapped( - Validator.All( - List( - Validator.Any(List(Validator.Min(1, true), Validator.Enumeration(List(1), _, _))), - Validator.Max(3, true) - ) - ), - _ - ) => + case Validator.Mapped(Validator.All(List(Validator.Min(1, false), Validator.Max(3, true))), _) => } } @@ -196,6 +172,14 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers { case Validator.Mapped(Validator.All(List(Validator.Min(1, true), Validator.Max(3, true))), _) => } } + "Generated validator for intersection of constraints" should "use tapir Validator.min(1, false) and Validator.max(3, false)" in { + type IntConstraint = GreaterEqual[1] & LessEqual[3] + type LimitedInt = Int :| IntConstraint + + summon[Schema[LimitedInt]].validator should matchPattern { + case Validator.Mapped(Validator.All(List(Validator.Min(1, false), Validator.Max(3, false))), _) => + } + } "Generated validator for union of constraints" should "use tapir Validator.min and Validator.max" in { type IntConstraint = Less[1] | Greater[3] @@ -206,6 +190,50 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers { } } + "Generated validator for union of constraints" should "use tapir Validator.min and strict equality (enumeration)" in { + type IntConstraint = StrictEqual[3] | Greater[5] + type LimitedInt = Int :| IntConstraint + + summon[Schema[LimitedInt]].validator should matchPattern { + case Validator.Mapped(Validator.Any(List(Validator.Enumeration(List(3), _, _), Validator.Min(5, true))), _) => + } + } + + "Generated validator for union of constraints" should "put muiltiple StrictEquality into a single enum and follow with the rest of constrains" in { + type IntConstraint = StrictEqual[3] | StrictEqual[4] | StrictEqual[13] | GreaterEqual[23] + type LimitedInt = Int :| IntConstraint + + summon[Schema[LimitedInt]].validator should matchPattern { + case Validator.Mapped(Validator.Any(List(Validator.Enumeration(List(3, 4, 13), _, _), Validator.Min(23, false))), _) => + } + } + + "Generated validator for union of constraints" should "use tapir Validator.enumeration" in { + type IntConstraint = In[ + ( + 110354433, + 110354454, + 122483323 + ) + ] + type LimitedInt = Int :| IntConstraint + + summon[Schema[LimitedInt]].validator should matchPattern { + case Validator.Mapped( + Validator.Enumeration( + List( + 110354433, + 110354454, + 122483323 + ), + _, + _ + ), + _ + ) => + } + } + "Generated validator for described union" should "use tapir Validator.min and Validator.max" in { type IntConstraint = (Less[1] | Greater[3]) DescribedAs ("Should be included in less than 1 or more than 3") type LimitedInt = Int :| IntConstraint @@ -214,6 +242,25 @@ class TapirCodecIronTestScala3 extends AnyFlatSpec with Matchers { case Validator.Mapped(Validator.Any(List(Validator.Max(1, true), Validator.Min(3, true))), _) => } } + + "Generated validator for described union" should "work with strings" in { + type StrConstraint = (Match["[a-c]*"] | Match["[x-z]*"]) DescribedAs ("Some description") + type LimitedStr = String :| StrConstraint + + val identifierCodec = implicitly[PlainCodec[LimitedStr]] + identifierCodec.decode("aac") shouldBe DecodeResult.Value("aac") + identifierCodec.decode("yzx") shouldBe DecodeResult.Value("yzx") + identifierCodec.decode("aax") shouldBe a[DecodeResult.InvalidValue] + } + + "Generated validator for described single constraint" should "use tapir Validator.max" in { + type IntConstraint = (Less[1]) DescribedAs ("Should be included in less than 1 or more than 3") + type LimitedInt = Int :| IntConstraint + + summon[Schema[LimitedInt]].validator should matchPattern { + case Validator.Mapped(Validator.Max(1, true), _) => + } + } "Instances for opaque refined type" should "be correctly derived" in: summon[Schema[RefinedInt]]