Skip to content

Commit

Permalink
Add file support to the bulk CDK (#49931)
Browse files Browse the repository at this point in the history
  • Loading branch information
benmoriceau authored Dec 24, 2024
1 parent 6fa2301 commit 671d4de
Show file tree
Hide file tree
Showing 33 changed files with 339 additions and 311 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
return LocalBatch(records.asSequence().toList())
}

override suspend fun processFile(file: DestinationFile): Batch {
return LocalFileBatch(file)
}

override suspend fun processBatch(batch: Batch): Batch {
return when (batch) {
is LocalBatch -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ abstract class DestinationConfiguration : Configuration {

open val numProcessRecordsWorkers: Int = 2
open val numProcessBatchWorkers: Int = 5
open val numProcessBatchWorkersForFileTransfer: Int = 3
open val batchQueueDepth: Int = 10

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.task.implementor.FileAggregateMessage
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Value
Expand Down Expand Up @@ -79,4 +80,13 @@ class SyncBeanFactory {
val channel = Channel<BatchEnvelope<*>>(config.batchQueueDepth)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel, "batchQueue")
}

@Singleton
@Named("fileMessageQueue")
fun fileMessageQueue(
config: DestinationConfiguration,
): MultiProducerChannel<FileTransferQueueMessage> {
val channel = Channel<FileTransferQueueMessage>(config.batchQueueDepth)
return MultiProducerChannel(1, channel, "fileMessageQueue")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.CheckpointMessageWrapped
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationStreamEvent
import io.airbyte.cdk.load.message.MessageQueue
import io.airbyte.cdk.load.message.MessageQueueSupplier
import io.airbyte.cdk.load.message.QueueWriter
import io.airbyte.cdk.load.state.Reserved
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.implementor.CloseStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.FailStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.FailSyncTaskFactory
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.airbyte.cdk.load.task.implementor.OpenStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.ProcessBatchTaskFactory
import io.airbyte.cdk.load.task.implementor.ProcessFileTaskFactory
Expand All @@ -36,6 +37,7 @@ import io.airbyte.cdk.load.util.setOnce
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import io.micronaut.context.annotation.Value
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicBoolean
import kotlinx.coroutines.CancellationException
Expand All @@ -49,8 +51,6 @@ interface DestinationTaskLauncher : TaskLauncher {
suspend fun handleNewBatch(stream: DestinationStream.Descriptor, wrapped: BatchEnvelope<*>)
suspend fun handleStreamClosed(stream: DestinationStream.Descriptor)
suspend fun handleTeardownComplete(success: Boolean = true)
suspend fun handleFile(stream: DestinationStream.Descriptor, file: DestinationFile, index: Long)

suspend fun handleException(e: Exception)
suspend fun handleFailStreamComplete(stream: DestinationStream.Descriptor, e: Exception)
}
Expand Down Expand Up @@ -128,6 +128,7 @@ class DefaultDestinationTaskLauncher(
private val recordQueueSupplier:
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
private val checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
@Named("fileMessageQueue") private val fileTransferQueue: MessageQueue<FileTransferQueueMessage>
) : DestinationTaskLauncher {
private val log = KotlinLogging.logger {}

Expand Down Expand Up @@ -179,7 +180,8 @@ class DefaultDestinationTaskLauncher(
inputFlow = inputFlow,
recordQueueSupplier = recordQueueSupplier,
checkpointQueue = checkpointQueue,
this,
fileTransferQueue = fileTransferQueue,
destinationTaskLauncher = this,
)
enqueue(inputConsumerTask)

Expand Down Expand Up @@ -208,6 +210,17 @@ class DefaultDestinationTaskLauncher(
val task = processBatchTaskFactory.make(this)
enqueue(task)
}
} else {
repeat(config.numProcessRecordsWorkers) {
log.info { "Launching process file task $it" }
enqueue(processFileTaskFactory.make(this))
}

repeat(config.numProcessBatchWorkersForFileTransfer) {
log.info { "Launching process batch task $it" }
val task = processBatchTaskFactory.make(this)
enqueue(task)
}
}

// Start flush task
Expand Down Expand Up @@ -283,14 +296,6 @@ class DefaultDestinationTaskLauncher(
}
}

override suspend fun handleFile(
stream: DestinationStream.Descriptor,
file: DestinationFile,
index: Long
) {
enqueue(processFileTaskFactory.make(this, stream, file, index))
}

override suspend fun handleException(e: Exception) {
catalog.streams
.map { failStreamTaskFactory.make(this, e, it.descriptor) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton
Expand All @@ -22,7 +23,7 @@ class DefaultProcessBatchTask(
private val batchQueue: MultiProducerChannel<BatchEnvelope<*>>,
private val taskLauncher: DestinationTaskLauncher
) : ProcessBatchTask {

val log = KotlinLogging.logger {}
override suspend fun execute() {
batchQueue.consume().collect { batchEnvelope ->
val streamLoader = syncManager.getOrAwaitStreamLoader(batchEnvelope.streamDescriptor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,73 @@

package io.airbyte.cdk.load.task.implementor

import com.google.common.collect.Range
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.MessageQueue
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.util.use
import io.airbyte.cdk.load.write.FileBatchAccumulator
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap

interface ProcessFileTask : ImplementorScope

class DefaultProcessFileTask(
private val streamDescriptor: DestinationStream.Descriptor,
private val taskLauncher: DestinationTaskLauncher,
private val syncManager: SyncManager,
private val file: DestinationFile,
private val index: Long,
private val taskLauncher: DestinationTaskLauncher,
private val inputQueue: MessageQueue<FileTransferQueueMessage>,
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessFileTask {
val log = KotlinLogging.logger {}
private val accumulators =
ConcurrentHashMap<DestinationStream.Descriptor, FileBatchAccumulator>()

override suspend fun execute() {
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
outputQueue.use {
inputQueue.consume().collect { (streamDescriptor, file, index) ->
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)

val batch = streamLoader.processFile(file)
val acc =
accumulators.getOrPut(streamDescriptor) {
streamLoader.createFileBatchAccumulator(outputQueue)
}

val wrapped = BatchEnvelope(batch, Range.singleton(index), streamDescriptor)
taskLauncher.handleNewBatch(streamDescriptor, wrapped)
acc.processFilePart(file, index)
}
}
}
}

interface ProcessFileTaskFactory {
fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream.Descriptor,
file: DestinationFile,
index: Long,
): ProcessFileTask
}

@Singleton
@Secondary
class DefaultFileRecordsTaskFactory(
private val syncManager: SyncManager,
@Named("fileMessageQueue")
private val fileTransferQueue: MessageQueue<FileTransferQueueMessage>,
@Named("batchQueue") private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessFileTaskFactory {
override fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream.Descriptor,
file: DestinationFile,
index: Long,
): ProcessFileTask {
return DefaultProcessFileTask(stream, taskLauncher, syncManager, file, index)
return DefaultProcessFileTask(syncManager, taskLauncher, fileTransferQueue, outputQueue)
}
}

data class FileTransferQueueMessage(
val streamDescriptor: DestinationStream.Descriptor,
val file: DestinationFile,
val index: Long,
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage
import io.airbyte.cdk.load.message.DestinationStreamEvent
import io.airbyte.cdk.load.message.GlobalCheckpoint
import io.airbyte.cdk.load.message.GlobalCheckpointWrapped
import io.airbyte.cdk.load.message.MessageQueue
import io.airbyte.cdk.load.message.MessageQueueSupplier
import io.airbyte.cdk.load.message.QueueWriter
import io.airbyte.cdk.load.message.SimpleBatch
Expand All @@ -33,9 +34,11 @@ import io.airbyte.cdk.load.state.Reserved
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.airbyte.cdk.load.util.use
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton

interface InputConsumerTask : KillableScope
Expand All @@ -60,6 +63,8 @@ class DefaultInputConsumerTask(
private val checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
private val syncManager: SyncManager,
private val destinationTaskLauncher: DestinationTaskLauncher,
@Named("fileMessageQueue")
private val fileTransferQueue: MessageQueue<FileTransferQueueMessage>,
) : InputConsumerTask {
private val log = KotlinLogging.logger {}

Expand Down Expand Up @@ -96,15 +101,17 @@ class DefaultInputConsumerTask(
}
is DestinationFile -> {
val index = manager.countRecordIn()
destinationTaskLauncher.handleFile(stream, message, index)
// destinationTaskLauncher.handleFile(stream, message, index)
fileTransferQueue.publish(FileTransferQueueMessage(stream, message, index))
}
is DestinationFileStreamComplete -> {
reserved.release() // safe because multiple calls conflate
manager.markEndOfStream(true)
fileTransferQueue.close()
val envelope =
BatchEnvelope(
SimpleBatch(Batch.State.COMPLETE),
streamDescriptor = message.stream
streamDescriptor = message.stream,
)
destinationTaskLauncher.handleNewBatch(stream, envelope)
}
Expand Down Expand Up @@ -197,6 +204,7 @@ interface InputConsumerTaskFactory {
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
destinationTaskLauncher: DestinationTaskLauncher,
fileTransferQueue: MessageQueue<FileTransferQueueMessage>
): InputConsumerTask
}

Expand All @@ -211,14 +219,16 @@ class DefaultInputConsumerTaskFactory(private val syncManager: SyncManager) :
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
destinationTaskLauncher: DestinationTaskLauncher,
fileTransferQueue: MessageQueue<FileTransferQueueMessage>,
): InputConsumerTask {
return DefaultInputConsumerTask(
catalog,
inputFlow,
recordQueueSupplier,
checkpointQueue,
syncManager,
destinationTaskLauncher
destinationTaskLauncher,
fileTransferQueue,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package io.airbyte.cdk.load.write

import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.message.SimpleBatch
import io.airbyte.cdk.load.state.StreamProcessingFailed

Expand Down Expand Up @@ -43,13 +45,15 @@ import io.airbyte.cdk.load.state.StreamProcessingFailed
* but only if [start] returned successfully. If any exception was thrown during processing, it is
* passed as an argument to [close].
*/
interface StreamLoader : BatchAccumulator {
interface StreamLoader : BatchAccumulator, FileBatchAccumulator {
val stream: DestinationStream

suspend fun start() {}
suspend fun createBatchAccumulator(): BatchAccumulator = this
suspend fun createFileBatchAccumulator(
outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
): FileBatchAccumulator = this

suspend fun processFile(file: DestinationFile): Batch
suspend fun processBatch(batch: Batch): Batch = SimpleBatch(Batch.State.COMPLETE)
suspend fun close(streamFailure: StreamProcessingFailed? = null) {}
}
Expand All @@ -64,3 +68,15 @@ interface BatchAccumulator {
"processRecords must be implemented if createBatchAccumulator is overridden"
)
}

interface FileBatchAccumulator {
/**
* This is an unusal way to process a message (the DestinationFile). The batch are pushed to the
* queue immediately instead of being return by the method, the main reason is that we nned to
* keep a single instance of a PartFactory for the whole file.
*/
suspend fun processFilePart(file: DestinationFile, index: Long): Unit =
throw NotImplementedError(
"processRecords must be implemented if createBatchAccumulator is overridden"
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import io.airbyte.cdk.load.task.implementor.FailStreamTask
import io.airbyte.cdk.load.task.implementor.FailStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.FailSyncTask
import io.airbyte.cdk.load.task.implementor.FailSyncTaskFactory
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.airbyte.cdk.load.task.implementor.OpenStreamTask
import io.airbyte.cdk.load.task.implementor.OpenStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.ProcessBatchTaskFactory
Expand Down Expand Up @@ -153,7 +154,8 @@ class DestinationTaskLauncherTest<T : ScopedTask> {
MessageQueueSupplier<
DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
destinationTaskLauncher: DestinationTaskLauncher
destinationTaskLauncher: DestinationTaskLauncher,
fileTransferQueue: MessageQueue<FileTransferQueueMessage>,
): InputConsumerTask {
return object : InputConsumerTask {
override suspend fun execute() {
Expand Down
Loading

0 comments on commit 671d4de

Please sign in to comment.