Skip to content

Commit

Permalink
Fix out of order incremental results (#556)
Browse files Browse the repository at this point in the history
* Fix out of order incremental results

* Fix another flaky test
  • Loading branch information
gnawf authored Jun 26, 2024
1 parent 4e51086 commit d649da7
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import graphql.incremental.DelayedIncrementalPartialResult
import graphql.nadel.engine.NadelIncrementalResultSupport.OutstandingJobCounter.OutstandingJobHandle
import graphql.nadel.engine.util.copy
import graphql.nadel.util.getLogger
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
Expand All @@ -14,12 +15,11 @@ import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.consumeAsFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger

/**
* todo: we do not handle the case where defer jobs finish before [onInitialResultComplete]
*/
class NadelIncrementalResultSupport internal constructor(
private val delayedResultsChannel: Channel<DelayedIncrementalPartialResult> = Channel(
capacity = 100,
Expand All @@ -33,12 +33,21 @@ class NadelIncrementalResultSupport internal constructor(
private val log = getLogger<NadelIncrementalResultSupport>()
}

private val channelMutex = Mutex()

/**
* The root [Job] to run the defer and stream work etc on.
*/
private val coroutineJob = SupervisorJob()
private val coroutineScope = CoroutineScope(coroutineJob + Dispatchers.Default)

/**
* Temporary _kind of_ hack to wait for the initial result to complete before kicking off other jobs.
*
* Doesn't really handle a defer job kicking off more deferrals, but we'll cross that bridge later.
*/
private val initialCompletionLock = CompletableDeferred<Unit>()

/**
* A single [Flow] that can only be collected from once.
*/
Expand All @@ -55,17 +64,17 @@ class NadelIncrementalResultSupport internal constructor(

fun defer(task: suspend CoroutineScope.() -> DelayedIncrementalPartialResult): Job {
return launch { outstandingJobHandle ->
val hasNext: Boolean
val result = try {
task()
} finally {
hasNext = outstandingJobHandle.decrementAndGetJobCount() > 0
}
val result = task()
initialCompletionLock.await()

delayedResultsChannel.send(
// Copy of result but with the correct hasNext according to the info we know
quickCopy(result, hasNext),
)
channelMutex.withLock {
val hasNext = outstandingJobHandle.decrementAndGetJobCount() > 0

delayedResultsChannel.send(
// Copy of result but with the correct hasNext according to the info we know
quickCopy(result, hasNext),
)
}
}
}

Expand All @@ -78,17 +87,21 @@ class NadelIncrementalResultSupport internal constructor(
return launch { outstandingJobHandle ->
serviceResults
.collect { result ->
// Here we'll stipulate that the last element of the Flow sets hasNext=false
val hasNext = if (result.hasNext()) {
true
} else {
outstandingJobHandle.decrementAndGetJobCount() > 0
initialCompletionLock.await()

channelMutex.withLock {
// Here we'll stipulate that the last element of the Flow sets hasNext=false
val hasNext = if (result.hasNext()) {
true
} else {
outstandingJobHandle.decrementAndGetJobCount() > 0
}

delayedResultsChannel.send(
// Copy of result but with the correct hasNext according to the info we know
quickCopy(result, hasNext),
)
}

delayedResultsChannel.send(
// Copy of result but with the correct hasNext according to the info we know
quickCopy(result, hasNext),
)
}
}
}
Expand All @@ -107,7 +120,11 @@ class NadelIncrementalResultSupport internal constructor(
}

fun onInitialResultComplete() {
// This signals the end for the job; not immediately, but as soon as the child jobs are all done
coroutineJob.complete()

// Unblocks work to yield results to the channel
initialCompletionLock.complete(Unit)
}

fun close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@ import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.lastOrNull
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.flow.withIndex
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withTimeoutOrNull
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import kotlin.test.assertFalse
import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.milliseconds

class NadelIncrementalResultSupportTest {
@Test
Expand Down Expand Up @@ -86,16 +91,61 @@ class NadelIncrementalResultSupportTest {

// Then
firstLock.unlock()
secondLock.unlock()
thirdLock.unlock()

val results = channel.consumeAsFlow().toList()
val results = channel
.consumeAsFlow()
.withIndex()
.onEach { (index, _) ->
when (index) {
0 -> secondLock.unlock()
1 -> thirdLock.unlock()
2 -> {} // Do nothing
else -> throw IllegalArgumentException("Test does not expect this many elements")
}
}
.map { (_, value) -> value }
.toList()

assertTrue(results.dropLast(n = 1).all { it.hasNext() })
val lastResult = results.last()
assertTrue((lastResult.incremental?.single() as DeferPayload).getData<String>() == "Bye world")
assertFalse(lastResult.hasNext())
}

@Test
fun `does not send anything before onInitialResultComplete is invoked`() = runTest {
val channel = Channel<DelayedIncrementalPartialResult>(UNLIMITED)

val subject = NadelIncrementalResultSupport(channel)
val lock = CompletableDeferred<Boolean>()

// When
subject.defer {
DelayedIncrementalPartialResultImpl.newIncrementalExecutionResult()
.incrementalItems(emptyList())
.hasNext(true)
.extensions(mapOf("hello" to "world"))
.build()
.also {
lock.complete(true)
}
}

// Then
lock.join()

// Nothing comes out
val timeoutResult = withTimeoutOrNull(100.milliseconds) {
channel.receive()
}
assertTrue(timeoutResult == null)
assertTrue(channel.isEmpty)

// We receive the result once we invoke this
subject.onInitialResultComplete()
assertTrue(channel.receive().extensions == mapOf("hello" to "world"))
}

@Test
fun `hasNext is true if last job launches more jobs`() = runTest {
val channel = Channel<DelayedIncrementalPartialResult>(UNLIMITED)
Expand All @@ -113,21 +163,30 @@ class NadelIncrementalResultSupportTest {

DelayedIncrementalPartialResultImpl.newIncrementalExecutionResult()
.incrementalItems(emptyList())
.extensions(mapOf("id" to 2))
.hasNext(true)
.build()
}

DelayedIncrementalPartialResultImpl.newIncrementalExecutionResult()
.incrementalItems(emptyList())
.extensions(mapOf("id" to 1))
.hasNext(false)
.build()
}

// Then
subject.onInitialResultComplete()
firstLock.complete(true)

val item = channel.receive()
assertTrue(item.hasNext())
val first = channel.receive()
assertTrue(first.hasNext())
assertTrue(first.extensions == mapOf("id" to 1))

secondLock.complete(true)
val second = channel.receive()
assertFalse(second.hasNext())
assertTrue(second.extensions == mapOf("id" to 2))
}

@Test
Expand Down Expand Up @@ -167,6 +226,8 @@ class NadelIncrementalResultSupportTest {
}

// Then
subject.onInitialResultComplete()

firstLock.complete(true)
val firstItem = channel.receive()
assertTrue(firstItem.incremental?.isEmpty() == true)
Expand Down

0 comments on commit d649da7

Please sign in to comment.