Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-36067][runtime] Support optimize stream graph based on input info. #25790

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class IntermediateResult {
private final int numParallelProducers;

private final ExecutionPlanSchedulingContext executionPlanSchedulingContext;
private final boolean produceBroadcastResult;

private int partitionsAssigned;

Expand Down Expand Up @@ -102,6 +103,8 @@ public IntermediateResult(
this.shuffleDescriptorCache = new HashMap<>();

this.executionPlanSchedulingContext = checkNotNull(executionPlanSchedulingContext);

this.produceBroadcastResult = intermediateDataSet.isBroadcast();
}

public boolean areAllConsumerVerticesCreated() {
Expand Down Expand Up @@ -207,6 +210,10 @@ public boolean isForward() {
return intermediateDataSet.isForward();
}

public boolean isEveryConsumerConsumeAllSubPartitions() {
Copy link
Contributor

@davidradl davidradl Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by why this method is checking broadcasts - how does this relate to downstream consumers being able to consume the subpartitions

return !produceBroadcastResult && intermediateDataSet.isBroadcast();
}

public int getConnectionIndex() {
return connectionIndex;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ public interface IntermediateResultInfo {
*/
boolean isBroadcast();

/**
* Indicates whether every downstream consumer needs to consume all produced sub-partitions.
Copy link
Contributor

@davidradl davidradl Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this method seems to be relate to broadcasts - it would be good to explain how.

*
* @return true if every downstream consumer needs to consume all produced sub-partitions, false
* otherwise.
*/
boolean isEveryConsumerConsumeAllSubPartitions();

/**
* Whether it is a pointwise result.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public static Map<IntermediateDataSetID, JobVertexInputInfo> computeVertexInputI
parallelism,
input::getNumSubpartitions,
isDynamicGraph,
input.isBroadcast()));
input.isBroadcast(),
input.isEveryConsumerConsumeAllSubPartitions()));
}
}

Expand Down Expand Up @@ -124,6 +125,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
1,
() -> numOfSubpartitionsRetriever.apply(start),
isDynamicGraph,
false,
false);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(index, partitionRange, subpartitionRange));
Expand All @@ -145,6 +147,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
numConsumers,
() -> numOfSubpartitionsRetriever.apply(finalPartitionNum),
isDynamicGraph,
false,
false);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
Expand All @@ -165,14 +168,16 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
* @param numOfSubpartitionsRetriever a retriever to get the number of subpartitions
* @param isDynamicGraph whether is dynamic graph
* @param isBroadcast whether the edge is broadcast
* @param consumeAllSubpartitions whether the edge should consume all subpartitions
* @return the computed {@link JobVertexInputInfo}
*/
static JobVertexInputInfo computeVertexInputInfoForAllToAll(
int sourceCount,
int targetCount,
Function<Integer, Integer> numOfSubpartitionsRetriever,
boolean isDynamicGraph,
boolean isBroadcast) {
boolean isBroadcast,
boolean consumeAllSubpartitions) {
final List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>();
IndexRange partitionRange = new IndexRange(0, sourceCount - 1);
for (int i = 0; i < targetCount; ++i) {
Expand All @@ -182,7 +187,8 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll(
targetCount,
() -> numOfSubpartitionsRetriever.apply(0),
isDynamicGraph,
isBroadcast);
isBroadcast,
consumeAllSubpartitions);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
}
Expand All @@ -199,6 +205,7 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll(
* @param numOfSubpartitionsSupplier a supplier to get the number of subpartitions
* @param isDynamicGraph whether is dynamic graph
* @param isBroadcast whether the edge is broadcast
* @param consumeAllSubpartitions whether the edge should consume all subpartitions
* @return the computed subpartition range
*/
@VisibleForTesting
Expand All @@ -207,16 +214,21 @@ static IndexRange computeConsumedSubpartitionRange(
int numConsumers,
Supplier<Integer> numOfSubpartitionsSupplier,
boolean isDynamicGraph,
boolean isBroadcast) {
boolean isBroadcast,
boolean consumeAllSubpartitions) {
int consumerIndex = consumerSubtaskIndex % numConsumers;
if (!isDynamicGraph) {
return new IndexRange(consumerIndex, consumerIndex);
} else {
int numSubpartitions = numOfSubpartitionsSupplier.get();
if (isBroadcast) {
// broadcast results have only one subpartition, and be consumed multiple times.
checkArgument(numSubpartitions == 1);
return new IndexRange(0, 0);
if (consumeAllSubpartitions) {
return new IndexRange(0, numSubpartitions - 1);
} else {
// broadcast results have only one subpartition, and be consumed multiple times.
checkArgument(numSubpartitions == 1);
return new IndexRange(0, 0);
}
} else {
checkArgument(consumerIndex < numConsumers);
checkArgument(numConsumers <= numSubpartitions);
Expand Down Expand Up @@ -246,6 +258,11 @@ public boolean isBroadcast() {
return intermediateResult.isBroadcast();
}

@Override
public boolean isEveryConsumerConsumeAllSubPartitions() {
return intermediateResult.isEveryConsumerConsumeAllSubPartitions();
}

@Override
public boolean isPointwise() {
return intermediateResult.getConsumingDistributionPattern()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@ public void configure(
}
}

public void updateOutputPattern(
DistributionPattern distributionPattern, boolean isBroadcast, boolean isForward) {
checkState(consumers.isEmpty(), "The output job edges have already been added.");
checkState(
numJobEdgesToCreate == 1,
"Modification is not allowed when the subscribing output is reused.");

this.distributionPattern = distributionPattern;
this.isBroadcast = isBroadcast;
this.isForward = isForward;
}

public void increaseNumJobEdgesToCreate() {
this.numJobEdgesToCreate++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ abstract class AbstractBlockingResultInfo implements BlockingResultInfo {
protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex;

AbstractBlockingResultInfo(
IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) {
IntermediateDataSetID resultId,
int numOfPartitions,
int numOfSubpartitions,
Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
this.resultId = checkNotNull(resultId);
this.numOfPartitions = numOfPartitions;
this.numOfSubpartitions = numOfSubpartitions;
this.subpartitionBytesByPartitionIndex = new HashMap<>();
this.subpartitionBytesByPartitionIndex = subpartitionBytesByPartitionIndex;
}

@Override
Expand All @@ -72,4 +75,9 @@ public void resetPartitionInfo(int partitionIndex) {
int getNumOfRecordedPartitions() {
return subpartitionBytesByPartitionIndex.size();
}

@Override
public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() {
return new HashMap<>(subpartitionBytesByPartitionIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ public void onNewJobVerticesAdded(List<JobVertex> newVertices, int pendingOperat

// 4. update json plan
getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph()));

// 5. update the DistributionPattern of the upstream results consumed by the newly created
// JobVertex and aggregate subpartition bytes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aggregate subpartition bytes. -> aggregate subpartition bytes when possible.

I would find is useful to describe in more detail was aggregate subpartition bytes means and why we can perform this optimization in some circumstances.

for (JobVertex newVertex : newVertices) {
for (JobEdge input : newVertex.getInputs()) {
tryUpdateResultInfo(input.getSourceId(), input.getDistributionPattern());
Optional.ofNullable(blockingResultInfos.get(input.getSourceId()))
.ifPresent(this::maybeAggregateSubpartitionBytes);
}
}
}

@Override
Expand Down Expand Up @@ -482,15 +492,29 @@ private void updateResultPartitionBytesMetrics(
result.getId(),
(ignored, resultInfo) -> {
if (resultInfo == null) {
resultInfo = createFromIntermediateResult(result);
resultInfo =
createFromIntermediateResult(result, new HashMap<>());
}
resultInfo.recordPartitionInfo(
partitionId.getPartitionNumber(), partitionBytes);
maybeAggregateSubpartitionBytes(resultInfo);
return resultInfo;
});
});
}

private void maybeAggregateSubpartitionBytes(BlockingResultInfo resultInfo) {
IntermediateResult intermediateResult =
getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId());

if (intermediateResult.areAllConsumerVerticesCreated()
&& intermediateResult.getConsumerVertices().stream()
.map(this::getExecutionJobVertex)
.allMatch(ExecutionJobVertex::isInitialized)) {
resultInfo.aggregateSubpartitionBytes();
}
}

@Override
public void allocateSlotsAndDeploy(final List<ExecutionVertexID> verticesToDeploy) {
List<ExecutionVertex> executionVertices =
Expand Down Expand Up @@ -657,6 +681,7 @@ public void initializeVerticesIfPossible() {
parallelismAndInputInfos.getJobVertexInputInfos(),
createTimestamp);
newlyInitializedJobVertices.add(jobVertex);
consumedResultsInfo.get().forEach(this::maybeAggregateSubpartitionBytes);
}
}
}
Expand Down Expand Up @@ -909,21 +934,24 @@ private static void resetDynamicParallelism(Iterable<JobVertex> vertices) {
}
}

private static BlockingResultInfo createFromIntermediateResult(IntermediateResult result) {
private static BlockingResultInfo createFromIntermediateResult(
IntermediateResult result, Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
checkArgument(result != null);
// Note that for dynamic graph, different partitions in the same result have the same number
// of subpartitions.
if (result.getConsumingDistributionPattern() == DistributionPattern.POINTWISE) {
return new PointwiseBlockingResultInfo(
result.getId(),
result.getNumberOfAssignedPartitions(),
result.getPartitions()[0].getNumberOfSubpartitions());
result.getPartitions()[0].getNumberOfSubpartitions(),
subpartitionBytesByPartitionIndex);
} else {
return new AllToAllBlockingResultInfo(
result.getId(),
result.getNumberOfAssignedPartitions(),
result.getPartitions()[0].getNumberOfSubpartitions(),
result.isBroadcast());
result.isBroadcast(),
subpartitionBytesByPartitionIndex);
}
}

Expand All @@ -937,6 +965,26 @@ SpeculativeExecutionHandler getSpeculativeExecutionHandler() {
return speculativeExecutionHandler;
}

private void tryUpdateResultInfo(IntermediateDataSetID id, DistributionPattern targetPattern) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some comments detailing the algorithm and its benefits.

if (blockingResultInfos.containsKey(id)) {
BlockingResultInfo resultInfo = blockingResultInfos.get(id);
IntermediateResult result = getExecutionGraph().getAllIntermediateResults().get(id);

if ((targetPattern == DistributionPattern.ALL_TO_ALL && resultInfo.isPointwise())
|| (targetPattern == DistributionPattern.POINTWISE
&& !resultInfo.isPointwise())) {

BlockingResultInfo newInfo =
createFromIntermediateResult(
result, resultInfo.getSubpartitionBytesByPartitionIndex());

blockingResultInfos.put(id, newInfo);
} else if (targetPattern == DistributionPattern.ALL_TO_ALL) {
((AllToAllBlockingResultInfo) resultInfo).setBroadcast(result.isBroadcast());
}
}
}

private class DefaultBatchJobRecoveryContext implements BatchJobRecoveryContext {

private final FailoverStrategy restartStrategyOnResultConsumable =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.streaming.api.graph.ExecutionPlan;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.util.DynamicCodeLoadingException;

import java.util.concurrent.Executor;

Expand All @@ -46,7 +47,8 @@ public class AdaptiveExecutionHandlerFactory {
public static AdaptiveExecutionHandler create(
ExecutionPlan executionPlan,
ClassLoader userClassLoader,
Executor serializationExecutor) {
Executor serializationExecutor)
throws DynamicCodeLoadingException {
if (executionPlan instanceof JobGraph) {
return new NonAdaptiveExecutionHandler((JobGraph) executionPlan);
} else {
Expand Down
Loading