Skip to content

Commit

Permalink
Merge branch 'master' into jonathan/destination-mssql-v2-skeleton
Browse files Browse the repository at this point in the history
  • Loading branch information
jdpgrailsdev authored Dec 17, 2024
2 parents 06804c7 + 6d8a3a2 commit 7a0bcfc
Show file tree
Hide file tree
Showing 69 changed files with 2,141 additions and 1,010 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.airbyte.cdk.load.mock_integration_test

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.Dedupe
import io.airbyte.cdk.load.command.DestinationStream
Expand All @@ -16,9 +17,11 @@ import io.airbyte.cdk.load.state.StreamProcessingFailed
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
import java.time.Instant
import java.util.UUID
import javax.inject.Singleton
import kotlinx.coroutines.delay

@Singleton
class MockDestinationWriter : DestinationWriter {
Expand All @@ -27,7 +30,10 @@ class MockDestinationWriter : DestinationWriter {
}
}

@SuppressFBWarnings("NP_NONNULL_PARAM_VIOLATION", justification = "Kotlin async continuation")
class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
private val log = KotlinLogging.logger {}

abstract class MockBatch : Batch {
override val groupId: String? = null
}
Expand All @@ -38,9 +44,6 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
data class LocalFileBatch(val file: DestinationFile) : MockBatch() {
override val state = Batch.State.LOCAL
}
data class PersistedBatch(val records: List<DestinationRecord>) : MockBatch() {
override val state = Batch.State.PERSISTED
}

override suspend fun close(streamFailure: StreamProcessingFailed?) {
if (streamFailure == null) {
Expand Down Expand Up @@ -82,6 +85,7 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
override suspend fun processBatch(batch: Batch): Batch {
return when (batch) {
is LocalBatch -> {
log.info { "Persisting ${batch.records.size} records for ${stream.descriptor}" }
batch.records.forEach {
val filename = getFilename(it.stream, staging = true)
val record =
Expand All @@ -99,9 +103,14 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
// blind insert into the staging area. We'll dedupe on commit.
MockDestinationBackend.insert(filename, record)
}
PersistedBatch(batch.records)
// HACK: This destination is too fast and causes a race
// condition between consuming and flushing state messages
// that causes the test to fail. This would not be an issue
// in a real sync, because we would always either get more
// data or an end-of-stream that would force a final flush.
delay(100L)
SimpleBatch(state = Batch.State.COMPLETE)
}
is PersistedBatch -> SimpleBatch(state = Batch.State.COMPLETE)
else -> throw IllegalStateException("Unexpected batch type: $batch")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ abstract class DestinationConfiguration : Configuration {
open val gracefulCancellationTimeoutMs: Long = 60 * 1000L // 1 minutes

open val numProcessRecordsWorkers: Int = 2
open val numProcessBatchWorkers: Int = 5
open val batchQueueDepth: Int = 10

/**
* Micronaut factory which glues [ConfigurationSpecificationSupplier] and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package io.airbyte.cdk.load.config

import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.task.implementor.FileAggregateMessage
Expand Down Expand Up @@ -69,4 +70,13 @@ class SyncBeanFactory {
val channel = Channel<FileAggregateMessage>(capacity)
return MultiProducerChannel(streamCount.toLong(), channel)
}

@Singleton
@Named("batchQueue")
fun batchQueue(
config: DestinationConfiguration,
): MultiProducerChannel<BatchEnvelope<*>> {
val channel = Channel<BatchEnvelope<*>>(config.batchQueueDepth)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ import java.time.temporal.ChronoUnit
*/
class TimeStringToInteger : AirbyteValueIdentityMapper() {
companion object {
private val DATE_TIME_FORMATTER: DateTimeFormatter =
val DATE_TIME_FORMATTER: DateTimeFormatter =
DateTimeFormatter.ofPattern(
"[yyyy][yy]['-']['/']['.'][' '][MMM][MM][M]['-']['/']['.'][' '][dd][d][[' '][G]][[' ']['T']HH:mm[':'ss[.][SSSSSS][SSSSS][SSSS][SSS][' '][z][zzz][Z][O][x][XXX][XX][X][[' '][G]]]]"
)
private val TIME_FORMATTER: DateTimeFormatter =
val TIME_FORMATTER: DateTimeFormatter =
DateTimeFormatter.ofPattern(
"HH:mm[':'ss[.][SSSSSS][SSSSS][SSSS][SSS][' '][z][zzz][Z][O][x][XXX][XX][X]]"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.airbyte.cdk.load.message
import com.google.common.collect.Range
import com.google.common.collect.RangeSet
import com.google.common.collect.TreeRangeSet
import io.airbyte.cdk.load.command.DestinationStream

/**
* Represents an accumulated batch of records in some stage of processing.
Expand Down Expand Up @@ -66,6 +67,13 @@ interface Batch {
}

val state: State

/**
* If a [Batch] is [State.COMPLETE], there's nothing further to do. If it is part of a group,
* then its state will be updated by the next batch in the group that advances.
*/
val requiresProcessing: Boolean
get() = state != State.COMPLETE && groupId == null
}

/** Simple batch: use if you need no other metadata for processing. */
Expand All @@ -80,14 +88,20 @@ data class SimpleBatch(
*/
data class BatchEnvelope<B : Batch>(
val batch: B,
val ranges: RangeSet<Long> = TreeRangeSet.create()
val ranges: RangeSet<Long> = TreeRangeSet.create(),
val streamDescriptor: DestinationStream.Descriptor
) {
constructor(
batch: B,
range: Range<Long>
) : this(batch = batch, ranges = TreeRangeSet.create(listOf(range)))
range: Range<Long>,
streamDescriptor: DestinationStream.Descriptor
) : this(
batch = batch,
ranges = TreeRangeSet.create(listOf(range)),
streamDescriptor = streamDescriptor
)

fun <C : Batch> withBatch(newBatch: C): BatchEnvelope<C> {
return BatchEnvelope(newBatch, ranges)
return BatchEnvelope(newBatch, ranges, streamDescriptor)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.CheckpointMessageWrapped
import io.airbyte.cdk.load.message.DestinationFile
Expand Down Expand Up @@ -153,6 +152,10 @@ class DefaultDestinationTaskLauncher(
handleException(e)
}
}

override fun toString(): String {
return "TaskWrapper($innerTask)"
}
}

inner class NoopWrapper(
Expand Down Expand Up @@ -199,6 +202,12 @@ class DefaultDestinationTaskLauncher(
val task = processRecordsTaskFactory.make(this)
enqueue(task)
}

repeat(config.numProcessBatchWorkers) {
log.info { "Launching process batch task $it" }
val task = processBatchTaskFactory.make(this)
enqueue(task)
}
}

// Start flush task
Expand All @@ -224,8 +233,8 @@ class DefaultDestinationTaskLauncher(
override suspend fun handleSetupComplete() {
catalog.streams.forEach {
log.info { "Starting open stream task for $it" }
val openStreamTask = openStreamTaskFactory.make(this, it)
enqueue(openStreamTask)
val task = openStreamTaskFactory.make(this, it)
enqueue(task)
}
}

Expand All @@ -248,17 +257,13 @@ class DefaultDestinationTaskLauncher(
streamManager.updateBatchState(wrapped)

if (wrapped.batch.isPersisted()) {
enqueue(flushCheckpointsTaskFactory.make())
}

if (wrapped.batch.state != Batch.State.COMPLETE) {
log.info {
"Batch not complete: Starting process batch task for ${stream}, batch $wrapped"
"Batch $wrapped is persisted: Starting flush checkpoints task for $stream"
}
enqueue(flushCheckpointsTaskFactory.make())
}

val task = processBatchTaskFactory.make(this, stream, wrapped)
enqueue(task)
} else if (streamManager.isBatchProcessingComplete()) {
if (streamManager.isBatchProcessingComplete()) {
log.info {
"Batch $wrapped complete and batch processing complete: Starting close stream task for $stream"
}
Expand Down Expand Up @@ -291,12 +296,9 @@ class DefaultDestinationTaskLauncher(
}

override suspend fun handleException(e: Exception) {
catalog.streams.forEach {
enqueue(
failStreamTaskFactory.make(this, e, it.descriptor),
withExceptionHandling = false
)
}
catalog.streams
.map { failStreamTaskFactory.make(this, e, it.descriptor) }
.forEach { enqueue(it, withExceptionHandling = false) }
}

override suspend fun handleFailStreamComplete(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,50 @@

package io.airbyte.cdk.load.task.implementor

import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.write.StreamLoader
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton

interface ProcessBatchTask : ImplementorScope

/** Wraps @[StreamLoader.processBatch] and handles the resulting batch. */
class DefaultProcessBatchTask(
private val syncManager: SyncManager,
private val batchEnvelope: BatchEnvelope<*>,
private val streamDescriptor: DestinationStream.Descriptor,
private val batchQueue: MultiProducerChannel<BatchEnvelope<*>>,
private val taskLauncher: DestinationTaskLauncher
) : ProcessBatchTask {

override suspend fun execute() {
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
val nextBatch = streamLoader.processBatch(batchEnvelope.batch)
val nextWrapped = batchEnvelope.withBatch(nextBatch)
taskLauncher.handleNewBatch(streamDescriptor, nextWrapped)
batchQueue.consume().collect { batchEnvelope ->
val streamLoader = syncManager.getOrAwaitStreamLoader(batchEnvelope.streamDescriptor)
val nextBatch = streamLoader.processBatch(batchEnvelope.batch)
val nextWrapped = batchEnvelope.withBatch(nextBatch)
taskLauncher.handleNewBatch(nextWrapped.streamDescriptor, nextWrapped)
}
}
}

interface ProcessBatchTaskFactory {
fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream.Descriptor,
batchEnvelope: BatchEnvelope<*>
): ProcessBatchTask
}

@Singleton
@Secondary
class DefaultProcessBatchTaskFactory(private val syncManager: SyncManager) :
ProcessBatchTaskFactory {
class DefaultProcessBatchTaskFactory(
private val syncManager: SyncManager,
@Named("batchQueue") private val batchQueue: MultiProducerChannel<BatchEnvelope<*>>
) : ProcessBatchTaskFactory {
override fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream.Descriptor,
batchEnvelope: BatchEnvelope<*>
): ProcessBatchTask {
return DefaultProcessBatchTask(syncManager, batchEnvelope, stream, taskLauncher)
return DefaultProcessBatchTask(syncManager, batchQueue, taskLauncher)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DefaultProcessFileTask(

val batch = streamLoader.processFile(file)

val wrapped = BatchEnvelope(batch, Range.singleton(index))
val wrapped = BatchEnvelope(batch, Range.singleton(index), streamDescriptor)
taskLauncher.handleNewBatch(streamDescriptor, wrapped)
}
}
Expand Down
Loading

0 comments on commit 7a0bcfc

Please sign in to comment.