diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumeratorTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumeratorTest.java index 15276d9b67ca1..308167b471ffc 100644 --- a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumeratorTest.java +++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumeratorTest.java @@ -54,8 +54,14 @@ class HybridSourceSplitEnumeratorTest { private HybridSourceSplit splitFromSource1; private void setupEnumeratorAndTriggerSourceSwitch() { + setupEnumeratorAndTriggerSourceSwitch( + HybridSource.builder(MOCK_SOURCE).addSource(MOCK_SOURCE).build()); + } + + private HybridSourceSplitEnumerator setupEnumeratorAndTriggerSourceSwitch( + HybridSource hybridSource) { context = new MockSplitEnumeratorContext<>(2); - source = HybridSource.builder(MOCK_SOURCE).addSource(MOCK_SOURCE).build(); + source = hybridSource; enumerator = (HybridSourceSplitEnumerator) source.createEnumerator(context); enumerator.start(); @@ -86,6 +92,7 @@ private void setupEnumeratorAndTriggerSourceSwitch() { assertThat(splitFromSource1.sourceIndex()).isEqualTo(1); enumerator.handleSourceEvent(SUBTASK1, new SourceReaderFinishedEvent(SUBTASK1)); assertThat(getCurrentSourceIndex(enumerator)).as("reader without assignment").isEqualTo(1); + return enumerator; } @Test @@ -232,6 +239,28 @@ void testRestoreEnumerator() throws Exception { .hasSize(1); } + @Test + public void testRestoreEnumeratorWithSwitchContextSource() throws Exception { + HybridSource hybridSource = + HybridSource.builder(MOCK_SOURCE) + .addSource( + switchContext -> { + assertThat(switchContext.getPreviousEnumerator()) + .describedAs( + "Previous enumerator is null, cannot derive start position for next source") + .isNotNull(); + return MOCK_SOURCE; + }, + MOCK_SOURCE.getBoundedness()) + .build(); + enumerator = setupEnumeratorAndTriggerSourceSwitch(hybridSource); + HybridSourceEnumeratorState enumeratorState = enumerator.snapshotState(0); + assertThat(enumeratorState.getCurrentSourceIndex()).isEqualTo(1); + enumerator = + (HybridSourceSplitEnumerator) source.restoreEnumerator(context, enumeratorState); + enumerator.start(); + } + @Test void testRestoreEnumeratorAfterFirstSourceWithoutRestoredSplits() throws Exception { setupEnumeratorAndTriggerSourceSwitch();