Skip to content

Commit

Permalink
Load CDK: S3V2 processes in processRecords, uploads in processBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt committed Dec 14, 2024
1 parent cc06fc5 commit 2a676d6
Show file tree
Hide file tree
Showing 20 changed files with 885 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {

override suspend fun processRecords(
records: Iterator<DestinationRecord>,
totalSizeBytes: Long
totalSizeBytes: Long,
endOfStream: Boolean
): Batch {
return LocalBatch(records.asSequence().toList())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.airbyte.cdk.load.config

import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.MultiProducerChannel
Expand All @@ -16,7 +17,6 @@ import jakarta.inject.Named
import jakarta.inject.Singleton
import kotlin.math.min
import kotlinx.coroutines.channels.Channel
import io.airbyte.cdk.load.command.DestinationCatalog

/** Factory for instantiating beans necessary for the sync process. */
@Factory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
package io.airbyte.cdk.load.message

import io.github.oshai.kotlinlogging.KotlinLogging
import java.lang.IllegalStateException
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.channels.Channel

/**
<<<<<<< HEAD
* A channel designed for use with a fixed amount of producers. Close will be called on the
* underlying channel, when there are no remaining registered producers.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,12 @@ class DefaultDestinationTaskLauncher(
val setupTask = setupTaskFactory.make(this)
enqueue(setupTask)

// TODO: pluggable file transfer
if (!fileTransferEnabled) {
// TODO: Close the task queues as part of shutdown
// so that it is not necessary to initialize
// every task before enqueueing.

// Start a spill-to-disk task for each record stream
catalog.streams.forEach { stream ->
log.info { "Starting spill-to-disk task for $stream" }
Expand All @@ -210,18 +215,6 @@ class DefaultDestinationTaskLauncher(
}
}

repeat(config.numProcessRecordsWorkers) {
log.info { "Launching process records task $it" }
val task = processRecordsTaskFactory.make(this)
enqueue(task)
}

repeat(config.numProcessBatchWorkers) {
log.info { "Launching process batch task $it" }
val task = processBatchTaskFactory.make(this)
enqueue(task)
}

// Start flush task
log.info { "Starting timed file aggregate flush task " }
enqueue(flushTickTask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.load.task.implementor

import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.Deserializer
import io.airbyte.cdk.load.message.DestinationMessage
Expand All @@ -22,12 +21,14 @@ import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.util.lineSequence
import io.airbyte.cdk.load.util.use
import io.airbyte.cdk.load.write.BatchAccumulator
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap
import kotlin.io.path.inputStream

interface ProcessRecordsTask : KillableScope
Expand All @@ -46,20 +47,25 @@ class DefaultProcessRecordsTask(
private val syncManager: SyncManager,
private val diskManager: ReservationManager,
private val inputQueue: MessageQueue<FileAggregateMessage>,
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessRecordsTask {
private val log = KotlinLogging.logger {}
private val accumulators = ConcurrentHashMap<DestinationStream.Descriptor, BatchAccumulator>()
override suspend fun execute() {
outputQueue.use {
inputQueue.consume().collect { (streamDescriptor, file) ->
log.info { "Fetching stream loader for $streamDescriptor" }
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
val acc =
accumulators.getOrPut(streamDescriptor) {
streamLoader.createBatchAccumulator()
}
log.info { "Processing records from $file for stream $streamDescriptor" }
val batch =
try {
file.localFile.inputStream().use { inputStream ->
val records = inputStream.toRecordIterator()
val batch = streamLoader.processRecords(records, file.totalSizeBytes)
val batch = acc.processRecords(records, file.totalSizeBytes)
log.info { "Finished processing $file" }
batch
}
Expand Down Expand Up @@ -120,6 +126,7 @@ class DefaultProcessRecordsTaskFactory(
@Named("fileAggregateQueue") private val inputQueue: MessageQueue<FileAggregateMessage>,
@Named("batchQueue") private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessRecordsTaskFactory {

override fun make(
taskLauncher: DestinationTaskLauncher,
): ProcessRecordsTask {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,52 @@ import io.airbyte.cdk.load.message.SimpleBatch
import io.airbyte.cdk.load.state.StreamProcessingFailed

/**
* Implementor interface. The framework calls open and close once per stream at the beginning and
* end of processing. The framework calls processRecords once per batch of records as batches of the
* configured size become available. (Specified in @
* [io.airbyte.cdk.command.WriteConfiguration.recordBatchSizeBytes])
* Implementor interface.
*
* [start] is called once before any records are processed.
*
* [processRecords] is called whenever a batch of records is available for processing, and only
* after [start] has returned successfully. The return value is a client-defined implementation of @
* [Batch] that the framework may pass to [processBatch] and/or [finalize]. (See @[Batch] for more
* details.)
* [processRecords] is called whenever a batch of records is available for processing (of the size
* configured in [io.airbyte.cdk.load.command.DestinationConfiguration.recordBatchSizeBytes]) and
* only after [start] has returned successfully. The return value is a client-defined implementation
* of @ [Batch] that the framework may pass to [processBatch]. (See @[Batch] for more details.)
*
* [processBatch] is called once per incomplete batch returned by either [processRecords] or
* [processBatch] itself.
* [processRecords] may be called concurrently by multiple workers, so it should be thread-safe if
* [io.airbyte.cdk.load.command.DestinationConfiguration.numProcessRecordsWorkers] > 1. For a
* non-thread-safe alternative, use [createBatchAccumulator].
*
* [createBatchAccumulator] returns an optional new instance of a [BatchAccumulator] to use for
* record processing instead of this stream loader. By default, it returns a reference to the stream
* loader itself. Use this interface if you want each record processing worker to use a separate
* instance (with its own state, etc).
*
* [finalize] is called once after all records and batches have been processed successfully.
* [processBatch] is called once per incomplete batch returned by either [processRecords] or
* [processBatch] itself. It must be thread-safe if
* [io.airbyte.cdk.load.command.DestinationConfiguration.numProcessBatchWorkers] > 1. If
* [processRecords] never returns a non-[Batch.State.COMPLETE] batch, [processBatch] will never be
* called.
*
* [close] is called once after all records have been processed, regardless of success or failure.
* If there are failed batches, they are passed in as an argument.
* [close] is called once after all records have been processed, regardless of success or failure,
* but only if [start] returned successfully. If any exception was thrown during processing, it is
* passed as an argument to [close].
*/
interface StreamLoader {
interface StreamLoader : BatchAccumulator {
val stream: DestinationStream

suspend fun start() {}
suspend fun processRecords(records: Iterator<DestinationRecord>, totalSizeBytes: Long): Batch
suspend fun createBatchAccumulator(): BatchAccumulator = this

suspend fun processFile(file: DestinationFile): Batch
suspend fun processBatch(batch: Batch): Batch = SimpleBatch(Batch.State.COMPLETE)
suspend fun close(streamFailure: StreamProcessingFailed? = null) {}
}

interface BatchAccumulator {
suspend fun processRecords(
records: Iterator<DestinationRecord>,
totalSizeBytes: Long,
endOfStream: Boolean = false
): Batch =
throw NotImplementedError(
"processRecords must be implemented if createBatchAccumulator is overridden"
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.ExtendWith

@ExtendWith(MockKExtension::class)
Expand Down Expand Up @@ -45,7 +44,6 @@ class MultiProducerChannelTest {

@Test
fun `subsequent calls to to close are idempotent`() = runTest {

channel.close()
channel.close()
channel.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DefaultDestinationTaskLauncher
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.util.write
import io.airbyte.cdk.load.write.BatchAccumulator
import io.airbyte.cdk.load.write.StreamLoader
import io.mockk.coEvery
import io.mockk.coVerify
Expand All @@ -36,6 +37,7 @@ class ProcessRecordsTaskTest {
private lateinit var diskManager: ReservationManager
private lateinit var deserializer: Deserializer<DestinationMessage>
private lateinit var streamLoader: StreamLoader
private lateinit var batchAccumulator: BatchAccumulator
private lateinit var inputQueue: MessageQueue<FileAggregateMessage>
private lateinit var processRecordsTaskFactory: DefaultProcessRecordsTaskFactory
private lateinit var launcher: DefaultDestinationTaskLauncher
Expand All @@ -49,7 +51,9 @@ class ProcessRecordsTaskTest {
outputQueue = mockk(relaxed = true)
syncManager = mockk(relaxed = true)
streamLoader = mockk(relaxed = true)
batchAccumulator = mockk(relaxed = true)
coEvery { syncManager.getOrAwaitStreamLoader(any()) } returns streamLoader
coEvery { streamLoader.createBatchAccumulator() } returns batchAccumulator
launcher = mockk(relaxed = true)
deserializer = mockk(relaxed = true)
coEvery { deserializer.deserialize(any()) } answers
Expand Down Expand Up @@ -106,7 +110,7 @@ class ProcessRecordsTaskTest {
files.map { FileAggregateMessage(descriptor, it) }.asFlow()

// Process records returns batches in 3 states.
coEvery { streamLoader.processRecords(any(), any()) } answers
coEvery { batchAccumulator.processRecords(any(), any()) } answers
{
MockBatch(
groupId = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ interface ObjectStorageClient<T : RemoteObject<*>> {
}

interface StreamingUpload<T : RemoteObject<*>> {
suspend fun uploadPart(part: ByteArray)
/**
* Uploads a part of the object. Each part must have a unique index. The parts do not need to be
* uploaded in order. The index is 1-based.
*/
suspend fun uploadPart(part: ByteArray, index: Int)

/**
* Completes a multipart upload. All parts must be uploaded before completing the upload, and
* there cannot be gaps in the indexes. Multiple calls will return the same object, but only the
* first call will have side effects.
*/
suspend fun complete(): T
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.avro.Schema

interface ObjectStorageFormattingWriter : Closeable {
fun accept(record: DestinationRecord)
fun flush()
}

@Singleton
Expand Down Expand Up @@ -86,6 +87,10 @@ class JsonFormattingWriter(
outputStream.write("\n")
}

override fun flush() {
outputStream.flush()
}

override fun close() {
outputStream.close()
}
Expand All @@ -105,6 +110,10 @@ class CSVFormattingWriter(
)
}

override fun flush() {
printer.flush()
}

override fun close() {
printer.close()
}
Expand Down Expand Up @@ -134,6 +143,10 @@ class AvroFormattingWriter(
writer.write(withMeta.toAvroRecord(mappedSchema, avroSchema))
}

override fun flush() {
writer.flush()
}

override fun close() {
writer.close()
}
Expand Down Expand Up @@ -163,6 +176,10 @@ class ParquetFormattingWriter(
writer.write(withMeta.toAvroRecord(mappedSchema, avroSchema))
}

override fun flush() {
// Parquet writer does not support flushing
}

override fun close() {
writer.close()
}
Expand Down Expand Up @@ -197,14 +214,19 @@ class BufferedFormattingWriter<T : OutputStream>(
writer.accept(record)
}

fun takeBytes(): ByteArray {
fun takeBytes(): ByteArray? {
wrappingBuffer.flush()
if (buffer.size() == 0) {
return null
}

val bytes = buffer.toByteArray()
buffer.reset()
return bytes
}

fun finish(): ByteArray? {
writer.flush()
writer.close()
streamProcessor.partFinisher.invoke(wrappingBuffer)
return if (buffer.size() > 0) {
Expand All @@ -214,6 +236,11 @@ class BufferedFormattingWriter<T : OutputStream>(
}
}

override fun flush() {
writer.flush()
wrappingBuffer.flush()
}

override fun close() {
writer.close()
}
Expand Down
Loading

0 comments on commit 2a676d6

Please sign in to comment.