Skip to content

Commit

Permalink
Merge pull request #1034 from bc-dima-pasieka/limit-errors-july-2023
Browse files Browse the repository at this point in the history
Limit number of errors returned (when variables are used)
  • Loading branch information
yanns authored Sep 7, 2023
2 parents 98e5895 + 9e60707 commit 80d7e09
Show file tree
Hide file tree
Showing 15 changed files with 295 additions and 169 deletions.
21 changes: 20 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,27 @@ lazy val core = project
description := "Scala GraphQL implementation",
mimaPreviousArtifacts := Set("org.sangria-graphql" %% "sangria-core" % "4.0.0"),
mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.apply"),
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.copy"),
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.execute"),
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.prepare"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.validation.RuleBasedQueryValidator.this"),
"sangria.execution.QueryReducerExecutor.reduceQueryWithoutVariables"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.execution.ValueCoercionHelper.isValidValue"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.execution.ValueCoercionHelper.getVariableValue"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.execution.batch.BatchExecutor.executeBatch"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.schema.ResolverBasedAstSchemaBuilder.validateSchema"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.validation.QueryValidator.validateQuery"),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"sangria.validation.QueryValidator.validateQuery"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.validation.RuleBasedQueryValidator.validateQuery"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.validation.ValidationContext.this")
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import sangria.validation.{QueryValidator, RuleBasedQueryValidator, Violation}
@State(Scope.Thread)
class OverlappingFieldsCanBeMergedBenchmark {

val validator: QueryValidator = RuleBasedQueryValidator(
val validator: QueryValidator = new RuleBasedQueryValidator(
List(new rules.OverlappingFieldsCanBeMerged))

val schema: Schema[_, _] =
Expand Down Expand Up @@ -98,7 +98,7 @@ class OverlappingFieldsCanBeMergedBenchmark {
bh.consume(doValidate(validator, deepAbstractConcrete))

private def doValidate(validator: QueryValidator, document: Document): Vector[Violation] = {
val result = validator.validateQuery(schema, document)
val result = validator.validateQuery(schema, document, None)
require(result.isEmpty)
result
}
Expand Down
27 changes: 18 additions & 9 deletions modules/core/src/main/scala/sangria/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ case class Executor[Ctx, Root](
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
)(implicit executionContext: ExecutionContext) {
def prepare[Input](
queryAst: ast.Document,
Expand All @@ -29,7 +30,7 @@ case class Executor[Ctx, Root](
variables: Input = emptyMapVars
)(implicit um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] = {
val (violations, validationTiming) =
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst))
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit))

if (violations.nonEmpty)
Future.failed(ValidationError(violations, exceptionHandler))
Expand All @@ -49,7 +50,9 @@ case class Executor[Ctx, Root](
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
unmarshalledVariables <- valueCollector.getVariableValues(
operation.variables,
scalarMiddleware)
scalarMiddleware,
errorsLimit
)
fieldCollector = new FieldCollector[Ctx, Root](
schema,
queryAst,
Expand Down Expand Up @@ -141,7 +144,7 @@ case class Executor[Ctx, Root](
um: InputUnmarshaller[Input],
scheme: ExecutionScheme): scheme.Result[Ctx, marshaller.Node] = {
val (violations, validationTiming) =
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst))
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit))

if (violations.nonEmpty)
scheme.failed(ValidationError(violations, exceptionHandler))
Expand All @@ -161,7 +164,9 @@ case class Executor[Ctx, Root](
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
unmarshalledVariables <- valueCollector.getVariableValues(
operation.variables,
scalarMiddleware)
scalarMiddleware,
errorsLimit
)
fieldCollector = new FieldCollector[Ctx, Root](
schema,
queryAst,
Expand Down Expand Up @@ -324,7 +329,8 @@ object Executor {
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
)(implicit
executionContext: ExecutionContext,
marshaller: ResultMarshaller,
Expand All @@ -338,7 +344,8 @@ object Executor {
deprecationTracker,
middleware,
maxQueryDepth,
queryReducers)
queryReducers,
errorsLimit)
.execute(queryAst, userContext, root, operationName, variables)

def prepare[Ctx, Root, Input](
Expand All @@ -354,7 +361,8 @@ object Executor {
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
)(implicit
executionContext: ExecutionContext,
um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] =
Expand All @@ -366,7 +374,8 @@ object Executor {
deprecationTracker,
middleware,
maxQueryDepth,
queryReducers)
queryReducers,
errorsLimit)
.prepare(queryAst, userContext, root, operationName, variables)

def getOperationRootType[Ctx, Root](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ object QueryReducerExecutor {
queryValidator: QueryValidator = QueryValidator.default,
exceptionHandler: ExceptionHandler = ExceptionHandler.empty,
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil
middleware: List[Middleware[Ctx]] = Nil,
errorsLimit: Option[Int] = None
)(implicit executionContext: ExecutionContext): Future[(Ctx, TimeMeasurement)] = {
val violations = queryValidator.validateQuery(schema, queryAst)
val violations = queryValidator.validateQuery(schema, queryAst, errorsLimit)

if (violations.nonEmpty)
Future.failed(ValidationError(violations, exceptionHandler))
Expand Down
189 changes: 105 additions & 84 deletions modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -580,107 +580,128 @@ class ValueCoercionHelper[Ctx](
nodeLocation.toList ++ firstValue.toList
}

def isValidValue[In](tpe: InputType[_], input: Option[In])(implicit
um: InputUnmarshaller[In]): Vector[Violation] = (tpe, input) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(ofType, Some(value))
case (OptionInputType(_), _) => Vector.empty
case (_, None) => Vector(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isListNode(values) =>
um.getListValue(values)
.toVector
.flatMap(v =>
isValidValue(
ofType,
v match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil)))

case (ListInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(
ofType,
value match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) =>
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f) =>
UnknownInputObjectFieldViolation(
SchemaRenderer.renderTypeName(objTpe, true),
f,
sourceMapper,
Nil)
}
private def isValidValue[In](
inputType: InputType[_],
input: Option[In],
errorsLimit: Option[Int])(implicit um: InputUnmarshaller[In]): Vector[Violation] = {

val fieldViolations =
objTpe.fields.toVector.flatMap(f =>
isValidValue(f.fieldType, um.getMapValue(valueMap, f.name))
.map(MapValueViolation(f.name, _, sourceMapper, Nil)))
// keeping track of the number of errors
var errors = 0
def addViolation(violation: Violation): Vector[Violation] = {
errors += 1
Vector(violation)
}

fieldViolations ++ unknownFields
def isValidValueRec(tpe: InputType[_], in: Option[In])(implicit
um: InputUnmarshaller[In]): Vector[Violation] =
// early termination if errors limit is defined and the current number of violations exceeds the limit
if (errorsLimit.exists(_ <= errors)) Vector.empty
else
(tpe, in) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValueRec(ofType, Some(value))
case (OptionInputType(_), _) => Vector.empty
case (_, None) => addViolation(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isListNode(values) =>
um.getListValue(values)
.toVector
.flatMap(v =>
isValidValueRec(
ofType,
v match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil)))

case (ListInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValueRec(
ofType,
value match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) =>
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f) =>
addViolation(
UnknownInputObjectFieldViolation(
SchemaRenderer.renderTypeName(objTpe, true),
f,
sourceMapper,
Nil)).head
}

case (objTpe: InputObjectType[_], _) =>
Vector(
InputObjectIsOfWrongTypeMissingViolation(
SchemaRenderer.renderTypeName(objTpe, true),
sourceMapper,
Nil))
val fieldViolations =
objTpe.fields.toVector.flatMap(f =>
isValidValueRec(f.fieldType, um.getMapValue(valueMap, f.name))
.map(MapValueViolation(f.name, _, sourceMapper, Nil)))

case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.coerceInput(node)
case other => scalar.coerceUserInput(other)
}
fieldViolations ++ unknownFields

coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
case (objTpe: InputObjectType[_], _) =>
addViolation(
InputObjectIsOfWrongTypeMissingViolation(
SchemaRenderer.renderTypeName(objTpe, true),
sourceMapper,
Nil))

case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.aliasFor.coerceInput(node)
case other => scalar.aliasFor.coerceUserInput(other)
}
case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.coerceInput(node)
case other => scalar.coerceUserInput(other)
}

coerced match {
case Left(violation) => Vector(violation)
case Right(v) =>
scalar.fromScalar(v) match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
}
coerced match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}

case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => enumT.coerceInput(node)
case other => enumT.coerceUserInput(other)
}
case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.aliasFor.coerceInput(node)
case other => scalar.aliasFor.coerceUserInput(other)
}

coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
coerced match {
case Left(violation) => addViolation(violation)
case Right(v) =>
scalar.fromScalar(v) match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}
}

case (enumT: EnumType[_], Some(value)) =>
Vector(EnumCoercionViolation)
case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => enumT.coerceInput(node)
case other => enumT.coerceUserInput(other)
}

coerced match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}

case (enumT: EnumType[_], Some(value)) =>
addViolation(EnumCoercionViolation)

case _ =>
Vector(GenericInvalidValueViolation(sourceMapper, Nil))
case _ =>
addViolation(GenericInvalidValueViolation(sourceMapper, Nil))
}

isValidValueRec(inputType, input)
}

def getVariableValue[In](
definition: ast.VariableDefinition,
tpe: InputType[_],
input: Option[In],
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]])(implicit
um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = {
val violations = isValidValue(tpe, input)
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]],
errorsLimit: Option[Int]
)(implicit um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = {
val violations = isValidValue(tpe, input, errorsLimit)

if (violations.isEmpty) {
val fieldPath = s"$$${definition.name}" :: Nil
Expand Down
Loading

0 comments on commit 80d7e09

Please sign in to comment.