Skip to content

Commit

Permalink
[FLINK-19059] Use a list + map state to reduce serdes
Browse files Browse the repository at this point in the history
  • Loading branch information
bvarghese1 committed Dec 10, 2024
1 parent d0b856b commit 2865a3f
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
Expand Down Expand Up @@ -366,6 +367,8 @@ private KeyedProcessFunction<RowData, RowData, RowData> createUnboundedOverProce
genAggsHandler,
flattenAccTypes);
case NON_TIME:
final int sortKeyIdx = orderKeys[0];
LogicalType sortKeyType = inputRowType.getTypeAt(sortKeyIdx);
final GeneratedRecordEqualiser generatedEqualiser =
new EqualiserCodeGenerator(inputRowType, ctx.classLoader())
.generateRecordEqualiser("FirstMatchingRowEqualiser");
Expand All @@ -378,14 +381,28 @@ private KeyedProcessFunction<RowData, RowData, RowData> createUnboundedOverProce
inputRowType,
SortUtil.getAscendingSortSpec(orderKeys));

final GeneratedRecordComparator keyGenRecordComparator =
ComparatorCodeGenerator.gen(
config,
ctx.classLoader(),
"KeySortComparator",
RowType.of(DataTypes.BIGINT().getLogicalType(), sortKeyType),
SortUtil.getAscendingSortSpec(orderKeys));

RowData.FieldGetter sortKeyFieldGetter =
RowData.createFieldGetter(sortKeyType, sortKeyIdx);

return new NonTimeUnboundedPrecedingFunction<>(
config.getStateRetentionTime(),
TableConfigUtils.getMaxIdleStateRetentionTime(config),
genAggsHandler,
generatedEqualiser,
genRecordComparator,
keyGenRecordComparator,
flattenAccTypes,
fieldTypes);
fieldTypes,
sortKeyFieldGetter,
sortKeyIdx);
default:
throw new TableException("Unsupported unbounded operation");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.flink.table.test.program.SourceTestStep;
import org.apache.flink.table.test.program.TableTestProgram;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;

import static org.apache.flink.table.api.config.TableConfigOptions.LOCAL_TIME_ZONE;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import org.apache.flink.FlinkVersion;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata;
import org.apache.flink.table.planner.plan.nodes.exec.common.OverAggregateTestPrograms;
import org.apache.flink.table.planner.plan.nodes.exec.testutils.RestoreTestBase;
import org.apache.flink.table.test.program.TableTestProgram;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ private void registerSinkObserver(
results.addAll(sinkTestStep.getExpectedAfterRestoreAsStrings());
}
List<String> expectedResults = getExpectedResults(sinkTestStep, tableName);
System.out.println(expectedResults);
final boolean shouldComplete =
CollectionUtils.isEqualCollection(expectedResults, results);
if (shouldComplete) {
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@
package org.apache.flink.table.runtime.operators.over;

import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.RowData.FieldGetter;
import org.apache.flink.table.data.utils.JoinedRowData;
import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore;
import org.apache.flink.table.runtime.functions.KeyedProcessFunctionWithCleanupState;
Expand All @@ -34,6 +40,7 @@
import org.apache.flink.table.runtime.generated.RecordEqualiser;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.types.RowKind;
import org.apache.flink.util.Collector;

Expand All @@ -53,12 +60,18 @@ public class NonTimeUnboundedPrecedingFunction<K>
private static final Logger LOG =
LoggerFactory.getLogger(NonTimeUnboundedPrecedingFunction.class);

private final int keyIdx;
private final int sortKeyIdx;
private final FieldGetter sortKeyFieldGetter;

private final GeneratedAggsHandleFunction genAggsHandler;
private final GeneratedRecordEqualiser generatedRecordEqualiser;
private final GeneratedRecordComparator generatedRecordComparator;
private final GeneratedRecordComparator keyGeneratedRecordComparator;

// The util to compare two rows based on the sort attribute.
private transient Comparator<RowData> sortKeyComparator;
private transient Comparator<RowData> newSortKeyComparator;
// The record equaliser used to equal RowData.
private transient RecordEqualiser equaliser;

Expand All @@ -71,6 +84,15 @@ public class NonTimeUnboundedPrecedingFunction<K>
// state to hold rows until state ttl expires
private transient ValueState<List<RowData>> inputState;

// state to hold the Long ID counter
private transient ValueState<Long> idState;

// state to hold a list of sorted keys with an artificial id
// The artificial id acts as the key in the valueMapState
private transient ValueState<List<RowData>> sortedKeyState;
// state to hold rows until state ttl expires
private transient MapState<Long, RowData> valueMapState;

protected transient AggsHandleFunction currFunction;
protected transient AggsHandleFunction prevFunction;

Expand All @@ -80,14 +102,21 @@ public NonTimeUnboundedPrecedingFunction(
GeneratedAggsHandleFunction genAggsHandler,
GeneratedRecordEqualiser genRecordEqualiser,
GeneratedRecordComparator genRecordComparator,
GeneratedRecordComparator keyGenRecordComparator,
LogicalType[] accTypes,
LogicalType[] inputFieldTypes) {
LogicalType[] inputFieldTypes,
FieldGetter sortKeyFieldGetter,
int sortKeyIdx) {
super(minRetentionTime, maxRetentionTime);
this.genAggsHandler = genAggsHandler;
this.generatedRecordEqualiser = genRecordEqualiser;
this.generatedRecordComparator = genRecordComparator;
this.keyGeneratedRecordComparator = keyGenRecordComparator;
this.accTypes = accTypes;
this.inputFieldTypes = inputFieldTypes;
this.sortKeyFieldGetter = sortKeyFieldGetter;
this.sortKeyIdx = sortKeyIdx;
this.keyIdx = 0;
}

@Override
Expand All @@ -102,14 +131,18 @@ public void open(OpenContext openContext) throws Exception {
// Initialize output record
output = new JoinedRowData();

// Intialize record equaliser
// Initialize record equaliser
equaliser =
generatedRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader());

// Initialize sort comparator
sortKeyComparator =
generatedRecordComparator.newInstance(getRuntimeContext().getUserCodeClassLoader());

newSortKeyComparator =
keyGeneratedRecordComparator.newInstance(
getRuntimeContext().getUserCodeClassLoader());

// Initialize accumulator state
InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes);
ValueStateDescriptor<RowData> accStateDesc =
Expand All @@ -126,12 +159,34 @@ public void open(OpenContext openContext) throws Exception {
// Initialize state which maintains records in sorted(ASC) order
inputState = getRuntimeContext().getState(inputStateDescriptor);

ValueStateDescriptor<Long> idStateDescriptor =
new ValueStateDescriptor<Long>("idState", Long.class);
idState = getRuntimeContext().getState(idStateDescriptor);

LogicalType idType = DataTypes.BIGINT().getLogicalType();
LogicalType sortKeyType = inputType.toRowType().getTypeAt(sortKeyIdx);
RowType sortedRow = RowType.of(idType, sortKeyType);
LogicalType[] sortedRowTypes = sortedRow.getChildren().toArray(new LogicalType[0]);
InternalTypeInfo<RowData> sortedKeyRowType = InternalTypeInfo.ofFields(sortedRowTypes);
ListTypeInfo<RowData> sortedKeyTypeInfo = new ListTypeInfo<>(sortedKeyRowType);
// Initialize state which maintains a sorted list of pair(ID, sortKey)
ValueStateDescriptor<List<RowData>> sortedKeyStateDescriptor =
new ValueStateDescriptor<List<RowData>>("sortedKeyState", sortedKeyTypeInfo);
sortedKeyState = getRuntimeContext().getState(sortedKeyStateDescriptor);

MapStateDescriptor<Long, RowData> valueStateDescriptor =
new MapStateDescriptor<Long, RowData>("valueState", Types.LONG, inputType);
valueMapState = getRuntimeContext().getMapState(valueStateDescriptor);

initCleanupTimeState("NonTimeUnboundedPrecedingFunctionCleanupTime");
}

/**
* Puts an element from the input stream into state if it is not late. Registers a timer for the
* next watermark.
* Puts an element from the input stream into state. Emits the aggregated value for the newly
* inserted element. For append stream emits updates(UB, UA) for all elements which are present
* after the newly inserted element. For retract stream emits an UB for the element and
* thereafter emits updates(UB, UA) for all elements which are present after the retracted
* element.
*
* @param input The input value.
* @param ctx A {@link Context} that allows querying the timestamp of the element and getting
Expand Down Expand Up @@ -170,9 +225,11 @@ public void processElement(
RowKind rowKind = input.getRowKind();

if (rowKind == RowKind.INSERT || rowKind == RowKind.UPDATE_AFTER) {
insertIntoSortedList(input, out);
// insertIntoSortedList(input, out);
insertIntoSortedListOptimized(input, out);
} else if (rowKind == RowKind.DELETE || rowKind == RowKind.UPDATE_BEFORE) {
removeFromSortedList(input, out);
// removeFromSortedList(input, out);
removeFromSortedListOptimized(input, out);
}

// Reset acc state since we can have out of order inserts into the ordered list
Expand Down Expand Up @@ -236,6 +293,75 @@ private void insertIntoSortedList(RowData rowData, Collector<RowData> out) throw
}
}

private void insertIntoSortedListOptimized(RowData input, Collector<RowData> out)
throws Exception {
List<RowData> sortedKeyList = sortedKeyState.value();
if (sortedKeyList == null) {
sortedKeyList = new ArrayList<>();
}
Long id = idState.value();
if (id == null) {
id = 0L;
}
boolean isInserted = false;
RowKind origRowKind = input.getRowKind();
input.setRowKind(RowKind.INSERT);
ListIterator<RowData> iterator = sortedKeyList.listIterator();

while (iterator.hasNext()) {
RowData curKey = iterator.next(); // (ID, sortKey)
RowData inputKey = GenericRowData.of(-1L, sortKeyFieldGetter.getFieldOrNull(input));

if (newSortKeyComparator.compare(curKey, inputKey) > 0) {
iterator.previous();
iterator.add(GenericRowData.of(id, sortKeyFieldGetter.getFieldOrNull(input)));
valueMapState.put(id, input);
isInserted = true;
id++;
break;
}
// Can also add the accKey to the sortedKeyList to avoid reading from the valueMapState
RowData curRow = valueMapState.get(curKey.getLong(keyIdx));
currFunction.accumulate(curRow);
prevFunction.accumulate(curRow);
}

// Add to the end of the list
if (!isInserted) {
iterator.add(GenericRowData.of(id, sortKeyFieldGetter.getFieldOrNull(input)));
valueMapState.put(id, input);
id++;
}

// Only accumulate rowData with currFunction
currFunction.accumulate(input);

// Update sorted key state with the newly inserted row's key
sortedKeyState.update(sortedKeyList);
idState.update(id);

// prepare output row
output.setRowKind(origRowKind);
output.replace(input, currFunction.getValue());
out.collect(output);

// Emit updated agg value for all records after newly inserted row
while (iterator.hasNext()) {
RowData curKey = iterator.next();
RowData curValue = valueMapState.get(curKey.getLong(keyIdx));
currFunction.accumulate(curValue);
prevFunction.accumulate(curValue);
// Generate UPDATE_BEFORE
output.setRowKind(RowKind.UPDATE_BEFORE);
output.replace(curValue, prevFunction.getValue());
out.collect(output);
// Generate UPDATE_AFTER
output.setRowKind(RowKind.UPDATE_AFTER);
output.replace(curValue, currFunction.getValue());
out.collect(output);
}
}

private void removeFromSortedList(RowData rowData, Collector<RowData> out) throws Exception {
boolean isRetracted = false;
rowData.setRowKind(RowKind.INSERT);
Expand Down Expand Up @@ -270,8 +396,48 @@ private void removeFromSortedList(RowData rowData, Collector<RowData> out) throw
inputState.update(rowList);
}

private void removeFromSortedListOptimized(RowData input, Collector<RowData> out)
throws Exception {
boolean isRetracted = false;
input.setRowKind(RowKind.INSERT);
List<RowData> sortedKeyList = sortedKeyState.value();
ListIterator<RowData> iterator = sortedKeyList.listIterator();

while (iterator.hasNext()) {
RowData curKey = iterator.next();
RowData curValue = valueMapState.get(curKey.getLong(keyIdx));
currFunction.accumulate(curValue);
prevFunction.accumulate(curValue);
if (isRetracted) {
// Emit updated agg value for all records after retraction
output.setRowKind(RowKind.UPDATE_BEFORE);
output.replace(curValue, prevFunction.getValue());
out.collect(output);

output.setRowKind(RowKind.UPDATE_AFTER);
output.replace(curValue, currFunction.getValue());
out.collect(output);
} else if (equaliser.equals(curValue, input)) {
// Retract record
output.setRowKind(RowKind.UPDATE_BEFORE);
output.replace(input, currFunction.getValue());
out.collect(output);
iterator.remove();
valueMapState.remove(curKey.getLong(keyIdx));
currFunction.retract(curValue);
isRetracted = true;
}
}

// Update sorted key state without the retracted row
sortedKeyState.update(sortedKeyList);
}

@Override
public void close() throws Exception {
if (null != prevFunction) {
prevFunction.close();
}
if (null != currFunction) {
currFunction.close();
}
Expand Down
Loading

0 comments on commit 2865a3f

Please sign in to comment.