From ecc318302b67fe536d5d8921b5ba6514e90008f8 Mon Sep 17 00:00:00 2001 From: JunRuiLee Date: Thu, 12 Dec 2024 17:44:58 +0800 Subject: [PATCH] [FLINK-36067][runtime] Support optimize stream graph based on input info. --- .../executiongraph/IntermediateResult.java | 7 + .../IntermediateResultInfo.java | 8 ++ .../VertexInputInfoComputationUtils.java | 31 ++++- .../runtime/jobgraph/IntermediateDataSet.java | 12 ++ .../AbstractBlockingResultInfo.java | 12 +- .../adaptivebatch/AdaptiveBatchScheduler.java | 33 ++++- .../AdaptiveExecutionHandlerFactory.java | 4 +- .../AllToAllBlockingResultInfo.java | 42 +++++- .../adaptivebatch/BlockingResultInfo.java | 10 ++ .../DefaultAdaptiveExecutionHandler.java | 48 ++++++- ...VertexParallelismAndInputInfosDecider.java | 6 +- .../PointwiseBlockingResultInfo.java | 20 ++- .../api/graph/AdaptiveGraphManager.java | 7 +- .../api/graph/DefaultStreamGraphContext.java | 33 ++++- .../VertexInputInfoComputationUtilsTest.java | 90 +++++++++---- .../AllToAllBlockingResultInfoTest.java | 27 ++-- .../DefaultAdaptiveExecutionHandlerTest.java | 121 +++++++++++++++++- ...exParallelismAndInputInfosDeciderTest.java | 90 ++++++++++--- .../StreamGraphOptimizerTest.java | 1 - .../AdaptiveBatchSchedulerITCase.java | 112 ++++++++++++++++ 20 files changed, 628 insertions(+), 86 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java index f00539b53070d..68cbe8d09ab37 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java @@ -63,6 +63,7 @@ public class IntermediateResult { private final int numParallelProducers; private final ExecutionPlanSchedulingContext executionPlanSchedulingContext; + private final boolean produceBroadcastResult; private int partitionsAssigned; @@ -102,6 +103,8 @@ public IntermediateResult( this.shuffleDescriptorCache = new HashMap<>(); this.executionPlanSchedulingContext = checkNotNull(executionPlanSchedulingContext); + + this.produceBroadcastResult = intermediateDataSet.isBroadcast(); } public boolean areAllConsumerVerticesCreated() { @@ -207,6 +210,10 @@ public boolean isForward() { return intermediateDataSet.isForward(); } + public boolean isEveryConsumerConsumeAllSubPartitions() { + return !produceBroadcastResult && intermediateDataSet.isBroadcast(); + } + public int getConnectionIndex() { return connectionIndex; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java index 26829893b5a3b..312321a9a4e8a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultInfo.java @@ -35,6 +35,14 @@ public interface IntermediateResultInfo { */ boolean isBroadcast(); + /** + * Indicates whether every downstream consumer needs to consume all produced sub-partitions. + * + * @return true if every downstream consumer needs to consume all produced sub-partitions, false + * otherwise. + */ + boolean isEveryConsumerConsumeAllSubPartitions(); + /** * Whether it is a pointwise result. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java index 680a0bb16347c..cef2b93fba36a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java @@ -84,7 +84,8 @@ public static Map computeVertexInputI parallelism, input::getNumSubpartitions, isDynamicGraph, - input.isBroadcast())); + input.isBroadcast(), + input.isEveryConsumerConsumeAllSubPartitions())); } } @@ -124,6 +125,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise( 1, () -> numOfSubpartitionsRetriever.apply(start), isDynamicGraph, + false, false); executionVertexInputInfos.add( new ExecutionVertexInputInfo(index, partitionRange, subpartitionRange)); @@ -145,6 +147,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise( numConsumers, () -> numOfSubpartitionsRetriever.apply(finalPartitionNum), isDynamicGraph, + false, false); executionVertexInputInfos.add( new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange)); @@ -165,6 +168,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise( * @param numOfSubpartitionsRetriever a retriever to get the number of subpartitions * @param isDynamicGraph whether is dynamic graph * @param isBroadcast whether the edge is broadcast + * @param consumeAllSubpartitions whether the edge should consume all subpartitions * @return the computed {@link JobVertexInputInfo} */ static JobVertexInputInfo computeVertexInputInfoForAllToAll( @@ -172,7 +176,8 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll( int targetCount, Function numOfSubpartitionsRetriever, boolean isDynamicGraph, - boolean isBroadcast) { + boolean isBroadcast, + boolean consumeAllSubpartitions) { final List executionVertexInputInfos = new ArrayList<>(); IndexRange partitionRange = new IndexRange(0, sourceCount - 1); for (int i = 0; i < targetCount; ++i) { @@ -182,7 +187,8 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll( targetCount, () -> numOfSubpartitionsRetriever.apply(0), isDynamicGraph, - isBroadcast); + isBroadcast, + consumeAllSubpartitions); executionVertexInputInfos.add( new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange)); } @@ -199,6 +205,7 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll( * @param numOfSubpartitionsSupplier a supplier to get the number of subpartitions * @param isDynamicGraph whether is dynamic graph * @param isBroadcast whether the edge is broadcast + * @param consumeAllSubpartitions whether the edge should consume all subpartitions * @return the computed subpartition range */ @VisibleForTesting @@ -207,16 +214,21 @@ static IndexRange computeConsumedSubpartitionRange( int numConsumers, Supplier numOfSubpartitionsSupplier, boolean isDynamicGraph, - boolean isBroadcast) { + boolean isBroadcast, + boolean consumeAllSubpartitions) { int consumerIndex = consumerSubtaskIndex % numConsumers; if (!isDynamicGraph) { return new IndexRange(consumerIndex, consumerIndex); } else { int numSubpartitions = numOfSubpartitionsSupplier.get(); if (isBroadcast) { - // broadcast results have only one subpartition, and be consumed multiple times. - checkArgument(numSubpartitions == 1); - return new IndexRange(0, 0); + if (consumeAllSubpartitions) { + return new IndexRange(0, numSubpartitions - 1); + } else { + // broadcast results have only one subpartition, and be consumed multiple times. + checkArgument(numSubpartitions == 1); + return new IndexRange(0, 0); + } } else { checkArgument(consumerIndex < numConsumers); checkArgument(numConsumers <= numSubpartitions); @@ -246,6 +258,11 @@ public boolean isBroadcast() { return intermediateResult.isBroadcast(); } + @Override + public boolean isEveryConsumerConsumeAllSubPartitions() { + return intermediateResult.isEveryConsumerConsumeAllSubPartitions(); + } + @Override public boolean isPointwise() { return intermediateResult.getConsumingDistributionPattern() diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java index c5d1187d23039..ec73f25e2838e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java @@ -134,6 +134,18 @@ public void configure( } } + public void updateOutputPattern( + DistributionPattern distributionPattern, boolean isBroadcast, boolean isForward) { + checkState(consumers.isEmpty(), "The output job edges have already been added."); + checkState( + numJobEdgesToCreate == 1, + "Modification is not allowed when the subscribing output is reused."); + + this.distributionPattern = distributionPattern; + this.isBroadcast = isBroadcast; + this.isForward = isForward; + } + public void increaseNumJobEdgesToCreate() { this.numJobEdgesToCreate++; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java index 33147bcdc1601..515480a06afc0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java @@ -44,11 +44,14 @@ abstract class AbstractBlockingResultInfo implements BlockingResultInfo { protected final Map subpartitionBytesByPartitionIndex; AbstractBlockingResultInfo( - IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) { + IntermediateDataSetID resultId, + int numOfPartitions, + int numOfSubpartitions, + Map subpartitionBytesByPartitionIndex) { this.resultId = checkNotNull(resultId); this.numOfPartitions = numOfPartitions; this.numOfSubpartitions = numOfSubpartitions; - this.subpartitionBytesByPartitionIndex = new HashMap<>(); + this.subpartitionBytesByPartitionIndex = subpartitionBytesByPartitionIndex; } @Override @@ -72,4 +75,9 @@ public void resetPartitionInfo(int partitionIndex) { int getNumOfRecordedPartitions() { return subpartitionBytesByPartitionIndex.size(); } + + @Override + public Map getSubpartitionBytesByPartitionIndex() { + return new HashMap<>(subpartitionBytesByPartitionIndex); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java index d8708634ff75f..2d9723bc0d3a5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java @@ -274,9 +274,11 @@ public void onNewJobVerticesAdded(List newVertices, int pendingOperat // 4. update json plan getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph())); - // 5. try aggregate subpartition bytes + // 5. update the DistributionPattern of the upstream results consumed by the newly created + // JobVertex and aggregate subpartition bytes. for (JobVertex newVertex : newVertices) { for (JobEdge input : newVertex.getInputs()) { + tryUpdateResultInfo(input.getSourceId(), input.getDistributionPattern()); Optional.ofNullable(blockingResultInfos.get(input.getSourceId())) .ifPresent(this::maybeAggregateSubpartitionBytes); } @@ -932,7 +934,8 @@ private static void resetDynamicParallelism(Iterable vertices) { } } - private static BlockingResultInfo createFromIntermediateResult(IntermediateResult result) { + private static BlockingResultInfo createFromIntermediateResult( + IntermediateResult result, Map subpartitionBytesByPartitionIndex) { checkArgument(result != null); // Note that for dynamic graph, different partitions in the same result have the same number // of subpartitions. @@ -940,13 +943,15 @@ private static BlockingResultInfo createFromIntermediateResult(IntermediateResul return new PointwiseBlockingResultInfo( result.getId(), result.getNumberOfAssignedPartitions(), - result.getPartitions()[0].getNumberOfSubpartitions()); + result.getPartitions()[0].getNumberOfSubpartitions(), + subpartitionBytesByPartitionIndex); } else { return new AllToAllBlockingResultInfo( result.getId(), result.getNumberOfAssignedPartitions(), result.getPartitions()[0].getNumberOfSubpartitions(), - result.isBroadcast()); + result.isBroadcast(), + subpartitionBytesByPartitionIndex); } } @@ -960,6 +965,26 @@ SpeculativeExecutionHandler getSpeculativeExecutionHandler() { return speculativeExecutionHandler; } + private void tryUpdateResultInfo(IntermediateDataSetID id, DistributionPattern targetPattern) { + if (blockingResultInfos.containsKey(id)) { + BlockingResultInfo resultInfo = blockingResultInfos.get(id); + IntermediateResult result = getExecutionGraph().getAllIntermediateResults().get(id); + + if ((targetPattern == DistributionPattern.ALL_TO_ALL && resultInfo.isPointwise()) + || (targetPattern == DistributionPattern.POINTWISE + && !resultInfo.isPointwise())) { + + BlockingResultInfo newInfo = + createFromIntermediateResult( + result, resultInfo.getSubpartitionBytesByPartitionIndex()); + + blockingResultInfos.put(id, newInfo); + } else if (targetPattern == DistributionPattern.ALL_TO_ALL) { + ((AllToAllBlockingResultInfo) resultInfo).setBroadcast(result.isBroadcast()); + } + } + } + private class DefaultBatchJobRecoveryContext implements BatchJobRecoveryContext { private final FailoverStrategy restartStrategyOnResultConsumable = diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java index b6113012f00c6..2d7be76c729b3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveExecutionHandlerFactory.java @@ -21,6 +21,7 @@ import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.streaming.api.graph.ExecutionPlan; import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.util.DynamicCodeLoadingException; import java.util.concurrent.Executor; @@ -46,7 +47,8 @@ public class AdaptiveExecutionHandlerFactory { public static AdaptiveExecutionHandler create( ExecutionPlan executionPlan, ClassLoader userClassLoader, - Executor serializationExecutor) { + Executor serializationExecutor) + throws DynamicCodeLoadingException { if (executionPlan instanceof JobGraph) { return new NonAdaptiveExecutionHandler((JobGraph) executionPlan); } else { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java index b9320d77d0206..332007e112b92 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; @@ -26,7 +27,9 @@ import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; @@ -35,21 +38,35 @@ /** Information of All-To-All result. */ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { - private final boolean isBroadcast; + private boolean isBroadcast; + + private boolean everyConsumerConsumeAllSubPartitions; /** * Aggregated subpartition bytes, which aggregates the subpartition bytes with the same * subpartition index in different partitions. Note that We can aggregate them because they will * be consumed by the same downstream task. */ - @Nullable private List aggregatedSubpartitionBytes; + @Nullable protected List aggregatedSubpartitionBytes; + + @VisibleForTesting + AllToAllBlockingResultInfo( + IntermediateDataSetID resultId, + int numOfPartitions, + int numOfSubpartitions, + boolean isBroadcast, + boolean everyConsumerConsumeAllSubPartitions) { + this(resultId, numOfPartitions, numOfSubpartitions, isBroadcast, new HashMap<>()); + this.everyConsumerConsumeAllSubPartitions = everyConsumerConsumeAllSubPartitions; + } AllToAllBlockingResultInfo( IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions, - boolean isBroadcast) { - super(resultId, numOfPartitions, numOfSubpartitions); + boolean isBroadcast, + Map subpartitionBytesByPartitionIndex) { + super(resultId, numOfPartitions, numOfSubpartitions, subpartitionBytesByPartitionIndex); this.isBroadcast = isBroadcast; } @@ -58,6 +75,21 @@ public boolean isBroadcast() { return isBroadcast; } + @Override + public boolean isEveryConsumerConsumeAllSubPartitions() { + return everyConsumerConsumeAllSubPartitions; + } + + void setBroadcast(boolean broadcast) { + if (!this.isBroadcast && broadcast) { + everyConsumerConsumeAllSubPartitions = true; + } else if (this.isBroadcast && !broadcast) { + everyConsumerConsumeAllSubPartitions = false; + } + + isBroadcast = broadcast; + } + @Override public boolean isPointwise() { return false; @@ -83,7 +115,7 @@ public long getNumBytesProduced() { List bytes = Optional.ofNullable(aggregatedSubpartitionBytes) .orElse(getAggregatedSubpartitionBytesInternal()); - if (isBroadcast) { + if (isBroadcast && !everyConsumerConsumeAllSubPartitions) { return bytes.get(0); } else { return bytes.stream().reduce(0L, Long::sum); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java index 5b446e7cdc9c4..7dc229b235765 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java @@ -22,6 +22,8 @@ import org.apache.flink.runtime.executiongraph.IntermediateResultInfo; import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import java.util.Map; + /** * The blocking result info, which will be used to calculate the vertex parallelism and input infos. */ @@ -67,4 +69,12 @@ public interface BlockingResultInfo extends IntermediateResultInfo { /** Aggregates the subpartition bytes to reduce space usage. */ void aggregateSubpartitionBytes(); + + /** + * Gets subpartition bytes by partition index. + * + * @return a map with integer keys representing partition indices and long array values + * representing subpartition bytes. + */ + Map getSubpartitionBytesByPartitionIndex(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java index b365db8d0e07f..b488a48bfe12b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandler.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -27,13 +28,17 @@ import org.apache.flink.runtime.jobmaster.event.JobEvent; import org.apache.flink.streaming.api.graph.AdaptiveGraphManager; import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.util.DynamicCodeLoadingException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.Executor; +import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkArgument; @@ -52,10 +57,16 @@ public class DefaultAdaptiveExecutionHandler implements AdaptiveExecutionHandler private final AdaptiveGraphManager adaptiveGraphManager; + private final StreamGraphOptimizer streamGraphOptimizer; + public DefaultAdaptiveExecutionHandler( - ClassLoader userClassloader, StreamGraph streamGraph, Executor serializationExecutor) { + ClassLoader userClassloader, StreamGraph streamGraph, Executor serializationExecutor) + throws DynamicCodeLoadingException { this.adaptiveGraphManager = new AdaptiveGraphManager(userClassloader, streamGraph, serializationExecutor); + + this.streamGraphOptimizer = + new StreamGraphOptimizer(streamGraph.getJobConfiguration(), userClassloader); } @Override @@ -66,6 +77,7 @@ public JobGraph getJobGraph() { @Override public void handleJobEvent(JobEvent jobEvent) { try { + tryOptimizeStreamGraph(jobEvent); tryUpdateJobGraph(jobEvent); } catch (Exception e) { log.error("Failed to handle job event {}.", jobEvent, e); @@ -73,6 +85,40 @@ public void handleJobEvent(JobEvent jobEvent) { } } + private void tryOptimizeStreamGraph(JobEvent jobEvent) throws Exception { + if (jobEvent instanceof ExecutionJobVertexFinishedEvent) { + ExecutionJobVertexFinishedEvent event = (ExecutionJobVertexFinishedEvent) jobEvent; + + JobVertexID vertexId = event.getVertexId(); + Map resultInfo = event.getResultInfo(); + Map> resultInfoMap = + resultInfo.entrySet().stream() + .collect( + Collectors.toMap( + entry -> + adaptiveGraphManager.getProducerStreamNodeId( + entry.getKey()), + entry -> + new ArrayList<>( + Collections.singletonList( + entry.getValue())), + (existing, replacement) -> { + existing.addAll(replacement); + return existing; + })); + + OperatorsFinished operatorsFinished = + new OperatorsFinished( + adaptiveGraphManager.getStreamNodeIdsByJobVertexId(vertexId), + resultInfoMap); + + streamGraphOptimizer.maybeOptimizeStreamGraph( + operatorsFinished, adaptiveGraphManager.getStreamGraphContext()); + } else { + throw new IllegalArgumentException("Unsupported job event " + jobEvent); + } + } + private void tryUpdateJobGraph(JobEvent jobEvent) throws Exception { if (jobEvent instanceof ExecutionJobVertexFinishedEvent) { ExecutionJobVertexFinishedEvent event = (ExecutionJobVertexFinishedEvent) jobEvent; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java index eb78b9cd7a303..7dbe116a98889 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java @@ -467,8 +467,12 @@ private static ParallelismAndInputInfos createParallelismAndInputInfos( List executionVertexInputInfos = new ArrayList<>(); for (int i = 0; i < subpartitionRanges.size(); ++i) { IndexRange subpartitionRange; - if (resultInfo.isBroadcast()) { + if (resultInfo.isBroadcast() + && !resultInfo.isEveryConsumerConsumeAllSubPartitions()) { subpartitionRange = new IndexRange(0, 0); + } else if (resultInfo.isBroadcast()) { + subpartitionRange = + new IndexRange(0, resultInfo.getNumSubpartitions(i) - 1); } else { subpartitionRange = subpartitionRanges.get(i); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java index 87b4a2a42cba2..7685ce78315eb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java @@ -18,18 +18,31 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import static org.apache.flink.util.Preconditions.checkState; /** Information of Pointwise result. */ public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo { + + @VisibleForTesting PointwiseBlockingResultInfo( IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) { - super(resultId, numOfPartitions, numOfSubpartitions); + this(resultId, numOfPartitions, numOfSubpartitions, new HashMap<>()); + } + + PointwiseBlockingResultInfo( + IntermediateDataSetID resultId, + int numOfPartitions, + int numOfSubpartitions, + Map subpartitionBytesByPartitionIndex) { + super(resultId, numOfPartitions, numOfSubpartitions, subpartitionBytesByPartitionIndex); } @Override @@ -37,6 +50,11 @@ public boolean isBroadcast() { return false; } + @Override + public boolean isEveryConsumerConsumeAllSubPartitions() { + return false; + } + @Override public boolean isPointwise() { return true; diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java index bec248898b115..f683f586d4ebb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java @@ -107,6 +107,9 @@ public class AdaptiveGraphManager implements AdaptiveGraphGenerator { private final Map> intermediateDataSetIdToOutputEdgesMap; + private final Map consumerEdgeIdToIntermediateDataSetMap = + new HashMap<>(); + // Records the ids of stream nodes in the StreamNodeForwardGroup. // When stream edge's partitioner is modified to forward, we need get forward groups by source // and target node id. @@ -167,7 +170,8 @@ public AdaptiveGraphManager( streamGraph, steamNodeIdToForwardGroupMap, frozenNodeToStartNodeMap, - intermediateOutputsCaches); + intermediateOutputsCaches, + consumerEdgeIdToIntermediateDataSetMap); this.jobGraph = createAndInitializeJobGraph(streamGraph, streamGraph.getJobID()); @@ -382,6 +386,7 @@ private void setVertexNonChainedOutputsConfig( intermediateDataSetIdToOutputEdgesMap .computeIfAbsent(dataSet.getId(), ignored -> new ArrayList<>()) .add(edge); + consumerEdgeIdToIntermediateDataSetMap.put(edge.getEdgeId(), dataSet); // we cache the output here for downstream vertex to create jobEdge. intermediateOutputsCaches .computeIfAbsent(edge.getSourceId(), k -> new HashMap<>()) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java index 07d8631bf9235..324f3466e7cb9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/DefaultStreamGraphContext.java @@ -19,6 +19,8 @@ package org.apache.flink.streaming.api.graph; import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSet; import org.apache.flink.runtime.jobgraph.forwardgroup.StreamNodeForwardGroup; import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; @@ -36,6 +38,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -69,16 +72,21 @@ public class DefaultStreamGraphContext implements StreamGraphContext { // as they reuse some attributes. private final Map> opIntermediateOutputsCaches; + private final Map consumerEdgeIdToIntermediateDataSetMap; + public DefaultStreamGraphContext( StreamGraph streamGraph, Map steamNodeIdToForwardGroupMap, Map frozenNodeToStartNodeMap, - Map> opIntermediateOutputsCaches) { + Map> opIntermediateOutputsCaches, + Map consumerEdgeIdToIntermediateDataSetMap) { this.streamGraph = checkNotNull(streamGraph); this.steamNodeIdToForwardGroupMap = checkNotNull(steamNodeIdToForwardGroupMap); this.frozenNodeToStartNodeMap = checkNotNull(frozenNodeToStartNodeMap); this.opIntermediateOutputsCaches = checkNotNull(opIntermediateOutputsCaches); this.immutableStreamGraph = new ImmutableStreamGraph(this.streamGraph); + this.consumerEdgeIdToIntermediateDataSetMap = + checkNotNull(consumerEdgeIdToIntermediateDataSetMap); } @Override @@ -188,9 +196,9 @@ private void modifyOutputPartitioner( tryConvertForwardPartitionerAndMergeForwardGroup(targetEdge); } - // The partitioner in NonChainedOutput derived from the consumer edge, so we need to ensure - // that any modifications to the partitioner of consumer edge are synchronized with - // NonChainedOutput. + // The partitioner in NonChainedOutput and IntermediateDataSet derived from the consumer + // edge, so we need to ensure that any modifications to the partitioner of consumer edge are + // synchronized with NonChainedOutput and IntermediateDataSet. Map opIntermediateOutputs = opIntermediateOutputsCaches.get(targetEdge.getSourceId()); NonChainedOutput output = @@ -198,6 +206,23 @@ private void modifyOutputPartitioner( if (output != null) { output.setPartitioner(targetEdge.getPartitioner()); } + + Optional.ofNullable(consumerEdgeIdToIntermediateDataSetMap.get(targetEdge.getEdgeId())) + .ifPresent( + dataSet -> { + DistributionPattern distributionPattern = + targetEdge.getPartitioner().isPointwise() + ? DistributionPattern.POINTWISE + : DistributionPattern.ALL_TO_ALL; + dataSet.updateOutputPattern( + distributionPattern, + targetEdge.getPartitioner().isBroadcast(), + targetEdge + .getPartitioner() + .getClass() + .equals(ForwardPartitioner.class)); + }); + LOG.info( "The original partitioner of the edge {} is: {} , requested change to: {} , and finally modified to: {}.", targetEdge, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java index e0f4d6e2fad6a..44c6c627e2a31 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtilsTest.java @@ -19,10 +19,13 @@ package org.apache.flink.runtime.executiongraph; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import static org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForAllToAll; import static org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Test for {@link VertexInputInfoComputationUtils}. */ class VertexInputInfoComputationUtilsTest { @@ -57,34 +60,49 @@ void testComputeConsumedSubpartitionRange6to4() { assertThat(range4).isEqualTo(new IndexRange(4, 5)); } - @Test - void testComputeBroadcastConsumedSubpartitionRange() { - final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, 1, true, true); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testComputeBroadcastConsumedSubpartitionRange(boolean consumeAllSubpartitions) { + final IndexRange range1 = + computeConsumedSubpartitionRange(0, 3, 1, true, true, consumeAllSubpartitions); assertThat(range1).isEqualTo(new IndexRange(0, 0)); - final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, 1, true, true); + final IndexRange range2 = + computeConsumedSubpartitionRange(1, 3, 1, true, true, consumeAllSubpartitions); assertThat(range2).isEqualTo(new IndexRange(0, 0)); - final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, 1, true, true); + final IndexRange range3 = + computeConsumedSubpartitionRange(2, 3, 1, true, true, consumeAllSubpartitions); assertThat(range3).isEqualTo(new IndexRange(0, 0)); + + if (consumeAllSubpartitions) { + final IndexRange range4 = computeConsumedSubpartitionRange(2, 3, 2, true, true, true); + assertThat(range4).isEqualTo(new IndexRange(0, 1)); + } else { + assertThatThrownBy( + () -> + computeConsumedSubpartitionRange( + 2, 3, 2, true, true, consumeAllSubpartitions)) + .isInstanceOf(IllegalArgumentException.class); + } } @Test void testComputeConsumedSubpartitionRangeForNonDynamicGraph() { - final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, -1, false, false); + final IndexRange range1 = computeConsumedSubpartitionRange(0, 3, -1, false, false, false); assertThat(range1).isEqualTo(new IndexRange(0, 0)); - final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, -1, false, false); + final IndexRange range2 = computeConsumedSubpartitionRange(1, 3, -1, false, false, false); assertThat(range2).isEqualTo(new IndexRange(1, 1)); - final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, -1, false, false); + final IndexRange range3 = computeConsumedSubpartitionRange(2, 3, -1, false, false, false); assertThat(range3).isEqualTo(new IndexRange(2, 2)); } @Test void testComputeVertexInputInfoForAllToAllWithNonDynamicGraph() { final JobVertexInputInfo nonBroadcast = - computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, false); + computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, false, false); assertThat(nonBroadcast.getExecutionVertexInputInfos()) .containsExactlyInAnyOrder( new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), @@ -93,7 +111,7 @@ void testComputeVertexInputInfoForAllToAllWithNonDynamicGraph() { 2, new IndexRange(0, 1), new IndexRange(2, 2))); final JobVertexInputInfo broadcast = - computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, true); + computeVertexInputInfoForAllToAll(2, 3, ignored -> 3, false, true, false); assertThat(broadcast.getExecutionVertexInputInfos()) .containsExactlyInAnyOrder( new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), @@ -102,10 +120,11 @@ void testComputeVertexInputInfoForAllToAllWithNonDynamicGraph() { 2, new IndexRange(0, 1), new IndexRange(2, 2))); } - @Test - void testComputeVertexInputInfoForAllToAllWithDynamicGraph() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testComputeVertexInputInfoForAllToAllWithDynamicGraph(boolean consumeAllSubpartitions) { final JobVertexInputInfo nonBroadcast = - computeVertexInputInfoForAllToAll(2, 3, ignored -> 10, true, false); + computeVertexInputInfoForAllToAll(2, 3, ignored -> 10, true, false, false); assertThat(nonBroadcast.getExecutionVertexInputInfos()) .containsExactlyInAnyOrder( new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 2)), @@ -113,14 +132,31 @@ void testComputeVertexInputInfoForAllToAllWithDynamicGraph() { new ExecutionVertexInputInfo( 2, new IndexRange(0, 1), new IndexRange(6, 9))); - final JobVertexInputInfo broadcast = - computeVertexInputInfoForAllToAll(2, 3, ignored -> 1, true, true); - assertThat(broadcast.getExecutionVertexInputInfos()) - .containsExactlyInAnyOrder( - new ExecutionVertexInputInfo(0, new IndexRange(0, 1), new IndexRange(0, 0)), - new ExecutionVertexInputInfo(1, new IndexRange(0, 1), new IndexRange(0, 0)), - new ExecutionVertexInputInfo( - 2, new IndexRange(0, 1), new IndexRange(0, 0))); + if (consumeAllSubpartitions) { + final JobVertexInputInfo broadcast = + computeVertexInputInfoForAllToAll( + 2, 3, ignored -> 4, true, true, consumeAllSubpartitions); + assertThat(broadcast.getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo( + 0, new IndexRange(0, 1), new IndexRange(0, 3)), + new ExecutionVertexInputInfo( + 1, new IndexRange(0, 1), new IndexRange(0, 3)), + new ExecutionVertexInputInfo( + 2, new IndexRange(0, 1), new IndexRange(0, 3))); + } else { + final JobVertexInputInfo broadcast = + computeVertexInputInfoForAllToAll( + 2, 3, ignored -> 1, true, true, consumeAllSubpartitions); + assertThat(broadcast.getExecutionVertexInputInfos()) + .containsExactlyInAnyOrder( + new ExecutionVertexInputInfo( + 0, new IndexRange(0, 1), new IndexRange(0, 0)), + new ExecutionVertexInputInfo( + 1, new IndexRange(0, 1), new IndexRange(0, 0)), + new ExecutionVertexInputInfo( + 2, new IndexRange(0, 1), new IndexRange(0, 0))); + } } @Test @@ -150,7 +186,7 @@ void testComputeVertexInputInfoForPointwiseWithDynamicGraph() { private static IndexRange computeConsumedSubpartitionRange( int consumerIndex, int numConsumers, int numSubpartitions) { return computeConsumedSubpartitionRange( - consumerIndex, numConsumers, numSubpartitions, true, false); + consumerIndex, numConsumers, numSubpartitions, true, false, false); } private static IndexRange computeConsumedSubpartitionRange( @@ -158,8 +194,14 @@ private static IndexRange computeConsumedSubpartitionRange( int numConsumers, int numSubpartitions, boolean isDynamicGraph, - boolean isBroadcast) { + boolean isBroadcast, + boolean consumeAllSubpartitions) { return VertexInputInfoComputationUtils.computeConsumedSubpartitionRange( - consumerIndex, numConsumers, () -> numSubpartitions, isDynamicGraph, isBroadcast); + consumerIndex, + numConsumers, + () -> numSubpartitions, + isDynamicGraph, + isBroadcast, + consumeAllSubpartitions); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java index e298b4a065a29..768e76e482cbe 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java @@ -23,6 +23,8 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -32,18 +34,20 @@ class AllToAllBlockingResultInfoTest { @Test void testGetNumBytesProducedForNonBroadcast() { - testGetNumBytesProduced(false, 192L); + testGetNumBytesProduced(false, false, 192L); } - @Test - void testGetNumBytesProducedForBroadcast() { - testGetNumBytesProduced(true, 96L); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testGetNumBytesProducedForBroadcast(boolean consumeAllSubpartitions) { + testGetNumBytesProduced( + true, consumeAllSubpartitions, consumeAllSubpartitions ? 192L : 96L); } @Test void testGetNumBytesProducedWithIndexRange() { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false, false); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L})); @@ -57,7 +61,7 @@ void testGetNumBytesProducedWithIndexRange() { @Test void testGetAggregatedSubpartitionBytes() { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false, false); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L})); @@ -67,8 +71,9 @@ void testGetAggregatedSubpartitionBytes() { @Test void testGetBytesWithPartialPartitionInfos() { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false, false); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); + resultInfo.aggregateSubpartitionBytes(); assertThatThrownBy(resultInfo::getNumBytesProduced) .isInstanceOf(IllegalStateException.class); @@ -79,7 +84,7 @@ void testGetBytesWithPartialPartitionInfos() { @Test void testRecordPartitionInfoMultiTimes() { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false, false); ResultPartitionBytes partitionBytes1 = new ResultPartitionBytes(new long[] {32L, 64L}); ResultPartitionBytes partitionBytes2 = new ResultPartitionBytes(new long[] {64L, 128L}); @@ -115,9 +120,11 @@ void testRecordPartitionInfoMultiTimes() { assertThat(resultInfo.getNumOfRecordedPartitions()).isZero(); } - private void testGetNumBytesProduced(boolean isBroadcast, long expectedBytes) { + private void testGetNumBytesProduced( + boolean isBroadcast, boolean consumeAllSubpartitions, long expectedBytes) { AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, isBroadcast); + new AllToAllBlockingResultInfo( + new IntermediateDataSetID(), 2, 2, isBroadcast, consumeAllSubpartitions); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 32L})); resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {64L, 64L})); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java index 6049006d24569..946c558f517c6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultAdaptiveExecutionHandlerTest.java @@ -24,16 +24,27 @@ import org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexFinishedEvent; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.graph.StreamGraphContext; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.util.DynamicCodeLoadingException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Random; +import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicInteger; @@ -47,7 +58,7 @@ class DefaultAdaptiveExecutionHandlerTest { TestingUtils.defaultExecutorExtension(); @Test - void testGetJobGraph() { + void testGetJobGraph() throws DynamicCodeLoadingException { JobGraph jobGraph = createAdaptiveExecutionHandler().getJobGraph(); assertThat(jobGraph).isNotNull(); @@ -56,7 +67,7 @@ void testGetJobGraph() { } @Test - void testHandleJobEvent() { + void testHandleJobEvent() throws DynamicCodeLoadingException { List newAddedJobVertices = new ArrayList<>(); AtomicInteger pendingOperators = new AtomicInteger(); @@ -95,7 +106,75 @@ void testHandleJobEvent() { } @Test - void testGetInitialParallelismAndNotifyJobVertexParallelismDecided() { + void testOptimizeStreamGraph() throws DynamicCodeLoadingException { + StreamGraph streamGraph = createStreamGraph(); + Iterator streamNodeIterator = streamGraph.getStreamNodes().iterator(); + StreamNode source = streamNodeIterator.next(); + StreamNode map = streamNodeIterator.next(); + + assertThat(source.getOutEdges().get(0).getPartitioner()) + .isInstanceOf(ForwardPartitioner.class); + assertThat(map.getOutEdges().get(0).getPartitioner()) + .isInstanceOf(RescalePartitioner.class); + + streamGraph + .getJobConfiguration() + .set( + StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY, + Collections.singletonList( + TestingStreamGraphOptimizerStrategy.class.getName())); + TestingStreamGraphOptimizerStrategy.convertToReBalanceEdgeIds.add( + source.getOutEdges().get(0).getEdgeId()); + TestingStreamGraphOptimizerStrategy.convertToReBalanceEdgeIds.add( + map.getOutEdges().get(0).getEdgeId()); + + DefaultAdaptiveExecutionHandler handler = + createAdaptiveExecutionHandler( + (newVertices, pendingOperatorsCount) -> {}, streamGraph); + + JobGraph jobGraph = handler.getJobGraph(); + JobVertex sourceVertex = jobGraph.getVertices().iterator().next(); + + // notify Source node is finished + ExecutionJobVertexFinishedEvent event1 = + new ExecutionJobVertexFinishedEvent(sourceVertex.getID(), Collections.emptyMap()); + handler.handleJobEvent(event1); + + // verify that the source output edge is not updated because the original edge is forward. + assertThat(sourceVertex.getProducedDataSets().get(0).getConsumers()).hasSize(1); + assertThat( + sourceVertex + .getProducedDataSets() + .get(0) + .getConsumers() + .get(0) + .getShipStrategyName()) + .isEqualToIgnoringCase("forward"); + + // notify Map node is finished + Iterator jobVertexIterator = jobGraph.getVertices().iterator(); + jobVertexIterator.next(); + JobVertex mapVertex = jobVertexIterator.next(); + + ExecutionJobVertexFinishedEvent event2 = + new ExecutionJobVertexFinishedEvent(mapVertex.getID(), Collections.emptyMap()); + handler.handleJobEvent(event2); + + // verify that the map output edge is updated to reBalance. + assertThat(mapVertex.getProducedDataSets().get(0).getConsumers()).hasSize(1); + assertThat( + mapVertex + .getProducedDataSets() + .get(0) + .getConsumers() + .get(0) + .getShipStrategyName()) + .isEqualToIgnoringCase("rebalance"); + } + + @Test + void testGetInitialParallelismAndNotifyJobVertexParallelismDecided() + throws DynamicCodeLoadingException { StreamGraph streamGraph = createStreamGraph(); DefaultAdaptiveExecutionHandler handler = createAdaptiveExecutionHandler( @@ -123,7 +202,8 @@ void testGetInitialParallelismAndNotifyJobVertexParallelismDecided() { assertThat(handler.getInitialParallelism(map.getID())).isEqualTo(parallelism); } - private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler() { + private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler() + throws DynamicCodeLoadingException { return createAdaptiveExecutionHandler( (newVertices, pendingOperatorsCount) -> {}, createStreamGraph()); } @@ -159,7 +239,8 @@ private StreamGraph createStreamGraph() { * and a given {@link StreamGraph}. */ private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler( - JobGraphUpdateListener listener, StreamGraph streamGraph) { + JobGraphUpdateListener listener, StreamGraph streamGraph) + throws DynamicCodeLoadingException { DefaultAdaptiveExecutionHandler handler = new DefaultAdaptiveExecutionHandler( getClass().getClassLoader(), streamGraph, EXECUTOR_RESOURCE.getExecutor()); @@ -167,4 +248,34 @@ private DefaultAdaptiveExecutionHandler createAdaptiveExecutionHandler( return handler; } + + public static final class TestingStreamGraphOptimizerStrategy + implements StreamGraphOptimizationStrategy { + + private static final Set convertToReBalanceEdgeIds = new HashSet<>(); + + @Override + public boolean maybeOptimizeStreamGraph( + OperatorsFinished operatorsFinished, StreamGraphContext context) { + List finishedStreamNodeIds = operatorsFinished.getFinishedStreamNodeIds(); + List requestInfos = new ArrayList<>(); + for (Integer finishedStreamNodeId : finishedStreamNodeIds) { + for (ImmutableStreamEdge outEdge : + context.getStreamGraph() + .getStreamNode(finishedStreamNodeId) + .getOutEdges()) { + if (convertToReBalanceEdgeIds.contains(outEdge.getEdgeId())) { + StreamEdgeUpdateRequestInfo requestInfo = + new StreamEdgeUpdateRequestInfo( + outEdge.getEdgeId(), + outEdge.getSourceId(), + outEdge.getTargetId()); + requestInfo.outputPartitioner(new RebalancePartitioner<>()); + requestInfos.add(requestInfo); + } + } + } + return context.modifyStreamEdge(requestInfos); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java index d1b24d862f4cd..fc4519cd55363 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java @@ -32,11 +32,14 @@ import org.apache.flink.shaded.guava32.com.google.common.collect.Iterables; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -206,13 +209,19 @@ void testFallBackToEvenlyDistributeSubpartitions() { new IndexRange(8, 9))); } - @Test - void testAllEdgesAllToAllAndOneIsBroadcast() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testAllEdgesAllToAllAndOneIsBroadcast(boolean consumeAllSubpartitions) { AllToAllBlockingResultInfo resultInfo1 = createAllToAllBlockingResultInfo( new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 17L}); - AllToAllBlockingResultInfo resultInfo2 = - createAllToAllBlockingResultInfo(new long[] {10L}, true); + AllToAllBlockingResultInfo resultInfo2; + if (consumeAllSubpartitions) { + // create three subpartitions + resultInfo2 = createAllToAllBlockingResultInfo(new long[] {0L, 5L, 10L}, true, true); + } else { + resultInfo2 = createAllToAllBlockingResultInfo(new long[] {10L}, true, false); + } ParallelismAndInputInfos parallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos( @@ -224,17 +233,38 @@ void testAllEdgesAllToAllAndOneIsBroadcast() { checkAllToAllJobVertexInputInfo( parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()), Arrays.asList(new IndexRange(0, 4), new IndexRange(5, 8), new IndexRange(9, 9))); - checkAllToAllJobVertexInputInfo( - parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()), - Arrays.asList(new IndexRange(0, 0), new IndexRange(0, 0), new IndexRange(0, 0))); + + if (consumeAllSubpartitions) { + checkAllToAllJobVertexInputInfo( + parallelismAndInputInfos + .getJobVertexInputInfos() + .get(resultInfo2.getResultId()), + Arrays.asList( + new IndexRange(0, 2), new IndexRange(0, 2), new IndexRange(0, 2))); + } else { + checkAllToAllJobVertexInputInfo( + parallelismAndInputInfos + .getJobVertexInputInfos() + .get(resultInfo2.getResultId()), + Arrays.asList( + new IndexRange(0, 0), new IndexRange(0, 0), new IndexRange(0, 0))); + } } - @Test - void testAllEdgesBroadcast() { - AllToAllBlockingResultInfo resultInfo1 = - createAllToAllBlockingResultInfo(new long[] {10L}, true); - AllToAllBlockingResultInfo resultInfo2 = - createAllToAllBlockingResultInfo(new long[] {10L}, true); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testAllEdgesBroadcast(boolean consumeAllSubpartitions) { + AllToAllBlockingResultInfo resultInfo1; + AllToAllBlockingResultInfo resultInfo2; + if (consumeAllSubpartitions) { + // create three subpartitions + resultInfo1 = createAllToAllBlockingResultInfo(new long[] {0L, 5L, 10L}, true, true); + resultInfo2 = createAllToAllBlockingResultInfo(new long[] {0L, 5L, 10L}, true, true); + } else { + resultInfo1 = createAllToAllBlockingResultInfo(new long[] {10L}, true, false); + resultInfo2 = createAllToAllBlockingResultInfo(new long[] {10L}, true, false); + } + ParallelismAndInputInfos parallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos( 1, 10, 60L, Arrays.asList(resultInfo1, resultInfo2)); @@ -242,12 +272,17 @@ void testAllEdgesBroadcast() { assertThat(parallelismAndInputInfos.getParallelism()).isOne(); assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2); + List expectedSubpartitionRanges = + consumeAllSubpartitions + ? Collections.singletonList(new IndexRange(0, 2)) + : Collections.singletonList(new IndexRange(0, 0)); + checkAllToAllJobVertexInputInfo( parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()), - Collections.singletonList(new IndexRange(0, 0))); + expectedSubpartitionRanges); checkAllToAllJobVertexInputInfo( parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()), - Collections.singletonList(new IndexRange(0, 0))); + expectedSubpartitionRanges); } @Test @@ -359,7 +394,8 @@ void testEvenlyDistributeDataWithMaxSubpartitionLimitation() { long[] subpartitionBytes = new long[1024]; Arrays.fill(subpartitionBytes, 1L); AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 1024, 1024, false); + new AllToAllBlockingResultInfo( + new IntermediateDataSetID(), 1024, 1024, false, false); for (int i = 0; i < 1024; ++i) { resultInfo.recordPartitionInfo(i, new ResultPartitionBytes(subpartitionBytes)); } @@ -507,11 +543,13 @@ private static ParallelismAndInputInfos createDeciderAndDecideParallelismAndInpu private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo( long[] aggregatedSubpartitionBytes) { - return createAllToAllBlockingResultInfo(aggregatedSubpartitionBytes, false); + return createAllToAllBlockingResultInfo(aggregatedSubpartitionBytes, false, false); } private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo( - long[] aggregatedSubpartitionBytes, boolean isBroadcast) { + long[] aggregatedSubpartitionBytes, + boolean isBroadcast, + boolean consumeAllSubpartitions) { // For simplicity, we configure only one partition here, so the aggregatedSubpartitionBytes // is equivalent to the subpartition bytes of partition0 AllToAllBlockingResultInfo resultInfo = @@ -519,7 +557,8 @@ private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo( new IntermediateDataSetID(), 1, aggregatedSubpartitionBytes.length, - isBroadcast); + isBroadcast, + consumeAllSubpartitions); resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(aggregatedSubpartitionBytes)); return resultInfo; } @@ -578,6 +617,11 @@ public boolean isBroadcast() { return isBroadcast; } + @Override + public boolean isEveryConsumerConsumeAllSubPartitions() { + return false; + } + @Override public boolean isPointwise() { return false; @@ -609,6 +653,14 @@ public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partiti @Override public void resetPartitionInfo(int partitionIndex) {} + + @Override + public void aggregateSubpartitionBytes() {} + + @Override + public Map getSubpartitionBytesByPartitionIndex() { + return Map.of(); + } } private static BlockingResultInfo createFromBroadcastResult(long producedBytes) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java index b0a0440269c0d..0792e6cc6d878 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/StreamGraphOptimizerTest.java @@ -23,7 +23,6 @@ import org.apache.flink.streaming.api.graph.util.ImmutableStreamGraph; import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; - import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; diff --git a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java index 6654c1eb17e19..cb2d0d57a26e5 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java @@ -31,8 +31,18 @@ import org.apache.flink.configuration.RestOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler; +import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished; +import org.apache.flink.runtime.scheduler.adaptivebatch.StreamGraphOptimizationStrategy; import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.graph.StreamGraphContext; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge; +import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo; +import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -40,8 +50,10 @@ import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.function.Function; import java.util.stream.Collectors; @@ -130,6 +142,72 @@ void testDifferentConsumerParallelism() throws Exception { env.execute(); } + @Test + void testAdaptiveOptimizeStreamGraph() throws Exception { + final Configuration configuration = createConfiguration(); + configuration.set( + StreamGraphOptimizationStrategy.STREAM_GRAPH_OPTIMIZATION_STRATEGY, + List.of(TestingStreamGraphOptimizerStrategy.class.getName())); + final StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment(configuration); + env.setRuntimeMode(RuntimeExecutionMode.BATCH); + env.disableOperatorChaining(); + env.setParallelism(8); + + SingleOutputStreamOperator source1 = + env.fromSequence(0, NUMBERS_TO_PRODUCE - 1) + .setParallelism(SOURCE_PARALLELISM_1) + .name("source1"); + SingleOutputStreamOperator source2 = + env.fromSequence(0, NUMBERS_TO_PRODUCE - 1) + .setParallelism(SOURCE_PARALLELISM_2) + .name("source2"); + + source1.keyBy(i -> i % SOURCE_PARALLELISM_1) + .map(i -> i) + .name("map1") + .rebalance() + .union(source2) + .rebalance() + .map(new NumberCounter()) + .name("map2"); + + StreamGraph streamGraph = env.getStreamGraph(); + StreamNode sourceNode1 = + streamGraph.getStreamNodes().stream() + .filter(node -> node.getOperatorName().contains("source1")) + .findFirst() + .get(); + StreamNode mapNode1 = + streamGraph.getStreamNodes().stream() + .filter(node -> node.getOperatorName().contains("map1")) + .findFirst() + .get(); + + TestingStreamGraphOptimizerStrategy.convertToRescaleEdgeIds.add( + sourceNode1.getOutEdges().get(0).getEdgeId()); + TestingStreamGraphOptimizerStrategy.convertToBroadcastEdgeIds.add( + mapNode1.getOutEdges().get(0).getEdgeId()); + + env.execute(streamGraph); + + Map numberCountResultMap = + numberCountResults.stream() + .flatMap(map -> map.entrySet().stream()) + .collect( + Collectors.toMap( + Map.Entry::getKey, Map.Entry::getValue, Long::sum)); + + // Because the parallelism of map2 is automatically determined to be 2, the result will have + // three times the produced numbers. One part comes from source2, while the other two parts + // come from the broadcast results of source1. + Map expectedResult = + LongStream.range(0, NUMBERS_TO_PRODUCE) + .boxed() + .collect(Collectors.toMap(Function.identity(), i -> 3L)); + assertThat(numberCountResultMap).isEqualTo(expectedResult); + } + private void testSchedulingBase(Boolean useSourceParallelismInference) throws Exception { executeJob(useSourceParallelismInference); @@ -257,4 +335,38 @@ public int inferParallelism(Context context) { return expectedParallelism; } } + + public static final class TestingStreamGraphOptimizerStrategy + implements StreamGraphOptimizationStrategy { + + private static final Set convertToBroadcastEdgeIds = new HashSet<>(); + private static final Set convertToRescaleEdgeIds = new HashSet<>(); + + @Override + public boolean maybeOptimizeStreamGraph( + OperatorsFinished operatorsFinished, StreamGraphContext context) throws Exception { + List finishedStreamNodeIds = operatorsFinished.getFinishedStreamNodeIds(); + List requestInfos = new ArrayList<>(); + for (Integer finishedStreamNodeId : finishedStreamNodeIds) { + for (ImmutableStreamEdge outEdge : + context.getStreamGraph() + .getStreamNode(finishedStreamNodeId) + .getOutEdges()) { + StreamEdgeUpdateRequestInfo requestInfo = + new StreamEdgeUpdateRequestInfo( + outEdge.getEdgeId(), + outEdge.getSourceId(), + outEdge.getTargetId()); + if (convertToBroadcastEdgeIds.contains(outEdge.getEdgeId())) { + requestInfo.outputPartitioner(new BroadcastPartitioner<>()); + requestInfos.add(requestInfo); + } else if (convertToRescaleEdgeIds.contains(outEdge.getEdgeId())) { + requestInfo.outputPartitioner(new RescalePartitioner<>()); + requestInfos.add(requestInfo); + } + } + } + return context.modifyStreamEdge(requestInfos); + } + } }