Skip to content

Commit

Permalink
Bulk Load CDK: Process Records Unit Tests; Test cleanup (#45846)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Sep 25, 2024
1 parent 8ecebea commit 477bcc4
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 100 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.file

import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
import java.nio.file.Path

@Singleton
@Requires(env = ["MockTempFileProvider"])
class MockTempFileProvider : TempFileProvider {
class MockLocalFile : LocalFile {
val writtenLines: MutableList<String> = mutableListOf()
var linesToRead: MutableList<String> = mutableListOf()
val writersCreated: MutableList<MockFileWriter> = mutableListOf()
val readersCreated: MutableList<MockFileReader> = mutableListOf()
var isDeleted: Boolean = false

class MockFileWriter(val file: MockLocalFile) : FileWriter {
var isClosed = false

override fun write(str: String) {
file.writtenLines.add(str)
}

override fun close() {
isClosed = true
}
}

class MockFileReader(val file: MockLocalFile) : FileReader {
var isClosed = false
var index = 0
override fun lines(): Sequence<String> {
return sequence {
while (index < file.linesToRead.size) {
yield(file.linesToRead[index])
index++
}
}
}

override fun close() {
isClosed = true
}
}

override fun toFileWriter(): FileWriter {
val writer = MockFileWriter(this)
writersCreated.add(writer)
return writer
}

override fun toFileReader(): FileReader {
val reader = MockFileReader(this)
readersCreated.add(reader)
return reader
}

override fun delete() {
isDeleted = true
}
}

override fun createTempFile(directory: Path, prefix: String, suffix: String): LocalFile {
return MockLocalFile()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.task

import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.message.BatchEnvelope
import io.airbyte.cdk.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.write.StreamLoader

class MockTaskLauncher(override val taskRunner: TaskRunner) : DestinationTaskLauncher {
val spilledFiles = mutableListOf<BatchEnvelope<SpilledRawMessagesLocalFile>>()
val batchEnvelopes = mutableListOf<BatchEnvelope<*>>()

override suspend fun handleSetupComplete() {
throw NotImplementedError()
}

override suspend fun handleStreamOpen(streamLoader: StreamLoader) {
throw NotImplementedError()
}

override suspend fun handleNewSpilledFile(
stream: DestinationStream,
wrapped: BatchEnvelope<SpilledRawMessagesLocalFile>
) {
spilledFiles.add(wrapped)
}

override suspend fun handleNewBatch(streamLoader: StreamLoader, wrapped: BatchEnvelope<*>) {
batchEnvelopes.add(wrapped)
}

override suspend fun handleStreamClosed(stream: DestinationStream) {
throw NotImplementedError()
}

override suspend fun handleTeardownComplete() {
throw NotImplementedError()
}

override suspend fun start() {
throw NotImplementedError()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.task

import com.google.common.collect.Range
import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1
import io.airbyte.cdk.data.IntegerValue
import io.airbyte.cdk.file.MockTempFileProvider
import io.airbyte.cdk.message.Batch
import io.airbyte.cdk.message.BatchEnvelope
import io.airbyte.cdk.message.Deserializer
import io.airbyte.cdk.message.DestinationMessage
import io.airbyte.cdk.message.DestinationRecord
import io.airbyte.cdk.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.write.StreamLoader
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Inject
import jakarta.inject.Singleton
import java.nio.file.Path
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

@MicronautTest(environments = ["ProcessRecordsTaskTest"])
class ProcessRecordsTaskTest {
@Inject lateinit var taskRunner: TaskRunner
@Inject lateinit var processRecordsTaskFactory: DefaultProcessRecordsTaskFactory

class MockBatch(
override val state: Batch.State,
val reportedByteSize: Long,
val recordCount: Long,
val pmChecksum: Long,
) : Batch

class MockStreamLoader : StreamLoader {
override val stream: DestinationStream = stream1

data class SumAndCount(val sum: Long = 0, val count: Long = 0)

override suspend fun processRecords(
records: Iterator<DestinationRecord>,
totalSizeBytes: Long
): Batch {
// Do a simple sum of the record values and count
// To demonstrate that the primed data was actually processed
val (sum, count) =
records.asSequence().fold(SumAndCount()) { acc, record ->
SumAndCount(acc.sum + (record.data as IntegerValue).value, acc.count + 1)
}
return MockBatch(
state = Batch.State.COMPLETE,
reportedByteSize = totalSizeBytes,
recordCount = count,
pmChecksum = sum
)
}
}

@Singleton
@Primary
@Requires(env = ["ProcessRecordsTaskTest"])
class MockDeserializer : Deserializer<DestinationMessage> {
override fun deserialize(serialized: String): DestinationMessage {
return DestinationRecord(
stream = stream1,
data = IntegerValue(serialized.toLong()),
emittedAtMs = 0L,
meta = null,
serialized = serialized,
)
}
}

@Test
fun testProcessRecordsTask() = runTest {
val byteSize = 999L
val recordCount = 1024L

val launcher = MockTaskLauncher(taskRunner)
val mockFile =
MockTempFileProvider()
.createTempFile(directory = Path.of("tmp/"), prefix = "test", suffix = ".json")
as MockTempFileProvider.MockLocalFile
val file =
SpilledRawMessagesLocalFile(
localFile = mockFile,
totalSizeBytes = byteSize,
)
val task =
processRecordsTaskFactory.make(
taskLauncher = launcher,
streamLoader = MockStreamLoader(),
fileEnvelope = BatchEnvelope(file, Range.closed(0, 1024))
)
mockFile.linesToRead = (0 until recordCount).map { "$it" }.toMutableList()

task.execute()

Assertions.assertEquals(1, launcher.batchEnvelopes.size)
val batch = launcher.batchEnvelopes[0].batch as MockBatch
Assertions.assertEquals(Batch.State.COMPLETE, batch.state)
Assertions.assertEquals(999, batch.reportedByteSize)
Assertions.assertEquals(recordCount, batch.recordCount)
Assertions.assertEquals((0 until recordCount).sum(), batch.pmChecksum)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,119 +9,31 @@ import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1
import io.airbyte.cdk.command.WriteConfiguration
import io.airbyte.cdk.data.NullValue
import io.airbyte.cdk.file.FileReader
import io.airbyte.cdk.file.FileWriter
import io.airbyte.cdk.file.LocalFile
import io.airbyte.cdk.file.TempFileProvider
import io.airbyte.cdk.message.BatchEnvelope
import io.airbyte.cdk.file.MockTempFileProvider
import io.airbyte.cdk.message.DestinationRecord
import io.airbyte.cdk.message.DestinationRecordWrapped
import io.airbyte.cdk.message.MessageQueueReader
import io.airbyte.cdk.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.message.StreamCompleteWrapped
import io.airbyte.cdk.message.StreamRecordWrapped
import io.airbyte.cdk.write.StreamLoader
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Inject
import jakarta.inject.Singleton
import java.nio.file.Path
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

@MicronautTest(environments = ["SpillToDiskTaskTest"])
@MicronautTest(environments = ["SpillToDiskTaskTest", "MockTempFileProvider"])
class SpillToDiskTaskTest {
@Inject lateinit var taskRunner: TaskRunner
@Inject lateinit var spillToDiskTaskFactory: DefaultSpillToDiskTaskFactory
@Inject lateinit var mockTempFileProvider: MockTempFileProvider

@Singleton
@Requires(env = ["SpillToDiskTaskTest"])
class MockTempFileProvider : TempFileProvider {
data class MockFile(
val path: Path,
val lines: MutableList<String> = mutableListOf(),
val isDeleted: Boolean = false,
)
private val tmpFileCounter = AtomicInteger(0)
val writerClosed = AtomicBoolean(false)
val mockFiles = mutableMapOf<String, MockFile>()
override fun createTempFile(directory: Path, prefix: String, suffix: String): LocalFile {
val path =
Path.of(
directory.toString(),
"/${prefix}-${tmpFileCounter.getAndIncrement()}${suffix}"
)
return object : LocalFile {
override fun toFileWriter(): FileWriter {
val mockFile = MockFile(path)
mockFiles[path.toString()] = mockFile

return object : FileWriter {
override fun write(str: String) {
mockFile.lines.add(str)
}

override fun close() {
writerClosed.set(true)
}
}
}

override fun toFileReader(): FileReader {
throw NotImplementedError()
}

override fun delete() {
throw NotImplementedError()
}
}
}
}

// TODO: Migrate this to a common mock.
class MockTaskLauncher(override val taskRunner: TaskRunner) : DestinationTaskLauncher {
val spilledFiles = mutableListOf<BatchEnvelope<SpilledRawMessagesLocalFile>>()

override suspend fun handleSetupComplete() {
throw NotImplementedError()
}

override suspend fun handleStreamOpen(streamLoader: StreamLoader) {
throw NotImplementedError()
}

override suspend fun handleNewSpilledFile(
stream: DestinationStream,
wrapped: BatchEnvelope<SpilledRawMessagesLocalFile>
) {
spilledFiles.add(wrapped)
}

override suspend fun handleNewBatch(streamLoader: StreamLoader, wrapped: BatchEnvelope<*>) {
throw NotImplementedError()
}

override suspend fun handleStreamClosed(stream: DestinationStream) {
throw NotImplementedError()
}

override suspend fun handleTeardownComplete() {
throw NotImplementedError()
}

override suspend fun start() {
throw NotImplementedError()
}
}

@Factory
@Requires(env = ["SpillToDiskTaskTest"])
class MockDestinationTaskLauncherFactory {
Expand Down Expand Up @@ -183,19 +95,21 @@ class SpillToDiskTaskTest {
Assertions.assertEquals(2, mockTaskLauncher.spilledFiles.size)
Assertions.assertEquals(1024, mockTaskLauncher.spilledFiles[0].batch.totalSizeBytes)
Assertions.assertEquals(512, mockTaskLauncher.spilledFiles[1].batch.totalSizeBytes)
Assertions.assertTrue(mockTempFileProvider.writerClosed.get())
Assertions.assertEquals(2, mockTempFileProvider.mockFiles.size)

val env1 = mockTaskLauncher.spilledFiles[0]
val env2 = mockTaskLauncher.spilledFiles[1]
Assertions.assertEquals(1024, env1.batch.totalSizeBytes)
Assertions.assertEquals(512, env2.batch.totalSizeBytes)

val file1 = env1.batch.localFile as MockTempFileProvider.MockLocalFile
val file2 = env2.batch.localFile as MockTempFileProvider.MockLocalFile
Assertions.assertTrue(file1.writersCreated[0].isClosed)
Assertions.assertTrue(file2.writersCreated[0].isClosed)

val expectedLinesFirst = (0 until 1024 / 8).flatMap { listOf("test$it", "\n") }
val expectedLinesSecond = (1024 / 8 until 1536 / 8).flatMap { listOf("test$it", "\n") }

Assertions.assertEquals(
expectedLinesFirst,
mockTempFileProvider.mockFiles.values.first().lines
)
Assertions.assertEquals(
expectedLinesSecond,
mockTempFileProvider.mockFiles.values.last().lines
)
Assertions.assertEquals(expectedLinesFirst, file1.writtenLines)
Assertions.assertEquals(expectedLinesSecond, file2.writtenLines)
}
}

0 comments on commit 477bcc4

Please sign in to comment.