Skip to content

Commit

Permalink
Limit number of errors returned (when variables are used)
Browse files Browse the repository at this point in the history
  • Loading branch information
bc-dima-pasieka committed Jul 17, 2023
1 parent e54bb95 commit 81914d3
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 123 deletions.
27 changes: 18 additions & 9 deletions modules/core/src/main/scala/sangria/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ case class Executor[Ctx, Root](
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
)(implicit executionContext: ExecutionContext) {
def prepare[Input](
queryAst: ast.Document,
Expand All @@ -29,7 +30,7 @@ case class Executor[Ctx, Root](
variables: Input = emptyMapVars
)(implicit um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] = {
val (violations, validationTiming) =
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst))
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit))

if (violations.nonEmpty)
Future.failed(ValidationError(violations, exceptionHandler))
Expand All @@ -49,7 +50,9 @@ case class Executor[Ctx, Root](
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
unmarshalledVariables <- valueCollector.getVariableValues(
operation.variables,
scalarMiddleware)
scalarMiddleware,
errorsLimit
)
fieldCollector = new FieldCollector[Ctx, Root](
schema,
queryAst,
Expand Down Expand Up @@ -141,7 +144,7 @@ case class Executor[Ctx, Root](
um: InputUnmarshaller[Input],
scheme: ExecutionScheme): scheme.Result[Ctx, marshaller.Node] = {
val (violations, validationTiming) =
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst))
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit))

if (violations.nonEmpty)
scheme.failed(ValidationError(violations, exceptionHandler))
Expand All @@ -161,7 +164,9 @@ case class Executor[Ctx, Root](
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
unmarshalledVariables <- valueCollector.getVariableValues(
operation.variables,
scalarMiddleware)
scalarMiddleware,
errorsLimit
)
fieldCollector = new FieldCollector[Ctx, Root](
schema,
queryAst,
Expand Down Expand Up @@ -324,7 +329,8 @@ object Executor {
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
)(implicit
executionContext: ExecutionContext,
marshaller: ResultMarshaller,
Expand All @@ -338,7 +344,8 @@ object Executor {
deprecationTracker,
middleware,
maxQueryDepth,
queryReducers)
queryReducers,
errorsLimit)
.execute(queryAst, userContext, root, operationName, variables)

def prepare[Ctx, Root, Input](
Expand All @@ -354,7 +361,8 @@ object Executor {
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
)(implicit
executionContext: ExecutionContext,
um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] =
Expand All @@ -366,7 +374,8 @@ object Executor {
deprecationTracker,
middleware,
maxQueryDepth,
queryReducers)
queryReducers,
errorsLimit)
.prepare(queryAst, userContext, root, operationName, variables)

def getOperationRootType[Ctx, Root](
Expand Down
187 changes: 103 additions & 84 deletions modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -580,107 +580,126 @@ class ValueCoercionHelper[Ctx](
nodeLocation.toList ++ firstValue.toList
}

def isValidValue[In](tpe: InputType[_], input: Option[In])(implicit
um: InputUnmarshaller[In]): Vector[Violation] = (tpe, input) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(ofType, Some(value))
case (OptionInputType(_), _) => Vector.empty
case (_, None) => Vector(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isListNode(values) =>
um.getListValue(values)
.toVector
.flatMap(v =>
isValidValue(
ofType,
v match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil)))

case (ListInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(
ofType,
value match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) =>
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f) =>
UnknownInputObjectFieldViolation(
SchemaRenderer.renderTypeName(objTpe, true),
f,
sourceMapper,
Nil)
}
def isValidValue[In](inputType: InputType[_], input: Option[In], errorsLimit: Option[Int])(
implicit um: InputUnmarshaller[In]): Vector[Violation] = {

// keeping track of the number of errors
var errors = 0
def addViolation(violation: Violation): Vector[Violation] = {
errors += 1
Vector(violation)
}

val fieldViolations =
objTpe.fields.toVector.flatMap(f =>
isValidValue(f.fieldType, um.getMapValue(valueMap, f.name))
.map(MapValueViolation(f.name, _, sourceMapper, Nil)))
def isValidValueRec(tpe: InputType[_], in: Option[In])(implicit
um: InputUnmarshaller[In]): Vector[Violation] =
// early termination if errors limit is defined and the current number of violations exceeds the limit
if (errorsLimit.exists(_ <= errors)) Vector.empty
else
(tpe, in) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValueRec(ofType, Some(value))
case (OptionInputType(_), _) => Vector.empty
case (_, None) => addViolation(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isListNode(values) =>
um.getListValue(values)
.toVector
.flatMap(v =>
isValidValueRec(
ofType,
v match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil)))

case (ListInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValueRec(
ofType,
value match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) =>
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f) =>
addViolation(
UnknownInputObjectFieldViolation(
SchemaRenderer.renderTypeName(objTpe, true),
f,
sourceMapper,
Nil)).head
}

fieldViolations ++ unknownFields
val fieldViolations =
objTpe.fields.toVector.flatMap(f =>
isValidValueRec(f.fieldType, um.getMapValue(valueMap, f.name))
.map(MapValueViolation(f.name, _, sourceMapper, Nil)))

case (objTpe: InputObjectType[_], _) =>
Vector(
InputObjectIsOfWrongTypeMissingViolation(
SchemaRenderer.renderTypeName(objTpe, true),
sourceMapper,
Nil))
fieldViolations ++ unknownFields

case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.coerceInput(node)
case other => scalar.coerceUserInput(other)
}
case (objTpe: InputObjectType[_], _) =>
addViolation(
InputObjectIsOfWrongTypeMissingViolation(
SchemaRenderer.renderTypeName(objTpe, true),
sourceMapper,
Nil))

coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.coerceInput(node)
case other => scalar.coerceUserInput(other)
}

case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.aliasFor.coerceInput(node)
case other => scalar.aliasFor.coerceUserInput(other)
}
coerced match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}

coerced match {
case Left(violation) => Vector(violation)
case Right(v) =>
scalar.fromScalar(v) match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
}
case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.aliasFor.coerceInput(node)
case other => scalar.aliasFor.coerceUserInput(other)
}

case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => enumT.coerceInput(node)
case other => enumT.coerceUserInput(other)
}
coerced match {
case Left(violation) => addViolation(violation)
case Right(v) =>
scalar.fromScalar(v) match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}
}

coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => enumT.coerceInput(node)
case other => enumT.coerceUserInput(other)
}

coerced match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}

case (enumT: EnumType[_], Some(value)) =>
Vector(EnumCoercionViolation)
case (enumT: EnumType[_], Some(value)) =>
addViolation(EnumCoercionViolation)

case _ =>
Vector(GenericInvalidValueViolation(sourceMapper, Nil))
case _ =>
addViolation(GenericInvalidValueViolation(sourceMapper, Nil))
}

isValidValueRec(inputType, input)
}

def getVariableValue[In](
definition: ast.VariableDefinition,
tpe: InputType[_],
input: Option[In],
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]])(implicit
um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = {
val violations = isValidValue(tpe, input)
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]],
errorsLimit: Option[Int]
)(implicit um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = {
val violations = isValidValue(tpe, input, errorsLimit)

if (violations.isEmpty) {
val fieldPath = s"$$${definition.name}" :: Nil
Expand Down
68 changes: 46 additions & 22 deletions modules/core/src/main/scala/sangria/execution/ValueCollector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,66 @@ class ValueCollector[Ctx, Input](

def getVariableValues(
definitions: Vector[ast.VariableDefinition],
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]])
: Try[Map[String, VariableValue]] =
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]]
): Try[Map[String, VariableValue]] = getVariableValues(definitions, fromScalarMiddleware, None)

def getVariableValues(
definitions: Vector[ast.VariableDefinition],
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]],
errorsLimit: Option[Int]
): Try[Map[String, VariableValue]] =
if (!um.isMapNode(inputVars))
Failure(
new ExecutionError(
s"Variables should be a map-like object, like JSON object. Got: ${um.render(inputVars)}",
exceptionHandler))
else {
val res =
definitions.foldLeft(Vector.empty[(String, Either[Vector[Violation], VariableValue])]) {
definitions.foldLeft(
(0, Vector.empty[(String, Either[Vector[Violation], VariableValue])])) {
case (acc, varDef) =>
val value = schema
.getInputType(varDef.tpe)
.map(
coercionHelper.getVariableValue(
val (accErrors, accResult) = acc

// early termination if errors limit is defined and the current number of violations exceeds the limit
if (errorsLimit.exists(_ <= accErrors)) acc
else {
val value = schema
.getInputType(varDef.tpe)
.map(coercionHelper.getVariableValue(
varDef,
_,
um.getRootMapValue(inputVars, varDef.name),
fromScalarMiddleware))
.getOrElse(
Left(
Vector(
UnknownVariableTypeViolation(
varDef.name,
QueryRenderer.render(varDef.tpe),
sourceMapper,
varDef.location.toList))))

value match {
case Right(Some(v)) => acc :+ (varDef.name -> Right(v))
case Right(None) => acc
case Left(violations) => acc :+ (varDef.name -> Left(violations))
fromScalarMiddleware,
// calculate the allowed number of errors to be returned (if any)
errorsLimit.map(_ - accErrors)
))
.getOrElse(
Left(
Vector(
UnknownVariableTypeViolation(
varDef.name,
QueryRenderer.render(varDef.tpe),
sourceMapper,
varDef.location.toList))))

value match {
case Right(Some(v)) => (accErrors, accResult :+ (varDef.name -> Right(v)))
case Right(None) => acc
case Left(violations) =>
// number of errors that is allowed to use (all if errors limit is not defined)
val errorsLeftToUse = errorsLimit.fold(violations.length) { limit =>
Math.min(violations.length, limit - accErrors)
}

(
accErrors + errorsLeftToUse,
accResult :+ (varDef.name -> Left(violations.take(errorsLeftToUse)))
)
}
}
}

val (errors, values) = res.partition(_._2.isLeft)
val (errors, values) = res._2.partition(_._2.isLeft)

if (errors.nonEmpty)
Failure(
Expand Down
Loading

0 comments on commit 81914d3

Please sign in to comment.