Skip to content

Commit

Permalink
Rbroughan/dont fail fast on stream incomplete (#49455)
Browse files Browse the repository at this point in the history
  • Loading branch information
tryangul authored Dec 13, 2024
1 parent 526c159 commit bdaae73
Show file tree
Hide file tree
Showing 36 changed files with 235 additions and 263 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.SimpleBatch
import io.airbyte.cdk.load.state.StreamIncompleteResult
import io.airbyte.cdk.load.state.StreamProcessingFailed
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader
Expand Down Expand Up @@ -42,7 +42,7 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
override val state = Batch.State.PERSISTED
}

override suspend fun close(streamFailure: StreamIncompleteResult?) {
override suspend fun close(streamFailure: StreamProcessingFailed?) {
if (streamFailure == null) {
when (val importType = stream.importType) {
is Append -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ data class DestinationCatalog(val streams: List<DestinationStream> = emptyList()

fun asProtocolObject(): ConfiguredAirbyteCatalog =
ConfiguredAirbyteCatalog().withStreams(streams.map { it.asProtocolObject() })

fun size(): Int = streams.size
}

interface DestinationCatalogFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.airbyte.cdk.load.config

import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.ReservationManager
Expand Down Expand Up @@ -51,7 +52,9 @@ class SyncBeanFactory {
fun fileAggregateQueue(
@Value("\${airbyte.resources.disk.bytes}") availableBytes: Long,
config: DestinationConfiguration,
catalog: DestinationCatalog
): MultiProducerChannel<FileAggregateMessage> {
val streamCount = catalog.size()
// total batches by disk capacity
val maxBatchesThatFitOnDisk = (availableBytes / config.recordBatchSizeBytes).toInt()
// account for batches in flight processing by the workers
Expand All @@ -64,6 +67,6 @@ class SyncBeanFactory {
val capacity = min(maxBatchesMinusUploadOverhead, idealDepth)
log.info { "Creating file aggregate queue with limit $capacity" }
val channel = Channel<FileAggregateMessage>(capacity)
return MultiProducerChannel(channel)
return MultiProducerChannel(streamCount.toLong(), channel)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,26 @@ interface Sized {
*/
sealed class DestinationStreamEvent : Sized

/** Contains a record to be aggregated and processed. */
data class StreamRecordEvent(
val index: Long,
override val sizeBytes: Long,
val record: DestinationRecord
) : DestinationStreamEvent()

data class StreamCompleteEvent(
/**
* Indicates the stream is in a terminal (complete or incomplete) state as signalled by upstream.
*/
data class StreamEndEvent(
val index: Long,
) : DestinationStreamEvent() {
override val sizeBytes: Long = 0L
}

/**
* Emitted to trigger evaluation of the conditional flush logic of a stream. The consumer may or may
* not decide to flush.
*/
data class StreamFlushEvent(
val tickedAtMs: Long,
) : DestinationStreamEvent() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,29 @@
package io.airbyte.cdk.load.message

import io.github.oshai.kotlinlogging.KotlinLogging
import java.lang.IllegalStateException
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.channels.Channel

/**
* A channel designed for use with a dynamic amount of producers. Close will only close the
* A channel designed for use with a fixed amount of producers. Close will be called on the
* underlying channel, when there are no remaining registered producers.
*/
class MultiProducerChannel<T>(override val channel: Channel<T>) : ChannelMessageQueue<T>() {
class MultiProducerChannel<T>(
producerCount: Long,
override val channel: Channel<T>,
) : ChannelMessageQueue<T>() {
private val log = KotlinLogging.logger {}
private val producerCount = AtomicLong(0)
private val closed = AtomicBoolean(false)

fun registerProducer(): MultiProducerChannel<T> {
if (closed.get()) {
throw IllegalStateException("Attempted to register producer for closed channel.")
}

val count = producerCount.incrementAndGet()
log.info { "Registering producer (count=$count)" }
return this
}
private val initializedProducerCount = producerCount
private val producerCount = AtomicLong(producerCount)

override suspend fun close() {
val count = producerCount.decrementAndGet()
log.info { "Closing producer (count=$count)" }
log.info {
"Closing producer (active count=$count, initialized count: $initializedProducerCount)"
}
if (count == 0L) {
log.info { "Closing queue" }
log.info { "Closing underlying queue" }
channel.close()
closed.getAndSet(true)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,9 @@ import kotlinx.coroutines.CompletableDeferred

sealed interface StreamResult

sealed interface StreamIncompleteResult : StreamResult
data class StreamProcessingFailed(val streamException: Exception) : StreamResult

data class StreamFailed(val streamException: Exception) : StreamIncompleteResult

data class StreamKilled(val syncException: Exception) : StreamIncompleteResult

data object StreamSucceeded : StreamResult
data object StreamProcessingSucceeded : StreamResult

/** Manages the state of a single stream. */
interface StreamManager {
Expand All @@ -38,13 +34,17 @@ interface StreamManager {
fun recordCount(): Long

/**
* Mark the end-of-stream and return the record count. Expect this exactly once. Expect no
* further `countRecordIn`, and expect that [markSucceeded] or [markFailed] or [markKilled] will
* alway occur after this.
* Mark the end-of-stream, set the end of stream variant (complete or incomplete) and return the
* record count. Expect this exactly once. Expect no further `countRecordIn`, and expect that
* [markProcessingSucceeded] will always occur after this, while [markProcessingFailed] can
* occur before or after.
*/
fun markEndOfStream(): Long
fun markEndOfStream(receivedStreamCompleteMessage: Boolean): Long
fun endOfStreamRead(): Boolean

/** Whether we received a stream complete message for the managed stream. */
fun isComplete(): Boolean

/**
* Mark a checkpoint in the stream and return the current index and the number of records since
* the last one.
Expand Down Expand Up @@ -72,22 +72,23 @@ interface StreamManager {
*/
fun areRecordsPersistedUntil(index: Long): Boolean

/** Mark the stream as closed. This should only be called after all records have been read. */
fun markSucceeded()

/**
* Mark that the stream was killed due to failure elsewhere. Returns false if task was already
* complete.
* Indicates destination processing of the stream succeeded, regardless of complete/incomplete
* status. This should only be called after all records and end of stream messages have been
* read.
*/
fun markKilled(causedBy: Exception): Boolean
fun markProcessingSucceeded()

/** Mark that the stream itself failed. Return false if task was already complete */
fun markFailed(causedBy: Exception): Boolean
/**
* Indicates destination processing of the stream failed. Returns false if task was already
* complete
*/
fun markProcessingFailed(causedBy: Exception): Boolean

/** Suspend until the stream completes, returning the result. */
suspend fun awaitStreamResult(): StreamResult

/** True if the stream has not yet been marked successful, failed, or killed. */
/** True if the stream processing has not yet been marked as successful or failed. */
fun isActive(): Boolean
}

Expand All @@ -105,6 +106,7 @@ class DefaultStreamManager(
private val lastCheckpoint = AtomicLong(0L)

private val markedEndOfStream = AtomicBoolean(false)
private val receivedComplete = AtomicBoolean(false)

private val rangesState: ConcurrentHashMap<Batch.State, RangeSet<Long>> = ConcurrentHashMap()

Expand All @@ -124,10 +126,11 @@ class DefaultStreamManager(
return recordCount.get()
}

override fun markEndOfStream(): Long {
override fun markEndOfStream(receivedStreamCompleteMessage: Boolean): Long {
if (markedEndOfStream.getAndSet(true)) {
throw IllegalStateException("Stream is closed for reading")
}
receivedComplete.getAndSet(receivedStreamCompleteMessage)

return recordCount.get()
}
Expand All @@ -136,6 +139,10 @@ class DefaultStreamManager(
return markedEndOfStream.get()
}

override fun isComplete(): Boolean {
return receivedComplete.get()
}

override fun markCheckpoint(): Pair<Long, Long> {
val index = recordCount.get()
val lastCheckpoint = lastCheckpoint.getAndSet(index)
Expand Down Expand Up @@ -220,19 +227,15 @@ class DefaultStreamManager(
return isProcessingCompleteForState(index, Batch.State.PERSISTED)
}

override fun markSucceeded() {
override fun markProcessingSucceeded() {
if (!markedEndOfStream.get()) {
throw IllegalStateException("Stream is not closed for reading")
}
streamResult.complete(StreamSucceeded)
}

override fun markKilled(causedBy: Exception): Boolean {
return streamResult.complete(StreamKilled(causedBy))
streamResult.complete(StreamProcessingSucceeded)
}

override fun markFailed(causedBy: Exception): Boolean {
return streamResult.complete(StreamFailed(causedBy))
override fun markProcessingFailed(causedBy: Exception): Boolean {
return streamResult.complete(StreamProcessingFailed(causedBy))
}

override suspend fun awaitStreamResult(): StreamResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap
import kotlinx.coroutines.CompletableDeferred

sealed interface SyncResult
sealed interface DestinationResult

data object SyncSuccess : SyncResult
data object DestinationSuccess : DestinationResult

data class SyncFailure(
val syncFailure: Exception,
data class DestinationFailure(
val cause: Exception,
val streamResults: Map<DestinationStream.Descriptor, StreamResult>
) : SyncResult
) : DestinationResult

/** Manages the state of all streams in the destination. */
interface SyncManager {
Expand All @@ -35,18 +35,26 @@ interface SyncManager {
suspend fun getOrAwaitStreamLoader(stream: DestinationStream.Descriptor): StreamLoader
suspend fun getStreamLoaderOrNull(stream: DestinationStream.Descriptor): StreamLoader?

/** Suspend until all streams are complete. Returns false if any stream was failed/killed. */
suspend fun awaitAllStreamsCompletedSuccessfully(): Boolean
/**
* Suspend until all streams are processed successfully. Returns false if processing failed for
* any stream.
*/
suspend fun awaitAllStreamsProcessedSuccessfully(): Boolean

suspend fun markInputConsumed()
suspend fun markCheckpointsProcessed()
suspend fun markFailed(causedBy: Exception): SyncFailure
suspend fun markSucceeded()
suspend fun markDestinationFailed(causedBy: Exception): DestinationFailure
suspend fun markDestinationSucceeded()

/**
* Whether we received stream complete messages for all streams in the catalog from upstream.
*/
suspend fun allStreamsComplete(): Boolean

fun isActive(): Boolean

suspend fun awaitInputProcessingComplete(): Unit
suspend fun awaitSyncResult(): SyncResult
suspend fun awaitInputProcessingComplete()
suspend fun awaitDestinationResult(): DestinationResult
}

@SuppressFBWarnings(
Expand All @@ -56,7 +64,7 @@ interface SyncManager {
class DefaultSyncManager(
private val streamManagers: ConcurrentHashMap<DestinationStream.Descriptor, StreamManager>
) : SyncManager {
private val syncResult = CompletableDeferred<SyncResult>()
private val destinationResult = CompletableDeferred<DestinationResult>()
private val streamLoaders =
ConcurrentHashMap<DestinationStream.Descriptor, CompletableDeferred<Result<StreamLoader>>>()
private val inputConsumed = CompletableDeferred<Boolean>()
Expand Down Expand Up @@ -87,32 +95,38 @@ class DefaultSyncManager(
return streamLoaders[stream]?.await()?.getOrNull()
}

override suspend fun awaitAllStreamsCompletedSuccessfully(): Boolean {
return streamManagers.all { (_, manager) -> manager.awaitStreamResult() is StreamSucceeded }
override suspend fun awaitAllStreamsProcessedSuccessfully(): Boolean {
return streamManagers.all { (_, manager) ->
manager.awaitStreamResult() is StreamProcessingSucceeded
}
}

override suspend fun markFailed(causedBy: Exception): SyncFailure {
override suspend fun markDestinationFailed(causedBy: Exception): DestinationFailure {
val result =
SyncFailure(causedBy, streamManagers.mapValues { it.value.awaitStreamResult() })
syncResult.complete(result)
DestinationFailure(causedBy, streamManagers.mapValues { it.value.awaitStreamResult() })
destinationResult.complete(result)
return result
}

override suspend fun markSucceeded() {
override suspend fun markDestinationSucceeded() {
if (streamManagers.values.any { it.isActive() }) {
throw IllegalStateException(
"Cannot mark sync as succeeded until all streams are complete"
)
}
syncResult.complete(SyncSuccess)
destinationResult.complete(DestinationSuccess)
}

override suspend fun allStreamsComplete(): Boolean {
return streamManagers.all { it.value.isComplete() }
}

override fun isActive(): Boolean {
return syncResult.isActive
return destinationResult.isActive
}

override suspend fun awaitSyncResult(): SyncResult {
return syncResult.await()
override suspend fun awaitDestinationResult(): DestinationResult {
return destinationResult.await()
}

override suspend fun awaitInputProcessingComplete() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class DefaultDestinationTaskLauncher(
// File transfer
@Value("\${airbyte.file-transfer.enabled}") private val fileTransferEnabled: Boolean,

// Input Comsumer requirements
// Input Consumer requirements
private val inputFlow: SizedInputFlow<Reserved<DestinationMessage>>,
private val recordQueueSupplier:
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ interface CloseStreamTask : ImplementorScope

/**
* Wraps @[StreamLoader.close] and marks the stream as closed in the stream manager. Also starts the
* teardown task.
* teardown task. Called after the end of stream message (complete OR incomplete) has been received
* and all record messages have been processed.
*/
class DefaultCloseStreamTask(
private val syncManager: SyncManager,
Expand All @@ -27,7 +28,7 @@ class DefaultCloseStreamTask(
override suspend fun execute() {
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
streamLoader.close()
syncManager.getStreamManager(streamDescriptor).markSucceeded()
syncManager.getStreamManager(streamDescriptor).markProcessingSucceeded()
taskLauncher.handleStreamClosed(streamLoader.stream.descriptor)
}
}
Expand Down
Loading

0 comments on commit bdaae73

Please sign in to comment.