diff --git a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/StateManager.kt b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/StateManager.kt index ee06cbe39e32..d051867c0f74 100644 --- a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/StateManager.kt +++ b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/StateManager.kt @@ -122,12 +122,28 @@ class StateManager( */ @Synchronized fun takeForCheckpoint(): StateForCheckpoint { - val stateForCheckpoint: StateForCheckpoint = - pendingStateValue?.let { Fresh(it, pendingNumRecords) } ?: Stale(currentStateValue) - currentStateValue = pendingStateValue + // Check if there is a pending state value or not. + // If not, then set() HASN'T been called since the last call to takeForCheckpoint(), + // because set() can only accept non-null state values. + // + // This means that there is nothing worth checkpointing for this particular feed. + // In that case, exit early with the current state value. + val freshStateValue: OpaqueStateValue = + pendingStateValue ?: return Stale(currentStateValue) + // This point is reached in the case where there is a pending state value. + // This means that set() HAS been called since the last call to takeForCheckpoint(). + // + // Keep a copy of the total number of records registered in all calls to set() since the + // last call to takeForCheckpoint(), this number will be returned. + val freshNumRecords: Long = pendingNumRecords + // Update current state value. + currentStateValue = freshStateValue + // Reset the pending state, which will be overwritten by the next call to set(). pendingStateValue = null pendingNumRecords = 0L - return stateForCheckpoint + // Return the latest state value as well as the total number of records seen since the + // last call to takeForCheckpoint(). + return Fresh(freshStateValue, freshNumRecords) } } diff --git a/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/StateManagerGlobalStatesTest.kt b/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/StateManagerGlobalStatesTest.kt index e1bb654f0852..15e6e5f30b4c 100644 --- a/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/StateManagerGlobalStatesTest.kt +++ b/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/StateManagerGlobalStatesTest.kt @@ -97,6 +97,67 @@ class StateManagerGlobalStatesTest { Assertions.assertEquals(emptyList(), stateManager.checkpoint()) } + @Test + @Property(name = "airbyte.connector.catalog.resource", value = "fakesource/cdc-catalog.json") + @Property( + name = "airbyte.connector.state.json", + value = """{"type": "GLOBAL", "global": { "shared_state": { "cdc": "starting" } } }""", + ) + fun testInitialSyncColdStart() { + val streams: Streams = prelude() + // test current state + Assertions.assertEquals( + Jsons.readTree("{ \"cdc\": \"starting\" }"), + stateManager.scoped(streams.global).current(), + ) + Assertions.assertNull(stateManager.scoped(streams.kv).current()) + Assertions.assertNull(stateManager.scoped(streams.events).current()) + Assertions.assertEquals(listOf(), handler.get()) + // update state manager with fake work results for the kv stream + stateManager.scoped(streams.kv).set(Jsons.readTree("{\"initial_sync\":\"ongoing\"}"), 123L) + // test checkpoint messages + val checkpointOngoing: List = stateManager.checkpoint() + Assertions.assertEquals( + listOf( + """{ + |"type":"GLOBAL", + |"global":{"shared_state":{"cdc":"starting"}, + |"stream_states":[ + |{"stream_descriptor":{"name":"KV","namespace":"PUBLIC"}, + |"stream_state":{"initial_sync":"ongoing"}} + |]},"sourceStats":{"recordCount":123.0} + |} + """.trimMargin(), + ) + .map { Jsons.readTree(it) }, + checkpointOngoing.map { Jsons.valueToTree(it) }, + ) + Assertions.assertEquals(emptyList(), stateManager.checkpoint()) + // update state manager with more fake work results for the kv stream + stateManager.scoped(streams.kv).set(Jsons.readTree("{\"initial_sync\":\"ongoing\"}"), 456L) + stateManager + .scoped(streams.kv) + .set(Jsons.readTree("{\"initial_sync\":\"completed\"}"), 789L) + // test checkpoint messages + val checkpointCompleted: List = stateManager.checkpoint() + Assertions.assertEquals( + listOf( + """{ + |"type":"GLOBAL", + |"global":{"shared_state":{"cdc":"starting"}, + |"stream_states":[ + |{"stream_descriptor":{"name":"KV","namespace":"PUBLIC"}, + |"stream_state":{"initial_sync":"completed"}} + |]},"sourceStats":{"recordCount":1245.0} + |} + """.trimMargin(), + ) + .map { Jsons.readTree(it) }, + checkpointCompleted.map { Jsons.valueToTree(it) }, + ) + Assertions.assertEquals(emptyList(), stateManager.checkpoint()) + } + @Test @Property(name = "airbyte.connector.catalog.resource", value = "fakesource/cdc-catalog.json") @Property(