Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking β€œSign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test on stream status #38298

Closed
wants to merge 8 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -126,7 +126,7 @@ object AirbyteTraceMessageUtility {
)
}

private fun makeStreamStatusTraceAirbyteMessage(
fun makeStreamStatusTraceAirbyteMessage(
airbyteStreamStatusHolder: AirbyteStreamStatusHolder
): AirbyteMessage {
return makeAirbyteMessageFromTraceMessage(airbyteStreamStatusHolder.toTraceMessage())
Original file line number Diff line number Diff line change
@@ -106,6 +106,32 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {

protected abstract fun assertExpectedStateMessages(stateMessages: List<AirbyteStateMessage>)

protected open fun assertStreamStatusTraceMessageIndex(
idx: Int,
allMessages: List<AirbyteMessage>,
expectedStreamStatus: AirbyteStreamStatusTraceMessage
) {
var actualMessage = allMessages[idx]
Assertions.assertEquals(AirbyteMessage.Type.TRACE, actualMessage.type)
var traceMessage = actualMessage.trace
Assertions.assertNotNull(traceMessage.streamStatus)
Assertions.assertEquals(expectedStreamStatus, traceMessage.streamStatus)
}

private fun createAirbteStreanStatusTraceMessage(
namespace: String,
streamName: String,
status: AirbyteStreamStatusTraceMessage.AirbyteStreamStatus
): AirbyteStreamStatusTraceMessage {

return AirbyteStreamStatusTraceMessage()
.withStreamDescriptor(StreamDescriptor()
.withNamespace(namespace)
.withName(streamName)
)
.withStatus(status)
}

protected open fun assertExpectedStateMessagesForFullRefresh(
stateMessages: List<AirbyteStateMessage>
) {}
@@ -309,6 +335,16 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
.collect(Collectors.toList())
}

protected fun extractTraceMessages(
messages: List<AirbyteMessage>
): MutableList<io.airbyte.protocol.models.v0.AirbyteTraceMessage>? {
return messages
.stream()
.filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.TRACE }
.map { obj: AirbyteMessage -> obj.trace }
.collect(Collectors.toList())
}

protected fun assertExpectedRecords(
expectedRecords: Set<JsonNode>,
actualRecords: Set<AirbyteRecordMessage>
@@ -380,6 +416,25 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val recordMessages = extractRecordMessages(actualRecords)
val stateMessages = extractStateMessages(actualRecords)

assertStreamStatusTraceMessageIndex(
0,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualRecords.size - 1,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

Assertions.assertNotNull(targetPosition)
recordMessages.forEach(
Consumer { record: AirbyteRecordMessage ->
@@ -441,6 +496,25 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val stateMessages1 = extractStateMessages(actualRecords1)
assertExpectedStateMessages(stateMessages1)

assertStreamStatusTraceMessageIndex(
0,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualRecords1.size - 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

updateCommand(MODELS_STREAM_NAME, COL_MODEL, updatedModel, COL_ID, 11)
waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1)

@@ -614,6 +688,29 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
val read1 = source().read(config()!!, configuredCatalog, null)
val actualRecords1 = AutoCloseableIterators.toListAndClose(read1)

// The first message will be start of the incremental stream.
// The last message will be the end of the full refresh stream.
// Index start of the incremental stream will be depending on if connector supports
// resumeable full refresh.
assertStreamStatusTraceMessageIndex(
0,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualRecords1.size - 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

val recordMessages1 = extractRecordMessages(actualRecords1)
val stateMessages1 = extractStateMessages(actualRecords1)
val names = HashSet(STREAM_NAMES)
@@ -637,6 +734,25 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
modelsSchema(),
)

assertStreamStatusTraceMessageIndex(
MODEL_RECORDS_2.size,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)
assertStreamStatusTraceMessageIndex(
MODEL_RECORDS_2.size + 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)

val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1]))
val read2 = source().read(config()!!, configuredCatalog, state)
val actualRecords2 = AutoCloseableIterators.toListAndClose(read2)
@@ -659,6 +775,27 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
stateMessages1,
MODEL_RECORDS.size.toLong() + MODEL_RECORDS_2.size.toLong()
)

// Expect state and record message from MODEL_RECORDS_2.
assertStreamStatusTraceMessageIndex(
2 * MODEL_RECORDS_2.size,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME_2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)
assertStreamStatusTraceMessageIndex(
2 * MODEL_RECORDS_2.size + 1,
actualRecords1,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)

assertExpectedRecords(
Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream())
.collect(Collectors.toSet()),
@@ -839,6 +976,25 @@ abstract class CdcSourceTest<S : Source, T : TestDatabase<*, T, *>> {
assertExpectedRecords(emptySet(), recordMessages)
assertExpectedStateMessagesForNoData(stateMessages)
assertExpectedStateMessageCountMatches(stateMessages, 0)

assertStreamStatusTraceMessageIndex(
0,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualRecords.size - 1,
actualRecords,
createAirbteStreanStatusTraceMessage(
modelsSchema(),
MODELS_STREAM_NAME,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)
}

protected open fun assertExpectedStateMessagesForNoData(
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ import io.airbyte.cdk.testutils.TestDatabase
import io.airbyte.commons.json.Jsons
import io.airbyte.commons.resources.MoreResources
import io.airbyte.commons.util.MoreIterators
import io.airbyte.protocol.models.AirbyteStreamStatusTraceMessage
import io.airbyte.protocol.models.Field
import io.airbyte.protocol.models.JsonSchemaType
import io.airbyte.protocol.models.v0.*
@@ -184,6 +185,33 @@ abstract class JdbcSourceAcceptanceTest<S : Source, T : TestDatabase<*, T, *>> {
// timeout.
}

protected open fun assertStreamStatusTraceMessageIndex(
idx: Int,
allMessages: List<AirbyteMessage>,
expectedStreamStatus: AirbyteStreamStatusTraceMessage
) {
var actualMessage = allMessages[idx]
Assertions.assertEquals(actualMessage.type, AirbyteMessage.Type.TRACE)
var traceMessage = actualMessage.trace
Assertions.assertNotNull(traceMessage.streamStatus)
Assertions.assertEquals(expectedStreamStatus, traceMessage.streamStatus)
}

private fun createAirbteStreanStatusTraceMessage(
namespace: String,
streamName: String,
status: AirbyteStreamStatusTraceMessage.AirbyteStreamStatus
): AirbyteStreamStatusTraceMessage {
return AirbyteStreamStatusTraceMessage()
.withStreamDescriptor(
io.airbyte.protocol.models
.StreamDescriptor()
.withNamespace(namespace)
.withName(streamName)
)
.withStatus(status)
}

@AfterEach
fun tearDown() {
testdb!!.close()
@@ -417,6 +445,25 @@ abstract class JdbcSourceAcceptanceTest<S : Source, T : TestDatabase<*, T, *>> {
)
val actualMessages = MoreIterators.toList(source()!!.read(config(), catalog, null))

assertStreamStatusTraceMessageIndex(
0,
actualMessages,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName(),
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualMessages.size - 1,
actualMessages,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName(),
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

setEmittedAtToNull(actualMessages)

val expectedMessages = airbyteMessagesReadOneColumn
@@ -509,6 +556,44 @@ abstract class JdbcSourceAcceptanceTest<S : Source, T : TestDatabase<*, T, *>> {
expectedMessages.addAll(getAirbyteMessagesSecondSync(streamName2))

val actualMessages = MoreIterators.toList(source()!!.read(config(), catalog, null))

assertStreamStatusTraceMessageIndex(
0,
actualMessages,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName(),
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualMessages.size - 5,
actualMessages,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualMessages.size - 6,
actualMessages,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName(),
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)
assertStreamStatusTraceMessageIndex(
actualMessages.size - 1,
actualMessages,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName2,
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

val actualRecordMessages = filterRecords(actualMessages)

setEmittedAtToNull(actualMessages)
@@ -736,6 +821,25 @@ abstract class JdbcSourceAcceptanceTest<S : Source, T : TestDatabase<*, T, *>> {
)
)

assertStreamStatusTraceMessageIndex(
0,
actualMessagesFirstSync,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName(),
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualMessagesFirstSync.size - 1,
actualMessagesFirstSync,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName(),
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

val stateAfterFirstSyncOptional =
actualMessagesFirstSync
.stream()
@@ -754,6 +858,25 @@ abstract class JdbcSourceAcceptanceTest<S : Source, T : TestDatabase<*, T, *>> {
)
)

assertStreamStatusTraceMessageIndex(
0,
actualMessagesSecondSync,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName(),
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED
)
)
assertStreamStatusTraceMessageIndex(
actualMessagesSecondSync.size - 1,
actualMessagesSecondSync,
createAirbteStreanStatusTraceMessage(
defaultNamespace,
streamName(),
AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE
)
)

Assertions.assertEquals(
2,
actualMessagesSecondSync
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@
import io.airbyte.protocol.models.v0.AirbyteStateMessage;
import io.airbyte.protocol.models.v0.AirbyteStream;
import io.airbyte.protocol.models.v0.AirbyteStreamState;
import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage.AirbyteStreamStatus;
import io.airbyte.protocol.models.v0.AirbyteTraceMessage;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream;
@@ -277,6 +278,8 @@ void testNewStreamAddedToExistingCDCSync() throws Exception {
validateStateMessages(stateMessages);
validateAllStreamsComplete(stateMessages, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));
validateAllStreamsStatuses(messages, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));
assertFalse(lastStateMessage.getGlobal().getStreamStates().stream().anyMatch(
createStateStreamFilter(new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName))));

@@ -297,6 +300,10 @@ void testNewStreamAddedToExistingCDCSync() throws Exception {
validateAllStreamsComplete(stateMessages2, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName)));

validateAllStreamsStatuses(messages2, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName)));
}

@Test
@@ -349,6 +356,8 @@ void testInsertUpdateDeleteIncrementalSync() throws Exception {
validateStateMessages(stateMessages);
validateAllStreamsComplete(stateMessages, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));
validateAllStreamsStatuses(messages, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));

final var result = mongoClient.getDatabase(databaseName).getCollection(collectionName).insertOne(createDocument(1));
final var insertedId = result.getInsertedId();
@@ -367,6 +376,8 @@ void testInsertUpdateDeleteIncrementalSync() throws Exception {
validateStateMessages(stateMessages2);
validateAllStreamsComplete(stateMessages2, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));
validateAllStreamsStatuses(messages2, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));

final var idFilter = new Document(DOCUMENT_ID_FIELD, insertedId);
mongoClient.getDatabase(databaseName).getCollection(collectionName).updateOne(idFilter, Updates.combine(Updates.set("newField", "new")));
@@ -385,6 +396,8 @@ void testInsertUpdateDeleteIncrementalSync() throws Exception {
validateStateMessages(stateMessages3);
validateAllStreamsComplete(stateMessages3, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));
validateAllStreamsStatuses(messages3, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));

mongoClient.getDatabase(databaseName).getCollection(collectionName).deleteOne(idFilter);

@@ -399,9 +412,10 @@ void testInsertUpdateDeleteIncrementalSync() throws Exception {
validateCdcEventRecordData(recordMessages4.get(0), insertedId, true);

validateStateMessages(stateMessages4);
validateAllStreamsComplete(stateMessages3, List.of(
validateAllStreamsComplete(stateMessages4, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));
validateAllStreamsStatuses(messages4, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName)));

}

@Test
@@ -435,6 +449,11 @@ void testCDCStreamCheckpointingWithMultipleStreams() throws Exception {
new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection2Name).withNamespace(databaseName)));

validateAllStreamsStatuses(messages, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection2Name).withNamespace(databaseName)));

// Start a second sync from somewhere in the middle of stream 2
final List<AirbyteMessage> messages2 = runRead(configuredCatalog, Jsons.jsonNode(List.of(stateMessages.get(recordCount + 50))));
final List<AirbyteRecordMessage> recordMessages2 = filterRecords(messages2);
@@ -466,6 +485,10 @@ void testCDCStreamCheckpointingWithMultipleStreams() throws Exception {
new StreamDescriptor().withName(collectionName).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection2Name).withNamespace(databaseName)));
validateAllStreamsStatuses(messages2, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection2Name).withNamespace(databaseName)));

// Insert more data for one stream
insertData(databaseName, otherCollection1Name, otherCollection1Count);
@@ -489,6 +512,10 @@ void testCDCStreamCheckpointingWithMultipleStreams() throws Exception {
new StreamDescriptor().withName(collectionName).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection2Name).withNamespace(databaseName)));
validateAllStreamsStatuses(messages3, List.of(
new StreamDescriptor().withName(collectionName).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection1Name).withNamespace(databaseName),
new StreamDescriptor().withName(otherCollection2Name).withNamespace(databaseName)));
}

@Test
@@ -627,6 +654,21 @@ private void validateAllStreamsComplete(final List<AirbyteStateMessage> stateMes
});
}

private void validateAllStreamsStatuses(final List<AirbyteMessage> allMessages, final List<StreamDescriptor> completedStreams) {

completedStreams.forEach(s -> {
var streamStatusMessage = allMessages.stream()
.filter(airbyteMessage -> airbyteMessage.getType() == Type.TRACE
&& airbyteMessage.getTrace().getStreamStatus().getStreamDescriptor().equals(completedStreams))
.collect(
Collectors.toList());

assertTrue(streamStatusMessage.size() == 2);
assertTrue(streamStatusMessage.get(0).getTrace().getStreamStatus().getStatus() == AirbyteStreamStatus.STARTED);
assertTrue(streamStatusMessage.get(1).getTrace().getStreamStatus().getStatus() == AirbyteStreamStatus.COMPLETE);
});
}

private Optional<AirbyteStreamState> getStreamState(final AirbyteStateMessage stateMessage, final StreamDescriptor streamDescriptor) {
return stateMessage.getGlobal().getStreamStates().stream().filter(createStateStreamFilter(streamDescriptor)).findFirst();
}
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ plugins {
airbyteJavaConnector {
cdkVersionRequired = '0.33.1'
features = ['db-sources']
useLocalCdk = false
useLocalCdk = true
}

java {