Skip to content

Commit

Permalink
[flink] Extract SourceListState as public utility class
Browse files Browse the repository at this point in the history
  • Loading branch information
yunfengzhou-hub committed Dec 12, 2024
1 parent 237850d commit c558e9c
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.paimon.flink.source.AbstractNonCoordinatedSource;
import org.apache.paimon.flink.source.AbstractNonCoordinatedSourceReader;
import org.apache.paimon.flink.source.SimpleSourceSplit;
import org.apache.paimon.flink.source.SplitListState;

import org.apache.flink.api.connector.source.Boundedness;
import org.apache.flink.api.connector.source.ReaderOutput;
Expand All @@ -29,7 +30,6 @@
import org.apache.flink.core.io.InputStatus;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
Expand Down Expand Up @@ -66,6 +66,8 @@ private static class Reader extends AbstractNonCoordinatedSourceReader<TestCdcEv
private final int totalSubtasks;

private final LinkedList<TestCdcEvent> events;
private final SplitListState<Integer> remainingEventsCount =
new SplitListState<>("events", x -> Integer.toString(x), Integer::parseInt);

private final int numRecordsPerCheckpoint;
private final AtomicInteger recordsThisCheckpoint;
Expand Down Expand Up @@ -104,17 +106,18 @@ public InputStatus pollNext(ReaderOutput<TestCdcEvent> readerOutput) throws Exce
@Override
public List<SimpleSourceSplit> snapshotState(long l) {
recordsThisCheckpoint.set(0);
return Collections.singletonList(
new SimpleSourceSplit(Integer.toString(events.size())));
remainingEventsCount.clear();
remainingEventsCount.add(events.size());
return remainingEventsCount.snapshotState();
}

@Override
public void addSplits(List<SimpleSourceSplit> list) {
int count =
list.stream()
.map(x -> Integer.parseInt(x.value()))
.reduce(Integer::sum)
.orElse(0);
remainingEventsCount.restoreState(list);
int count = 0;
for (int c : remainingEventsCount.get()) {
count += c;
}
while (events.size() > count) {
events.poll();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public SimpleSourceSplit(String value) {
this(UUID.randomUUID().toString(), value);
}

SimpleSourceSplit(String splitId, String value) {
public SimpleSourceSplit(String splitId, String value) {
this.splitId = splitId;
this.value = value;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.flink.source;

import org.apache.paimon.utils.Preconditions;

import org.apache.flink.api.common.state.ListState;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* Utility class to provide {@link ListState}-like experience for sources that use {@link
* SimpleSourceSplit}.
*/
public class SplitListState<T> implements ListState<T> {
private final String splitPrefix;
private final List<T> values;
private final Function<T, String> serializer;
private final Function<String, T> deserializer;

public SplitListState(
String identifier, Function<T, String> serializer, Function<String, T> deserializer) {
Preconditions.checkArgument(
!Character.isDigit(identifier.charAt(0)),
String.format("Identifier %s should not start with digits.", identifier));
this.splitPrefix = identifier.length() + identifier;
this.serializer = serializer;
this.deserializer = deserializer;
this.values = new ArrayList<>();
}

@Override
public void add(T value) {
values.add(value);
}

@Override
public List<T> get() {
return new ArrayList<>(values);
}

@Override
public void update(List<T> values) {
this.values.clear();
this.values.addAll(values);
}

@Override
public void addAll(List<T> values) throws Exception {
this.values.addAll(values);
}

@Override
public void clear() {
values.clear();
}

public List<SimpleSourceSplit> snapshotState() {
return values.stream()
.map(x -> new SimpleSourceSplit(splitPrefix + serializer.apply(x)))
.collect(Collectors.toList());
}

public void restoreState(List<SimpleSourceSplit> splits) {
values.clear();
splits.stream()
.map(SimpleSourceSplit::value)
.filter(x -> x.startsWith(splitPrefix))
.map(x -> x.substring(splitPrefix.length()))
.map(this.deserializer)
.forEach(values::add);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.paimon.flink.source.AbstractNonCoordinatedSource;
import org.apache.paimon.flink.source.AbstractNonCoordinatedSourceReader;
import org.apache.paimon.flink.source.SimpleSourceSplit;
import org.apache.paimon.flink.source.SplitListState;
import org.apache.paimon.flink.utils.JavaTypeInfo;
import org.apache.paimon.table.BucketMode;
import org.apache.paimon.table.sink.ChannelComputer;
Expand Down Expand Up @@ -52,8 +53,6 @@
import java.util.NavigableMap;
import java.util.OptionalLong;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.apache.paimon.table.BucketMode.BUCKET_UNAWARE;

Expand Down Expand Up @@ -111,10 +110,10 @@ private class Reader extends AbstractNonCoordinatedSourceReader<Split> {
private static final String NEXT_SNAPSHOT_STATE = "NSS";

private final StreamTableScan scan = readBuilder.newStreamScan();
private final SplitState<Long> checkpointState =
new SplitState<>(CHECKPOINT_STATE, x -> Long.toString(x), Long::parseLong);
private final SplitState<Tuple2<Long, Long>> nextSnapshotState =
new SplitState<>(
private final SplitListState<Long> checkpointState =
new SplitListState<>(CHECKPOINT_STATE, x -> Long.toString(x), Long::parseLong);
private final SplitListState<Tuple2<Long, Long>> nextSnapshotState =
new SplitListState<>(
NEXT_SNAPSHOT_STATE,
x -> x.f0 + ":" + x.f1,
x ->
Expand Down Expand Up @@ -203,56 +202,6 @@ public InputStatus pollNext(ReaderOutput<Split> readerOutput) throws Exception {
}
}

private static class SplitState<T> {
private final String identifier;
private final List<T> values;
private final Function<T, String> serializer;
private final Function<String, T> deserializer;

private SplitState(
String identifier,
Function<T, String> serializer,
Function<String, T> deserializer) {
this.identifier = identifier;
this.serializer = serializer;
this.deserializer = deserializer;
this.values = new ArrayList<>();
}

private void add(T value) {
values.add(value);
}

private List<T> get() {
return new ArrayList<>(values);
}

private void update(List<T> values) {
this.values.clear();
this.values.addAll(values);
}

private void clear() {
values.clear();
}

private List<SimpleSourceSplit> snapshotState() {
return values.stream()
.map(x -> new SimpleSourceSplit(identifier + serializer.apply(x)))
.collect(Collectors.toList());
}

private void restoreState(List<SimpleSourceSplit> splits) {
values.clear();
splits.stream()
.map(SimpleSourceSplit::value)
.filter(x -> x.startsWith(identifier))
.map(x -> x.substring(identifier.length()))
.map(this.deserializer)
.forEach(values::add);
}
}

public static DataStream<RowData> buildSource(
StreamExecutionEnvironment env,
String name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.paimon.flink.source.AbstractNonCoordinatedSource;
import org.apache.paimon.flink.source.AbstractNonCoordinatedSourceReader;
import org.apache.paimon.flink.source.SimpleSourceSplit;
import org.apache.paimon.flink.source.SplitListState;
import org.apache.paimon.utils.Preconditions;

import org.apache.flink.api.connector.source.Boundedness;
Expand All @@ -30,9 +31,8 @@
import org.apache.flink.core.io.InputStatus;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;

import java.util.Collections;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

/**
* A stream source that: 1) emits a list of elements without allowing checkpoints, 2) then waits for
Expand Down Expand Up @@ -71,6 +71,9 @@ private static class Reader<T> extends AbstractNonCoordinatedSourceReader<T> {

private final boolean emitOnce;

private final SplitListState<Integer> checkpointedState =
new SplitListState<>("emit-times", x -> Integer.toString(x), Integer::parseInt);

private int numTimesEmitted = 0;

private int numCheckpointsComplete;
Expand Down Expand Up @@ -117,17 +120,18 @@ public synchronized InputStatus pollNext(ReaderOutput<T> readerOutput) {

@Override
public void addSplits(List<SimpleSourceSplit> list) {
List<Integer> retrievedStates =
list.stream()
.map(x -> Integer.parseInt(x.value()))
.collect(Collectors.toList());
checkpointedState.restoreState(list);
List<Integer> retrievedStates = new ArrayList<>();
for (Integer entry : this.checkpointedState.get()) {
retrievedStates.add(entry);
}

// given that the parallelism of the function is 1, we can only have 1 state
Preconditions.checkArgument(
retrievedStates.size() == 1,
getClass().getSimpleName() + " retrieved invalid state.");

numTimesEmitted = retrievedStates.get(0);
this.numTimesEmitted = retrievedStates.get(0);
Preconditions.checkArgument(
numTimesEmitted <= 2,
getClass().getSimpleName()
Expand All @@ -137,8 +141,9 @@ public void addSplits(List<SimpleSourceSplit> list) {

@Override
public List<SimpleSourceSplit> snapshotState(long l) {
return Collections.singletonList(
new SimpleSourceSplit(Integer.toString(numTimesEmitted)));
this.checkpointedState.clear();
this.checkpointedState.add(this.numTimesEmitted);
return this.checkpointedState.snapshotState();
}

@Override
Expand Down

0 comments on commit c558e9c

Please sign in to comment.