Skip to content

Commit

Permalink
Prerelease test split processing
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt committed Dec 16, 2024
1 parent ccf187b commit 22fb53f
Show file tree
Hide file tree
Showing 33 changed files with 951 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {

override suspend fun processRecords(
records: Iterator<DestinationRecord>,
totalSizeBytes: Long
totalSizeBytes: Long,
endOfStream: Boolean
): Batch {
return LocalBatch(records.asSequence().toList())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ import java.nio.file.Path
* ```
*/
abstract class DestinationConfiguration : Configuration {
open val recordBatchSizeBytes: Long = 200L * 1024L * 1024L
open val recordBatchSizeBytes: Long = 20L * 1024L * 1024L
open val tmpFileDirectory: Path = Path.of("airbyte-cdk-load")

/** Memory queue settings */
open val maxMessageQueueMemoryUsageRatio: Double = 0.2 // 0 => No limit, 1.0 => 100% of JVM heap
open val estimatedRecordMemoryOverheadRatio: Double =
1.1 // 1.0 => No overhead, 2.0 => 100% overhead
open val messageQueuePerStreamLimit: Int = 100

/**
* If we have not flushed state checkpoints in this amount of time, make a best-effort attempt
Expand All @@ -85,7 +86,7 @@ abstract class DestinationConfiguration : Configuration {
open val gracefulCancellationTimeoutMs: Long = 60 * 1000L // 1 minutes

open val numProcessRecordsWorkers: Int = 2
open val numProcessBatchWorkers: Int = 5
open val numProcessBatchWorkers: Int = 10
open val batchQueueDepth: Int = 10

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SyncBeanFactory {
val capacity = min(maxBatchesMinusUploadOverhead, idealDepth)
log.info { "Creating file aggregate queue with limit $capacity" }
val channel = Channel<FileAggregateMessage>(capacity)
return MultiProducerChannel(streamCount.toLong(), channel)
return MultiProducerChannel(streamCount.toLong(), channel, "fileAggregateQueue")
}

@Singleton
Expand All @@ -77,6 +77,6 @@ class SyncBeanFactory {
config: DestinationConfiguration,
): MultiProducerChannel<BatchEnvelope<*>> {
val channel = Channel<BatchEnvelope<*>>(config.batchQueueDepth)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel, "batchQueue")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ class DefaultSpillFileProvider(val config: DestinationConfiguration) : SpillFile
override fun createTempFile(): Path {
val directory = config.tmpFileDirectory
Files.createDirectories(directory)
return Files.createTempFile(directory, "staged-raw-records", "jsonl")
return Files.createTempFile(directory, "staged-raw-records", ".jsonl")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
package io.airbyte.cdk.load.message

import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.state.Reserved
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap
import kotlinx.coroutines.channels.Channel

interface Sized {
val sizeBytes: Long
Expand Down Expand Up @@ -61,7 +63,8 @@ class DestinationStreamEventQueue : ChannelMessageQueue<Reserved<DestinationStre
*/
@Singleton
@Secondary
class DestinationStreamQueueSupplier(catalog: DestinationCatalog) :
class DestinationStreamQueueSupplier(catalog: DestinationCatalog
) :
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>> {
private val queues =
ConcurrentHashMap<DestinationStream.Descriptor, DestinationStreamEventQueue>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import kotlinx.coroutines.channels.Channel
class MultiProducerChannel<T>(
producerCount: Long,
override val channel: Channel<T>,
private val name: String,
) : ChannelMessageQueue<T>() {
private val log = KotlinLogging.logger {}
private val initializedProducerCount = producerCount
Expand All @@ -23,7 +24,7 @@ class MultiProducerChannel<T>(
override suspend fun close() {
val count = producerCount.decrementAndGet()
log.info {
"Closing producer (active count=$count, initialized count: $initializedProducerCount)"
"Closing producer $name (active count=$count, initialized count: $initializedProducerCount)"
}
if (count == 0L) {
log.info { "Closing underlying queue" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class DefaultDestinationTaskLauncher(
val setupTask = setupTaskFactory.make(this)
enqueue(setupTask)

// TODO: pluggable file transfer
if (!fileTransferEnabled) {
// Start a spill-to-disk task for each record stream
catalog.streams.forEach { stream ->
Expand Down Expand Up @@ -264,16 +265,12 @@ class DefaultDestinationTaskLauncher(
}

if (streamManager.isBatchProcessingComplete()) {
log.info {
"Batch $wrapped complete and batch processing complete: Starting close stream task for $stream"
}
log.info { "Batch processing complete: Starting close stream task for $stream" }

val task = closeStreamTaskFactory.make(this, stream)
enqueue(task)
} else {
log.info {
"Batch $wrapped complete, but batch processing not complete: nothing else to do."
}
log.info { "Batch processing not complete: nothing else to do." }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.write.StreamLoader
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton

interface ProcessBatchTask : ImplementorScope
interface ProcessBatchTask : KillableScope

/** Wraps @[StreamLoader.processBatch] and handles the resulting batch. */
class DefaultProcessBatchTask(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.util.lineSequence
import io.airbyte.cdk.load.util.use
import io.airbyte.cdk.load.write.BatchAccumulator
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap
import kotlin.io.path.inputStream

interface ProcessRecordsTask : KillableScope
Expand All @@ -45,20 +47,25 @@ class DefaultProcessRecordsTask(
private val syncManager: SyncManager,
private val diskManager: ReservationManager,
private val inputQueue: MessageQueue<FileAggregateMessage>,
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessRecordsTask {
private val log = KotlinLogging.logger {}
private val accumulators = ConcurrentHashMap<DestinationStream.Descriptor, BatchAccumulator>()
override suspend fun execute() {
outputQueue.use {
inputQueue.consume().collect { (streamDescriptor, file) ->
log.info { "Fetching stream loader for $streamDescriptor" }
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
val acc =
accumulators.getOrPut(streamDescriptor) {
streamLoader.createBatchAccumulator()
}
log.info { "Processing records from $file for stream $streamDescriptor" }
val batch =
try {
file.localFile.inputStream().use { inputStream ->
val records = inputStream.toRecordIterator()
val batch = streamLoader.processRecords(records, file.totalSizeBytes)
val batch = acc.processRecords(records, file.totalSizeBytes)
log.info { "Finished processing $file" }
batch
}
Expand Down Expand Up @@ -119,6 +126,7 @@ class DefaultProcessRecordsTaskFactory(
@Named("fileAggregateQueue") private val inputQueue: MessageQueue<FileAggregateMessage>,
@Named("batchQueue") private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessRecordsTaskFactory {

override fun make(
taskLauncher: DestinationTaskLauncher,
): ProcessRecordsTask {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,55 @@ import io.airbyte.cdk.load.message.SimpleBatch
import io.airbyte.cdk.load.state.StreamProcessingFailed

/**
* Implementor interface. The framework calls open and close once per stream at the beginning and
* end of processing. The framework calls processRecords once per batch of records as batches of the
* configured size become available. (Specified in @
* [io.airbyte.cdk.command.WriteConfiguration.recordBatchSizeBytes])
* Implementor interface.
*
* [start] is called once before any records are processed.
*
* [processRecords] is called whenever a batch of records is available for processing, and only
* after [start] has returned successfully. The return value is a client-defined implementation of @
* [Batch] that the framework may pass to [processBatch] and/or [finalize]. (See @[Batch] for more
* details.)
* [processRecords] is called whenever a batch of records is available for processing (of the size
* configured in [io.airbyte.cdk.load.command.DestinationConfiguration.recordBatchSizeBytes]) and
* only after [start] has returned successfully. The return value is a client-defined implementation
* of @ [Batch] that the framework may pass to [processBatch]. (See @[Batch] for more details.)
*
* [processRecords] may be called concurrently by multiple workers, so it should be thread-safe if
* [io.airbyte.cdk.load.command.DestinationConfiguration.numProcessRecordsWorkers] > 1. For a
* non-thread-safe alternative, use [createBatchAccumulator].
*
* [createBatchAccumulator] returns an optional new instance of a [BatchAccumulator] to use for
* record processing instead of this stream loader. By default, it returns a reference to the stream
* loader itself. Use this interface if you want each record processing worker to use a separate
* instance (with its own state, etc).
*
* [processBatch] is called once per incomplete batch returned by either [processRecords] or
* [processBatch] itself.
* [processBatch] itself. It must be thread-safe if
* [io.airbyte.cdk.load.command.DestinationConfiguration.numProcessBatchWorkers] > 1. If
* [processRecords] never returns a non-[Batch.State.COMPLETE] batch, [processBatch] will never be
* called.
*
* [finalize] is called once after all records and batches have been processed successfully.
* NOTE: even if [processBatch] returns a not-[Batch.State.COMPLETE] batch, it will be called again.
* TODO: allow the client to specify subsequent processing stages instead.
*
* [close] is called once after all records have been processed, regardless of success or failure.
* If there are failed batches, they are passed in as an argument.
* [close] is called once after all records have been processed, regardless of success or failure,
* but only if [start] returned successfully. If any exception was thrown during processing, it is
* passed as an argument to [close].
*/
interface StreamLoader {
interface StreamLoader : BatchAccumulator {
val stream: DestinationStream

suspend fun start() {}
suspend fun processRecords(records: Iterator<DestinationRecord>, totalSizeBytes: Long): Batch
suspend fun createBatchAccumulator(): BatchAccumulator = this

suspend fun processFile(file: DestinationFile): Batch
suspend fun processBatch(batch: Batch): Batch = SimpleBatch(Batch.State.COMPLETE)
suspend fun close(streamFailure: StreamProcessingFailed? = null) {}
}

interface BatchAccumulator {
suspend fun processRecords(
records: Iterator<DestinationRecord>,
totalSizeBytes: Long,
endOfStream: Boolean = false
): Batch =
throw NotImplementedError(
"processRecords must be implemented if createBatchAccumulator is overridden"
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DefaultDestinationTaskLauncher
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.util.write
import io.airbyte.cdk.load.write.BatchAccumulator
import io.airbyte.cdk.load.write.StreamLoader
import io.mockk.coEvery
import io.mockk.coVerify
Expand All @@ -36,6 +37,7 @@ class ProcessRecordsTaskTest {
private lateinit var diskManager: ReservationManager
private lateinit var deserializer: Deserializer<DestinationMessage>
private lateinit var streamLoader: StreamLoader
private lateinit var batchAccumulator: BatchAccumulator
private lateinit var inputQueue: MessageQueue<FileAggregateMessage>
private lateinit var processRecordsTaskFactory: DefaultProcessRecordsTaskFactory
private lateinit var launcher: DefaultDestinationTaskLauncher
Expand All @@ -49,7 +51,9 @@ class ProcessRecordsTaskTest {
outputQueue = mockk(relaxed = true)
syncManager = mockk(relaxed = true)
streamLoader = mockk(relaxed = true)
batchAccumulator = mockk(relaxed = true)
coEvery { syncManager.getOrAwaitStreamLoader(any()) } returns streamLoader
coEvery { streamLoader.createBatchAccumulator() } returns batchAccumulator
launcher = mockk(relaxed = true)
deserializer = mockk(relaxed = true)
coEvery { deserializer.deserialize(any()) } answers
Expand Down Expand Up @@ -106,7 +110,7 @@ class ProcessRecordsTaskTest {
files.map { FileAggregateMessage(descriptor, it) }.asFlow()

// Process records returns batches in 3 states.
coEvery { streamLoader.processRecords(any(), any()) } answers
coEvery { batchAccumulator.processRecords(any(), any()) } answers
{
MockBatch(
groupId = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ interface ObjectStorageClient<T : RemoteObject<*>> {
}

interface StreamingUpload<T : RemoteObject<*>> {
suspend fun uploadPart(part: ByteArray)
/**
* Uploads a part of the object. Each part must have a unique index. The parts do not need to be
* uploaded in order. The index is 1-based.
*/
suspend fun uploadPart(part: ByteArray, index: Int)

/**
* Completes a multipart upload. All parts must be uploaded before completing the upload, and
* there cannot be gaps in the indexes. Multiple calls will return the same object, but only the
* first call will have side effects.
*/
suspend fun complete(): T
}
Loading

0 comments on commit 22fb53f

Please sign in to comment.