diff --git a/airbyte-cdk/bulk/core/base/src/main/resources/application.yaml b/airbyte-cdk/bulk/core/base/src/main/resources/application.yaml index dfe35e3a85fb..e6160c5260e0 100644 --- a/airbyte-cdk/bulk/core/base/src/main/resources/application.yaml +++ b/airbyte-cdk/bulk/core/base/src/main/resources/application.yaml @@ -5,3 +5,6 @@ airbyte: resources: disk: bytes: ${CONNECTOR_STORAGE_LIMIT_BYTES:5368709120} # 5GB + flush: + rate-ms: 900000 # 15 minutes + window-ms: 900000 # 15 minutes diff --git a/airbyte-cdk/bulk/core/load/build.gradle b/airbyte-cdk/bulk/core/load/build.gradle index 06836bda3532..7d459a62b3db 100644 --- a/airbyte-cdk/bulk/core/load/build.gradle +++ b/airbyte-cdk/bulk/core/load/build.gradle @@ -41,11 +41,16 @@ def integrationTestTask = tasks.register('integrationTest', Test) { maxParallelForks = project.test.maxParallelForks maxHeapSize = project.test.maxHeapSize } + // These tests are lightweight enough to run on every PR. tasks.named('check').configure { dependsOn integrationTest } +test { + systemProperties(["mockk.junit.extension.requireParallelTesting":"true"]) +} + configurations { integrationTestImplementation.extendsFrom testImplementation integrationTestRuntimeOnly.extendsFrom testRuntimeOnly diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/message/DestinationMessageQueues.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/message/DestinationMessageQueues.kt index f8cf4e37feda..457799c1c9d0 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/message/DestinationMessageQueues.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/message/DestinationMessageQueues.kt @@ -17,61 +17,54 @@ interface Sized { } /** - * Wrapper for record messages published to the message queue, containing metadata like index and - * size. + * Wrapper message for stream events published to the stream specific queues, containing metadata + * like index and size. * * In a future where we deserialize only the info necessary for routing, this could include a dumb * container for the serialized, and deserialization could be deferred until the spooled records * were recovered from disk. */ -sealed class DestinationRecordWrapped : Sized +sealed class DestinationStreamEvent : Sized -sealed class DestinationFileWrapped : Sized - -data class StreamRecordWrapped( +data class StreamRecordEvent( val index: Long, override val sizeBytes: Long, val record: DestinationRecord -) : DestinationRecordWrapped() +) : DestinationStreamEvent() -data class StreamFileWrapped( +data class StreamCompleteEvent( val index: Long, - override val sizeBytes: Long, - val file: DestinationFile -) : DestinationFileWrapped() - -data class StreamRecordCompleteWrapped( - val index: Long, -) : DestinationRecordWrapped() { +) : DestinationStreamEvent() { override val sizeBytes: Long = 0L } -data class StreamFileCompleteWrapped( - val index: Long, -) : DestinationFileWrapped() { +data class StreamFlushEvent( + val tickedAtMs: Long, +) : DestinationStreamEvent() { override val sizeBytes: Long = 0L } -class DestinationRecordQueue : ChannelMessageQueue>() +class DestinationStreamEventQueue : ChannelMessageQueue>() /** - * A supplier of message queues to which ([ReservationManager.reserve]'d) @ - * [DestinationRecordWrapped] messages can be published on a @ [DestinationStream] key. The queues - * themselves do not manage memory. + * A supplier of message queues to which ([ReservationManager.reserve]'d) @ [DestinationStreamEvent] + * messages can be published on a @ [DestinationStream] key. The queues themselves do not manage + * memory. */ @Singleton @Secondary -class DestinationRecordQueueSupplier(catalog: DestinationCatalog) : - MessageQueueSupplier> { - private val queues = ConcurrentHashMap() +class DestinationStreamQueueSupplier(catalog: DestinationCatalog) : + MessageQueueSupplier> { + private val queues = + ConcurrentHashMap() init { - catalog.streams.forEach { queues[it.descriptor] = DestinationRecordQueue() } + catalog.streams.forEach { queues[it.descriptor] = DestinationStreamEventQueue() } } - override fun get(key: DestinationStream.Descriptor): DestinationRecordQueue { + override fun get(key: DestinationStream.Descriptor): DestinationStreamEventQueue { return queues[key] - ?: throw IllegalArgumentException("Reading from non-existent record stream: $key") + ?: throw IllegalArgumentException("Reading from non-existent stream: $key") } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/state/ReservationManager.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/state/ReservationManager.kt index ddfeb69a8f0d..3cc82c78f310 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/state/ReservationManager.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/state/ReservationManager.kt @@ -14,8 +14,8 @@ import kotlinx.coroutines.sync.withLock /** Releasable reservation of memory. */ class Reserved( - private val parentManager: ReservationManager, - val bytesReserved: Long, + private val parentManager: ReservationManager? = null, + val bytesReserved: Long = 0, val value: T, ) : CloseableCoroutine { private var released = AtomicBoolean(false) @@ -24,7 +24,7 @@ class Reserved( if (!released.compareAndSet(false, true)) { return } - parentManager.release(bytesReserved) + parentManager?.release(bytesReserved) } fun replace(value: U): Reserved = Reserved(parentManager, bytesReserved, value) diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/state/TimeWindowTrigger.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/state/TimeWindowTrigger.kt new file mode 100644 index 000000000000..ca2693a8a355 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/state/TimeWindowTrigger.kt @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.state + +import java.time.Clock + +/* + * Simple time-windowing strategy for bucketing partial aggregates. + * + * Works off time relative to the injected @param clock. Generally this is the processing time domain. + */ +data class TimeWindowTrigger( + private val clock: Clock, + private val windowWidthMs: Long, +) { + private var openedAtMs: Long? = null + + /* + * Sets window open timestamp for computing completeness. Idempotent. Mutative. + */ + fun open(): Long { + if (openedAtMs == null) { + openedAtMs = clock.millis() + } + return openedAtMs!! + } + + /* + * Returns whether window is complete relative to configured @param windowWidthMs. Non-mutative. + */ + fun isComplete(): Boolean { + return openedAtMs?.let { ts -> (clock.millis() - ts) >= windowWidthMs } ?: false + } +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt index 349e9d257d4c..b70da5daf690 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt @@ -12,7 +12,7 @@ import io.airbyte.cdk.load.message.BatchEnvelope import io.airbyte.cdk.load.message.CheckpointMessageWrapped import io.airbyte.cdk.load.message.DestinationFile import io.airbyte.cdk.load.message.DestinationMessage -import io.airbyte.cdk.load.message.DestinationRecordWrapped +import io.airbyte.cdk.load.message.DestinationStreamEvent import io.airbyte.cdk.load.message.MessageQueueSupplier import io.airbyte.cdk.load.message.QueueWriter import io.airbyte.cdk.load.state.Reserved @@ -25,6 +25,7 @@ import io.airbyte.cdk.load.task.implementor.ProcessRecordsTaskFactory import io.airbyte.cdk.load.task.implementor.SetupTaskFactory import io.airbyte.cdk.load.task.implementor.TeardownTaskFactory import io.airbyte.cdk.load.task.internal.FlushCheckpointsTaskFactory +import io.airbyte.cdk.load.task.internal.FlushTickTask import io.airbyte.cdk.load.task.internal.InputConsumerTaskFactory import io.airbyte.cdk.load.task.internal.SizedInputFlow import io.airbyte.cdk.load.task.internal.SpillToDiskTaskFactory @@ -98,6 +99,7 @@ class DefaultDestinationTaskLauncher( // Internal Tasks private val inputConsumerTaskFactory: InputConsumerTaskFactory, private val spillToDiskTaskFactory: SpillToDiskTaskFactory, + private val flushTickTask: FlushTickTask, // Implementor Tasks private val setupTaskFactory: SetupTaskFactory, @@ -110,7 +112,7 @@ class DefaultDestinationTaskLauncher( // Checkpoint Tasks private val flushCheckpointsTaskFactory: FlushCheckpointsTaskFactory, - private val timedFlushTask: TimedForcedCheckpointFlushTask, + private val timedCheckpointFlushTask: TimedForcedCheckpointFlushTask, private val updateCheckpointsTask: UpdateCheckpointsTask, // Exception handling @@ -120,7 +122,7 @@ class DefaultDestinationTaskLauncher( // Input Comsumer requirements private val inputFlow: SizedInputFlow>, private val recordQueueSupplier: - MessageQueueSupplier>, + MessageQueueSupplier>, private val checkpointQueue: QueueWriter>, ) : DestinationTaskLauncher { private val log = KotlinLogging.logger {} @@ -163,9 +165,13 @@ class DefaultDestinationTaskLauncher( } } + // Start flush task + log.info { "Starting timed file aggregate flush task " } + enqueue(flushTickTask) + // Start the checkpoint management tasks - log.info { "Starting timed flush task" } - enqueue(timedFlushTask) + log.info { "Starting timed checkpoint flush task" } + enqueue(timedCheckpointFlushTask) log.info { "Starting checkpoint update task" } enqueue(updateCheckpointsTask) diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTask.kt new file mode 100644 index 000000000000..6c4a6caf83fe --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTask.kt @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.task.internal + +import com.google.common.annotations.VisibleForTesting +import io.airbyte.cdk.load.command.DestinationCatalog +import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.file.TimeProvider +import io.airbyte.cdk.load.message.DestinationStreamEvent +import io.airbyte.cdk.load.message.MessageQueueSupplier +import io.airbyte.cdk.load.message.StreamFlushEvent +import io.airbyte.cdk.load.state.Reserved +import io.airbyte.cdk.load.task.KillableScope +import io.airbyte.cdk.load.task.SyncLevel +import io.micronaut.context.annotation.Secondary +import io.micronaut.context.annotation.Value +import jakarta.inject.Singleton +import java.time.Clock + +@Singleton +@Secondary +class FlushTickTask( + @Value("\${airbyte.flush.rate-ms}") private val tickIntervalMs: Long, + private val clock: Clock, + private val coroutineTimeUtils: TimeProvider, + private val catalog: DestinationCatalog, + private val recordQueueSupplier: + MessageQueueSupplier>, +) : SyncLevel, KillableScope { + override suspend fun execute() { + while (true) { + waitAndPublishFlushTick() + } + } + + @VisibleForTesting + suspend fun waitAndPublishFlushTick() { + coroutineTimeUtils.delay(tickIntervalMs) + + catalog.streams.forEach { + val queue = recordQueueSupplier.get(it.descriptor) + queue.publish(Reserved(value = StreamFlushEvent(clock.millis()))) + } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTask.kt index 8a11ae1f8a14..a94787b215eb 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTask.kt @@ -16,16 +16,16 @@ import io.airbyte.cdk.load.message.DestinationMessage import io.airbyte.cdk.load.message.DestinationRecord import io.airbyte.cdk.load.message.DestinationRecordStreamComplete import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete -import io.airbyte.cdk.load.message.DestinationRecordWrapped import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage +import io.airbyte.cdk.load.message.DestinationStreamEvent import io.airbyte.cdk.load.message.GlobalCheckpoint import io.airbyte.cdk.load.message.GlobalCheckpointWrapped import io.airbyte.cdk.load.message.MessageQueueSupplier import io.airbyte.cdk.load.message.QueueWriter import io.airbyte.cdk.load.message.StreamCheckpoint import io.airbyte.cdk.load.message.StreamCheckpointWrapped -import io.airbyte.cdk.load.message.StreamRecordCompleteWrapped -import io.airbyte.cdk.load.message.StreamRecordWrapped +import io.airbyte.cdk.load.message.StreamCompleteEvent +import io.airbyte.cdk.load.message.StreamRecordEvent import io.airbyte.cdk.load.message.Undefined import io.airbyte.cdk.load.state.Reserved import io.airbyte.cdk.load.state.SyncManager @@ -55,7 +55,7 @@ class DefaultInputConsumerTask( private val catalog: DestinationCatalog, private val inputFlow: SizedInputFlow>, private val recordQueueSupplier: - MessageQueueSupplier>, + MessageQueueSupplier>, private val checkpointQueue: QueueWriter>, private val syncManager: SyncManager, private val destinationTaskLauncher: DestinationTaskLauncher, @@ -72,7 +72,7 @@ class DefaultInputConsumerTask( when (val message = reserved.value) { is DestinationRecord -> { val wrapped = - StreamRecordWrapped( + StreamRecordEvent( index = manager.countRecordIn(), sizeBytes = sizeBytes, record = message @@ -81,7 +81,7 @@ class DefaultInputConsumerTask( } is DestinationRecordStreamComplete -> { reserved.release() // safe because multiple calls conflate - val wrapped = StreamRecordCompleteWrapped(index = manager.markEndOfStream()) + val wrapped = StreamCompleteEvent(index = manager.markEndOfStream()) recordQueue.publish(reserved.replace(wrapped)) recordQueue.close() } @@ -179,7 +179,7 @@ interface InputConsumerTaskFactory { catalog: DestinationCatalog, inputFlow: SizedInputFlow>, recordQueueSupplier: - MessageQueueSupplier>, + MessageQueueSupplier>, checkpointQueue: QueueWriter>, destinationTaskLauncher: DestinationTaskLauncher, ): InputConsumerTask @@ -193,7 +193,7 @@ class DefaultInputConsumerTaskFactory(private val syncManager: SyncManager) : catalog: DestinationCatalog, inputFlow: SizedInputFlow>, recordQueueSupplier: - MessageQueueSupplier>, + MessageQueueSupplier>, checkpointQueue: QueueWriter>, destinationTaskLauncher: DestinationTaskLauncher, ): InputConsumerTask { diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTask.kt index 534cd58e758a..dd2f6f0804c0 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTask.kt @@ -7,14 +7,16 @@ package io.airbyte.cdk.load.task.internal import com.google.common.collect.Range import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.file.SpillFileProvider -import io.airbyte.cdk.load.message.DestinationRecordWrapped +import io.airbyte.cdk.load.message.DestinationStreamEvent import io.airbyte.cdk.load.message.MessageQueueSupplier import io.airbyte.cdk.load.message.QueueReader -import io.airbyte.cdk.load.message.StreamRecordCompleteWrapped -import io.airbyte.cdk.load.message.StreamRecordWrapped +import io.airbyte.cdk.load.message.StreamCompleteEvent +import io.airbyte.cdk.load.message.StreamFlushEvent +import io.airbyte.cdk.load.message.StreamRecordEvent import io.airbyte.cdk.load.state.FlushStrategy import io.airbyte.cdk.load.state.ReservationManager import io.airbyte.cdk.load.state.Reserved +import io.airbyte.cdk.load.state.TimeWindowTrigger import io.airbyte.cdk.load.task.DestinationTaskLauncher import io.airbyte.cdk.load.task.InternalScope import io.airbyte.cdk.load.task.StreamLevel @@ -23,9 +25,11 @@ import io.airbyte.cdk.load.util.use import io.airbyte.cdk.load.util.withNextAdjacentValue import io.airbyte.cdk.load.util.write import io.github.oshai.kotlinlogging.KotlinLogging +import io.micronaut.context.annotation.Value import jakarta.inject.Named import jakarta.inject.Singleton import java.nio.file.Path +import java.time.Clock import kotlin.io.path.outputStream import kotlinx.coroutines.flow.last import kotlinx.coroutines.flow.runningFold @@ -40,11 +44,12 @@ interface SpillToDiskTask : StreamLevel, InternalScope */ class DefaultSpillToDiskTask( private val spillFileProvider: SpillFileProvider, - private val queue: QueueReader>, + private val queue: QueueReader>, private val flushStrategy: FlushStrategy, override val streamDescriptor: DestinationStream.Descriptor, private val launcher: DestinationTaskLauncher, private val diskManager: ReservationManager, + private val timeWindow: TimeWindowTrigger, ) : SpillToDiskTask { private val log = KotlinLogging.logger {} @@ -64,7 +69,11 @@ class DefaultSpillToDiskTask( .runningFold(ReadResult()) { (range, sizeBytes, _), reserved -> reserved.use { when (val wrapped = it.value) { - is StreamRecordWrapped -> { + is StreamRecordEvent -> { + // once we have received a record for the stream, consider the + // aggregate opened. + timeWindow.open() + // reserve enough room for the record diskManager.reserve(wrapped.sizeBytes) @@ -87,10 +96,14 @@ class DefaultSpillToDiskTask( forceFlush = forceFlush ) } - is StreamRecordCompleteWrapped -> { + is StreamCompleteEvent -> { val nextRange = range.withNextAdjacentValue(wrapped.index) ReadResult(nextRange, sizeBytes, hasReadEndOfStream = true) } + is StreamFlushEvent -> { + val forceFlush = timeWindow.isComplete() + ReadResult(range, sizeBytes, forceFlush = forceFlush) + } } } } @@ -125,14 +138,18 @@ interface SpillToDiskTaskFactory { class DefaultSpillToDiskTaskFactory( private val spillFileProvider: SpillFileProvider, private val queueSupplier: - MessageQueueSupplier>, + MessageQueueSupplier>, private val flushStrategy: FlushStrategy, @Named("diskManager") private val diskManager: ReservationManager, + private val clock: Clock, + @Value("\${airbyte.flush.window-ms}") private val windowWidthMs: Long, ) : SpillToDiskTaskFactory { override fun make( taskLauncher: DestinationTaskLauncher, stream: DestinationStream.Descriptor ): SpillToDiskTask { + val timeWindow = TimeWindowTrigger(clock, windowWidthMs) + return DefaultSpillToDiskTask( spillFileProvider, queueSupplier.get(stream), @@ -140,6 +157,7 @@ class DefaultSpillToDiskTaskFactory( stream, taskLauncher, diskManager, + timeWindow, ) } } diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/state/TimeWindowTriggerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/state/TimeWindowTriggerTest.kt new file mode 100644 index 000000000000..86c07486b0bc --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/state/TimeWindowTriggerTest.kt @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.state + +import io.airbyte.cdk.load.state.TimeWindowTriggerTest.Fixtures.TIME_WINDOW_WIDTH_MS +import io.mockk.every +import io.mockk.impl.annotations.MockK +import io.mockk.junit5.MockKExtension +import java.time.Clock +import java.util.stream.Stream +import kotlin.test.assertEquals +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource + +@ExtendWith(MockKExtension::class) +class TimeWindowTriggerTest { + @MockK lateinit var clock: Clock + + private lateinit var timeWindow: TimeWindowTrigger + + @BeforeEach + fun setup() { + timeWindow = TimeWindowTrigger(clock, TIME_WINDOW_WIDTH_MS) + } + + @Test + fun `open is idempotent`() { + val initialOpenedAt = 1000L + + every { clock.millis() } returns initialOpenedAt + val openedAt1 = timeWindow.open() + + assertEquals(initialOpenedAt, openedAt1) + + every { clock.millis() } returns initialOpenedAt + 1 + val openedAt2 = timeWindow.open() + + assertEquals(initialOpenedAt, openedAt2) + } + + @Test + fun `isComplete returns false if window not opened`() { + every { clock.millis() } returns TIME_WINDOW_WIDTH_MS + assertEquals(false, timeWindow.isComplete()) + + every { clock.millis() } returns TIME_WINDOW_WIDTH_MS + 1 + assertEquals(false, timeWindow.isComplete()) + + every { clock.millis() } returns TIME_WINDOW_WIDTH_MS + 60000 + assertEquals(false, timeWindow.isComplete()) + } + + @ParameterizedTest + @MethodSource("windowWidthMatrix") + fun `isComplete calculates time window based on configured width`(windowWidthMs: Long) { + every { clock.millis() } returns 0 + + timeWindow = TimeWindowTrigger(clock, windowWidthMs) + + val openedAt = timeWindow.open() + assertEquals(0, openedAt) + + every { clock.millis() } returns windowWidthMs - 1 + assertEquals(false, timeWindow.isComplete()) + + every { clock.millis() } returns windowWidthMs - 124 + assertEquals(false, timeWindow.isComplete()) + + every { clock.millis() } returns windowWidthMs + assertEquals(true, timeWindow.isComplete()) + + every { clock.millis() } returns windowWidthMs + 1 + assertEquals(true, timeWindow.isComplete()) + + every { clock.millis() } returns windowWidthMs + 60000 + assertEquals(true, timeWindow.isComplete()) + } + + object Fixtures { + const val TIME_WINDOW_WIDTH_MS = 60000L + } + + companion object { + @JvmStatic + private fun windowWidthMatrix(): Stream { + return Stream.of( + Arguments.of(100L), + Arguments.of(10000L), + Arguments.of(900000L), + Arguments.of(900001L), + ) + } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherTest.kt index 47a2167eca1d..9ae9af677852 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherTest.kt @@ -13,7 +13,7 @@ import io.airbyte.cdk.load.message.Batch import io.airbyte.cdk.load.message.BatchEnvelope import io.airbyte.cdk.load.message.CheckpointMessageWrapped import io.airbyte.cdk.load.message.DestinationMessage -import io.airbyte.cdk.load.message.DestinationRecordWrapped +import io.airbyte.cdk.load.message.DestinationStreamEvent import io.airbyte.cdk.load.message.MessageQueue import io.airbyte.cdk.load.message.MessageQueueSupplier import io.airbyte.cdk.load.message.QueueWriter @@ -40,6 +40,7 @@ import io.airbyte.cdk.load.task.implementor.TeardownTaskFactory import io.airbyte.cdk.load.task.internal.DefaultSpillToDiskTaskFactory import io.airbyte.cdk.load.task.internal.FlushCheckpointsTask import io.airbyte.cdk.load.task.internal.FlushCheckpointsTaskFactory +import io.airbyte.cdk.load.task.internal.FlushTickTask import io.airbyte.cdk.load.task.internal.InputConsumerTask import io.airbyte.cdk.load.task.internal.InputConsumerTaskFactory import io.airbyte.cdk.load.task.internal.SizedInputFlow @@ -61,7 +62,6 @@ import kotlinx.coroutines.channels.toList import kotlinx.coroutines.delay import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test @@ -96,6 +96,12 @@ class DestinationTaskLauncherTest where T : LeveledTask, T : ScopedTask { @Inject lateinit var inputFlow: MockInputFlow @Inject lateinit var queueWriter: MockQueueWriter @Inject lateinit var messageQueueSupplier: MockMessageQueueSupplier + @Inject lateinit var flushTickTask: FlushTickTask + + @Singleton + @Primary + @Requires(env = ["DestinationTaskLauncherTest"]) + fun flushTickTask(): FlushTickTask = mockk(relaxed = true) @Singleton @Primary @@ -119,10 +125,10 @@ class DestinationTaskLauncherTest where T : LeveledTask, T : ScopedTask { @Primary @Requires(env = ["DestinationTaskLauncherTest"]) class MockMessageQueueSupplier : - MessageQueueSupplier> { + MessageQueueSupplier> { override fun get( key: DestinationStream.Descriptor - ): MessageQueue> { + ): MessageQueue> { return mockk() } } @@ -138,7 +144,7 @@ class DestinationTaskLauncherTest where T : LeveledTask, T : ScopedTask { inputFlow: SizedInputFlow>, recordQueueSupplier: MessageQueueSupplier< - DestinationStream.Descriptor, Reserved>, + DestinationStream.Descriptor, Reserved>, checkpointQueue: QueueWriter>, destinationTaskLauncher: DestinationTaskLauncher ): InputConsumerTask { @@ -340,7 +346,7 @@ class DestinationTaskLauncherTest where T : LeveledTask, T : ScopedTask { } override suspend fun withExceptionHandling(task: T): WrappedTask { - runBlocking { wrappedTasks.send(task) } + wrappedTasks.send(task) val innerTask = object : InternalScope { override suspend fun execute() { diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherUTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherUTest.kt index 3d2ab2492b1c..ee86d3bbba7d 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherUTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherUTest.kt @@ -8,7 +8,7 @@ import io.airbyte.cdk.load.command.DestinationCatalog import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.message.CheckpointMessageWrapped import io.airbyte.cdk.load.message.DestinationMessage -import io.airbyte.cdk.load.message.DestinationRecordWrapped +import io.airbyte.cdk.load.message.DestinationStreamEvent import io.airbyte.cdk.load.message.MessageQueueSupplier import io.airbyte.cdk.load.message.QueueWriter import io.airbyte.cdk.load.state.Reserved @@ -22,6 +22,7 @@ import io.airbyte.cdk.load.task.implementor.ProcessRecordsTaskFactory import io.airbyte.cdk.load.task.implementor.SetupTaskFactory import io.airbyte.cdk.load.task.implementor.TeardownTaskFactory import io.airbyte.cdk.load.task.internal.FlushCheckpointsTaskFactory +import io.airbyte.cdk.load.task.internal.FlushTickTask import io.airbyte.cdk.load.task.internal.InputConsumerTaskFactory import io.airbyte.cdk.load.task.internal.SizedInputFlow import io.airbyte.cdk.load.task.internal.SpillToDiskTask @@ -46,6 +47,7 @@ class DestinationTaskLauncherUTest { // Internal Tasks private val inputConsumerTaskFactory: InputConsumerTaskFactory = mockk(relaxed = true) private val spillToDiskTaskFactory: SpillToDiskTaskFactory = mockk(relaxed = true) + private val flushTickTask: FlushTickTask = mockk(relaxed = true) // Implementor Tasks private val setupTaskFactory: SetupTaskFactory = mockk(relaxed = true) @@ -68,7 +70,7 @@ class DestinationTaskLauncherUTest { // Input Comsumer requirements private val inputFlow: SizedInputFlow> = mockk(relaxed = true) private val recordQueueSupplier: - MessageQueueSupplier> = + MessageQueueSupplier> = mockk(relaxed = true) private val checkpointQueue: QueueWriter> = mockk(relaxed = true) @@ -81,6 +83,7 @@ class DestinationTaskLauncherUTest { syncManager, inputConsumerTaskFactory, spillToDiskTaskFactory, + flushTickTask, setupTaskFactory, openStreamTaskFactory, processRecordsTaskFactory, diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTaskTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTaskTest.kt new file mode 100644 index 000000000000..31b24d2baf54 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTaskTest.kt @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.task.internal + +import io.airbyte.cdk.load.command.Append +import io.airbyte.cdk.load.command.DestinationCatalog +import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.data.FieldType +import io.airbyte.cdk.load.data.IntegerType +import io.airbyte.cdk.load.data.ObjectType +import io.airbyte.cdk.load.data.StringType +import io.airbyte.cdk.load.file.TimeProvider +import io.airbyte.cdk.load.message.DestinationStreamEvent +import io.airbyte.cdk.load.message.MessageQueue +import io.airbyte.cdk.load.message.MessageQueueSupplier +import io.airbyte.cdk.load.message.StreamFlushEvent +import io.airbyte.cdk.load.state.Reserved +import io.mockk.coVerify +import io.mockk.every +import io.mockk.impl.annotations.MockK +import io.mockk.junit5.MockKExtension +import io.mockk.mockk +import io.mockk.slot +import java.time.Clock +import java.util.stream.Stream +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource + +@ExtendWith(MockKExtension::class) +class FlushTickTaskTest { + @MockK(relaxed = true) lateinit var clock: Clock + @MockK(relaxed = true) lateinit var coroutineTimeUtils: TimeProvider + @MockK(relaxed = true) lateinit var catalog: DestinationCatalog + @MockK(relaxed = true) + lateinit var recordQueueSupplier: + MessageQueueSupplier> + + private val tickIntervalMs = 60000L // 1 min + + private lateinit var task: FlushTickTask + + @BeforeEach + fun setup() { + task = + FlushTickTask( + tickIntervalMs, + clock, + coroutineTimeUtils, + catalog, + recordQueueSupplier, + ) + } + + @Test + fun `waits for the configured amount of time`() = runTest { + task.waitAndPublishFlushTick() + + coVerify { coroutineTimeUtils.delay(tickIntervalMs) } + } + + @ParameterizedTest + @MethodSource("streamMatrix") + fun `publishes a flush message for each stream in the catalog`( + streams: List + ) = runTest { + every { catalog.streams } returns streams + val queues = + streams.associateWith { + mockk>>(relaxed = true) + } + + streams.forEach { + every { recordQueueSupplier.get(eq(it.descriptor)) } returns queues[it]!! + } + + task.waitAndPublishFlushTick() + + streams.forEach { + val msgSlot = slot>() + coVerify { queues[it]!!.publish(capture(msgSlot)) } + assert(msgSlot.captured.value is StreamFlushEvent) + } + } + + companion object { + @JvmStatic + fun streamMatrix(): Stream { + return Stream.of( + Arguments.of(listOf(Fixtures.stream1)), + Arguments.of(listOf(Fixtures.stream1, Fixtures.stream2)), + Arguments.of(listOf(Fixtures.stream1, Fixtures.stream3)), + Arguments.of(listOf(Fixtures.stream2, Fixtures.stream3)), + Arguments.of(listOf(Fixtures.stream1, Fixtures.stream2, Fixtures.stream3)), + ) + } + } + + object Fixtures { + val stream1 = + DestinationStream( + DestinationStream.Descriptor("test", "stream1"), + importType = Append, + schema = + ObjectType( + properties = + linkedMapOf( + "id" to FieldType(type = IntegerType, nullable = true), + "name" to FieldType(type = StringType, nullable = true), + ), + ), + generationId = 1, + minimumGenerationId = 0, + syncId = 42, + ) + val stream2 = + DestinationStream( + DestinationStream.Descriptor("test", "stream2"), + importType = Append, + schema = + ObjectType( + properties = + linkedMapOf( + "id" to FieldType(type = IntegerType, nullable = true), + "name" to FieldType(type = StringType, nullable = true), + ), + ), + generationId = 3, + minimumGenerationId = 0, + syncId = 42, + ) + val stream3 = + DestinationStream( + DestinationStream.Descriptor(null, "stream3"), + importType = Append, + schema = + ObjectType( + properties = + linkedMapOf( + "id" to FieldType(type = IntegerType, nullable = true), + "name" to FieldType(type = StringType, nullable = true), + ), + ), + generationId = 9, + minimumGenerationId = 0, + syncId = 42, + ) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTaskTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTaskTest.kt index ed24b5100345..0c056ee8f5ab 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTaskTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTaskTest.kt @@ -4,33 +4,23 @@ package io.airbyte.cdk.load.task.internal -import com.fasterxml.jackson.databind.node.JsonNodeFactory import io.airbyte.cdk.load.command.DestinationConfiguration import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.MockDestinationCatalogFactory -import io.airbyte.cdk.load.data.NullValue -import io.airbyte.cdk.load.message.CheckpointMessage import io.airbyte.cdk.load.message.CheckpointMessageWrapped -import io.airbyte.cdk.load.message.DestinationFile -import io.airbyte.cdk.load.message.DestinationFileStreamComplete -import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete import io.airbyte.cdk.load.message.DestinationMessage -import io.airbyte.cdk.load.message.DestinationRecord -import io.airbyte.cdk.load.message.DestinationRecordStreamComplete -import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete -import io.airbyte.cdk.load.message.DestinationRecordWrapped -import io.airbyte.cdk.load.message.GlobalCheckpoint +import io.airbyte.cdk.load.message.DestinationStreamEvent import io.airbyte.cdk.load.message.GlobalCheckpointWrapped import io.airbyte.cdk.load.message.MessageQueue import io.airbyte.cdk.load.message.MessageQueueSupplier -import io.airbyte.cdk.load.message.StreamCheckpoint import io.airbyte.cdk.load.message.StreamCheckpointWrapped -import io.airbyte.cdk.load.message.StreamRecordCompleteWrapped -import io.airbyte.cdk.load.message.StreamRecordWrapped +import io.airbyte.cdk.load.message.StreamCompleteEvent +import io.airbyte.cdk.load.message.StreamRecordEvent import io.airbyte.cdk.load.state.ReservationManager import io.airbyte.cdk.load.state.Reserved import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.test.util.CoroutineTestUtils +import io.airbyte.cdk.load.test.util.StubDestinationMessageFactory import io.airbyte.cdk.load.util.takeUntilInclusive import io.micronaut.context.annotation.Primary import io.micronaut.context.annotation.Requires @@ -62,7 +52,7 @@ class InputConsumerTaskTest { @Inject lateinit var taskFactory: InputConsumerTaskFactory @Inject lateinit var recordQueueSupplier: - MessageQueueSupplier> + MessageQueueSupplier> @Inject lateinit var checkpointQueue: MessageQueue> @Inject lateinit var syncManager: SyncManager @Inject lateinit var mockInputFlow: MockInputFlow @@ -93,65 +83,6 @@ class InputConsumerTaskTest { } } - private fun makeRecord(stream: DestinationStream, record: String): DestinationRecord { - return DestinationRecord( - stream = stream.descriptor, - data = NullValue, - emittedAtMs = 0, - meta = null, - serialized = record - ) - } - - private val nullFileMessage = DestinationFile.AirbyteRecordMessageFile() - - private fun makeFile(stream: DestinationStream, record: String): DestinationFile { - return DestinationFile( - stream = stream.descriptor, - emittedAtMs = 0, - serialized = record, - fileMessage = nullFileMessage, - ) - } - - private fun makeStreamComplete(stream: DestinationStream): DestinationRecordStreamComplete { - return DestinationRecordStreamComplete(stream = stream.descriptor, emittedAtMs = 0) - } - - private fun makeFileStreamComplete(stream: DestinationStream): DestinationFileStreamComplete { - return DestinationFileStreamComplete(stream = stream.descriptor, emittedAtMs = 0) - } - - private fun makeStreamIncomplete(stream: DestinationStream): DestinationRecordStreamIncomplete { - return DestinationRecordStreamIncomplete(stream = stream.descriptor, emittedAtMs = 0) - } - - private fun makeFileStreamIncomplete( - stream: DestinationStream - ): DestinationFileStreamIncomplete { - return DestinationFileStreamIncomplete(stream = stream.descriptor, emittedAtMs = 0) - } - - private fun makeStreamState(stream: DestinationStream, recordCount: Long): CheckpointMessage { - return StreamCheckpoint( - checkpoint = - CheckpointMessage.Checkpoint( - stream.descriptor, - JsonNodeFactory.instance.objectNode() - ), - sourceStats = CheckpointMessage.Stats(recordCount), - ) - } - - private fun makeGlobalState(recordCount: Long): CheckpointMessage { - return GlobalCheckpoint( - state = JsonNodeFactory.instance.objectNode(), - sourceStats = CheckpointMessage.Stats(recordCount), - checkpoints = emptyList(), - additionalProperties = emptyMap(), - ) - } - @Test fun testSendRecords() = runTest { val queue1 = recordQueueSupplier.get(MockDestinationCatalogFactory.stream1.descriptor) @@ -164,12 +95,19 @@ class InputConsumerTaskTest { (0 until 10).forEach { mockInputFlow.addMessage( - makeRecord(MockDestinationCatalogFactory.stream1, "test${it}"), + StubDestinationMessageFactory.makeRecord( + MockDestinationCatalogFactory.stream1, + "test${it}" + ), it * 2L ) } - mockInputFlow.addMessage(makeStreamComplete(MockDestinationCatalogFactory.stream1)) - mockInputFlow.addMessage(makeStreamComplete(MockDestinationCatalogFactory.stream2)) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamComplete(MockDestinationCatalogFactory.stream1) + ) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamComplete(MockDestinationCatalogFactory.stream2) + ) val task = taskFactory.make( @@ -184,33 +122,34 @@ class InputConsumerTaskTest { val messages1 = queue1 .consume() - .takeUntilInclusive { - (it.value as StreamRecordWrapped).record.serialized == "test9" - } + .takeUntilInclusive { (it.value as StreamRecordEvent).record.serialized == "test9" } .toList() Assertions.assertEquals(10, messages1.size) val expectedRecords = (0 until 10).map { - StreamRecordWrapped( + StreamRecordEvent( it.toLong(), it * 2L, - makeRecord(MockDestinationCatalogFactory.stream1, "test${it}") + StubDestinationMessageFactory.makeRecord( + MockDestinationCatalogFactory.stream1, + "test${it}" + ) ) } - val streamComplete1: Reserved = + val streamComplete1: Reserved = queue1.consume().take(1).toList().first() - val streamComplete2: Reserved = + val streamComplete2: Reserved = queue2.consume().take(1).toList().first() Assertions.assertEquals(expectedRecords, messages1.map { it.value }) Assertions.assertEquals(expectedRecords.map { _ -> 1L }, messages1.map { it.bytesReserved }) - Assertions.assertEquals(StreamRecordCompleteWrapped(10), streamComplete1.value) + Assertions.assertEquals(StreamCompleteEvent(10), streamComplete1.value) Assertions.assertEquals(1, streamComplete1.bytesReserved) Assertions.assertEquals(10L, manager1.recordCount()) - Assertions.assertEquals(emptyList(), queue1.consume().toList()) - Assertions.assertEquals(StreamRecordCompleteWrapped(0), streamComplete2.value) - Assertions.assertEquals(emptyList(), queue2.consume().toList()) + Assertions.assertEquals(emptyList(), queue1.consume().toList()) + Assertions.assertEquals(StreamCompleteEvent(0), streamComplete2.value) + Assertions.assertEquals(emptyList(), queue2.consume().toList()) Assertions.assertEquals(0L, manager2.recordCount()) mockInputFlow.stop() } @@ -227,14 +166,26 @@ class InputConsumerTaskTest { (0 until 10).forEach { _ -> mockInputFlow.addMessage( - makeRecord(MockDestinationCatalogFactory.stream1, "whatever"), + StubDestinationMessageFactory.makeRecord( + MockDestinationCatalogFactory.stream1, + "whatever" + ), 0L ) } - mockInputFlow.addMessage(makeRecord(MockDestinationCatalogFactory.stream2, "test"), 1L) - mockInputFlow.addMessage(makeStreamComplete(MockDestinationCatalogFactory.stream1), 0L) - mockInputFlow.addMessage(makeStreamComplete(MockDestinationCatalogFactory.stream2), 0L) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeRecord(MockDestinationCatalogFactory.stream2, "test"), + 1L + ) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamComplete(MockDestinationCatalogFactory.stream1), + 0L + ) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamComplete(MockDestinationCatalogFactory.stream2), + 0L + ) val task = taskFactory.make( mockCatalogFactory.make(), @@ -249,12 +200,15 @@ class InputConsumerTaskTest { queue2.close() Assertions.assertEquals( listOf( - StreamRecordWrapped( + StreamRecordEvent( 0, 1L, - makeRecord(MockDestinationCatalogFactory.stream2, "test") + StubDestinationMessageFactory.makeRecord( + MockDestinationCatalogFactory.stream2, + "test" + ) ), - StreamRecordCompleteWrapped(1) + StreamCompleteEvent(1) ), queue2.consume().toList().map { it.value } ) @@ -266,7 +220,7 @@ class InputConsumerTaskTest { queue1.close() val messages1 = queue1.consume().toList() Assertions.assertEquals(11, messages1.size) - Assertions.assertEquals(messages1[10].value, StreamRecordCompleteWrapped(10)) + Assertions.assertEquals(messages1[10].value, StreamCompleteEvent(10)) Assertions.assertEquals( mockInputFlow.initialMemory - 11, mockInputFlow.memoryManager.remainingCapacityBytes, @@ -300,15 +254,27 @@ class InputConsumerTaskTest { ) launch { task.execute() } batches.forEach { (stream, count, expectedCount) -> - repeat(count) { mockInputFlow.addMessage(makeRecord(stream, "test"), 1L) } - mockInputFlow.addMessage(makeStreamState(stream, count.toLong()), 0L) + repeat(count) { + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeRecord(stream, "test"), + 1L + ) + } + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamState(stream, count.toLong()), + 0L + ) val state = checkpointQueue.consume().take(1).toList().first().value as StreamCheckpointWrapped Assertions.assertEquals(expectedCount, state.index) Assertions.assertEquals(count.toLong(), state.checkpoint.destinationStats?.recordCount) } - mockInputFlow.addMessage(makeStreamComplete(MockDestinationCatalogFactory.stream1)) - mockInputFlow.addMessage(makeStreamComplete(MockDestinationCatalogFactory.stream2)) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamComplete(MockDestinationCatalogFactory.stream1) + ) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamComplete(MockDestinationCatalogFactory.stream2) + ) mockInputFlow.stop() } @@ -347,11 +313,17 @@ class InputConsumerTaskTest { when (event) { is AddRecords -> { repeat(event.count) { - mockInputFlow.addMessage(makeRecord(event.stream, "test"), 1L) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeRecord(event.stream, "test"), + 1L + ) } } is SendState -> { - mockInputFlow.addMessage(makeGlobalState(event.expectedStream1Count), 0L) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeGlobalState(event.expectedStream1Count), + 0L + ) val state = checkpointQueue.consume().take(1).toList().first().value as GlobalCheckpointWrapped @@ -372,15 +344,27 @@ class InputConsumerTaskTest { } } } - mockInputFlow.addMessage(makeStreamComplete(MockDestinationCatalogFactory.stream1)) - mockInputFlow.addMessage(makeStreamComplete(MockDestinationCatalogFactory.stream2)) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamComplete(MockDestinationCatalogFactory.stream1) + ) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamComplete(MockDestinationCatalogFactory.stream2) + ) mockInputFlow.stop() } @Test fun testStreamIncompleteThrows() = runTest { - mockInputFlow.addMessage(makeRecord(MockDestinationCatalogFactory.stream1, "test"), 1L) - mockInputFlow.addMessage(makeStreamIncomplete(MockDestinationCatalogFactory.stream1), 0L) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeRecord(MockDestinationCatalogFactory.stream1, "test"), + 1L + ) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeStreamIncomplete( + MockDestinationCatalogFactory.stream1 + ), + 0L + ) val task = taskFactory.make( mockCatalogFactory.make(), @@ -395,9 +379,14 @@ class InputConsumerTaskTest { @Test fun testFileStreamIncompleteThrows() = runTest { - mockInputFlow.addMessage(makeFile(MockDestinationCatalogFactory.stream1, "test"), 1L) mockInputFlow.addMessage( - makeFileStreamIncomplete(MockDestinationCatalogFactory.stream1), + StubDestinationMessageFactory.makeFile(MockDestinationCatalogFactory.stream1, "test"), + 1L + ) + mockInputFlow.addMessage( + StubDestinationMessageFactory.makeFileStreamIncomplete( + MockDestinationCatalogFactory.stream1 + ), 0L ) val task = diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTaskTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTaskTest.kt index 23856b65d812..30241313200a 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTaskTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTaskTest.kt @@ -7,147 +7,274 @@ package io.airbyte.cdk.load.task.internal import com.google.common.collect.Range import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.command.MockDestinationCatalogFactory +import io.airbyte.cdk.load.command.MockDestinationConfiguration import io.airbyte.cdk.load.data.NullValue +import io.airbyte.cdk.load.file.DefaultSpillFileProvider import io.airbyte.cdk.load.file.SpillFileProvider import io.airbyte.cdk.load.message.DestinationRecord -import io.airbyte.cdk.load.message.DestinationRecordWrapped +import io.airbyte.cdk.load.message.DestinationStreamEvent +import io.airbyte.cdk.load.message.DestinationStreamEventQueue +import io.airbyte.cdk.load.message.DestinationStreamQueueSupplier import io.airbyte.cdk.load.message.MessageQueueSupplier -import io.airbyte.cdk.load.message.StreamRecordCompleteWrapped -import io.airbyte.cdk.load.message.StreamRecordWrapped +import io.airbyte.cdk.load.message.StreamCompleteEvent +import io.airbyte.cdk.load.message.StreamFlushEvent +import io.airbyte.cdk.load.message.StreamRecordEvent import io.airbyte.cdk.load.state.FlushStrategy import io.airbyte.cdk.load.state.ReservationManager import io.airbyte.cdk.load.state.Reserved +import io.airbyte.cdk.load.state.TimeWindowTrigger +import io.airbyte.cdk.load.task.DestinationTaskLauncher import io.airbyte.cdk.load.task.MockTaskLauncher +import io.airbyte.cdk.load.test.util.StubDestinationMessageFactory import io.airbyte.cdk.load.util.lineSequence -import io.micronaut.test.extensions.junit5.annotation.MicronautTest -import jakarta.inject.Inject +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.every +import io.mockk.impl.annotations.MockK +import io.mockk.junit5.MockKExtension +import io.mockk.mockk +import java.time.Clock import kotlin.io.path.inputStream import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith -@MicronautTest( - environments = - [ - "MockDestinationConfiguration", - "MockDestinationCatalog", - ] -) class SpillToDiskTaskTest { - private lateinit var memoryManager: ReservationManager - private lateinit var diskManager: ReservationManager - private lateinit var spillToDiskTaskFactory: DefaultSpillToDiskTaskFactory - @Inject - lateinit var queueSupplier: - MessageQueueSupplier> - @Inject lateinit var spillFileProvider: SpillFileProvider - - @BeforeEach - fun setup() { - memoryManager = ReservationManager(Fixtures.INITIAL_MEMORY_CAPACITY) - diskManager = ReservationManager(Fixtures.INITIAL_DISK_CAPACITY) - spillToDiskTaskFactory = - DefaultSpillToDiskTaskFactory( - spillFileProvider, - queueSupplier, - MockFlushStrategy(), - diskManager, - ) - } + /** Validates task delegates to dependencies as expected. Does not test dependency behavior. */ + @Nested + @ExtendWith(MockKExtension::class) + inner class UnitTests { + @MockK(relaxed = true) lateinit var spillFileProvider: SpillFileProvider + + @MockK(relaxed = true) lateinit var flushStrategy: FlushStrategy + + @MockK(relaxed = true) lateinit var taskLauncher: DestinationTaskLauncher + + @MockK(relaxed = true) lateinit var timeWindow: TimeWindowTrigger + + @MockK(relaxed = true) lateinit var diskManager: ReservationManager + + private lateinit var inputQueue: DestinationStreamEventQueue - class MockFlushStrategy : FlushStrategy { - override suspend fun shouldFlush( - stream: DestinationStream.Descriptor, - rangeRead: Range, - bytesProcessed: Long - ): Boolean { - return bytesProcessed >= 1024 + private lateinit var task: DefaultSpillToDiskTask + + @BeforeEach + fun setup() { + inputQueue = DestinationStreamEventQueue() + task = + DefaultSpillToDiskTask( + spillFileProvider, + inputQueue, + flushStrategy, + MockDestinationCatalogFactory.stream1.descriptor, + taskLauncher, + diskManager, + timeWindow, + ) } - } - private suspend fun primeMessageQueue(): Long { - val queue = queueSupplier.get(MockDestinationCatalogFactory.stream1.descriptor) - val maxRecords = ((1024 * 1.5) / 8).toLong() - var recordsWritten = 0L - while (recordsWritten < maxRecords) { - val index = recordsWritten++ - queue.publish( - memoryManager.reserve( - Fixtures.MEMORY_RESERVATION_SIZE_BYTES, - StreamRecordWrapped( - index = index, - sizeBytes = Fixtures.SERIALIZED_SIZE_BYTES, - record = - DestinationRecord( - stream = MockDestinationCatalogFactory.stream1.descriptor, - data = NullValue, - emittedAtMs = 0, - meta = null, - serialized = "test${index}" - ) + @Test + fun `publishes 'spilled file' aggregates according to flush strategy on stream record`() = + runTest { + val recordMsg = + StreamRecordEvent( + 3L, + 2L, + StubDestinationMessageFactory.makeRecord( + MockDestinationCatalogFactory.stream1, + "test 3", + ), ) + // flush strategy returns true, so we flush + coEvery { flushStrategy.shouldFlush(any(), any(), any()) } returns true + inputQueue.publish(Reserved(value = recordMsg)) + + task.execute() + coVerify(exactly = 1) { taskLauncher.handleNewSpilledFile(any(), any()) } + } + + @Test + fun `publishes 'spilled file' aggregates on stream complete event`() = runTest { + val completeMsg = StreamCompleteEvent(0L) + inputQueue.publish(Reserved(value = completeMsg)) + + task.execute() + coVerify(exactly = 1) { taskLauncher.handleNewSpilledFile(any(), any()) } + } + + @Test + fun `publishes 'spilled file' aggregates according to time window on stream flush event`() = + runTest { + // flush strategy returns false, so it won't flush + coEvery { flushStrategy.shouldFlush(any(), any(), any()) } returns false + every { timeWindow.isComplete() } returns true + + val flushMsg = StreamFlushEvent(101L) + val recordMsg = + StreamRecordEvent( + 3L, + 2L, + StubDestinationMessageFactory.makeRecord( + MockDestinationCatalogFactory.stream1, + "test 3", + ), + ) + + // must publish 1 record message so range isn't empty + inputQueue.publish(Reserved(value = recordMsg)) + inputQueue.publish(Reserved(value = flushMsg)) + + task.execute() + coVerify(exactly = 1) { taskLauncher.handleNewSpilledFile(any(), any()) } + } + } + + /** + * Validates end to end behaviors including those of dependencies. Also exercises the factory. + */ + @Nested + inner class EndToEndTests { + private lateinit var memoryManager: ReservationManager + private lateinit var diskManager: ReservationManager + private lateinit var spillToDiskTaskFactory: DefaultSpillToDiskTaskFactory + private lateinit var taskLauncher: MockTaskLauncher + private val clock: Clock = mockk(relaxed = true) + private val flushWindowMs = 60000L + + private lateinit var queueSupplier: + MessageQueueSupplier> + private lateinit var spillFileProvider: SpillFileProvider + + @BeforeEach + fun setup() { + spillFileProvider = DefaultSpillFileProvider(MockDestinationConfiguration()) + queueSupplier = + DestinationStreamQueueSupplier( + MockDestinationCatalogFactory().make(), ) + taskLauncher = MockTaskLauncher() + memoryManager = ReservationManager(Fixtures.INITIAL_MEMORY_CAPACITY) + diskManager = ReservationManager(Fixtures.INITIAL_DISK_CAPACITY) + spillToDiskTaskFactory = + DefaultSpillToDiskTaskFactory( + spillFileProvider, + queueSupplier, + MockFlushStrategy(), + diskManager, + clock, + flushWindowMs, + ) + } + + @Test + fun `writes aggregates to files and manages disk and memory reservations`() = runTest { + val messageCount = primeMessageQueue() + val bytesReservedMemory = Fixtures.MEMORY_RESERVATION_SIZE_BYTES * messageCount + val bytesReservedDisk = Fixtures.SERIALIZED_SIZE_BYTES * messageCount + + // memory manager has reserved bytes for messages + Assertions.assertEquals( + Fixtures.INITIAL_MEMORY_CAPACITY - bytesReservedMemory, + memoryManager.remainingCapacityBytes, + ) + // disk manager has not reserved any bytes + Assertions.assertEquals( + Fixtures.INITIAL_DISK_CAPACITY, + diskManager.remainingCapacityBytes, + ) + + spillToDiskTaskFactory + .make(taskLauncher, MockDestinationCatalogFactory.stream1.descriptor) + .execute() + Assertions.assertEquals(1, taskLauncher.spilledFiles.size) + spillToDiskTaskFactory + .make(taskLauncher, MockDestinationCatalogFactory.stream1.descriptor) + .execute() + Assertions.assertEquals(2, taskLauncher.spilledFiles.size) + + Assertions.assertEquals(1024, taskLauncher.spilledFiles[0].totalSizeBytes) + Assertions.assertEquals(512, taskLauncher.spilledFiles[1].totalSizeBytes) + + val spilled1 = taskLauncher.spilledFiles[0] + val spilled2 = taskLauncher.spilledFiles[1] + Assertions.assertEquals(1024, spilled1.totalSizeBytes) + Assertions.assertEquals(512, spilled2.totalSizeBytes) + + val file1 = spilled1.localFile + val file2 = spilled2.localFile + + val expectedLinesFirst = (0 until 1024 / 8).flatMap { listOf("test$it") } + val expectedLinesSecond = (1024 / 8 until 1536 / 8).flatMap { listOf("test$it") } + + Assertions.assertEquals( + expectedLinesFirst, + file1.inputStream().lineSequence().toList(), + ) + Assertions.assertEquals( + expectedLinesSecond, + file2.inputStream().lineSequence().toList(), ) + + // we have released all memory reservations + Assertions.assertEquals( + Fixtures.INITIAL_MEMORY_CAPACITY, + memoryManager.remainingCapacityBytes, + ) + // we now have equivalent disk reservations + Assertions.assertEquals( + Fixtures.INITIAL_DISK_CAPACITY - bytesReservedDisk, + diskManager.remainingCapacityBytes, + ) + + file1.toFile().delete() + file2.toFile().delete() + } + + inner class MockFlushStrategy : FlushStrategy { + override suspend fun shouldFlush( + stream: DestinationStream.Descriptor, + rangeRead: Range, + bytesProcessed: Long + ): Boolean { + return bytesProcessed >= 1024 + } } - queue.publish(memoryManager.reserve(0L, StreamRecordCompleteWrapped(index = maxRecords))) - return recordsWritten - } - @Test - fun testSpillToDiskTask() = runTest { - val messageCount = primeMessageQueue() - val bytesReservedMemory = Fixtures.MEMORY_RESERVATION_SIZE_BYTES * messageCount - val bytesReservedDisk = Fixtures.SERIALIZED_SIZE_BYTES * messageCount - - // memory manager has reserved bytes for messages - Assertions.assertEquals( - Fixtures.INITIAL_MEMORY_CAPACITY - bytesReservedMemory, - memoryManager.remainingCapacityBytes - ) - // disk manager has not reserved any bytes - Assertions.assertEquals(Fixtures.INITIAL_DISK_CAPACITY, diskManager.remainingCapacityBytes) - - val mockTaskLauncher = MockTaskLauncher() - spillToDiskTaskFactory - .make(mockTaskLauncher, MockDestinationCatalogFactory.stream1.descriptor) - .execute() - Assertions.assertEquals(1, mockTaskLauncher.spilledFiles.size) - spillToDiskTaskFactory - .make(mockTaskLauncher, MockDestinationCatalogFactory.stream1.descriptor) - .execute() - Assertions.assertEquals(2, mockTaskLauncher.spilledFiles.size) - - Assertions.assertEquals(1024, mockTaskLauncher.spilledFiles[0].totalSizeBytes) - Assertions.assertEquals(512, mockTaskLauncher.spilledFiles[1].totalSizeBytes) - - val spilled1 = mockTaskLauncher.spilledFiles[0] - val spilled2 = mockTaskLauncher.spilledFiles[1] - Assertions.assertEquals(1024, spilled1.totalSizeBytes) - Assertions.assertEquals(512, spilled2.totalSizeBytes) - - val file1 = spilled1.localFile - val file2 = spilled2.localFile - - val expectedLinesFirst = (0 until 1024 / 8).flatMap { listOf("test$it") } - val expectedLinesSecond = (1024 / 8 until 1536 / 8).flatMap { listOf("test$it") } - - Assertions.assertEquals(expectedLinesFirst, file1.inputStream().lineSequence().toList()) - Assertions.assertEquals(expectedLinesSecond, file2.inputStream().lineSequence().toList()) - - // we have released all memory reservations - Assertions.assertEquals( - Fixtures.INITIAL_MEMORY_CAPACITY, - memoryManager.remainingCapacityBytes - ) - // we now have equivalent disk reservations - Assertions.assertEquals( - Fixtures.INITIAL_DISK_CAPACITY - bytesReservedDisk, - diskManager.remainingCapacityBytes - ) - - file1.toFile().delete() - file2.toFile().delete() + private suspend fun primeMessageQueue(): Long { + val queue = queueSupplier.get(MockDestinationCatalogFactory.stream1.descriptor) + val maxRecords = ((1024 * 1.5) / 8).toLong() + var recordsWritten = 0L + while (recordsWritten < maxRecords) { + val index = recordsWritten++ + queue.publish( + memoryManager.reserve( + Fixtures.MEMORY_RESERVATION_SIZE_BYTES, + StreamRecordEvent( + index = index, + sizeBytes = Fixtures.SERIALIZED_SIZE_BYTES, + record = + DestinationRecord( + stream = MockDestinationCatalogFactory.stream1.descriptor, + data = NullValue, + emittedAtMs = 0, + meta = null, + serialized = "test${index}", + ), + ), + ), + ) + } + queue.publish( + memoryManager.reserve( + 0L, + StreamCompleteEvent(index = maxRecords), + ), + ) + return recordsWritten + } } object Fixtures { diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/test/util/StubDestinationMessageFactory.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/test/util/StubDestinationMessageFactory.kt new file mode 100644 index 000000000000..a166566bae48 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/test/util/StubDestinationMessageFactory.kt @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.test.util + +import com.fasterxml.jackson.databind.node.JsonNodeFactory +import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.data.NullValue +import io.airbyte.cdk.load.message.CheckpointMessage +import io.airbyte.cdk.load.message.DestinationFile +import io.airbyte.cdk.load.message.DestinationFileStreamComplete +import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete +import io.airbyte.cdk.load.message.DestinationRecord +import io.airbyte.cdk.load.message.DestinationRecordStreamComplete +import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete +import io.airbyte.cdk.load.message.GlobalCheckpoint +import io.airbyte.cdk.load.message.StreamCheckpoint + +/* + * Shared factory methods for making stub destination messages for testing. + */ +object StubDestinationMessageFactory { + fun makeRecord(stream: DestinationStream, record: String): DestinationRecord { + return DestinationRecord( + stream = stream.descriptor, + data = NullValue, + emittedAtMs = 0, + meta = null, + serialized = record + ) + } + + fun makeFile(stream: DestinationStream, record: String): DestinationFile { + return DestinationFile( + stream = stream.descriptor, + emittedAtMs = 0, + serialized = record, + fileMessage = nullFileMessage, + ) + } + + fun makeStreamComplete(stream: DestinationStream): DestinationRecordStreamComplete { + return DestinationRecordStreamComplete(stream = stream.descriptor, emittedAtMs = 0) + } + + fun makeFileStreamComplete(stream: DestinationStream): DestinationFileStreamComplete { + return DestinationFileStreamComplete(stream = stream.descriptor, emittedAtMs = 0) + } + + fun makeStreamIncomplete(stream: DestinationStream): DestinationRecordStreamIncomplete { + return DestinationRecordStreamIncomplete(stream = stream.descriptor, emittedAtMs = 0) + } + + fun makeFileStreamIncomplete(stream: DestinationStream): DestinationFileStreamIncomplete { + return DestinationFileStreamIncomplete(stream = stream.descriptor, emittedAtMs = 0) + } + + fun makeStreamState(stream: DestinationStream, recordCount: Long): CheckpointMessage { + return StreamCheckpoint( + checkpoint = + CheckpointMessage.Checkpoint( + stream.descriptor, + JsonNodeFactory.instance.objectNode() + ), + sourceStats = CheckpointMessage.Stats(recordCount), + ) + } + + fun makeGlobalState(recordCount: Long): CheckpointMessage { + return GlobalCheckpoint( + state = JsonNodeFactory.instance.objectNode(), + sourceStats = CheckpointMessage.Stats(recordCount), + checkpoints = emptyList(), + additionalProperties = emptyMap(), + ) + } + + private val nullFileMessage = DestinationFile.AirbyteRecordMessageFile() +}