Skip to content

Commit

Permalink
Better union exclusion (#455)
Browse files Browse the repository at this point in the history
* Better union exclusion

* Fix tests

* Ignore applied directives

New behavior introduced in graphql-java/graphql-java#2562

* Fix counter

* Add tests
  • Loading branch information
gnawf authored Jun 21, 2023
1 parent 5246e9f commit b115c98
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 14 deletions.
13 changes: 13 additions & 0 deletions lib/src/main/java/graphql/nadel/engine/util/CollectionUtil.kt
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,16 @@ fun <T> sequenceOfNulls(size: Int): Sequence<T?> {
}
}
}

/**
* Similar to [Sequence.all] but it requires at least [min] matching elements to pass.
*/
fun <T> Sequence<T>.all(min: Int, predicate: (T) -> Boolean): Boolean {
var count = 0
for (element in this) {
if (!predicate(element)) return false
count++
}

return count >= min
}
38 changes: 27 additions & 11 deletions lib/src/main/java/graphql/nadel/validation/NadelTypeValidation.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package graphql.nadel.validation

import graphql.Scalars.GraphQLID
import graphql.Scalars.GraphQLString
import graphql.language.ObjectTypeDefinition
import graphql.language.UnionTypeDefinition
import graphql.nadel.Service
import graphql.nadel.engine.util.AnyNamedNode
import graphql.nadel.engine.util.all
import graphql.nadel.engine.util.isExtensionDef
import graphql.nadel.engine.util.isList
import graphql.nadel.engine.util.isNonNull
Expand All @@ -14,6 +15,7 @@ import graphql.nadel.engine.util.operationTypes
import graphql.nadel.engine.util.unwrapAll
import graphql.nadel.engine.util.unwrapNonNull
import graphql.nadel.engine.util.unwrapOne
import graphql.nadel.schema.NadelDirectives
import graphql.nadel.schema.NadelDirectives.hydratedDirectiveDefinition
import graphql.nadel.validation.NadelSchemaValidationError.DuplicatedUnderlyingType
import graphql.nadel.validation.NadelSchemaValidationError.IncompatibleFieldOutputType
Expand All @@ -24,6 +26,7 @@ import graphql.nadel.validation.util.NadelSchemaUtil.getUnderlyingName
import graphql.nadel.validation.util.NadelSchemaUtil.getUnderlyingType
import graphql.nadel.validation.util.getReachableTypeNames
import graphql.schema.GraphQLFieldDefinition
import graphql.schema.GraphQLFieldsContainer
import graphql.schema.GraphQLImplementingType
import graphql.schema.GraphQLInterfaceType
import graphql.schema.GraphQLNamedOutputType
Expand Down Expand Up @@ -185,8 +188,8 @@ internal class NadelTypeValidation(
service: Service,
): Pair<List<NadelServiceSchemaElement>, List<NadelSchemaValidationError>> {
val errors = mutableListOf<NadelSchemaValidationError>()
val polymorphicHydrationUnions = getPolymorphicHydrationUnions(service)
val namesUsed = getTypeNamesUsed(service, externalTypes = polymorphicHydrationUnions)
val hydrationUnions = getHydrationUnions(service)
val namesUsed = getTypeNamesUsed(service, externalTypes = hydrationUnions)

fun addMissingUnderlyingTypeError(overallType: GraphQLNamedType) {
errors.add(MissingUnderlyingType(service, overallType))
Expand Down Expand Up @@ -229,16 +232,30 @@ internal class NadelTypeValidation(
} to errors
}

private fun getPolymorphicHydrationUnions(service: Service): Set<GraphQLUnionType> {
private fun getHydrationUnions(service: Service): Set<GraphQLUnionType> {
return service.definitionRegistry
.definitions
.asSequence()
.filterIsInstance<ObjectTypeDefinition>()
.flatMap { it.fieldDefinitions }
.filter { it.getDirectives(hydratedDirectiveDefinition.name).size > 1 }
.map { it.type.unwrapAll() }
.map { overallSchema.getType(it.name) }
.filterIsInstance<GraphQLUnionType>()
.filterIsInstance<UnionTypeDefinition>()
.filter { union ->
// Check that ALL fields that output the union are annotated with @hydrated
overallSchema.typeMap
.values
.asSequence()
.filterIsInstance<GraphQLFieldsContainer>()
.flatMap {
it.fieldDefinitions
}
.filter {
it.type.unwrapAll().name == union.name
}
.all(min = 1) {
it.hasAppliedDirective(hydratedDirectiveDefinition.name)
}
}
.map {
overallSchema.typeMap[it.name] as GraphQLUnionType
}
.toSet()
}

Expand Down Expand Up @@ -267,7 +284,6 @@ internal class NadelTypeValidation(
}
.toSet() - namesToIgnore


// If it can be reached by using your service, you must own it to return it!
val referencedTypes = getReachableTypeNames(overallSchema, service, definitionNames)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ import graphql.nadel.engine.util.unwrapAll
import graphql.nadel.validation.util.NadelCombinedTypeUtil.getFieldsThatServiceContributed
import graphql.nadel.validation.util.NadelCombinedTypeUtil.isCombinedType
import graphql.nadel.validation.util.NadelSchemaUtil.hasHydration
import graphql.schema.GraphQLAppliedDirective
import graphql.schema.GraphQLAppliedDirectiveArgument
import graphql.schema.GraphQLArgument
import graphql.schema.GraphQLCompositeType
import graphql.schema.GraphQLDirective
import graphql.schema.GraphQLDirectiveContainer
import graphql.schema.GraphQLEnumType
import graphql.schema.GraphQLEnumValueDefinition
import graphql.schema.GraphQLFieldDefinition
import graphql.schema.GraphQLFieldsContainer
import graphql.schema.GraphQLInputFieldsContainer
Expand Down Expand Up @@ -223,6 +226,15 @@ internal fun getReachableTypeNames(
add(node.name)
return CONTINUE
}

override fun visitGraphQLAppliedDirective(
node: GraphQLAppliedDirective,
context: TraverserContext<GraphQLSchemaElement>,
): TraversalControl {
// Don't look into applied directives. Could be a shared directive.
// As long as the schema compiled then we don't care.
return ABORT
}
}

SchemaTraverser { element ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ import graphql.nadel.validation.NadelSchemaValidationError.MissingUnderlyingInpu
import graphql.nadel.validation.NadelSchemaValidationError.MissingUnderlyingType
import graphql.nadel.validation.util.assertSingleOfType
import io.kotest.core.spec.style.DescribeSpec
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
import kotlin.time.ExperimentalTime

@OptIn(ExperimentalTime::class)
class NadelTypeValidationTest : DescribeSpec({
describe("validate") {
it("passes if types are valid") {
Expand Down Expand Up @@ -171,6 +168,89 @@ class NadelTypeValidationTest : DescribeSpec({
assert(errors.map { it.message }.isNotEmpty())
}

it("allows synthetic union if exclusively used for @hydrated fields") {
val fixture = NadelValidationTestFixture(
overallSchema = mapOf(
"test" to """
type Query {
echo: Echo
}
type Echo {
world: World
}
type World {
hello: Something @hydrated(
service: "test"
field: "echo.world"
arguments: []
)
}
union Something = Echo | World
""".trimIndent(),
),
underlyingSchema = mapOf(
"test" to """
type Query {
echo: Echo
}
type Echo {
world: World
}
type World {
hello: String
}
""".trimIndent(),
),
)

val errors = validate(fixture)
assert(errors.map { it.message }.isEmpty())
}

it("prohibits synthetic union if not exclusively used for @hydrated fields") {
val fixture = NadelValidationTestFixture(
overallSchema = mapOf(
"test" to """
type Query {
echo: Echo
}
type Echo {
world: World
test: Something
}
type World {
hello: Something @hydrated(
service: "test"
field: "echo.world"
arguments: []
)
}
union Something = Echo | World
""".trimIndent(),
),
underlyingSchema = mapOf(
"test" to """
type Query {
echo: Echo
}
type Echo {
world: World
}
type World {
hello: String
}
""".trimIndent(),
),
)

val errors = validate(fixture)
assert(errors.map { it.message }.isNotEmpty())

val error = errors.assertSingleOfType<MissingUnderlyingType>()
assert(error.service.name == "test")
assert(error.overallType.name == "Something")
}

it("tracks visited types to avoid stack overflow").config(timeout = 1.seconds) {
val fixture = NadelValidationTestFixture(
overallSchema = mapOf(
Expand Down

0 comments on commit b115c98

Please sign in to comment.