Skip to content

Commit

Permalink
[FLINK-36067][runtime] Manually trigger aggregate all-to-all result p…
Browse files Browse the repository at this point in the history
…artition info when all consumers created and initialized.
  • Loading branch information
JunRuiLee committed Dec 12, 2024
1 parent 221f712 commit 02cdbf3
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ public void onNewJobVerticesAdded(List<JobVertex> newVertices, int pendingOperat

// 4. update json plan
getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph()));

// 5. try aggregate subpartition bytes
for (JobVertex newVertex : newVertices) {
for (JobEdge input : newVertex.getInputs()) {
Optional.ofNullable(blockingResultInfos.get(input.getSourceId()))
.ifPresent(this::maybeAggregateSubpartitionBytes);
}
}
}

@Override
Expand Down Expand Up @@ -482,15 +490,29 @@ private void updateResultPartitionBytesMetrics(
result.getId(),
(ignored, resultInfo) -> {
if (resultInfo == null) {
resultInfo = createFromIntermediateResult(result);
resultInfo =
createFromIntermediateResult(result, new HashMap<>());
}
resultInfo.recordPartitionInfo(
partitionId.getPartitionNumber(), partitionBytes);
maybeAggregateSubpartitionBytes(resultInfo);
return resultInfo;
});
});
}

private void maybeAggregateSubpartitionBytes(BlockingResultInfo resultInfo) {
IntermediateResult intermediateResult =
getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId());

if (intermediateResult.areAllConsumerVerticesCreated()
&& intermediateResult.getConsumerVertices().stream()
.map(this::getExecutionJobVertex)
.allMatch(ExecutionJobVertex::isInitialized)) {
resultInfo.aggregateSubpartitionBytes();
}
}

@Override
public void allocateSlotsAndDeploy(final List<ExecutionVertexID> verticesToDeploy) {
List<ExecutionVertex> executionVertices =
Expand Down Expand Up @@ -657,6 +679,7 @@ public void initializeVerticesIfPossible() {
parallelismAndInputInfos.getJobVertexInputInfos(),
createTimestamp);
newlyInitializedJobVertices.add(jobVertex);
consumedResultsInfo.get().forEach(this::maybeAggregateSubpartitionBytes);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.apache.flink.util.Preconditions.checkState;
Expand Down Expand Up @@ -74,18 +75,28 @@ public int getNumSubpartitions(int partitionIndex) {

@Override
public long getNumBytesProduced() {
checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready");
checkState(
aggregatedSubpartitionBytes != null
|| subpartitionBytesByPartitionIndex.size() == numOfPartitions,
"Not all partition infos are ready");

List<Long> bytes =
Optional.ofNullable(aggregatedSubpartitionBytes)
.orElse(getAggregatedSubpartitionBytesInternal());
if (isBroadcast) {
return aggregatedSubpartitionBytes.get(0);
return bytes.get(0);
} else {
return aggregatedSubpartitionBytes.stream().reduce(0L, Long::sum);
return bytes.stream().reduce(0L, Long::sum);
}
}

@Override
public long getNumBytesProduced(
IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready");
List<Long> bytes =
Optional.ofNullable(aggregatedSubpartitionBytes)
.orElse(getAggregatedSubpartitionBytesInternal());

checkState(
partitionIndexRange.getStartIndex() == 0
&& partitionIndexRange.getEndIndex() == numOfPartitions - 1,
Expand All @@ -96,7 +107,7 @@ public long getNumBytesProduced(
"Subpartition index %s is out of range.",
subpartitionIndexRange.getEndIndex());

return aggregatedSubpartitionBytes
return bytes
.subList(
subpartitionIndexRange.getStartIndex(),
subpartitionIndexRange.getEndIndex() + 1)
Expand All @@ -106,31 +117,34 @@ public long getNumBytesProduced(

@Override
public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) {
// Once all partitions are finished, we can convert the subpartition bytes to aggregated
// value to reduce the space usage, because the distribution of source splits does not
// affect the distribution of data consumed by downstream tasks of ALL_TO_ALL edges(Hashing
// or Rebalancing, we do not consider rare cases such as custom partitions here).
if (aggregatedSubpartitionBytes == null) {
super.recordPartitionInfo(partitionIndex, partitionBytes);
}
}

if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) {
long[] aggregatedBytes = new long[numOfSubpartitions];
subpartitionBytesByPartitionIndex
.values()
.forEach(
subpartitionBytes -> {
checkState(subpartitionBytes.length == numOfSubpartitions);
for (int i = 0; i < subpartitionBytes.length; ++i) {
aggregatedBytes[i] += subpartitionBytes[i];
}
});
this.aggregatedSubpartitionBytes =
Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList());
this.subpartitionBytesByPartitionIndex.clear();
}
@Override
public void aggregateSubpartitionBytes() {
if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) {
this.aggregatedSubpartitionBytes = getAggregatedSubpartitionBytesInternal();
this.subpartitionBytesByPartitionIndex.clear();
}
}

protected List<Long> getAggregatedSubpartitionBytesInternal() {
long[] aggregatedBytes = new long[numOfSubpartitions];
subpartitionBytesByPartitionIndex
.values()
.forEach(
subpartitionBytes -> {
checkState(subpartitionBytes.length == numOfSubpartitions);
for (int i = 0; i < subpartitionBytes.length; ++i) {
aggregatedBytes[i] += subpartitionBytes[i];
}
});

return Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList());
}

@Override
public void resetPartitionInfo(int partitionIndex) {
if (aggregatedSubpartitionBytes == null) {
Expand All @@ -139,7 +153,14 @@ public void resetPartitionInfo(int partitionIndex) {
}

public List<Long> getAggregatedSubpartitionBytes() {
checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready");
return Collections.unmodifiableList(aggregatedSubpartitionBytes);
checkState(
aggregatedSubpartitionBytes != null
|| subpartitionBytesByPartitionIndex.size() == numOfPartitions,
"Not all partition infos are ready");
if (aggregatedSubpartitionBytes == null) {
return getAggregatedSubpartitionBytesInternal();
} else {
return Collections.unmodifiableList(aggregatedSubpartitionBytes);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,7 @@ public interface BlockingResultInfo extends IntermediateResultInfo {
* @param partitionIndex the intermediate result partition index
*/
void resetPartitionInfo(int partitionIndex);

/** Aggregates the subpartition bytes to reduce space usage. */
void aggregateSubpartitionBytes();
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,9 @@ public long getNumBytesProduced(
}
return inputBytes;
}

@Override
public void aggregateSubpartitionBytes() {
// do nothing because pointWise result should not be aggregated
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ void testRecordPartitionInfoMultiTimes() {
// The result info should be (partitionBytes2 + partitionBytes3)
assertThat(resultInfo.getNumBytesProduced()).isEqualTo(576L);
assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(192L, 384L);
// The raw info should not be clear
assertThat(resultInfo.getNumOfRecordedPartitions()).isGreaterThan(0);
resultInfo.aggregateSubpartitionBytes();
// The raw info should be clear
assertThat(resultInfo.getNumOfRecordedPartitions()).isZero();

Expand Down

0 comments on commit 02cdbf3

Please sign in to comment.