diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TaskLauncher.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TaskLauncher.kt index c554721536df..e42110e8d557 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TaskLauncher.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TaskLauncher.kt @@ -4,8 +4,6 @@ package io.airbyte.cdk.task -import jakarta.inject.Singleton - interface Task { suspend fun execute() } @@ -19,13 +17,6 @@ interface TaskLauncher { suspend fun start() suspend fun stop() { - taskRunner.enqueue(Done()) - } -} - -@Singleton -class Done : Task { - override suspend fun execute() { - throw IllegalStateException("The Done() task cannot be executed") + taskRunner.close() } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TaskRunner.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TaskRunner.kt index bc7928b5d4df..5c93a1447085 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TaskRunner.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TaskRunner.kt @@ -8,8 +8,8 @@ import io.github.oshai.kotlinlogging.KotlinLogging import jakarta.inject.Singleton import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.flow.consumeAsFlow import kotlinx.coroutines.launch -import kotlinx.coroutines.yield /** * A Task is a unit of work that can be executed concurrently. Even though we aren't scheduling @@ -22,6 +22,8 @@ import kotlinx.coroutines.yield */ @Singleton class TaskRunner { + val log = KotlinLogging.logger {} + private val queue = Channel(Channel.UNLIMITED) suspend fun enqueue(task: Task) { @@ -29,23 +31,15 @@ class TaskRunner { } suspend fun run() = coroutineScope { - val log = KotlinLogging.logger {} - - while (true) { - val task = queue.receive() - - if (task is Done) { - log.info { "Task queue received Done() task, exiting" } - return@coroutineScope - } - - /** Launch the task concurrently and update counters. */ + queue.consumeAsFlow().collect { task -> launch { log.info { "Executing task: $task" } task.execute() } - - yield() } } + + fun close() { + queue.close() + } } diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/TaskRunnerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/TaskRunnerTest.kt new file mode 100644 index 000000000000..7bbafbf88b6a --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/TaskRunnerTest.kt @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.task + +import java.util.concurrent.atomic.AtomicBoolean +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ClosedSendChannelException +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +class TaskRunnerTest { + @Test + fun testTasksCompleteAfterClose() = runTest { + val task1Completed = AtomicBoolean(false) + val task2Completed = AtomicBoolean(false) + val task3Completed = AtomicBoolean(false) + + val innerTaskCompleted = AtomicBoolean(false) + val innerTaskEnqueueFailed = AtomicBoolean(false) + + val task1ReportingChannel = Channel() + val task2ReportingChannel = Channel() + val task3ReportingChannel = Channel() + + val task2BlockingChannel = Channel() + + // Make 3 tasks. + // - the first one should complete right away + // - the second one will block until we send a message to it + // - BUT the third one will not be blocked by the second + // - AND the second one should still run after we close the runner + // - BUT the second one tried to enqueue another after close, which throws + val runner = TaskRunner() + val task1 = + object : Task { + override suspend fun execute() { + task1Completed.set(true) + task1ReportingChannel.send(Unit) + } + } + val task2 = + object : Task { + override suspend fun execute() { + task2BlockingChannel.receive() + task2Completed.set(true) + try { + runner.enqueue( + object : Task { + override suspend fun execute() { + innerTaskCompleted.set(true) + } + } + ) + } catch (e: ClosedSendChannelException) { + innerTaskEnqueueFailed.set(true) + } + task2ReportingChannel.send(Unit) + } + } + val task3 = + object : Task { + override suspend fun execute() { + task3Completed.set(true) + task3ReportingChannel.send(Unit) + } + } + + runner.enqueue(task1) + runner.enqueue(task2) + runner.enqueue(task3) + + launch { runner.run() } + + task1ReportingChannel.receive() // wait for task1 to complete + Assertions.assertTrue(task1Completed.get()) + Assertions.assertFalse(task2Completed.get()) + + task3ReportingChannel.receive() // wait for task3 to complete + Assertions.assertTrue(task3Completed.get()) + Assertions.assertFalse(task2Completed.get()) + + runner.close() + task2BlockingChannel.send(Unit) + task2ReportingChannel.receive() // wait for task2 to complete + Assertions.assertTrue(task2Completed.get()) + + Assertions.assertTrue(innerTaskEnqueueFailed.get()) + Assertions.assertFalse(innerTaskCompleted.get()) + } +}