Skip to content

Commit

Permalink
Load CDK: S3V2 processes in processRecords, uploads in processBatch (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Dec 18, 2024
1 parent 8ba7ecb commit 6118419
Show file tree
Hide file tree
Showing 56 changed files with 1,188 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ airbyte:
rate-ms: 900000 # 15 minutes
window-ms: 900000 # 15 minutes
destination:
record-batch-size: ${AIRBYTE_DESTINATION_RECORD_BATCH_SIZE:209715200}
record-batch-size-override: ${AIRBYTE_DESTINATION_RECORD_BATCH_SIZE_OVERRIDE:null}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {

override suspend fun processRecords(
records: Iterator<DestinationRecord>,
totalSizeBytes: Long
totalSizeBytes: Long,
endOfStream: Boolean
): Batch {
return LocalBatch(records.asSequence().toList())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ import java.nio.file.Path
* ```
*/
abstract class DestinationConfiguration : Configuration {
open val recordBatchSizeBytes: Long = 200L * 1024L * 1024L
open val recordBatchSizeBytes: Long = DEFAULT_RECORD_BATCH_SIZE_BYTES
open val processEmptyFiles: Boolean = false
open val tmpFileDirectory: Path = Path.of("airbyte-cdk-load")

/** Memory queue settings */
Expand All @@ -88,6 +89,10 @@ abstract class DestinationConfiguration : Configuration {
open val numProcessBatchWorkers: Int = 5
open val batchQueueDepth: Int = 10

companion object {
const val DEFAULT_RECORD_BATCH_SIZE_BYTES = 200L * 1024L * 1024L
}

/**
* Micronaut factory which glues [ConfigurationSpecificationSupplier] and
* [DestinationConfigurationFactory] together to produce a [DestinationConfiguration] singleton.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SyncBeanFactory {
val capacity = min(maxBatchesMinusUploadOverhead, idealDepth)
log.info { "Creating file aggregate queue with limit $capacity" }
val channel = Channel<FileAggregateMessage>(capacity)
return MultiProducerChannel(streamCount.toLong(), channel)
return MultiProducerChannel(streamCount.toLong(), channel, "fileAggregateQueue")
}

@Singleton
Expand All @@ -77,6 +77,6 @@ class SyncBeanFactory {
config: DestinationConfiguration,
): MultiProducerChannel<BatchEnvelope<*>> {
val channel = Channel<BatchEnvelope<*>>(config.batchQueueDepth)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel, "batchQueue")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ class DefaultSpillFileProvider(val config: DestinationConfiguration) : SpillFile
override fun createTempFile(): Path {
val directory = config.tmpFileDirectory
Files.createDirectories(directory)
return Files.createTempFile(directory, "staged-raw-records", "jsonl")
return Files.createTempFile(directory, "staged-raw-records", ".jsonl")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ interface Batch {
val groupId: String?

enum class State {
PROCESSED,
LOCAL,
PERSISTED,
COMPLETE
Expand Down Expand Up @@ -93,11 +94,11 @@ data class BatchEnvelope<B : Batch>(
) {
constructor(
batch: B,
range: Range<Long>,
range: Range<Long>?,
streamDescriptor: DestinationStream.Descriptor
) : this(
batch = batch,
ranges = TreeRangeSet.create(listOf(range)),
ranges = range?.let { TreeRangeSet.create(listOf(range)) } ?: TreeRangeSet.create(),
streamDescriptor = streamDescriptor
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import kotlinx.coroutines.channels.Channel
class MultiProducerChannel<T>(
producerCount: Long,
override val channel: Channel<T>,
private val name: String,
) : ChannelMessageQueue<T>() {
private val log = KotlinLogging.logger {}
private val initializedProducerCount = producerCount
Expand All @@ -23,7 +24,7 @@ class MultiProducerChannel<T>(
override suspend fun close() {
val count = producerCount.decrementAndGet()
log.info {
"Closing producer (active count=$count, initialized count: $initializedProducerCount)"
"Closing producer $name (active count=$count, initialized count: $initializedProducerCount)"
}
if (count == 0L) {
log.info { "Closing underlying queue" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,13 @@ class DefaultStreamManager(
}

override fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>) {

rangesState[batch.batch.state]
?: throw IllegalArgumentException("Invalid batch state: ${batch.batch.state}")

// If the batch is part of a group, update all ranges associated with its groupId
// to the most advanced state. Otherwise, just use the ranges provided.
val cachedRangesMaybe = batch.batch.groupId?.let { cachedRangesById[batch.batch.groupId] }

log.info {
"Updating state for stream ${stream.descriptor} with batch $batch using cached ranges $cachedRangesMaybe"
}

val stateToSet =
cachedRangesMaybe?.state?.let { maxOf(it, batch.batch.state) } ?: batch.batch.state
val rangesToUpdate = TreeRangeSet.create(batch.ranges)
Expand All @@ -178,24 +173,37 @@ class DefaultStreamManager(
rangesToUpdate.asRanges().map { it.span(Range.singleton(it.upperEndpoint() + 1)) }

when (stateToSet) {
Batch.State.PERSISTED -> {
rangesState[Batch.State.PERSISTED]?.addAll(expanded)
}
Batch.State.COMPLETE -> {
// A COMPLETED state implies PERSISTED, so also mark PERSISTED.
rangesState[Batch.State.PERSISTED]?.addAll(expanded)
rangesState[Batch.State.COMPLETE]?.addAll(expanded)
}
else -> Unit
}

log.info {
"Updated ranges for ${stream.descriptor}[${batch.batch.state}]: $expanded. PERSISTED is also updated on COMPLETE."
else -> {
// For all other states, just mark the state.
rangesState[stateToSet]?.addAll(expanded)
}
}

batch.batch.groupId?.also {
cachedRangesById[it] = CachedRanges(stateToSet, rangesToUpdate)
}

log.info {
val groupLineMaybe =
if (cachedRangesMaybe != null) {
"\n (from group: ${cachedRangesMaybe.state}->${cachedRangesMaybe.ranges})\n"
} else {
""
}
""" For stream ${stream.descriptor.namespace}.${stream.descriptor.name}
From batch ${batch.batch.state}->${batch.ranges} (groupId ${batch.batch.groupId})$groupLineMaybe
Added $stateToSet->$rangesToUpdate to ${stream.descriptor.namespace}.${stream.descriptor.name}
PROCESSED: ${rangesState[Batch.State.PROCESSED]}
LOCAL: ${rangesState[Batch.State.LOCAL]}
PERSISTED: ${rangesState[Batch.State.PERSISTED]}
COMPLETE: ${rangesState[Batch.State.COMPLETE]}
""".trimIndent()
}
}

/** True if all records in `[0, index)` have reached the given state. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class DefaultDestinationTaskLauncher(
val setupTask = setupTaskFactory.make(this)
enqueue(setupTask)

// TODO: pluggable file transfer
if (!fileTransferEnabled) {
// Start a spill-to-disk task for each record stream
catalog.streams.forEach { stream ->
Expand Down Expand Up @@ -264,16 +265,12 @@ class DefaultDestinationTaskLauncher(
}

if (streamManager.isBatchProcessingComplete()) {
log.info {
"Batch $wrapped complete and batch processing complete: Starting close stream task for $stream"
}
log.info { "Batch processing complete: Starting close stream task for $stream" }

val task = closeStreamTaskFactory.make(this, stream)
enqueue(task)
} else {
log.info {
"Batch $wrapped complete, but batch processing not complete: nothing else to do."
}
log.info { "Batch processing not complete: nothing else to do." }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.write.StreamLoader
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton

interface ProcessBatchTask : ImplementorScope
interface ProcessBatchTask : KillableScope

/** Wraps @[StreamLoader.processBatch] and handles the resulting batch. */
class DefaultProcessBatchTask(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

package io.airbyte.cdk.load.task.implementor

import com.google.common.collect.Range
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.Deserializer
import io.airbyte.cdk.load.message.DestinationMessage
Expand All @@ -21,12 +24,14 @@ import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.util.lineSequence
import io.airbyte.cdk.load.util.use
import io.airbyte.cdk.load.write.BatchAccumulator
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap
import kotlin.io.path.inputStream

interface ProcessRecordsTask : KillableScope
Expand All @@ -40,25 +45,37 @@ interface ProcessRecordsTask : KillableScope
* moved to the task launcher.
*/
class DefaultProcessRecordsTask(
private val config: DestinationConfiguration,
private val taskLauncher: DestinationTaskLauncher,
private val deserializer: Deserializer<DestinationMessage>,
private val syncManager: SyncManager,
private val diskManager: ReservationManager,
private val inputQueue: MessageQueue<FileAggregateMessage>,
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessRecordsTask {
private val log = KotlinLogging.logger {}
private val accumulators = ConcurrentHashMap<DestinationStream.Descriptor, BatchAccumulator>()
override suspend fun execute() {
outputQueue.use {
inputQueue.consume().collect { (streamDescriptor, file) ->
log.info { "Fetching stream loader for $streamDescriptor" }
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
val acc =
accumulators.getOrPut(streamDescriptor) {
streamLoader.createBatchAccumulator()
}
log.info { "Processing records from $file for stream $streamDescriptor" }
val batch =
try {
file.localFile.inputStream().use { inputStream ->
val records = inputStream.toRecordIterator()
val batch = streamLoader.processRecords(records, file.totalSizeBytes)
file.localFile.inputStream().use {
val records =
if (file.isEmpty) {
emptyList<DestinationRecord>().listIterator()
} else {
it.toRecordIterator()
}
val batch =
acc.processRecords(records, file.totalSizeBytes, file.endOfStream)
log.info { "Finished processing $file" }
batch
}
Expand All @@ -67,19 +84,35 @@ class DefaultProcessRecordsTask(
file.localFile.toFile().delete()
diskManager.release(file.totalSizeBytes)
}

val wrapped = BatchEnvelope(batch, file.indexRange, streamDescriptor)
log.info { "Updating batch $wrapped for $streamDescriptor" }
taskLauncher.handleNewBatch(streamDescriptor, wrapped)
if (batch.requiresProcessing) {
outputQueue.publish(wrapped)
} else {
log.info { "Batch $wrapped requires no further processing." }
handleBatch(streamDescriptor, batch, file.indexRange)
}
if (config.processEmptyFiles) {
// TODO: Get rid of the need to handle empty files please
log.info { "Forcing finalization of all accumulators." }
accumulators.forEach { (streamDescriptor, acc) ->
val finalBatch =
acc.processRecords(emptyList<DestinationRecord>().listIterator(), 0, true)
handleBatch(streamDescriptor, finalBatch, null)
}
}
}
}

private suspend fun handleBatch(
streamDescriptor: DestinationStream.Descriptor,
batch: Batch,
indexRange: Range<Long>?
) {
val wrapped = BatchEnvelope(batch, indexRange, streamDescriptor)
taskLauncher.handleNewBatch(streamDescriptor, wrapped)
log.info { "Updating batch $wrapped for $streamDescriptor" }
if (batch.requiresProcessing) {
outputQueue.publish(wrapped)
} else {
log.info { "Batch $wrapped requires no further processing." }
}
}

private fun InputStream.toRecordIterator(): Iterator<DestinationRecord> {
return lineSequence()
.map {
Expand Down Expand Up @@ -113,16 +146,19 @@ data class FileAggregateMessage(
@Singleton
@Secondary
class DefaultProcessRecordsTaskFactory(
private val config: DestinationConfiguration,
private val deserializer: Deserializer<DestinationMessage>,
private val syncManager: SyncManager,
@Named("diskManager") private val diskManager: ReservationManager,
@Named("fileAggregateQueue") private val inputQueue: MessageQueue<FileAggregateMessage>,
@Named("batchQueue") private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessRecordsTaskFactory {

override fun make(
taskLauncher: DestinationTaskLauncher,
): ProcessRecordsTask {
return DefaultProcessRecordsTask(
config,
taskLauncher,
deserializer,
syncManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ class DefaultInputConsumerTask(
is DestinationRecordStreamComplete -> {
reserved.release() // safe because multiple calls conflate
val wrapped = StreamEndEvent(index = manager.markEndOfStream(true))
log.info { "Read COMPLETE for stream $stream" }
recordQueue.publish(reserved.replace(wrapped))
recordQueue.close()
}
is DestinationRecordStreamIncomplete -> {
reserved.release() // safe because multiple calls conflate
val wrapped = StreamEndEvent(index = manager.markEndOfStream(false))
log.info { "Read INCOMPLETE for stream $stream" }
recordQueue.publish(reserved.replace(wrapped))
recordQueue.close()
}
Expand Down
Loading

0 comments on commit 6118419

Please sign in to comment.