Skip to content

Commit

Permalink
[FLINK-36881][table] Introduce GroupTableAggFunction in GroupTableAgg…
Browse files Browse the repository at this point in the history
…regate with Async State API
  • Loading branch information
Au-Miner committed Dec 13, 2024
1 parent 529f640 commit 6eb3463
Show file tree
Hide file tree
Showing 9 changed files with 547 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,19 @@ public void drainStateRequests() {
public void finish() throws Exception {
super.finish();
if (isAsyncStateProcessingEnabled()) {
asyncExecutionController.drainInflightRecords(0);
if (asyncExecutionController != null) {
asyncExecutionController.drainInflightRecords(0);
}
}
}

@Override
public void close() throws Exception {
super.close();
if (isAsyncStateProcessingEnabled()) {
asyncExecutionController.close();
if (asyncExecutionController != null) {
asyncExecutionController.close();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.operators.asyncprocessing;

import org.apache.flink.annotation.Internal;
import org.apache.flink.runtime.asyncprocessing.operators.AbstractAsyncStateUdfStreamOperator;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.streaming.api.SimpleTimerService;
import org.apache.flink.streaming.api.TimeDomain;
import org.apache.flink.streaming.api.TimerService;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.operators.InternalTimerService;
import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.TimestampedCollector;
import org.apache.flink.streaming.api.operators.Triggerable;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.OutputTag;

import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;

/**
* A {@link StreamOperator} for executing {@link KeyedProcessFunction KeyedProcessFunctions}.
*
* <p>This class is nearly identical with {@link KeyedProcessOperator}, but extending from {@link
* AbstractAsyncStateUdfStreamOperator} to integrate with asynchronous state access. Another
* difference is this class is internal.
*/
@Internal
public class AsyncStateKeyedProcessOperator<K, IN, OUT>
extends AbstractAsyncStateUdfStreamOperator<OUT, KeyedProcessFunction<K, IN, OUT>>
implements OneInputStreamOperator<IN, OUT>, Triggerable<K, VoidNamespace> {

private static final long serialVersionUID = 1L;

private transient TimestampedCollector<OUT> collector;

private transient ContextImpl context;

private transient OnTimerContextImpl onTimerContext;

public AsyncStateKeyedProcessOperator(KeyedProcessFunction<K, IN, OUT> function) {
super(function);
}

@Override
public void open() throws Exception {
super.open();
collector = new TimestampedCollector<>(output);

InternalTimerService<VoidNamespace> internalTimerService =
getInternalTimerService("user-timers", VoidNamespaceSerializer.INSTANCE, this);

TimerService timerService = new SimpleTimerService(internalTimerService);

context = new ContextImpl(userFunction, timerService);
onTimerContext = new OnTimerContextImpl(userFunction, timerService);
}

@Override
public void onEventTime(InternalTimer<K, VoidNamespace> timer) throws Exception {
collector.setAbsoluteTimestamp(timer.getTimestamp());
invokeUserFunction(TimeDomain.EVENT_TIME, timer);
}

@Override
public void onProcessingTime(InternalTimer<K, VoidNamespace> timer) throws Exception {
collector.eraseTimestamp();
invokeUserFunction(TimeDomain.PROCESSING_TIME, timer);
}

@Override
public void processElement(StreamRecord<IN> element) throws Exception {
collector.setTimestamp(element);
context.element = element;
userFunction.processElement(element.getValue(), context, collector);
context.element = null;
}

private void invokeUserFunction(TimeDomain timeDomain, InternalTimer<K, VoidNamespace> timer)
throws Exception {
onTimerContext.timeDomain = timeDomain;
onTimerContext.timer = timer;
userFunction.onTimer(timer.getTimestamp(), onTimerContext, collector);
onTimerContext.timeDomain = null;
onTimerContext.timer = null;
}

private class ContextImpl extends KeyedProcessFunction<K, IN, OUT>.Context {

private final TimerService timerService;

private StreamRecord<IN> element;

ContextImpl(KeyedProcessFunction<K, IN, OUT> function, TimerService timerService) {
function.super();
this.timerService = checkNotNull(timerService);
}

@Override
public Long timestamp() {
checkState(element != null);

if (element.hasTimestamp()) {
return element.getTimestamp();
} else {
return null;
}
}

@Override
public TimerService timerService() {
return timerService;
}

@Override
public <X> void output(OutputTag<X> outputTag, X value) {
if (outputTag == null) {
throw new IllegalArgumentException("OutputTag must not be null.");
}

output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp()));
}

@Override
@SuppressWarnings("unchecked")
public K getCurrentKey() {
return (K) AsyncStateKeyedProcessOperator.this.getCurrentKey();
}
}

private class OnTimerContextImpl extends KeyedProcessFunction<K, IN, OUT>.OnTimerContext {

private final TimerService timerService;

private TimeDomain timeDomain;

private InternalTimer<K, VoidNamespace> timer;

OnTimerContextImpl(KeyedProcessFunction<K, IN, OUT> function, TimerService timerService) {
function.super();
this.timerService = checkNotNull(timerService);
}

@Override
public Long timestamp() {
checkState(timer != null);
return timer.getTimestamp();
}

@Override
public TimerService timerService() {
return timerService;
}

@Override
public <X> void output(OutputTag<X> outputTag, X value) {
if (outputTag == null) {
throw new IllegalArgumentException("OutputTag must not be null.");
}

output.collect(outputTag, new StreamRecord<>(value, timer.getTimestamp()));
}

@Override
public TimeDomain timeDomain() {
checkState(timeDomain != null);
return timeDomain;
}

@Override
public K getCurrentKey() {
return timer.getKey();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.asyncprocessing.AsyncStateKeyedProcessOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
Expand All @@ -42,6 +43,7 @@
import org.apache.flink.table.runtime.generated.GeneratedTableAggsHandleFunction;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.operators.aggregate.GroupTableAggFunction;
import org.apache.flink.table.runtime.operators.aggregate.asyncprocessing.AsyncStateGroupTableAggFunction;
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
Expand Down Expand Up @@ -153,16 +155,30 @@ protected Transformation<RowData> translateToPlanInternal(
.map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType)
.toArray(LogicalType[]::new);
final int inputCountIndex = aggInfoList.getIndexOfCountStar();
final GroupTableAggFunction aggFunction =
new GroupTableAggFunction(
aggsHandler,
accTypes,
inputCountIndex,
generateUpdateBefore,
generator.isIncrementalUpdate(),
config.getStateRetentionTime());
final OneInputStreamOperator<RowData, RowData> operator =
new KeyedProcessOperator<>(aggFunction);
final boolean enableAsyncState = AggregateUtil.enableAsyncState(config, aggInfoList);

final OneInputStreamOperator<RowData, RowData> operator;
if (enableAsyncState) {
final AsyncStateGroupTableAggFunction aggFunction =
new AsyncStateGroupTableAggFunction(
aggsHandler,
accTypes,
inputCountIndex,
generateUpdateBefore,
generator.isIncrementalUpdate(),
config.getStateRetentionTime());
operator = new AsyncStateKeyedProcessOperator<>(aggFunction);
} else {
final GroupTableAggFunction aggFunction =
new GroupTableAggFunction(
aggsHandler,
accTypes,
inputCountIndex,
generateUpdateBefore,
generator.isIncrementalUpdate(),
config.getStateRetentionTime());
operator = new KeyedProcessOperator<>(aggFunction);
}

// partitioned aggregation
final OneInputTransformation<RowData, RowData> transform =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
*/
package org.apache.flink.table.planner.plan.utils

import org.apache.flink.configuration.ReadableConfig
import org.apache.flink.table.api.TableException
import org.apache.flink.table.api.config.ExecutionConfigOptions
import org.apache.flink.table.expressions._
import org.apache.flink.table.expressions.ExpressionUtils.extractValue
import org.apache.flink.table.functions._
Expand Down Expand Up @@ -1175,4 +1177,19 @@ object AggregateUtil extends Enumeration {
})
.exists(_.getKind == FunctionKind.TABLE_AGGREGATE)
}

def enableAsyncState(config: ReadableConfig, aggInfoList: AggregateInfoList): Boolean = {
// Currently, we do not support async state with agg functions that include DataView.
val containsDataViewInAggInfo =
aggInfoList.aggInfos.toStream.stream().anyMatch(agg => !agg.viewSpecs.isEmpty)

val containsDataViewInDistinctInfo =
aggInfoList.distinctInfos.toStream
.stream()
.anyMatch(distinct => distinct.dataViewSpec.isDefined)

config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_STATE_ENABLED) &&
!containsDataViewInAggInfo &&
!containsDataViewInDistinctInfo
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,41 @@ import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness
import org.apache.flink.table.api.{EnvironmentSettings, _}
import org.apache.flink.table.api.bridge.scala.{dataStreamConversions, tableConversions}
import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl
import org.apache.flink.table.api.config.ExecutionConfigOptions
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.runtime.utils.StreamingEnvUtil
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode}
import org.apache.flink.table.planner.utils.{Top3WithMapView, Top3WithRetractInput}
import org.apache.flink.table.runtime.typeutils.RowDataSerializer
import org.apache.flink.table.runtime.util.RowDataHarnessAssertor
import org.apache.flink.table.runtime.util.StreamRecordUtils.{deleteRecord, insertRecord}
import org.apache.flink.table.types.logical.LogicalType
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension
import org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedTestExtension, Parameters}
import org.apache.flink.types.Row

import org.junit.jupiter.api.{BeforeEach, TestTemplate}
import org.junit.jupiter.api.extension.ExtendWith

import java.lang.{Integer => JInt}
import java.time.Duration
import java.util.{Collection => JCollection}
import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.JavaConversions._
import scala.collection.mutable

@ExtendWith(Array(classOf[ParameterizedTestExtension]))
class TableAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase(mode) {
class TableAggregateHarnessTest(mode: StateBackendMode, enableAsyncState: Boolean)
extends HarnessTestBase(mode) {

@BeforeEach
override def before(): Unit = {
super.before()
val setting = EnvironmentSettings.newInstance().inStreamingMode().build()
this.tEnv = StreamTableEnvironmentImpl.create(env, setting)
tEnv.getConfig.set(
ExecutionConfigOptions.TABLE_EXEC_ASYNC_STATE_ENABLED,
Boolean.box(enableAsyncState))
}

val data = new mutable.MutableList[(Int, Int)]
Expand Down Expand Up @@ -183,3 +190,15 @@ class TableAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase(
testHarness.close()
}
}

object TableAggregateHarnessTest {

@Parameters(name = "StateBackend={0}, EnableAsyncState={1}")
def parameters(): JCollection[Array[java.lang.Object]] = {
Seq[Array[AnyRef]](
Array(HEAP_BACKEND, Boolean.box(false)),
Array(HEAP_BACKEND, Boolean.box(true)),
Array(ROCKSDB_BACKEND, Boolean.box(false))
)
}
}
Loading

0 comments on commit 6eb3463

Please sign in to comment.