Skip to content

Commit

Permalink
Defer payloads in groups (#562)
Browse files Browse the repository at this point in the history
* Add POC to merge defer Payloads into correct groups

* Emit accumulated results when completed

* Refactor

* Update tests

* Add extra test

* Update comments

* Idk

* Fix tests

* Add accumulator tests

* Add doco
  • Loading branch information
gnawf authored Aug 8, 2024
1 parent 576327b commit 79b230c
Show file tree
Hide file tree
Showing 13 changed files with 2,109 additions and 76 deletions.
10 changes: 5 additions & 5 deletions lib/src/main/java/graphql/nadel/NextgenEngine.kt
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ internal class NextgenEngine(
val operationParseOptions = baseParseOptions
.deferSupport(executionHints.deferSupport.invoke())

val query = timer.time(step = RootStep.ExecutableOperationParsing) {
val operation = timer.time(step = RootStep.ExecutableOperationParsing) {
createExecutableNormalizedOperationWithRawVariables(
querySchema,
queryDocument,
Expand All @@ -161,11 +161,11 @@ internal class NextgenEngine(
)
}

val incrementalResultSupport = NadelIncrementalResultSupport()
val incrementalResultSupport = NadelIncrementalResultSupport(operation)
val resultTracker = NadelResultTracker()
val executionContext = NadelExecutionContext(
executionInput,
query,
operation,
executionHooks,
executionHints,
instrumentationState,
Expand All @@ -175,15 +175,15 @@ internal class NextgenEngine(
)

val beginExecuteContext = instrumentation.beginExecute(
query,
operation,
queryDocument,
executionInput,
engineSchema,
instrumentationState,
)

val result: ExecutionResult = try {
val fields = fieldToService.getServicesForTopLevelFields(query, executionHints)
val fields = fieldToService.getServicesForTopLevelFields(operation, executionHints)
val results = coroutineScope {
fields
.map { (field, service) ->
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package graphql.nadel.engine

import graphql.GraphQLError
import graphql.incremental.DeferPayload
import graphql.incremental.DelayedIncrementalPartialResult
import graphql.incremental.DelayedIncrementalPartialResultImpl
import graphql.nadel.engine.transform.query.NadelQueryPath
import graphql.nadel.engine.util.MutableJsonMap
import graphql.normalized.ExecutableNormalizedField
import graphql.normalized.ExecutableNormalizedOperation
import graphql.normalized.incremental.NormalizedDeferredExecution

/**
* This class helps to return defer payloads in the correct groupings.
* This can become an issue if part of the defer payload is executed by Nadel, and
* another part is executed by an underlying service.
*
* e.g.
*
* ```
* query {
* me {
* ... @defer {
* name # Executed by underlying service
* manager { # Executed by Nadel hydration
* name
* }
* }
* }
* }
* ```
*
* In this case we will receive two [DeferPayload]s, one from the underlying service
* and one from Nadel itself.
*
* Thing is, because of how the `@defer` was applied, these results need to be returned together.
*
* This class accumulates the multiple [DeferPayload]s until they are complete i.e. in the above
* example `name` and `manager` are both present. Only then is the [DeferPayload] sent back to
* the user.
*/
class NadelIncrementalResultAccumulator(
private val operation: ExecutableNormalizedOperation,
) {
data class DeferAccumulatorKey(
val incrementalPayloadPath: List<Any>,
val deferExecution: NormalizedDeferredExecution,
)

data class DeferAccumulator(
val data: MutableJsonMap,
val errors: MutableList<GraphQLError>,
)

private val deferAccumulators = mutableMapOf<DeferAccumulatorKey, DeferAccumulator>()

private val queryPathToExecutions: Map<NadelQueryPath, List<NormalizedDeferredExecution>> = operation.walkTopDown()
.filter {
it.deferredExecutions.isNotEmpty()
}
.groupBy(
keySelector = {
NadelQueryPath(it.parent?.listOfResultKeys ?: emptyList())
},
valueTransform = {
it.deferredExecutions
},
)
.mapValues { (_, values) ->
values.flatten()
}

/**
* todo: this doesn't account for type conditions
*/
private val deferExecutionToFields: Map<NormalizedDeferredExecution, List<ExecutableNormalizedField>> =
operation.walkTopDown()
.filter {
it.deferredExecutions.isNotEmpty()
}
.flatMap { field ->
field.deferredExecutions
.map { deferExecution ->
deferExecution to field
}
}
.groupBy(
keySelector = { (deferExecution) ->
deferExecution
},
valueTransform = { (_, field) ->
field
},
)
.filterValues {
it.isNotEmpty()
}
.mapValues { (_, fields) ->
val topLevel = fields.minOf {
it.level
}
fields.filter {
it.level == topLevel
}
}

fun getIncrementalPartialResult(hasNext: Boolean): DelayedIncrementalPartialResult? {
val readyAccumulators = deferAccumulators
.filter {
// i.e. complete
it.value.data.size == deferExecutionToFields[it.key.deferExecution]!!.size
}
.onEach {
deferAccumulators.remove(it.key)
}

if (readyAccumulators.isEmpty()) {
return null
}

val payloadsToEmit = readyAccumulators
.map { (key, accumulator) ->
DeferPayload.newDeferredItem()
.data(accumulator.data)
.errors(accumulator.errors)
.path(key.incrementalPayloadPath)
.label(key.deferExecution.label)
.build()
}

// todo: handle extensions
return DelayedIncrementalPartialResultImpl.newIncrementalExecutionResult()
.incrementalItems(payloadsToEmit)
.hasNext(hasNext)
.build()
}

fun accumulate(result: DelayedIncrementalPartialResult) {
result.incremental
?.forEach { payload ->
when (payload) {
is DeferPayload -> {
accumulate(payload)
}
}
}
}

private fun accumulate(payload: DeferPayload) {
val data = payload.getData<Map<String, Any?>?>()
?: return

val queryPath = NadelQueryPath.fromResultPath(payload.path)
val deferredExecutions = queryPathToExecutions[queryPath]
?: return

deferredExecutions
.asSequence()
.filter {
payload.label == it.label
}
.forEachIndexed { index, deferExecution ->
val accumulatorKey = DeferAccumulatorKey(
incrementalPayloadPath = payload.path,
deferExecution = deferExecution,
)

val deferAccumulator = deferAccumulators.computeIfAbsent(accumulatorKey) {
DeferAccumulator(
data = mutableMapOf(),
errors = mutableListOf(),
)
}

deferExecutionToFields[deferExecution]!!.forEach { field ->
if (field.resultKey in data) {
deferAccumulator.data[field.resultKey] = data[field.resultKey]
}
}

// todo: there's no good way to determine which defer execution a payload belongs to
if (index == 0) {
deferAccumulator.errors.addAll(payload.errors ?: emptyList())
}
}
}
}

/**
* Similar to [java.io.File.walkTopDown] but for ENFs.
*/
private fun ExecutableNormalizedOperation.walkTopDown(): Sequence<ExecutableNormalizedField> {
return topLevelFields
.asSequence()
.flatMap {
it.walkTopDown()
}
}

/**
* Similar to [java.io.File.walkTopDown] but for ENFs.
*/
private fun ExecutableNormalizedField.walkTopDown(): Sequence<ExecutableNormalizedField> {
return sequenceOf(this) + children.asSequence()
.flatMap {
it.walkTopDown()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import graphql.incremental.DelayedIncrementalPartialResult
import graphql.nadel.engine.NadelIncrementalResultSupport.OutstandingJobCounter.OutstandingJobHandle
import graphql.nadel.engine.util.copy
import graphql.nadel.util.getLogger
import graphql.normalized.ExecutableNormalizedOperation
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
Expand All @@ -21,19 +22,32 @@ import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger

class NadelIncrementalResultSupport internal constructor(
private val delayedResultsChannel: Channel<DelayedIncrementalPartialResult> = Channel(
capacity = 100,
onBufferOverflow = BufferOverflow.DROP_LATEST,
onUndeliveredElement = {
log.error("Dropping incremental result because of buffer overflow")
},
),
private val accumulator: NadelIncrementalResultAccumulator,
private val delayedResultsChannel: Channel<DelayedIncrementalPartialResult> = makeDefaultChannel(),
) {
internal constructor(
operation: ExecutableNormalizedOperation,
delayedResultsChannel: Channel<DelayedIncrementalPartialResult> = makeDefaultChannel(),
) : this(
accumulator = NadelIncrementalResultAccumulator(
operation = operation,
),
delayedResultsChannel = delayedResultsChannel,
)

companion object {
private val log = getLogger<NadelIncrementalResultSupport>()

private fun makeDefaultChannel(): Channel<DelayedIncrementalPartialResult> = Channel(
capacity = 100,
onBufferOverflow = BufferOverflow.DROP_LATEST,
onUndeliveredElement = {
log.error("Dropping incremental result because of buffer overflow")
},
)
}

private val channelMutex = Mutex()
private val operationMutex = Mutex()

/**
* The root [Job] to run the defer and stream work etc on.
Expand Down Expand Up @@ -67,13 +81,15 @@ class NadelIncrementalResultSupport internal constructor(
val result = task()
initialCompletionLock.await()

channelMutex.withLock {
operationMutex.withLock {
accumulator.accumulate(result)

val hasNext = outstandingJobHandle.decrementAndGetJobCount() > 0

delayedResultsChannel.send(
// Copy of result but with the correct hasNext according to the info we know
quickCopy(result, hasNext),
)
val next = accumulator.getIncrementalPartialResult(hasNext)
if (next != null) {
delayedResultsChannel.send(next)
}
}
}
}
Expand All @@ -89,18 +105,20 @@ class NadelIncrementalResultSupport internal constructor(
.collect { result ->
initialCompletionLock.await()

channelMutex.withLock {
operationMutex.withLock {
accumulator.accumulate(result)

// Here we'll stipulate that the last element of the Flow sets hasNext=false
val hasNext = if (result.hasNext()) {
true
} else {
outstandingJobHandle.decrementAndGetJobCount() > 0
}

delayedResultsChannel.send(
// Copy of result but with the correct hasNext according to the info we know
quickCopy(result, hasNext),
)
val next = accumulator.getIncrementalPartialResult(hasNext)
if (next != null) {
delayedResultsChannel.send(next)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,9 @@ data class NadelQueryPath(val segments: List<String>) {

companion object {
val root = NadelQueryPath(emptyList())

fun fromResultPath(path: List<Any>): NadelQueryPath {
return NadelQueryPath(path.filterIsInstance<String>())
}
}
}
Loading

0 comments on commit 79b230c

Please sign in to comment.