Skip to content

Commit

Permalink
Load CDK: S3V2 processes in processRecords, uploads in processBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt committed Dec 17, 2024
1 parent fd430b2 commit 3ec5951
Show file tree
Hide file tree
Showing 47 changed files with 1,049 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,3 @@ airbyte:
flush:
rate-ms: 900000 # 15 minutes
window-ms: 900000 # 15 minutes
destination:
record-batch-size: ${AIRBYTE_DESTINATION_RECORD_BATCH_SIZE:209715200}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {

override suspend fun processRecords(
records: Iterator<DestinationRecord>,
totalSizeBytes: Long
totalSizeBytes: Long,
endOfStream: Boolean
): Batch {
return LocalBatch(records.asSequence().toList())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ import java.nio.file.Path
* ```
*/
abstract class DestinationConfiguration : Configuration {
open val recordBatchSizeBytes: Long = 200L * 1024L * 1024L
open val recordBatchSizeBytes: Long = DEFAULT_RECORD_BATCH_SIZE_BYTES
open val processEmptyFiles: Boolean = false
open val tmpFileDirectory: Path = Path.of("airbyte-cdk-load")

/** Memory queue settings */
Expand All @@ -88,6 +89,10 @@ abstract class DestinationConfiguration : Configuration {
open val numProcessBatchWorkers: Int = 5
open val batchQueueDepth: Int = 10

companion object {
const val DEFAULT_RECORD_BATCH_SIZE_BYTES = 200L * 1024L * 1024L
}

/**
* Micronaut factory which glues [ConfigurationSpecificationSupplier] and
* [DestinationConfigurationFactory] together to produce a [DestinationConfiguration] singleton.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SyncBeanFactory {
val capacity = min(maxBatchesMinusUploadOverhead, idealDepth)
log.info { "Creating file aggregate queue with limit $capacity" }
val channel = Channel<FileAggregateMessage>(capacity)
return MultiProducerChannel(streamCount.toLong(), channel)
return MultiProducerChannel(streamCount.toLong(), channel, "fileAggregateQueue")
}

@Singleton
Expand All @@ -77,6 +77,6 @@ class SyncBeanFactory {
config: DestinationConfiguration,
): MultiProducerChannel<BatchEnvelope<*>> {
val channel = Channel<BatchEnvelope<*>>(config.batchQueueDepth)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel, "batchQueue")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ class DefaultSpillFileProvider(val config: DestinationConfiguration) : SpillFile
override fun createTempFile(): Path {
val directory = config.tmpFileDirectory
Files.createDirectories(directory)
return Files.createTempFile(directory, "staged-raw-records", "jsonl")
return Files.createTempFile(directory, "staged-raw-records", ".jsonl")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ data class BatchEnvelope<B : Batch>(
) {
constructor(
batch: B,
range: Range<Long>,
range: Range<Long>?,
streamDescriptor: DestinationStream.Descriptor
) : this(
batch = batch,
ranges = TreeRangeSet.create(listOf(range)),
ranges = range?.let { TreeRangeSet.create(listOf(range)) } ?: TreeRangeSet.create(),
streamDescriptor = streamDescriptor
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import kotlinx.coroutines.channels.Channel
class MultiProducerChannel<T>(
producerCount: Long,
override val channel: Channel<T>,
private val name: String,
) : ChannelMessageQueue<T>() {
private val log = KotlinLogging.logger {}
private val initializedProducerCount = producerCount
Expand All @@ -23,7 +24,7 @@ class MultiProducerChannel<T>(
override suspend fun close() {
val count = producerCount.decrementAndGet()
log.info {
"Closing producer (active count=$count, initialized count: $initializedProducerCount)"
"Closing producer $name (active count=$count, initialized count: $initializedProducerCount)"
}
if (count == 0L) {
log.info { "Closing underlying queue" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class DefaultDestinationTaskLauncher(
val setupTask = setupTaskFactory.make(this)
enqueue(setupTask)

// TODO: pluggable file transfer
if (!fileTransferEnabled) {
// Start a spill-to-disk task for each record stream
catalog.streams.forEach { stream ->
Expand Down Expand Up @@ -264,16 +265,12 @@ class DefaultDestinationTaskLauncher(
}

if (streamManager.isBatchProcessingComplete()) {
log.info {
"Batch $wrapped complete and batch processing complete: Starting close stream task for $stream"
}
log.info { "Batch processing complete: Starting close stream task for $stream" }

val task = closeStreamTaskFactory.make(this, stream)
enqueue(task)
} else {
log.info {
"Batch $wrapped complete, but batch processing not complete: nothing else to do."
}
log.info { "Batch processing not complete: nothing else to do." }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.write.StreamLoader
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton

interface ProcessBatchTask : ImplementorScope
interface ProcessBatchTask : KillableScope

/** Wraps @[StreamLoader.processBatch] and handles the resulting batch. */
class DefaultProcessBatchTask(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.util.lineSequence
import io.airbyte.cdk.load.util.use
import io.airbyte.cdk.load.write.BatchAccumulator
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap
import kotlin.io.path.inputStream

interface ProcessRecordsTask : KillableScope
Expand All @@ -45,20 +47,31 @@ class DefaultProcessRecordsTask(
private val syncManager: SyncManager,
private val diskManager: ReservationManager,
private val inputQueue: MessageQueue<FileAggregateMessage>,
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessRecordsTask {
private val log = KotlinLogging.logger {}
private val accumulators = ConcurrentHashMap<DestinationStream.Descriptor, BatchAccumulator>()
override suspend fun execute() {
outputQueue.use {
inputQueue.consume().collect { (streamDescriptor, file) ->
log.info { "Fetching stream loader for $streamDescriptor" }
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
val acc =
accumulators.getOrPut(streamDescriptor) {
streamLoader.createBatchAccumulator()
}
log.info { "Processing records from $file for stream $streamDescriptor" }
val batch =
try {
file.localFile.inputStream().use { inputStream ->
val records = inputStream.toRecordIterator()
val batch = streamLoader.processRecords(records, file.totalSizeBytes)
file.localFile.inputStream().use {
val records =
if (file.isEmpty) {
emptyList<DestinationRecord>().listIterator()
} else {
it.toRecordIterator()
}
val batch =
acc.processRecords(records, file.totalSizeBytes, file.endOfStream)
log.info { "Finished processing $file" }
batch
}
Expand Down Expand Up @@ -119,6 +132,7 @@ class DefaultProcessRecordsTaskFactory(
@Named("fileAggregateQueue") private val inputQueue: MessageQueue<FileAggregateMessage>,
@Named("batchQueue") private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessRecordsTaskFactory {

override fun make(
taskLauncher: DestinationTaskLauncher,
): ProcessRecordsTask {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package io.airbyte.cdk.load.task.internal

import com.google.common.collect.Range
import com.google.common.collect.TreeRangeSet
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.file.SpillFileProvider
import io.airbyte.cdk.load.message.Batch
Expand Down Expand Up @@ -54,7 +55,8 @@ class DefaultSpillToDiskTask(
private val flushStrategy: FlushStrategy,
val streamDescriptor: DestinationStream.Descriptor,
private val diskManager: ReservationManager,
private val taskLauncher: DestinationTaskLauncher
private val taskLauncher: DestinationTaskLauncher,
private val processEmptyFiles: Boolean,
) : SpillToDiskTask {
private val log = KotlinLogging.logger {}

Expand Down Expand Up @@ -124,7 +126,7 @@ class DefaultSpillToDiskTask(
event: StreamEndEvent,
): FileAccumulator {
val (spillFile, outputStream, timeWindow, range, sizeBytes) = acc
if (sizeBytes == 0L) {
if (sizeBytes == 0L && !processEmptyFiles) {
log.info { "Skipping empty file $spillFile" }
// Cleanup empty file
spillFile.deleteExisting()
Expand All @@ -138,7 +140,12 @@ class DefaultSpillToDiskTask(
)
taskLauncher.handleNewBatch(streamDescriptor, empty)
} else {
val nextRange = range.withNextAdjacentValue(event.index)
val nextRange =
if (sizeBytes == 0L) {
null
} else {
range.withNextAdjacentValue(event.index)
}
val file =
SpilledRawMessagesLocalFile(
spillFile,
Expand Down Expand Up @@ -203,6 +210,7 @@ interface SpillToDiskTaskFactory {

@Singleton
class DefaultSpillToDiskTaskFactory(
private val config: DestinationConfiguration,
private val fileAccFactory: FileAccumulatorFactory,
private val queueSupplier:
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
Expand All @@ -224,6 +232,7 @@ class DefaultSpillToDiskTaskFactory(
stream,
diskManager,
taskLauncher,
config.processEmptyFiles,
)
}
}
Expand Down Expand Up @@ -255,6 +264,9 @@ data class FileAccumulator(
data class SpilledRawMessagesLocalFile(
val localFile: Path,
val totalSizeBytes: Long,
val indexRange: Range<Long>,
val indexRange: Range<Long>?,
val endOfStream: Boolean = false
)
) {
val isEmpty
get() = totalSizeBytes == 0L
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,55 @@ import io.airbyte.cdk.load.message.SimpleBatch
import io.airbyte.cdk.load.state.StreamProcessingFailed

/**
* Implementor interface. The framework calls open and close once per stream at the beginning and
* end of processing. The framework calls processRecords once per batch of records as batches of the
* configured size become available. (Specified in @
* [io.airbyte.cdk.command.WriteConfiguration.recordBatchSizeBytes])
* Implementor interface.
*
* [start] is called once before any records are processed.
*
* [processRecords] is called whenever a batch of records is available for processing, and only
* after [start] has returned successfully. The return value is a client-defined implementation of @
* [Batch] that the framework may pass to [processBatch] and/or [finalize]. (See @[Batch] for more
* details.)
* [processRecords] is called whenever a batch of records is available for processing (of the size
* configured in [io.airbyte.cdk.load.command.DestinationConfiguration.recordBatchSizeBytes]) and
* only after [start] has returned successfully. The return value is a client-defined implementation
* of @ [Batch] that the framework may pass to [processBatch]. (See @[Batch] for more details.)
*
* [processRecords] may be called concurrently by multiple workers, so it should be thread-safe if
* [io.airbyte.cdk.load.command.DestinationConfiguration.numProcessRecordsWorkers] > 1. For a
* non-thread-safe alternative, use [createBatchAccumulator].
*
* [createBatchAccumulator] returns an optional new instance of a [BatchAccumulator] to use for
* record processing instead of this stream loader. By default, it returns a reference to the stream
* loader itself. Use this interface if you want each record processing worker to use a separate
* instance (with its own state, etc).
*
* [processBatch] is called once per incomplete batch returned by either [processRecords] or
* [processBatch] itself.
* [processBatch] itself. It must be thread-safe if
* [io.airbyte.cdk.load.command.DestinationConfiguration.numProcessBatchWorkers] > 1. If
* [processRecords] never returns a non-[Batch.State.COMPLETE] batch, [processBatch] will never be
* called.
*
* [finalize] is called once after all records and batches have been processed successfully.
* NOTE: even if [processBatch] returns a not-[Batch.State.COMPLETE] batch, it will be called again.
* TODO: allow the client to specify subsequent processing stages instead.
*
* [close] is called once after all records have been processed, regardless of success or failure.
* If there are failed batches, they are passed in as an argument.
* [close] is called once after all records have been processed, regardless of success or failure,
* but only if [start] returned successfully. If any exception was thrown during processing, it is
* passed as an argument to [close].
*/
interface StreamLoader {
interface StreamLoader : BatchAccumulator {
val stream: DestinationStream

suspend fun start() {}
suspend fun processRecords(records: Iterator<DestinationRecord>, totalSizeBytes: Long): Batch
suspend fun createBatchAccumulator(): BatchAccumulator = this

suspend fun processFile(file: DestinationFile): Batch
suspend fun processBatch(batch: Batch): Batch = SimpleBatch(Batch.State.COMPLETE)
suspend fun close(streamFailure: StreamProcessingFailed? = null) {}
}

interface BatchAccumulator {
suspend fun processRecords(
records: Iterator<DestinationRecord>,
totalSizeBytes: Long,
endOfStream: Boolean = false
): Batch =
throw NotImplementedError(
"processRecords must be implemented if createBatchAccumulator is overridden"
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class MultiProducerChannelTest {

@BeforeEach
fun setup() {
channel = MultiProducerChannel(size, wrapped)
channel = MultiProducerChannel(size, wrapped, "test")
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DefaultDestinationTaskLauncher
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.util.write
import io.airbyte.cdk.load.write.BatchAccumulator
import io.airbyte.cdk.load.write.StreamLoader
import io.mockk.coEvery
import io.mockk.coVerify
Expand All @@ -36,6 +37,7 @@ class ProcessRecordsTaskTest {
private lateinit var diskManager: ReservationManager
private lateinit var deserializer: Deserializer<DestinationMessage>
private lateinit var streamLoader: StreamLoader
private lateinit var batchAccumulator: BatchAccumulator
private lateinit var inputQueue: MessageQueue<FileAggregateMessage>
private lateinit var processRecordsTaskFactory: DefaultProcessRecordsTaskFactory
private lateinit var launcher: DefaultDestinationTaskLauncher
Expand All @@ -49,7 +51,9 @@ class ProcessRecordsTaskTest {
outputQueue = mockk(relaxed = true)
syncManager = mockk(relaxed = true)
streamLoader = mockk(relaxed = true)
batchAccumulator = mockk(relaxed = true)
coEvery { syncManager.getOrAwaitStreamLoader(any()) } returns streamLoader
coEvery { streamLoader.createBatchAccumulator() } returns batchAccumulator
launcher = mockk(relaxed = true)
deserializer = mockk(relaxed = true)
coEvery { deserializer.deserialize(any()) } answers
Expand Down Expand Up @@ -106,7 +110,7 @@ class ProcessRecordsTaskTest {
files.map { FileAggregateMessage(descriptor, it) }.asFlow()

// Process records returns batches in 3 states.
coEvery { streamLoader.processRecords(any(), any()) } answers
coEvery { batchAccumulator.processRecords(any(), any()) } answers
{
MockBatch(
groupId = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class SpillToDiskTaskTest {
MockDestinationCatalogFactory.stream1.descriptor,
diskManager,
taskLauncher,
false,
)
}

Expand Down Expand Up @@ -183,6 +184,7 @@ class SpillToDiskTaskTest {
diskManager = ReservationManager(Fixtures.INITIAL_DISK_CAPACITY)
spillToDiskTaskFactory =
DefaultSpillToDiskTaskFactory(
MockDestinationConfiguration(),
fileAccumulatorFactory,
queueSupplier,
MockFlushStrategy(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ airbyte:
rate-ms: 900000 # 15 minutes
window-ms: 900000 # 15 minutes
destination:
record-batch-size: 1 # 1 byte for testing; 1 record => 1 upload
record-batch-size-override: 1 # 1 byte for testing; 1 record => 1 upload
Loading

0 comments on commit 3ec5951

Please sign in to comment.