Skip to content

Commit

Permalink
Bulk load cdk: add checkpointing test (#46749)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgao authored Oct 18, 2024
1 parent 8f2c6f9 commit 1cc7f2c
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,9 @@ class MockBasicFunctionalityIntegrationTest :
override fun testBasicWrite() {
super.testBasicWrite()
}

@Test
override fun testMidSyncCheckpointingStreamState() {
super.testMidSyncCheckpointingStreamState()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import io.airbyte.cdk.load.command.DestinationConfigurationFactory
import io.micronaut.context.annotation.Factory
import jakarta.inject.Singleton

class MockDestinationConfiguration : DestinationConfiguration()
class MockDestinationConfiguration : DestinationConfiguration() {
// override to 10KB instead of 200MB
override val recordBatchSizeBytes = 10 * 1024L
}

@Singleton class MockDestinationSpecification : ConfigurationSpecification()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ package io.airbyte.cdk.load.write

import io.airbyte.cdk.command.ConfigurationSpecification
import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.data.ObjectTypeWithoutSchema
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationStreamComplete
import io.airbyte.cdk.load.message.StreamCheckpoint
import io.airbyte.cdk.load.test.util.DestinationCleaner
import io.airbyte.cdk.load.test.util.DestinationDataDumper
Expand All @@ -18,9 +22,15 @@ import io.airbyte.cdk.load.test.util.NameMapper
import io.airbyte.cdk.load.test.util.NoopExpectedRecordMapper
import io.airbyte.cdk.load.test.util.NoopNameMapper
import io.airbyte.cdk.load.test.util.OutputRecord
import io.airbyte.cdk.util.Jsons
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
import io.airbyte.protocol.models.v0.AirbyteStateMessage
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertAll

Expand All @@ -42,7 +52,7 @@ abstract class BasicFunctionalityIntegrationTest(
DestinationStream(
DestinationStream.Descriptor(randomizedNamespace, "test_stream"),
Append,
ObjectTypeWithoutSchema,
ObjectType(linkedMapOf("id" to FieldType(IntegerType, nullable = true))),
generationId = 0,
minimumGenerationId = 0,
syncId = 42,
Expand Down Expand Up @@ -132,4 +142,167 @@ abstract class BasicFunctionalityIntegrationTest(
},
)
}

@Test
open fun testMidSyncCheckpointingStreamState() =
runBlocking(Dispatchers.IO) {
fun makeStream(name: String) =
DestinationStream(
DestinationStream.Descriptor(randomizedNamespace, name),
Append,
ObjectType(linkedMapOf("id" to FieldType(IntegerType, nullable = true))),
generationId = 0,
minimumGenerationId = 0,
syncId = 42,
)
val destination =
destinationProcessFactory.createDestinationProcess(
"write",
config,
DestinationCatalog(
listOf(
makeStream("test_stream1"),
makeStream("test_stream2"),
)
)
.asProtocolObject(),
)
launch { destination.run() }

// Send one record+state to each stream
destination.sendMessages(
DestinationRecord(
namespace = randomizedNamespace,
name = "test_stream1",
data = """{"id": 12}""",
emittedAtMs = 1234,
)
.asProtocolMessage(),
StreamCheckpoint(
streamNamespace = randomizedNamespace,
streamName = "test_stream1",
blob = """{"foo": "bar1"}""",
sourceRecordCount = 1
)
.asProtocolMessage(),
DestinationRecord(
namespace = randomizedNamespace,
name = "test_stream2",
data = """{"id": 34}""",
emittedAtMs = 1234,
)
.asProtocolMessage(),
StreamCheckpoint(
streamNamespace = randomizedNamespace,
streamName = "test_stream2",
blob = """{"foo": "bar2"}""",
sourceRecordCount = 1
)
.asProtocolMessage()
)
// Send records to stream1 until we get a state message back.
// Generally, we expect that that state message will belong to stream1.
val stateMessages: List<AirbyteStateMessage>
var i = 0
while (true) {
destination.sendMessage(
DestinationRecord(
namespace = randomizedNamespace,
name = "test_stream1",
data = """{"id": 56}""",
emittedAtMs = 1234,
)
.asProtocolMessage()
)
val returnedMessages = destination.readMessages()
if (returnedMessages.any { it.type == AirbyteMessage.Type.STATE }) {
stateMessages =
returnedMessages
.filter { it.type == AirbyteMessage.Type.STATE }
.map { it.state }
break
}
i++
}

// for each state message, verify that it's a valid state,
// and that we actually wrote the data
stateMessages.forEach { stateMessage ->
val streamName = stateMessage.stream.streamDescriptor.name
val streamNamespace = stateMessage.stream.streamDescriptor.namespace
// basic state message checks - this is mostly just exercising the CDK itself,
// but is cheap and easy to do.
assertAll(
{ assertEquals(randomizedNamespace, streamNamespace) },
{
assertTrue(
streamName == "test_stream1" || streamName == "test_stream2",
"Expected stream name to be test_stream1 or test_stream2, got $streamName"
)
},
{
assertEquals(
1.0,
stateMessage.destinationStats.recordCount,
"Expected destination stats to show 1 record"
)
},
{
when (streamName) {
"test_stream1" -> {
assertEquals(
Jsons.readTree("""{"foo": "bar1"}"""),
stateMessage.stream.streamState,
)
}
"test_stream2" -> {
assertEquals(
Jsons.readTree("""{"foo": "bar2"}"""),
stateMessage.stream.streamState
)
}
else ->
throw IllegalStateException("Unexpected stream name: $streamName")
}
}
)
if (verifyDataWriting) {
val records = dataDumper.dumpRecords(config, makeStream(streamName))
val expectedId =
when (streamName) {
"test_stream1" -> 12
"test_stream2" -> 34
else ->
throw IllegalStateException("Unexpected stream name: $streamName")
}
val expectedRecord =
recordMangler.mapRecord(
OutputRecord(
extractedAt = 1234,
generationId = 0,
data = mapOf("id" to expectedId),
airbyteMeta = OutputRecord.Meta(changes = listOf(), syncId = 42)
)
)

assertTrue("Expected the first record to be present in the dumped records.") {
records.any { actualRecord -> expectedRecord.data == actualRecord.data }
}
}
}

destination.sendMessages(
DestinationStreamComplete(
DestinationStream.Descriptor(randomizedNamespace, "test_stream1"),
System.currentTimeMillis()
)
.asProtocolMessage(),
DestinationStreamComplete(
DestinationStream.Descriptor(randomizedNamespace, "test_stream2"),
System.currentTimeMillis()
)
.asProtocolMessage()
)
destination.shutdown()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,9 @@ class DevNullBasicFunctionalityIntegrationTest :
override fun testBasicWrite() {
super.testBasicWrite()
}

@Test
override fun testMidSyncCheckpointingStreamState() {
super.testMidSyncCheckpointingStreamState()
}
}

0 comments on commit 1cc7f2c

Please sign in to comment.