Skip to content

Commit

Permalink
[FLINK-36067][runtime] Support optimize stream graph based on input i…
Browse files Browse the repository at this point in the history
…nfo.
  • Loading branch information
JunRuiLee committed Dec 12, 2024
1 parent 02cdbf3 commit 9b2cc2a
Show file tree
Hide file tree
Showing 20 changed files with 628 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class IntermediateResult {
private final int numParallelProducers;

private final ExecutionPlanSchedulingContext executionPlanSchedulingContext;
private final boolean produceBroadcastResult;

private int partitionsAssigned;

Expand Down Expand Up @@ -102,6 +103,8 @@ public IntermediateResult(
this.shuffleDescriptorCache = new HashMap<>();

this.executionPlanSchedulingContext = checkNotNull(executionPlanSchedulingContext);

this.produceBroadcastResult = intermediateDataSet.isBroadcast();
}

public boolean areAllConsumerVerticesCreated() {
Expand Down Expand Up @@ -207,6 +210,10 @@ public boolean isForward() {
return intermediateDataSet.isForward();
}

public boolean isEveryConsumerConsumeAllSubPartitions() {
return !produceBroadcastResult && intermediateDataSet.isBroadcast();
}

public int getConnectionIndex() {
return connectionIndex;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ public interface IntermediateResultInfo {
*/
boolean isBroadcast();

/**
* Indicates whether every downstream consumer needs to consume all produced sub-partitions.
*
* @return true if every downstream consumer needs to consume all produced sub-partitions, false
* otherwise.
*/
boolean isEveryConsumerConsumeAllSubPartitions();

/**
* Whether it is a pointwise result.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public static Map<IntermediateDataSetID, JobVertexInputInfo> computeVertexInputI
parallelism,
input::getNumSubpartitions,
isDynamicGraph,
input.isBroadcast()));
input.isBroadcast(),
input.isEveryConsumerConsumeAllSubPartitions()));
}
}

Expand Down Expand Up @@ -124,6 +125,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
1,
() -> numOfSubpartitionsRetriever.apply(start),
isDynamicGraph,
false,
false);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(index, partitionRange, subpartitionRange));
Expand All @@ -145,6 +147,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
numConsumers,
() -> numOfSubpartitionsRetriever.apply(finalPartitionNum),
isDynamicGraph,
false,
false);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
Expand All @@ -165,14 +168,16 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
* @param numOfSubpartitionsRetriever a retriever to get the number of subpartitions
* @param isDynamicGraph whether is dynamic graph
* @param isBroadcast whether the edge is broadcast
* @param consumeAllSubpartitions whether the edge should consume all subpartitions
* @return the computed {@link JobVertexInputInfo}
*/
static JobVertexInputInfo computeVertexInputInfoForAllToAll(
int sourceCount,
int targetCount,
Function<Integer, Integer> numOfSubpartitionsRetriever,
boolean isDynamicGraph,
boolean isBroadcast) {
boolean isBroadcast,
boolean consumeAllSubpartitions) {
final List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>();
IndexRange partitionRange = new IndexRange(0, sourceCount - 1);
for (int i = 0; i < targetCount; ++i) {
Expand All @@ -182,7 +187,8 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll(
targetCount,
() -> numOfSubpartitionsRetriever.apply(0),
isDynamicGraph,
isBroadcast);
isBroadcast,
consumeAllSubpartitions);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
}
Expand All @@ -199,6 +205,7 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll(
* @param numOfSubpartitionsSupplier a supplier to get the number of subpartitions
* @param isDynamicGraph whether is dynamic graph
* @param isBroadcast whether the edge is broadcast
* @param consumeAllSubpartitions whether the edge should consume all subpartitions
* @return the computed subpartition range
*/
@VisibleForTesting
Expand All @@ -207,16 +214,21 @@ static IndexRange computeConsumedSubpartitionRange(
int numConsumers,
Supplier<Integer> numOfSubpartitionsSupplier,
boolean isDynamicGraph,
boolean isBroadcast) {
boolean isBroadcast,
boolean consumeAllSubpartitions) {
int consumerIndex = consumerSubtaskIndex % numConsumers;
if (!isDynamicGraph) {
return new IndexRange(consumerIndex, consumerIndex);
} else {
int numSubpartitions = numOfSubpartitionsSupplier.get();
if (isBroadcast) {
// broadcast results have only one subpartition, and be consumed multiple times.
checkArgument(numSubpartitions == 1);
return new IndexRange(0, 0);
if (consumeAllSubpartitions) {
return new IndexRange(0, numSubpartitions - 1);
} else {
// broadcast results have only one subpartition, and be consumed multiple times.
checkArgument(numSubpartitions == 1);
return new IndexRange(0, 0);
}
} else {
checkArgument(consumerIndex < numConsumers);
checkArgument(numConsumers <= numSubpartitions);
Expand Down Expand Up @@ -246,6 +258,11 @@ public boolean isBroadcast() {
return intermediateResult.isBroadcast();
}

@Override
public boolean isEveryConsumerConsumeAllSubPartitions() {
return intermediateResult.isEveryConsumerConsumeAllSubPartitions();
}

@Override
public boolean isPointwise() {
return intermediateResult.getConsumingDistributionPattern()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@ public void configure(
}
}

public void updateOutputPattern(
DistributionPattern distributionPattern, boolean isBroadcast, boolean isForward) {
checkState(consumers.isEmpty(), "The output job edges have already been added.");
checkState(
numJobEdgesToCreate == 1,
"Modification is not allowed when the subscribing output is reused.");

this.distributionPattern = distributionPattern;
this.isBroadcast = isBroadcast;
this.isForward = isForward;
}

public void increaseNumJobEdgesToCreate() {
this.numJobEdgesToCreate++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ abstract class AbstractBlockingResultInfo implements BlockingResultInfo {
protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex;

AbstractBlockingResultInfo(
IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) {
IntermediateDataSetID resultId,
int numOfPartitions,
int numOfSubpartitions,
Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
this.resultId = checkNotNull(resultId);
this.numOfPartitions = numOfPartitions;
this.numOfSubpartitions = numOfSubpartitions;
this.subpartitionBytesByPartitionIndex = new HashMap<>();
this.subpartitionBytesByPartitionIndex = subpartitionBytesByPartitionIndex;
}

@Override
Expand All @@ -72,4 +75,9 @@ public void resetPartitionInfo(int partitionIndex) {
int getNumOfRecordedPartitions() {
return subpartitionBytesByPartitionIndex.size();
}

@Override
public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() {
return new HashMap<>(subpartitionBytesByPartitionIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,11 @@ public void onNewJobVerticesAdded(List<JobVertex> newVertices, int pendingOperat
// 4. update json plan
getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph()));

// 5. try aggregate subpartition bytes
// 5. update the DistributionPattern of the upstream results consumed by the newly created
// JobVertex and aggregate subpartition bytes.
for (JobVertex newVertex : newVertices) {
for (JobEdge input : newVertex.getInputs()) {
tryUpdateResultInfo(input.getSourceId(), input.getDistributionPattern());
Optional.ofNullable(blockingResultInfos.get(input.getSourceId()))
.ifPresent(this::maybeAggregateSubpartitionBytes);
}
Expand Down Expand Up @@ -932,21 +934,24 @@ private static void resetDynamicParallelism(Iterable<JobVertex> vertices) {
}
}

private static BlockingResultInfo createFromIntermediateResult(IntermediateResult result) {
private static BlockingResultInfo createFromIntermediateResult(
IntermediateResult result, Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
checkArgument(result != null);
// Note that for dynamic graph, different partitions in the same result have the same number
// of subpartitions.
if (result.getConsumingDistributionPattern() == DistributionPattern.POINTWISE) {
return new PointwiseBlockingResultInfo(
result.getId(),
result.getNumberOfAssignedPartitions(),
result.getPartitions()[0].getNumberOfSubpartitions());
result.getPartitions()[0].getNumberOfSubpartitions(),
subpartitionBytesByPartitionIndex);
} else {
return new AllToAllBlockingResultInfo(
result.getId(),
result.getNumberOfAssignedPartitions(),
result.getPartitions()[0].getNumberOfSubpartitions(),
result.isBroadcast());
result.isBroadcast(),
subpartitionBytesByPartitionIndex);
}
}

Expand All @@ -960,6 +965,26 @@ SpeculativeExecutionHandler getSpeculativeExecutionHandler() {
return speculativeExecutionHandler;
}

private void tryUpdateResultInfo(IntermediateDataSetID id, DistributionPattern targetPattern) {
if (blockingResultInfos.containsKey(id)) {
BlockingResultInfo resultInfo = blockingResultInfos.get(id);
IntermediateResult result = getExecutionGraph().getAllIntermediateResults().get(id);

if ((targetPattern == DistributionPattern.ALL_TO_ALL && resultInfo.isPointwise())
|| (targetPattern == DistributionPattern.POINTWISE
&& !resultInfo.isPointwise())) {

BlockingResultInfo newInfo =
createFromIntermediateResult(
result, resultInfo.getSubpartitionBytesByPartitionIndex());

blockingResultInfos.put(id, newInfo);
} else if (targetPattern == DistributionPattern.ALL_TO_ALL) {
((AllToAllBlockingResultInfo) resultInfo).setBroadcast(result.isBroadcast());
}
}
}

private class DefaultBatchJobRecoveryContext implements BatchJobRecoveryContext {

private final FailoverStrategy restartStrategyOnResultConsumable =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.streaming.api.graph.ExecutionPlan;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.util.DynamicCodeLoadingException;

import java.util.concurrent.Executor;

Expand All @@ -46,7 +47,8 @@ public class AdaptiveExecutionHandlerFactory {
public static AdaptiveExecutionHandler create(
ExecutionPlan executionPlan,
ClassLoader userClassLoader,
Executor serializationExecutor) {
Executor serializationExecutor)
throws DynamicCodeLoadingException {
if (executionPlan instanceof JobGraph) {
return new NonAdaptiveExecutionHandler((JobGraph) executionPlan);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.runtime.scheduler.adaptivebatch;

import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
Expand All @@ -26,7 +27,9 @@

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

Expand All @@ -35,21 +38,35 @@
/** Information of All-To-All result. */
public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo {

private final boolean isBroadcast;
private boolean isBroadcast;

private boolean everyConsumerConsumeAllSubPartitions;

/**
* Aggregated subpartition bytes, which aggregates the subpartition bytes with the same
* subpartition index in different partitions. Note that We can aggregate them because they will
* be consumed by the same downstream task.
*/
@Nullable private List<Long> aggregatedSubpartitionBytes;
@Nullable protected List<Long> aggregatedSubpartitionBytes;

@VisibleForTesting
AllToAllBlockingResultInfo(
IntermediateDataSetID resultId,
int numOfPartitions,
int numOfSubpartitions,
boolean isBroadcast,
boolean everyConsumerConsumeAllSubPartitions) {
this(resultId, numOfPartitions, numOfSubpartitions, isBroadcast, new HashMap<>());
this.everyConsumerConsumeAllSubPartitions = everyConsumerConsumeAllSubPartitions;
}

AllToAllBlockingResultInfo(
IntermediateDataSetID resultId,
int numOfPartitions,
int numOfSubpartitions,
boolean isBroadcast) {
super(resultId, numOfPartitions, numOfSubpartitions);
boolean isBroadcast,
Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
super(resultId, numOfPartitions, numOfSubpartitions, subpartitionBytesByPartitionIndex);
this.isBroadcast = isBroadcast;
}

Expand All @@ -58,6 +75,21 @@ public boolean isBroadcast() {
return isBroadcast;
}

@Override
public boolean isEveryConsumerConsumeAllSubPartitions() {
return everyConsumerConsumeAllSubPartitions;
}

void setBroadcast(boolean broadcast) {
if (!this.isBroadcast && broadcast) {
everyConsumerConsumeAllSubPartitions = true;
} else if (this.isBroadcast && !broadcast) {
everyConsumerConsumeAllSubPartitions = false;
}

isBroadcast = broadcast;
}

@Override
public boolean isPointwise() {
return false;
Expand All @@ -83,7 +115,7 @@ public long getNumBytesProduced() {
List<Long> bytes =
Optional.ofNullable(aggregatedSubpartitionBytes)
.orElse(getAggregatedSubpartitionBytesInternal());
if (isBroadcast) {
if (isBroadcast && !everyConsumerConsumeAllSubPartitions) {
return bytes.get(0);
} else {
return bytes.stream().reduce(0L, Long::sum);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.flink.runtime.executiongraph.IntermediateResultInfo;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;

import java.util.Map;

/**
* The blocking result info, which will be used to calculate the vertex parallelism and input infos.
*/
Expand Down Expand Up @@ -67,4 +69,12 @@ public interface BlockingResultInfo extends IntermediateResultInfo {

/** Aggregates the subpartition bytes to reduce space usage. */
void aggregateSubpartitionBytes();

/**
* Gets subpartition bytes by partition index.
*
* @return a map with integer keys representing partition indices and long array values
* representing subpartition bytes.
*/
Map<Integer, long[]> getSubpartitionBytesByPartitionIndex();
}
Loading

0 comments on commit 9b2cc2a

Please sign in to comment.