Skip to content

Commit

Permalink
Add support for @oneOf inputs (#1846)
Browse files Browse the repository at this point in the history
* Add `ArgBuilder` derivation for `oneOff` inputs

* Add rendering and introspection of `oneOf` inputs

* Add tests for executing queries with `oneOf` inputs

* Don't introspect `isOneOf` in `IntrospectionClient`

* Add value type implementation

* Rollback irrelevant changes

* Fix Scala 3 derivation

* Fix merge errors

* PR comments

* fmt

* Add schema & input validations

* Remove `nullable` methods and add schema validation tests

* PR comments

* fmt

* Allow OneOf inputs to have a single field

* Fix merging errors and add mima exclusions

* Fix merging errors

* Disable mima

* Reuse `hasAnnotation` macro

* Change `parentTypeName` to `parentType` on `__InputValue`

* Remove `isOneOf` argument from `makeInputObject`

* Fix mima

* Micro-optimize validation

* Reimplement handling of OneOf inputs via a PartialFunction

* Fix Scala 2.12
  • Loading branch information
kyri-petrou authored Jun 18, 2024
1 parent c76e2f5 commit b56673b
Show file tree
Hide file tree
Showing 27 changed files with 850 additions and 136 deletions.
8 changes: 7 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,13 @@ lazy val enableMimaSettingsJVM =
mimaFailOnProblem := enforceMimaCompatibility,
mimaPreviousArtifacts := previousStableVersion.value.map(organization.value %% moduleName.value % _).toSet,
mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[IncompatibleMethTypeProblem]("caliban.execution.Executor#ReducedStepExecutor.makeQuery")
ProblemFilters.exclude[IncompatibleMethTypeProblem]("caliban.execution.Executor#ReducedStepExecutor.makeQuery"),
ProblemFilters.exclude[DirectMissingMethodProblem]("caliban.parsing.adt.Type.$init$"),
ProblemFilters.exclude[DirectMissingMethodProblem]("caliban.introspection.adt.__Type.*"),
ProblemFilters.exclude[DirectMissingMethodProblem]("caliban.introspection.adt.__InputValue.*"),
ProblemFilters.exclude[FinalMethodProblem]("caliban.parsing.adt.Type*"),
ProblemFilters.exclude[MissingTypesProblem]("caliban.introspection.adt.__Type$"),
ProblemFilters.exclude[MissingTypesProblem]("caliban.introspection.adt.__InputValue$")
)
)

Expand Down
69 changes: 57 additions & 12 deletions core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package caliban.schema
import caliban.CalibanError.ExecutionError
import caliban.InputValue
import caliban.Value._
import caliban.schema.Annotations.GQLDefault
import caliban.schema.Annotations.GQLName
import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput }
import magnolia1._

import scala.collection.compat._
import scala.language.experimental.macros

trait CommonArgBuilderDerivation {
Expand All @@ -25,19 +25,44 @@ trait CommonArgBuilderDerivation {
override def map[A, B](from: EitherExecutionError[A])(fn: A => B): EitherExecutionError[B] = from.map(fn)
}

def join[T](ctx: CaseClass[ArgBuilder, T]): ArgBuilder[T] =
(input: InputValue) =>
def join[T](ctx: CaseClass[ArgBuilder, T]): ArgBuilder[T] = new ArgBuilder[T] {

private val params = {
val arr = Array.ofDim[(String, EitherExecutionError[Any])](ctx.parameters.length)
ctx.parameters.zipWithIndex.foreach { case (p, i) =>
val label = p.annotations.collectFirst { case GQLName(name) => name }.getOrElse(p.label)
val default = p.typeclass.buildMissing(p.annotations.collectFirst { case GQLDefault(v) => v })
arr(i) = (label, default)
}
arr
}

private val required = params.collect { case (label, default) if default.isLeft => label }

override private[schema] val partial: PartialFunction[InputValue, Either[ExecutionError, T]] = {
case InputValue.ObjectValue(fields) if required.forall(fields.contains) => fromFields(fields)
}

def build(input: InputValue): Either[ExecutionError, T] =
input match {
case InputValue.ObjectValue(fields) => fromFields(fields)
case _ => Left(ExecutionError("expected an input object"))
}

private[this] def fromFields(fields: Map[String, InputValue]): Either[ExecutionError, T] =
ctx.constructMonadic { p =>
input match {
case InputValue.ObjectValue(fields) =>
val label = p.annotations.collectFirst { case GQLName(name) => name }.getOrElse(p.label)
val default = p.annotations.collectFirst { case GQLDefault(v) => v }
fields.get(label).fold(p.typeclass.buildMissing(default))(p.typeclass.build)
case value => p.typeclass.build(value)
}
val idx = p.index
val (label, default) = params(idx)
val field = fields.getOrElse(label, null)
if (field ne null) p.typeclass.build(field) else default
}
}

def split[T](ctx: SealedTrait[ArgBuilder, T]): ArgBuilder[T] = input =>
def split[T](ctx: SealedTrait[ArgBuilder, T]): ArgBuilder[T] =
if (ctx.annotations.contains(GQLOneOfInput())) makeOneOfBuilder(ctx)
else makeSumBuilder(ctx)

private def makeSumBuilder[T](ctx: SealedTrait[ArgBuilder, T]): ArgBuilder[T] = input =>
(input match {
case EnumValue(value) => Some(value)
case StringValue(value) => Some(value)
Expand All @@ -53,6 +78,26 @@ trait CommonArgBuilderDerivation {
}
case None => Left(ExecutionError(s"Can't build a trait from input $input"))
}

private def makeOneOfBuilder[A](ctx: SealedTrait[ArgBuilder, A]): ArgBuilder[A] = new ArgBuilder[A] {

private def inputError(input: InputValue) =
ExecutionError(s"Invalid oneOf input $input for trait ${ctx.typeName.short}")

override val partial: PartialFunction[InputValue, Either[ExecutionError, A]] = {
val xs = ctx.subtypes.map(_.typeclass).toList.asInstanceOf[List[ArgBuilder[A]]]

val checkSize: PartialFunction[InputValue, Either[ExecutionError, A]] = {
case InputValue.ObjectValue(f) if f.size != 1 =>
Left(ExecutionError("Exactly one key must be specified for oneOf inputs"))
}
xs.foldLeft(checkSize)(_ orElse _.partial)
}

def build(input: InputValue): Either[ExecutionError, A] =
partial.applyOrElse(input, (in: InputValue) => Left(inputError(in)))
}

}

trait ArgBuilderDerivation extends CommonArgBuilderDerivation {
Expand Down
32 changes: 24 additions & 8 deletions core/src/main/scala-2/caliban/schema/SchemaDerivation.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package caliban.schema

import caliban.CalibanError.ValidationError
import caliban.Value._
import caliban.introspection.adt._
import caliban.parsing.adt.Directive
import caliban.parsing.adt.{ Directive, Directives }
import caliban.schema.Annotations._
import caliban.schema.Types._
import magnolia1._
Expand Down Expand Up @@ -65,8 +66,8 @@ trait CommonSchemaDerivation[R] {
if (_isValueType) {
if (isScalarValueType(ctx)) makeScalar(getName(ctx), getDescription(ctx))
else ctx.parameters.head.typeclass.toType_(isInput, isSubscription)
} else if (isInput)
makeInputObject(
} else if (isInput) {
lazy val tpe: __Type = makeInputObject(
Some(ctx.annotations.collectFirst { case GQLInputName(suffix) => suffix }
.getOrElse(customizeInputTypeName(getName(ctx)))),
getDescription(ctx),
Expand All @@ -81,14 +82,16 @@ trait CommonSchemaDerivation[R] {
p.annotations.collectFirst { case GQLDefault(v) => v },
p.annotations.collectFirst { case GQLDeprecated(_) => () }.isDefined,
p.annotations.collectFirst { case GQLDeprecated(reason) => reason },
Some(p.annotations.collect { case GQLDirective(dir) => dir }.toList).filter(_.nonEmpty)
Some(p.annotations.collect { case GQLDirective(dir) => dir }.toList).filter(_.nonEmpty),
() => Some(tpe)
)
)
.toList,
Some(ctx.typeName.full),
Some(getDirectives(ctx))
)
else
tpe
} else
makeObject(
Some(getName(ctx)),
getDescription(ctx),
Expand Down Expand Up @@ -172,11 +175,13 @@ trait CommonSchemaDerivation[R] {
case _ => false
}

if (isEnum && subtypes.nonEmpty && !isInterface && !isUnion)
val isOneOfInput = ctx.annotations.contains(GQLOneOfInput())

if (isEnum && subtypes.nonEmpty && !isInterface && !isUnion && !isOneOfInput) {
makeEnum(
Some(getName(ctx)),
getDescription(ctx),
subtypes.collect { case (__Type(_, Some(name), description, _, _, _, _, _, _, _, _, _), annotations) =>
subtypes.collect { case (__Type(_, Some(name), description, _, _, _, _, _, _, _, _, _, _), annotations) =>
__EnumValue(
name,
description,
Expand All @@ -188,7 +193,18 @@ trait CommonSchemaDerivation[R] {
Some(ctx.typeName.full),
Some(getDirectives(ctx.annotations))
)
else if (!isInterface) {
} else if (isOneOfInput && isInput) {
makeInputObject(
Some(ctx.annotations.collectFirst { case GQLInputName(suffix) => suffix }
.getOrElse(customizeInputTypeName(getName(ctx)))),
getDescription(ctx),
ctx.subtypes.toList.flatMap { p =>
p.typeclass.toType_(isInput = true).allInputFields.map(_.nullable)
},
Some(ctx.typeName.full),
Some(List(Directive(Directives.OneOf)))
)
} else if (!isInterface) {
containsEmptyUnionObjects = emptyUnionObjectIdxs.contains(true)
makeUnion(
Some(getName(ctx)),
Expand Down
93 changes: 70 additions & 23 deletions core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package caliban.schema

import caliban.CalibanError.ExecutionError
import caliban.Value.*
import caliban.schema.Annotations.{ GQLDefault, GQLName }
import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput }
import caliban.schema.macros.Macros
import caliban.{ CalibanError, InputValue }
import magnolia1.Macro as MagnoliaMacro
Expand Down Expand Up @@ -53,10 +53,16 @@ trait CommonArgBuilderDerivation {
inline def derived[A]: ArgBuilder[A] =
inline summonInline[Mirror.Of[A]] match {
case m: Mirror.SumOf[A] =>
makeSumArgBuilder[A](
recurseSum[A, m.MirroredElemLabels, m.MirroredElemTypes](),
constValue[m.MirroredLabel]
)
inline if (Macros.hasAnnotation[A, GQLOneOfInput]) {
makeOneOfBuilder[A](
recurseSum[A, m.MirroredElemLabels, m.MirroredElemTypes](),
constValue[m.MirroredLabel]
)
} else
makeSumArgBuilder[A](
recurseSum[A, m.MirroredElemLabels, m.MirroredElemTypes](),
constValue[m.MirroredLabel]
)

case m: Mirror.ProductOf[A] =>
makeProductArgBuilder(
Expand All @@ -68,11 +74,11 @@ trait CommonArgBuilderDerivation {
private def makeSumArgBuilder[A](
_subTypes: => List[(String, List[Any], ArgBuilder[Any])],
traitLabel: String
) = new ArgBuilder[A] {
): ArgBuilder[A] = new ArgBuilder[A] {
private lazy val subTypes = _subTypes
private val emptyInput = InputValue.ObjectValue(Map.empty)

def build(input: InputValue): Either[ExecutionError, A] =
final def build(input: InputValue): Either[ExecutionError, A] =
input.match {
case EnumValue(value) => Right(value)
case StringValue(value) => Right(value)
Expand All @@ -91,28 +97,69 @@ trait CommonArgBuilderDerivation {
}
}

private def makeOneOfBuilder[A](
_subTypes: => List[(String, List[Any], ArgBuilder[Any])],
traitLabel: String
): ArgBuilder[A] = new ArgBuilder[A] {

override val partial: PartialFunction[InputValue, Either[ExecutionError, A]] = {
val xs = _subTypes.map(_._3).asInstanceOf[List[ArgBuilder[A]]]

val checkSize: PartialFunction[InputValue, Either[ExecutionError, A]] = {
case InputValue.ObjectValue(f) if f.size != 1 =>
Left(ExecutionError("Exactly one key must be specified for oneOf inputs"))
}
xs.foldLeft(checkSize)(_ orElse _.partial)
}

def build(input: InputValue): Either[ExecutionError, A] =
partial.applyOrElse(input, (in: InputValue) => Left(inputError(in)))

private def inputError(input: InputValue) =
ExecutionError(s"Invalid oneOf input $input for trait $traitLabel")
}

private def makeProductArgBuilder[A](
_fields: => List[(String, ArgBuilder[Any])],
annotations: Map[String, List[Any]]
)(fromProduct: Product => A) = new ArgBuilder[A] {
private lazy val fields = _fields
)(fromProduct: Product => A): ArgBuilder[A] = new ArgBuilder[A] {

private val params = Array.from(_fields.map { (label, builder) =>
val labelList = annotations.get(label)
val default = builder.buildMissing(labelList.flatMap(_.collectFirst { case GQLDefault(v) => v }))
val finalLabel = labelList.flatMap(_.collectFirst { case GQLName(name) => name }).getOrElse(label)
(finalLabel, default, builder)
})

private val required = params.collect { case (label, default, _) if default.isLeft => label }

override private[schema] val partial: PartialFunction[InputValue, Either[ExecutionError, A]] = {
case InputValue.ObjectValue(fields) if required.forall(fields.contains) => fromFields(fields)
}

def build(input: InputValue): Either[ExecutionError, A] =
fields.view.map { (label, builder) =>
input match {
case InputValue.ObjectValue(fields) =>
val labelList = annotations.get(label)
def default = labelList.flatMap(_.collectFirst { case GQLDefault(v) => v })
val finalLabel = labelList.flatMap(_.collectFirst { case GQLName(name) => name }).getOrElse(label)
fields.get(finalLabel).fold(builder.buildMissing(default))(builder.build)
case value => builder.build(value)
}
}.foldLeft[Either[ExecutionError, Tuple]](Right(EmptyTuple)) { case (acc, item) =>
item match {
case Right(value) => acc.map(_ :* value)
case Left(e) => Left(e)
input match {
case InputValue.ObjectValue(fields) => fromFields(fields)
case _ => Left(ExecutionError("expected an input object"))
}

private def fromFields(fields: Map[String, InputValue]): Either[ExecutionError, A] = {
var i = 0
val l = params.length
var acc: Tuple = EmptyTuple
while (i < l) {
val (label, default, builder) = params(i)
val field = fields.getOrElse(label, null)
val value = if (field ne null) builder.build(field) else default
value match {
case Right(v) => acc :*= v
case e @ Left(_) => return e.asInstanceOf[Either[ExecutionError, A]]
}
}.map(fromProduct)
i += 1
}
Right(fromProduct(acc))
}

}
}

Expand Down
Loading

0 comments on commit b56673b

Please sign in to comment.