Skip to content

Commit

Permalink
Unit tests for streams manager (#45090)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Sep 3, 2024
1 parent 5bf11d9 commit af58faa
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ data class DestinationCatalog(
}
}

interface DestinationCatalogFactory {
fun make(): DestinationCatalog
}

@Factory
class DestinationCatalogFactory(
class DefaultDestinationCatalogFactory(
private val catalog: ConfiguredAirbyteCatalog,
private val streamFactory: DestinationStreamFactory
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ class DestinationMessageQueueWriter(
/* If the input message represents a record. */
is DestinationRecordMessage -> {
val manager = streamsManager.getManager(message.stream)
val index = manager.countRecordIn(sizeBytes)
when (message) {
/* If a data record */
is DestinationRecord -> {
val wrapped =
StreamRecordWrapped(
index = index,
index = manager.countRecordIn(),
sizeBytes = sizeBytes,
record = message
)
Expand All @@ -58,7 +57,7 @@ class DestinationMessageQueueWriter(

/* If an end-of-stream marker. */
is DestinationStreamComplete -> {
val wrapped = StreamCompleteWrapped(index)
val wrapped = StreamCompleteWrapped(index = manager.countEndOfStream())
messageQueue.getChannel(message.stream).send(wrapped)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Factory
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.coroutines.channels.Channel

/** Manages the state of all streams in the destination. */
interface StreamsManager {
/** Get the manager for the given stream. Throws an exception if the stream is not found. */
fun getManager(stream: DestinationStream): StreamManager
suspend fun awaitAllStreamsComplete()

/** Suspend until all streams are closed. */
suspend fun awaitAllStreamsClosed()
}

class DefaultStreamsManager(
Expand All @@ -33,68 +35,98 @@ class DefaultStreamsManager(
return streamManagers[stream] ?: throw IllegalArgumentException("Stream not found: $stream")
}

override suspend fun awaitAllStreamsComplete() {
override suspend fun awaitAllStreamsClosed() {
streamManagers.forEach { (_, manager) -> manager.awaitStreamClosed() }
}
}

/** Manages the state of a single stream. */
interface StreamManager {
fun countRecordIn(sizeBytes: Long): Long
/** Count incoming record and return the record's *index*. */
fun countRecordIn(): Long

/**
* Count the end-of-stream. Expect this exactly once. Expect no further `countRecordIn`, and
* expect that `markClosed` will always occur after this.
*/
fun countEndOfStream(): Long

/**
* Mark a checkpoint in the stream and return the current index and the number of records since
* the last one.
*
* NOTE: Single-writer. If in the future multiple threads set checkpoints, this method should be
* synchronized.
*/
fun markCheckpoint(): Pair<Long, Long>

/** Record that the given batch's state has been reached for the associated range(s). */
fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>)

/**
* True if all are true:
* * all records have been seen (ie, we've counted an end-of-stream)
* * a [Batch.State.COMPLETE] batch range has been seen covering every record
*
* Does NOT require that the stream be closed.
*/
fun isBatchProcessingComplete(): Boolean

/**
* True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is
* implicitly true if they have all reached [Batch.State.COMPLETE].
*/
fun areRecordsPersistedUntil(index: Long): Boolean

/** Mark the stream as closed. This should only be called after all records have been read. */
fun markClosed()

/** True if the stream has been marked as closed. */
fun streamIsClosed(): Boolean

/** Suspend until the stream is closed. */
suspend fun awaitStreamClosed()
}

/**
* Maintains a map of stream -> status metadata, and a map of batch state -> record ranges for which
* that state has been reached.
*
* TODO: Log a detailed report of the stream status on a regular cadence.
*/
class DefaultStreamManager(
val stream: DestinationStream,
) : StreamManager {
private val log = KotlinLogging.logger {}

data class StreamStatus(
val recordCount: AtomicLong = AtomicLong(0),
val totalBytes: AtomicLong = AtomicLong(0),
val enqueuedSize: AtomicLong = AtomicLong(0),
val lastCheckpoint: AtomicLong = AtomicLong(0L),
val closedLatch: CountDownLatch = CountDownLatch(1),
)
private val recordCount = AtomicLong(0)
private val lastCheckpoint = AtomicLong(0L)
private val readIsClosed = AtomicBoolean(false)
private val streamIsClosed = AtomicBoolean(false)
private val closedLock = Channel<Unit>()

private val streamStatus: StreamStatus = StreamStatus()
private val rangesState: ConcurrentHashMap<Batch.State, RangeSet<Long>> = ConcurrentHashMap()

init {
Batch.State.entries.forEach { rangesState[it] = TreeRangeSet.create() }
}

override fun countRecordIn(sizeBytes: Long): Long {
val index = streamStatus.recordCount.getAndIncrement()
streamStatus.totalBytes.addAndGet(sizeBytes)
streamStatus.enqueuedSize.addAndGet(sizeBytes)
return index
override fun countRecordIn(): Long {
if (readIsClosed.get()) {
throw IllegalStateException("Stream is closed for reading")
}

return recordCount.getAndIncrement()
}

override fun countEndOfStream(): Long {
if (readIsClosed.getAndSet(true)) {
throw IllegalStateException("Stream is closed for reading")
}

return recordCount.get()
}

/**
* Mark a checkpoint in the stream and return the current index and the number of records since
* the last one.
*/
override fun markCheckpoint(): Pair<Long, Long> {
val index = streamStatus.recordCount.get()
val lastCheckpoint = streamStatus.lastCheckpoint.getAndSet(index)
val index = recordCount.get()
val lastCheckpoint = lastCheckpoint.getAndSet(index)
return Pair(index, index - lastCheckpoint)
}

/** Record that the given batch's state has been reached for the associated range(s). */
override fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>) {
val stateRanges =
rangesState[batch.batch.state]
Expand All @@ -112,37 +144,44 @@ class DefaultStreamManager(
log.info { "Updated ranges for $stream[${batch.batch.state}]: $stateRanges" }
}

/** True if all records in [0, index] have reached the given state. */
/** True if all records in `[0, index)` have reached the given state. */
private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean {

val completeRanges = rangesState[state]!!
return completeRanges.encloses(Range.closedOpen(0L, index))
}

/** True if all records have associated [Batch.State.COMPLETE] batches. */
override fun isBatchProcessingComplete(): Boolean {
return isProcessingCompleteForState(streamStatus.recordCount.get(), Batch.State.COMPLETE)
/* If the stream hasn't been fully read, it can't be done. */
if (!readIsClosed.get()) {
return false
}

return isProcessingCompleteForState(recordCount.get(), Batch.State.COMPLETE)
}

/**
* True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is
* implicitly true if they have all reached [Batch.State.COMPLETE].
*/
override fun areRecordsPersistedUntil(index: Long): Boolean {
return isProcessingCompleteForState(index, Batch.State.PERSISTED) ||
isProcessingCompleteForState(index, Batch.State.COMPLETE) // complete => persisted
}

override fun markClosed() {
streamStatus.closedLatch.countDown()
if (!readIsClosed.get()) {
throw IllegalStateException("Stream must be fully read before it can be closed")
}

if (streamIsClosed.compareAndSet(false, true)) {
closedLock.trySend(Unit)
}
}

override fun streamIsClosed(): Boolean {
return streamStatus.closedLatch.count == 0L
return streamIsClosed.get()
}

override suspend fun awaitStreamClosed() {
withContext(Dispatchers.IO) { streamStatus.closedLatch.await() }
if (!streamIsClosed.get()) {
closedLock.receive()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TeardownTask(
}

/** Ensure we don't run until all streams have completed */
streamsManager.awaitAllStreamsComplete()
streamsManager.awaitAllStreamsClosed()

destination.teardown()
taskLauncher.stop()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.command

import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Replaces
import io.micronaut.context.annotation.Requires
import jakarta.inject.Named
import jakarta.inject.Singleton

@Factory
@Replaces(factory = DestinationCatalogFactory::class)
@Requires(env = ["test"])
class MockCatalogFactory : DestinationCatalogFactory {
companion object {
val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1"))
val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2"))
}

@Singleton
@Named("mockCatalog")
override fun make(): DestinationCatalog {
return DestinationCatalog(streams = listOf(stream1, stream2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ import com.google.common.collect.Range
import com.google.common.collect.RangeSet
import com.google.common.collect.TreeRangeSet
import io.airbyte.cdk.command.DestinationCatalog
import io.airbyte.cdk.command.DestinationCatalogFactory
import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1
import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream2
import io.airbyte.cdk.message.Batch
import io.airbyte.cdk.message.BatchEnvelope
import io.airbyte.cdk.message.MessageConverter
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Prototype
import io.micronaut.context.annotation.Replaces
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Inject
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.util.function.Consumer
import java.util.stream.Stream
Expand All @@ -29,25 +29,10 @@ import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import org.junit.jupiter.params.provider.ArgumentsSource

@MicronautTest
@MicronautTest(environments = ["StateManagerTest"])
class StateManagerTest {
@Inject lateinit var stateManager: TestStateManager

companion object {
val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1"))
val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2"))
}

@Factory
@Replaces(factory = DestinationCatalogFactory::class)
class MockCatalogFactory {
@Singleton
@Requires(env = ["test"])
fun make(): DestinationCatalog {
return DestinationCatalog(streams = listOf(stream1, stream2))
}
}

/**
* Test state messages.
*
Expand Down Expand Up @@ -95,7 +80,11 @@ class StateManagerTest {
class MockStreamManager : StreamManager {
var persistedRanges: RangeSet<Long> = TreeRangeSet.create()

override fun countRecordIn(sizeBytes: Long): Long {
override fun countRecordIn(): Long {
throw NotImplementedError()
}

override fun countEndOfStream(): Long {
throw NotImplementedError()
}

Expand Down Expand Up @@ -129,7 +118,8 @@ class StateManagerTest {
}

@Prototype
class MockStreamsManager(catalog: DestinationCatalog) : StreamsManager {
@Requires(env = ["StateManagerTest"])
class MockStreamsManager(@Named("mockCatalog") catalog: DestinationCatalog) : StreamsManager {
private val mockManagers = catalog.streams.associateWith { MockStreamManager() }

fun addPersistedRanges(stream: DestinationStream, ranges: List<Range<Long>>) {
Expand All @@ -141,14 +131,14 @@ class StateManagerTest {
?: throw IllegalArgumentException("Stream not found: $stream")
}

override suspend fun awaitAllStreamsComplete() {
override suspend fun awaitAllStreamsClosed() {
throw NotImplementedError()
}
}

@Prototype
class TestStateManager(
override val catalog: DestinationCatalog,
@Named("mockCatalog") override val catalog: DestinationCatalog,
override val streamsManager: MockStreamsManager,
override val outputFactory: MessageConverter<MockStateIn, MockStateOut>,
override val outputConsumer: MockOutputConsumer
Expand Down
Loading

0 comments on commit af58faa

Please sign in to comment.