diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 85ccea4915e..cb536c0e6e9 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -129,7 +129,15 @@ public ShuffleHandle registerShuffle( if (fallbackPolicyRunner.applyAllFallbackPolicy( lifecycleManager, dependency.partitioner().numPartitions())) { - logger.warn("Fallback to SortShuffleManager!"); + if (conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + logger.error( + "DRA is enabled but we fallback to vanilla Spark SortShuffleManager for " + + "shuffle: {} due to fallback policy. It may cause block can not found when reducer " + + "task fetch data.", + shuffleId); + } else { + logger.warn("Fallback to vanilla Spark SortShuffleManager for shuffle: {}", shuffleId); + } sortShuffleIds.add(shuffleId); return sortShuffleManager().registerShuffle(shuffleId, dependency); } else { diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 5ce89190738..36ee314489b 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -97,11 +97,17 @@ class CelebornShuffleReader[K, C]( }).flatMap( serializerInstance.deserializeStream(_).asKeyValueIterator) + val iterWithUpdatedRecordsRead = + if (GlutenColumnarBatchSerdeHelper.isGlutenSerde(serializerInstance.getClass.getName)) { + GlutenColumnarBatchSerdeHelper.withUpdatedRecordsRead(recordIter, metrics) + } else { + recordIter.map { record => + metrics.incRecordsRead(1) + record + } + } val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map { record => - metrics.incRecordsRead(1) - record - }, + iterWithUpdatedRecordsRead, context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/GlutenColumnarBatchSerdeHelper.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/GlutenColumnarBatchSerdeHelper.scala new file mode 100644 index 00000000000..259bb954d21 --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/GlutenColumnarBatchSerdeHelper.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * A helper class to be compatible with Gluten Celeborn. + */ +object GlutenColumnarBatchSerdeHelper { + + def isGlutenSerde(serdeName: String): Boolean = { + // scalastyle:off + // see Gluten + // https://github.com/oap-project/gluten/blob/main/gluten-celeborn/src/main/scala/org/apache/spark/shuffle/CelebornColumnarBatchSerializer.scala + // scalastyle:on + "org.apache.spark.shuffle.CelebornColumnarBatchSerializer".equals(serdeName) + } + + def withUpdatedRecordsRead( + input: Iterator[(Any, Any)], + metrics: ShuffleReadMetricsReporter): Iterator[(Any, Any)] = { + input.map { record => + metrics.incRecordsRead(record._2.asInstanceOf[ColumnarBatch].numRows()) + record + } + } +} diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index 637aada7fa5..d42219773c0 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -47,7 +47,7 @@ public class DfsPartitionReader implements PartitionReader { private static Logger logger = LoggerFactory.getLogger(DfsPartitionReader.class); PartitionLocation location; - private final int shuffleChunkSize; + private final long shuffleChunkSize; private final int fetchMaxReqsInFlight; private final LinkedBlockingQueue results; private final AtomicReference exception = new AtomicReference<>(); @@ -66,7 +66,7 @@ public DfsPartitionReader( int startMapIndex, int endMapIndex) throws IOException { - shuffleChunkSize = (int) conf.shuffleChunkSize(); + shuffleChunkSize = conf.dfsReadChunkSize(); fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight(); results = new LinkedBlockingQueue<>(); diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index c56e5ae0ae0..1d165ae1dd4 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -151,6 +151,7 @@ message PbHeartbeatFromWorker { string requestId = 8; map userResourceConsumption = 9; map estimatedAppDiskUsage = 10; + bool highWorkload = 11; } message PbHeartbeatFromWorkerResponse { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index be762058e7e..c39e8c789b5 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -654,6 +654,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def workerPushMaxComponents: Int = get(WORKER_PUSH_COMPOSITEBUFFER_MAXCOMPONENTS) def workerFetchHeartbeatEnabled: Boolean = get(WORKER_FETCH_HEARTBEAT_ENABLED) def workerPartitionSplitEnabled: Boolean = get(WORKER_PARTITION_SPLIT_ENABLED) + def workerActiveConnectionMax: Option[Long] = get(WORKER_ACTIVE_CONNECTION_MAX) // ////////////////////////////////////////////////////// // Metrics System // @@ -804,6 +805,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def shuffleExpiredCheckIntervalMs: Long = get(SHUFFLE_EXPIRED_CHECK_INTERVAL) def shuffleManagerPort: Int = get(CLIENT_SHUFFLE_MANAGER_PORT) def shuffleChunkSize: Long = get(SHUFFLE_CHUNK_SIZE) + def dfsReadChunkSize: Long = get(CLIENT_FETCH_DFS_READ_CHUNK_SIZE) def shufflePartitionSplitMode: PartitionSplitMode = PartitionSplitMode.valueOf(get(SHUFFLE_PARTITION_SPLIT_MODE)) def shufflePartitionSplitThreshold: Long = get(SHUFFLE_PARTITION_SPLIT_THRESHOLD) @@ -1945,13 +1947,21 @@ object CelebornConf extends Logging { val SHUFFLE_CHUNK_SIZE: ConfigEntry[Long] = buildConf("celeborn.shuffle.chunk.size") - .categories("client", "worker") + .categories("worker") .version("0.2.0") .doc("Max chunk size of reducer's merged shuffle data. For example, if a reducer's " + "shuffle data is 128M and the data will need 16 fetch chunk requests to fetch.") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("8m") + val CLIENT_FETCH_DFS_READ_CHUNK_SIZE: ConfigEntry[Long] = + buildConf("celeborn.client.fetch.dfsReadChunkSize") + .categories("client") + .version("0.3.1") + .doc("Max chunk size for DfsPartitionReader.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("8m") + val WORKER_PARTITION_SPLIT_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.worker.shuffle.partitionSplit.enabled") .withAlternative("celeborn.worker.partition.split.enabled") @@ -2685,6 +2695,16 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val WORKER_ACTIVE_CONNECTION_MAX: OptionalConfigEntry[Long] = + buildConf("celeborn.worker.activeConnection.max") + .categories("worker") + .doc("If the number of active connections on a worker exceeds this configuration value, " + + "the worker will be marked as high-load in the heartbeat report, " + + "and the master will not include that node in the response of RequestSlots.") + .version("0.3.1") + .longConf + .createOptional + val APPLICATION_HEARTBEAT_INTERVAL: ConfigEntry[Long] = buildConf("celeborn.client.application.heartbeatInterval") .withAlternative("celeborn.application.heartbeatInterval") diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 458a7dbb863..733ca2d4a28 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -113,6 +113,7 @@ object ControlMessages extends Logging { userResourceConsumption: util.Map[UserIdentifier, ResourceConsumption], activeShuffleKeys: util.Set[String], estimatedAppDiskUsage: util.HashMap[String, java.lang.Long], + highWorkload: Boolean, override var requestId: String = ZERO_UUID) extends MasterRequestMessage case class HeartbeatFromWorkerResponse( @@ -446,6 +447,7 @@ object ControlMessages extends Logging { userResourceConsumption, activeShuffleKeys, estimatedAppDiskUsage, + highWorkload, requestId) => val pbDisks = disks.map(PbSerDeUtils.toPbDiskInfo).asJava val pbUserResourceConsumption = @@ -460,6 +462,7 @@ object ControlMessages extends Logging { .setReplicatePort(replicatePort) .addAllActiveShuffleKeys(activeShuffleKeys) .putAllEstimatedAppDiskUsage(estimatedAppDiskUsage) + .setHighWorkload(highWorkload) .setRequestId(requestId) .build().toByteArray new TransportMessage(MessageType.HEARTBEAT_FROM_WORKER, payload) @@ -824,6 +827,7 @@ object ControlMessages extends Logging { userResourceConsumption, activeShuffleKeys, estimatedAppDiskUsage, + pbHeartbeatFromWorker.getHighWorkload, pbHeartbeatFromWorker.getRequestId) case HEARTBEAT_FROM_WORKER_RESPONSE_VALUE => diff --git a/dev/dependencies.sh b/dev/dependencies.sh new file mode 100755 index 00000000000..90b94b5ab0c --- /dev/null +++ b/dev/dependencies.sh @@ -0,0 +1,216 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -ex + +# Explicitly set locale in order to make `sort` output consistent across machines. +# See https://stackoverflow.com/questions/28881 for more details. +export LC_ALL=C + +PWD=$(cd "$(dirname "$0")"/.. || exit; pwd) + +MVN="${PWD}/build/mvn" +SBT="${PWD}/build/sbt" + +SBT_ENABLED="false" +REPLACE="false" +CHECK="false" +MODULE="" + +DEP_PR="" +DEP="" + +function mvn_build_classpath() { + $MVN -P$MODULE clean install -DskipTests -am -pl $MVN_MODULES + $MVN -P$MODULE dependency:build-classpath -am -pl $MVN_MODULES | \ + grep -v "INFO\|WARN" | \ + # This will skip the first two lines + tail -n +3 | \ + tr ":" "\n" | \ + awk -F '/' '{ + artifact_id=$(NF-2); + version=$(NF-1); + jar_name=$NF; + classifier_start_index=length(artifact_id"-"version"-") + 1; + classifier_end_index=index(jar_name, ".jar") - 1; + classifier=substr(jar_name, classifier_start_index, classifier_end_index - classifier_start_index + 1); + print artifact_id"/"version"/"classifier"/"jar_name + }' | grep -v "celeborn" | sort -u >> "${DEP_PR}" +} + +function sbt_build_client_classpath() { + $SBT -P$MODULE "error; clean; export ${SBT_PROJECT}/Runtime/externalDependencyClasspath" | \ + tail -1 | \ + tr ":" "\n" | \ + awk -F '/' '{ + artifact_id=$(NF-2); + version=$(NF-1); + jar_name=$NF; + classifier_start_index=length(artifact_id"-"version"-") + 1; + classifier_end_index=index(jar_name, ".jar") - 1; + classifier=substr(jar_name, classifier_start_index, classifier_end_index - classifier_start_index + 1); + print artifact_id"/"version"/"classifier"/"jar_name + }' | sort -u >> "${DEP_PR}" +} + +function sbt_build_server_classpath() { + $SBT "error; clean; export externalDependencyClasspath" | \ + awk '/externalDependencyClasspath/ { found=1 } found { print }' | \ + awk 'NR % 2 == 0' | \ + # This will skip the last line + sed '$d' | \ + tr ":" "\n" | \ + awk -F '/' '{ + artifact_id=$(NF-2); + version=$(NF-1); + jar_name=$NF; + classifier_start_index=length(artifact_id"-"version"-") + 1; + classifier_end_index=index(jar_name, ".jar") - 1; + classifier=substr(jar_name, classifier_start_index, classifier_end_index - classifier_start_index + 1); + print artifact_id"/"version"/"classifier"/"jar_name + }' | sort -u >> "${DEP_PR}" +} + +function check_diff() { + set +e + the_diff=$(diff "${DEP}" "${DEP_PR}") + set -e + rm -rf "${DEP_PR}" + if [[ -n "${the_diff}" ]]; then + echo "Dependency List Changed Detected: " + echo "${the_diff}" + echo "To update the dependency file, run './dev/dependencies.sh --replace'." + exit 1 + fi +} + +function exit_with_usage() { + echo "Usage: $0 [--sbt | --mvn] [--replace] [--check] [--module MODULE_VALUE] [--help]" + exit 1 +} + +# Parse arguments +while (( "$#" )); do + case $1 in + --sbt) + SBT_ENABLED="true" + ;; + --mvn) + SBT_ENABLED="false" + ;; + --replace) + REPLACE="true" + ;; + --check) + CHECK="true" + ;; + --module) # Support for --module parameter + shift + MODULE="$1" + ;; + --help) + exit_with_usage + ;; + --*) + echo "Error: $1 is not supported" + exit_with_usage + ;; + *) + echo "Error: $1 is not supported" + exit_with_usage + ;; + esac + shift +done + +case "$MODULE" in + "spark-2.4") + MVN_MODULES="client-spark/spark-2" + SBT_PROJECT="celeborn-client-spark-2" + ;; + "spark-3.0" | "spark-3.1" | "spark-3.2" | "spark-3.3" | "spark-3.4") + MVN_MODULES="client-spark/spark-3" + SBT_PROJECT="celeborn-client-spark-3" + ;; + "flink-1.14") + MVN_MODULES="client-flink/flink-1.14" + SBT_PROJECT="celeborn-client-flink-1_14" + ;; + "flink-1.15") + MVN_MODULES="client-flink/flink-1.15" + SBT_PROJECT="celeborn-client-flink-1_15" + ;; + "flink-1.17") + MVN_MODULES="client-flink/flink-1.17" + SBT_PROJECT="celeborn-client-flink-1_17" + ;; + *) + MVN_MODULES="worker,master" + ;; +esac + + +if [ "$MODULE" = "server" ]; then + DEP="${PWD}/dev/deps/dependencies-server" + DEP_PR="${PWD}/dev/deps/dependencies-server.tmp" +else + DEP="${PWD}/dev/deps/dependencies-client-$MODULE" + DEP_PR="${PWD}/dev/deps/dependencies-client-$MODULE.tmp" +fi + +rm -rf "${DEP_PR}" +cat >"${DEP_PR}"< Personal Access Tokens for +# your own token management. +JIRA_ACCESS_TOKEN = os.environ.get("JIRA_ACCESS_TOKEN") # OAuth key used for issuing requests against the GitHub API. If this is not defined, then requests # will be unauthenticated. You should only need to configure this if you find yourself regularly # exceeding your IP's unauthenticated request rate limit. You can create an OAuth key at @@ -238,9 +243,12 @@ def cherry_pick(pr_num, merge_hash, default_branch): def resolve_jira_issue(merge_branches, comment, default_jira_id=""): - asf_jira = jira.client.JIRA( - {"server": JIRA_API_BASE}, basic_auth=(ASF_USERNAME, ASF_PASSWORD) - ) + jira_server = {"server": JIRA_API_BASE} + + if JIRA_ACCESS_TOKEN is not None: + asf_jira = jira.client.JIRA(jira_server, token_auth=JIRA_ACCESS_TOKEN) + else: + asf_jira = jira.client.JIRA(jira_server, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) jira_id = input("Enter a JIRA id [%s]: " % default_jira_id) if jira_id == "": @@ -385,7 +393,22 @@ def choose_jira_assignee(issue, asf_jira): except BaseException: # assume it's a user id, and try to assign (might fail, we just prompt again) assignee = asf_jira.user(raw_assignee) - assign_issue(asf_jira, issue.key, assignee.name) + try: + assign_issue(asf_jira, issue.key, assignee.name) + except Exception as e: + if ( + e.__class__.__name__ == "JIRAError" + and ("'%s' cannot be assigned" % assignee.name) + in getattr(e, "response").text + ): + continue_maybe( + "User '%s' cannot be assigned, add to contributors role and try again?" + % assignee.name + ) + grant_contributor_role(assignee.name, asf_jira) + assign_issue(asf_jira, issue.key, assignee.name) + else: + raise e return assignee except KeyboardInterrupt: raise @@ -393,6 +416,11 @@ def choose_jira_assignee(issue, asf_jira): traceback.print_exc() print("Error assigning JIRA, try again (or leave blank and fix manually)") +def grant_contributor_role(user: str, asf_jira): + role = asf_jira.project_role("CELEBORN", 10010) + role.add_user(user) + print("Successfully added user '%s' to contributors role" % user) + def assign_issue(client, issue: int, assignee: str) -> bool: """ Assign an issue to a user, which is a shorthand for jira.client.JIRA.assign_issue. @@ -474,8 +502,9 @@ def main(): original_head = get_current_ref() # Check this up front to avoid failing the JIRA update at the very end - if not ASF_USERNAME or not ASF_PASSWORD: - continue_maybe("The env-vars ASF_USERNAME and/or ASF_PASSWORD are not set. Continue?") + if not JIRA_ACCESS_TOKEN and (not ASF_USERNAME or not ASF_PASSWORD): + msg = "The env-vars JIRA_ACCESS_TOKEN or ASF_USERNAME/ASF_PASSWORD are not set. Continue?" + continue_maybe(msg) branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = list(filter(lambda x: x.startswith("branch-"), [x["name"] for x in branches])) @@ -575,7 +604,7 @@ def main(): merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: - if ASF_USERNAME and ASF_PASSWORD: + if JIRA_ACCESS_TOKEN or (ASF_USERNAME and ASF_PASSWORD): continue_maybe("Would you like to update an associated JIRA?") jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % ( pr_num, @@ -584,7 +613,7 @@ def main(): ) resolve_jira_issues(title, merged_refs, jira_comment) else: - print("ASF_USERNAME and ASF_PASSWORD not set") + print("Neither JIRA_ACCESS_TOKEN nor ASF_USERNAME/ASF_PASSWORD are set.") print("Exiting without trying to close the associated JIRA.") else: print("Could not find jira-python library. Run 'sudo pip3 install jira' to install.") diff --git a/docs/configuration/client.md b/docs/configuration/client.md index e61f3207d2d..79d2d4816b5 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -24,6 +24,7 @@ license: | | celeborn.client.commitFiles.ignoreExcludedWorker | false | When true, LifecycleManager will skip workers which are in the excluded list. | 0.3.0 | | celeborn.client.excludePeerWorkerOnFailure.enabled | true | When true, Celeborn will exclude partition's peer worker on failure when push data to replica failed. | 0.3.0 | | celeborn.client.excludedWorker.expireTimeout | 180s | Timeout time for LifecycleManager to clear reserved excluded worker. Default to be 1.5 * `celeborn.master.heartbeat.worker.timeout`to cover worker heartbeat timeout check period | 0.3.0 | +| celeborn.client.fetch.dfsReadChunkSize | 8m | Max chunk size for DfsPartitionReader. | 0.3.1 | | celeborn.client.fetch.excludeWorkerOnFailure.enabled | false | Whether to enable shuffle client-side fetch exclude workers on failure. | 0.3.0 | | celeborn.client.fetch.excludedWorker.expireTimeout | <value of celeborn.client.excludedWorker.expireTimeout> | ShuffleClient is a static object, it will be used in the whole lifecycle of Executor,We give a expire time for excluded workers to avoid a transient worker issues. | 0.3.0 | | celeborn.client.fetch.maxReqsInFlight | 3 | Amount of in-flight chunk fetch request. | 0.3.0 | @@ -98,6 +99,5 @@ license: | | celeborn.client.spark.shuffle.forceFallback.numPartitionsThreshold | 2147483647 | Celeborn will only accept shuffle of partition number lower than this configuration value. | 0.3.0 | | celeborn.client.spark.shuffle.writer | HASH | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. | 0.3.0 | | celeborn.master.endpoints | <localhost>:9097 | Endpoints of master nodes for celeborn client to connect, allowed pattern is: `:[,:]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If the port is omitted, 9097 will be used. | 0.2.0 | -| celeborn.shuffle.chunk.size | 8m | Max chunk size of reducer's merged shuffle data. For example, if a reducer's shuffle data is 128M and the data will need 16 fetch chunk requests to fetch. | 0.2.0 | | celeborn.storage.hdfs.dir | <undefined> | HDFS base directory for Celeborn to store shuffle data. | 0.2.0 | diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index 994838b3fef..8b05d5ad784 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -24,6 +24,7 @@ license: | | celeborn.shuffle.chunk.size | 8m | Max chunk size of reducer's merged shuffle data. For example, if a reducer's shuffle data is 128M and the data will need 16 fetch chunk requests to fetch. | 0.2.0 | | celeborn.storage.activeTypes | HDD,SSD | Enabled storage levels. Available options: HDD,SSD,HDFS. | 0.3.0 | | celeborn.storage.hdfs.dir | <undefined> | HDFS base directory for Celeborn to store shuffle data. | 0.2.0 | +| celeborn.worker.activeConnection.max | <undefined> | If the number of active connections on a worker exceeds this configuration value, the worker will be marked as high-load in the heartbeat report, and the master will not include that node in the response of RequestSlots. | 0.3.1 | | celeborn.worker.bufferStream.threadsPerMountpoint | 8 | Threads count for read buffer per mount point. | 0.3.0 | | celeborn.worker.closeIdleConnections | false | Whether worker will close idle connections. | 0.2.0 | | celeborn.worker.commitFiles.threads | 32 | Thread number of worker to commit shuffle data files asynchronously. It's recommended to set at least `128` when `HDFS` is enabled in `celeborn.storage.activeTypes`. | 0.3.0 | diff --git a/docs/migration.md b/docs/migration.md index 96632b804e3..a76c343014f 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -21,14 +21,18 @@ license: | # Migration Guide -## Upgrading from 0.3 to 0.4 +## Upgrading from 0.3.1 to 0.4 - Since 0.4.0, Celeborn won't be compatible with Celeborn client that versions below 0.3.0. Note that: It's strongly recommended to use the same version of Client and Celeborn Master/Worker in production. - Since 0.4.0, Celeborn won't support `org.apache.spark.shuffle.celeborn.RssShuffleManager`. -## Upgrading from 0.2 to 0.3 +## Upgrading from 0.3.0 to 0.3.1 + +- Since 0.3.1, Celeborn changed the default value of `celeborn.worker.directMemoryRatioToResume` from `0.5` to `0.7`. + +## Upgrading from 0.2 to 0.3.0 - Celeborn 0.2 Client is compatible with 0.3 Master/Server, it allows to upgrade Master/Worker first then Client. Note that: It's strongly recommended to use the same version of Client and Celeborn Master/Worker in production. diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java index 8df57405fa5..8d7b9376a7b 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java @@ -143,7 +143,8 @@ public void updateWorkerHeartbeatMeta( Map disks, Map userResourceConsumption, Map estimatedAppDiskUsage, - long time) { + long time, + boolean highWorkload) { WorkerInfo worker = new WorkerInfo( host, rpcPort, pushPort, fetchPort, replicatePort, disks, userResourceConsumption); @@ -161,10 +162,11 @@ public void updateWorkerHeartbeatMeta( } appDiskUsageMetric.update(estimatedAppDiskUsage); // If using HDFSONLY mode, workers with empty disks should not be put into excluded worker list. - if (!excludedWorkers.contains(worker) && (disks.isEmpty() && !conf.hasHDFSStorage())) { + if (!excludedWorkers.contains(worker) + && ((disks.isEmpty() && !conf.hasHDFSStorage()) || highWorkload)) { LOG.debug("Worker: {} num total slots is 0, add to excluded list", worker); excludedWorkers.add(worker); - } else if (availableSlots.get() > 0) { + } else if ((availableSlots.get() > 0 || conf.hasHDFSStorage()) && !highWorkload) { // only unblack if numSlots larger than 0 excludedWorkers.remove(worker); } diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java index 6c4c65a73db..a34cb445d53 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java @@ -55,6 +55,7 @@ void handleWorkerHeartbeat( Map userResourceConsumption, Map estimatedAppDiskUsage, long time, + boolean highWorkload, String requestId); void handleRegisterWorker( diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java index 3d12db8b405..15c0c6d6d7b 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java @@ -92,6 +92,7 @@ public void handleWorkerHeartbeat( Map userResourceConsumption, Map estimatedAppDiskUsage, long time, + boolean highWorkload, String requestId) { updateWorkerHeartbeatMeta( host, @@ -102,7 +103,8 @@ public void handleWorkerHeartbeat( disks, userResourceConsumption, estimatedAppDiskUsage, - time); + time, + highWorkload); } @Override diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java index f7a10013c03..181c6e4874b 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java @@ -197,6 +197,7 @@ public void handleWorkerHeartbeat( Map userResourceConsumption, Map estimatedAppDiskUsage, long time, + boolean highWorkload, String requestId) { try { ratisServer.submitRequest( @@ -215,6 +216,7 @@ public void handleWorkerHeartbeat( MetaUtil.toPbUserResourceConsumption(userResourceConsumption)) .putAllEstimatedAppDiskUsage(estimatedAppDiskUsage) .setTime(time) + .setHighWorkload(highWorkload) .build()) .build()); } catch (CelebornRuntimeException e) { diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java index d6ab8309cba..27ba6d8828b 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java @@ -163,6 +163,7 @@ public ResourceResponse handleWriteRequest(ResourceProtos.ResourceRequest reques estimatedAppDiskUsage.putAll( request.getWorkerHeartbeatRequest().getEstimatedAppDiskUsageMap()); replicatePort = request.getWorkerHeartbeatRequest().getReplicatePort(); + boolean highWorkload = request.getWorkerHeartbeatRequest().getHighWorkload(); LOG.debug( "Handle worker heartbeat for {} {} {} {} {} {} {}", host, @@ -182,7 +183,8 @@ public ResourceResponse handleWriteRequest(ResourceProtos.ResourceRequest reques diskInfos, userResourceConsumption, estimatedAppDiskUsage, - time); + time, + highWorkload); break; case RegisterWorker: diff --git a/master/src/main/proto/Resource.proto b/master/src/main/proto/Resource.proto index 8f6f62f258b..a6fb5d17327 100644 --- a/master/src/main/proto/Resource.proto +++ b/master/src/main/proto/Resource.proto @@ -121,6 +121,7 @@ message WorkerHeartbeatRequest { required int64 time = 7; map userResourceConsumption = 8; map estimatedAppDiskUsage = 9; + required bool highWorkload = 10; } message RegisterWorkerRequest { diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala index 466e088d165..c3dd2712555 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala @@ -334,6 +334,7 @@ private[celeborn] class Master( userResourceConsumption, activeShuffleKey, estimatedAppDiskUsage, + highWorkload, requestId) => logDebug(s"Received heartbeat from" + s" worker $host:$rpcPort:$pushPort:$fetchPort:$replicatePort with $disks.") @@ -350,6 +351,7 @@ private[celeborn] class Master( userResourceConsumption, activeShuffleKey, estimatedAppDiskUsage, + highWorkload, requestId)) case ReportWorkerUnavailable(failedWorkers: util.List[WorkerInfo], requestId: String) => @@ -432,6 +434,7 @@ private[celeborn] class Master( userResourceConsumption: util.Map[UserIdentifier, ResourceConsumption], activeShuffleKeys: util.Set[String], estimatedAppDiskUsage: util.HashMap[String, java.lang.Long], + highWorkload: Boolean, requestId: String): Unit = { val targetWorker = new WorkerInfo(host, rpcPort, pushPort, fetchPort, replicatePort) val registered = workersSnapShot.asScala.contains(targetWorker) @@ -449,6 +452,7 @@ private[celeborn] class Master( userResourceConsumption, estimatedAppDiskUsage, System.currentTimeMillis(), + highWorkload, requestId) } diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java index aaae7861a87..2962ebafdb5 100644 --- a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java +++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java @@ -506,6 +506,7 @@ public void testHandleWorkerHeartbeat() { userResourceConsumption1, new HashMap<>(), 1, + false, getNewReqeustId()); Assert.assertEquals(statusSystem.excludedWorkers.size(), 1); @@ -520,23 +521,40 @@ public void testHandleWorkerHeartbeat() { userResourceConsumption2, new HashMap<>(), 1, + false, getNewReqeustId()); Assert.assertEquals(statusSystem.excludedWorkers.size(), 2); statusSystem.handleWorkerHeartbeat( - HOSTNAME1, - RPCPORT1, - PUSHPORT1, - FETCHPORT1, + HOSTNAME3, + RPCPORT3, + PUSHPORT3, + FETCHPORT3, REPLICATEPORT3, - disks1, - userResourceConsumption1, + disks3, + userResourceConsumption3, new HashMap<>(), 1, + false, getNewReqeustId()); Assert.assertEquals(statusSystem.excludedWorkers.size(), 2); + + statusSystem.handleWorkerHeartbeat( + HOSTNAME3, + RPCPORT3, + PUSHPORT3, + FETCHPORT3, + REPLICATEPORT3, + disks3, + userResourceConsumption3, + new HashMap<>(), + 1, + true, + getNewReqeustId()); + + Assert.assertEquals(statusSystem.excludedWorkers.size(), 3); } @Test diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java index 1d0e2c17b76..6ae55d1620e 100644 --- a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java +++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java @@ -735,6 +735,7 @@ public void testHandleWorkerHeartbeat() throws InterruptedException { userResourceConsumption1, new HashMap<>(), 1, + false, getNewReqeustId()); Thread.sleep(3000L); @@ -752,6 +753,7 @@ public void testHandleWorkerHeartbeat() throws InterruptedException { userResourceConsumption2, new HashMap<>(), 1, + false, getNewReqeustId()); Thread.sleep(3000L); @@ -770,6 +772,7 @@ public void testHandleWorkerHeartbeat() throws InterruptedException { userResourceConsumption1, new HashMap<>(), 1, + false, getNewReqeustId()); Thread.sleep(3000L); @@ -777,6 +780,24 @@ public void testHandleWorkerHeartbeat() throws InterruptedException { Assert.assertEquals(1, STATUSSYSTEM1.excludedWorkers.size()); Assert.assertEquals(1, STATUSSYSTEM2.excludedWorkers.size()); Assert.assertEquals(1, STATUSSYSTEM3.excludedWorkers.size()); + + statusSystem.handleWorkerHeartbeat( + HOSTNAME1, + RPCPORT1, + PUSHPORT1, + FETCHPORT1, + REPLICATEPORT1, + disks1, + userResourceConsumption1, + new HashMap<>(), + 1, + true, + getNewReqeustId()); + Thread.sleep(3000L); + Assert.assertEquals(2, statusSystem.excludedWorkers.size()); + Assert.assertEquals(2, STATUSSYSTEM1.excludedWorkers.size()); + Assert.assertEquals(2, STATUSSYSTEM2.excludedWorkers.size()); + Assert.assertEquals(2, STATUSSYSTEM3.excludedWorkers.size()); } @Before diff --git a/pom.xml b/pom.xml index 6e7880ca2e0..786517fefdc 100644 --- a/pom.xml +++ b/pom.xml @@ -268,6 +268,12 @@ org.apache.spark spark-core_${scala.binary.version} ${spark.version} + + + io.netty + * + + org.apache.spark @@ -286,16 +292,6 @@ ${spark.version} test-jar - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.databind.version} - com.google.guava guava @@ -705,6 +701,17 @@ maven-dependency-plugin ${maven.plugin.dependency.version} + + default-cli + + build-classpath + + + + runtime + + copy-module-dependencies @@ -878,26 +885,12 @@ tests/spark-it - 2.6.7 - 2.6.7.3 1.4.0 2.11.12 2.11 2.4.8 1.4.4-3 - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.databind.version} - - @@ -909,8 +902,6 @@ tests/spark-it - 2.10.0 - 2.10.0 1.7.1 2.12.10 2.12 @@ -929,8 +920,6 @@ tests/spark-it - 2.10.0 - 2.10.0 1.7.1 2.12.10 2.12 @@ -949,8 +938,6 @@ tests/spark-it - 2.12.3 - 2.12.3 1.7.1 2.12.15 2.12 @@ -968,8 +955,6 @@ tests/spark-it - 2.13.4 - 2.13.4.2 1.8.0 2.12.15 2.12 @@ -987,8 +972,6 @@ tests/spark-it - 2.14.2 - 2.14.2 1.8.0 2.12.17 2.12 diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala index e6ab134c3e6..f63dfc745cd 100644 --- a/project/CelebornBuild.scala +++ b/project/CelebornBuild.scala @@ -62,7 +62,8 @@ object Dependencies { val protocVersion = "3.19.2" val protoVersion = "3.19.2" - val commonsCrypto = "org.apache.commons" % "commons-crypto" % commonsCryptoVersion + val commonsCrypto = "org.apache.commons" % "commons-crypto" % commonsCryptoVersion excludeAll( + ExclusionRule("net.java.dev.jna", "jna")) val commonsIo = "commons-io" % "commons-io" % commonsIoVersion val commonsLang3 = "org.apache.commons" % "commons-lang3" % commonsLang3Version val findbugsJsr305 = "com.google.code.findbugs" % "jsr305" % findbugsVersion @@ -72,7 +73,8 @@ object Dependencies { val ioDropwizardMetricsCore = "io.dropwizard.metrics" % "metrics-core" % metricsVersion val ioDropwizardMetricsGraphite = "io.dropwizard.metrics" % "metrics-graphite" % metricsVersion val ioDropwizardMetricsJvm = "io.dropwizard.metrics" % "metrics-jvm" % metricsVersion - val ioNetty = "io.netty" % "netty-all" % nettyVersion + val ioNetty = "io.netty" % "netty-all" % nettyVersion excludeAll( + ExclusionRule("io.netty", "netty-handler-ssl-ocsp")) val javaxServletApi = "javax.servlet" % "javax.servlet-api" % javaxServletVersion val leveldbJniAll = "org.fusesource.leveldbjni" % "leveldbjni-all" % leveldbJniVersion val log4j12Api = "org.apache.logging.log4j" % "log4j-1.2-api" % log4j2Version @@ -133,6 +135,10 @@ object CelebornCommonSettings { // -target cannot be passed as a parameter to javadoc. See https://github.com/sbt/sbt/issues/355 Compile / compile / javacOptions ++= Seq("-target", "1.8"), + + dependencyOverrides := Seq( + Dependencies.findbugsJsr305, + Dependencies.slf4jApi), // Make sure any tests in any project that uses Spark is configured for running well locally Test / javaOptions ++= Seq( diff --git a/project/build.properties b/project/build.properties index 41f6be16879..ddb8431706a 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=1.9.3 +sbt.version=1.9.4 diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java index 7f01924968f..cb248dd9b8d 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java @@ -142,41 +142,7 @@ private MemoryManager(CelebornConf conf) { checkService.scheduleWithFixedDelay( () -> { try { - ServingState lastState = servingState; - servingState = currentServingState(); - if (lastState != servingState) { - logger.info("Serving state transformed from {} to {}", lastState, servingState); - if (servingState == ServingState.PUSH_PAUSED) { - pausePushDataCounter.increment(); - logger.info("Trigger action: PAUSE PUSH, RESUME REPLICATE"); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE)); - trimAllListeners(); - } else if (servingState == ServingState.PUSH_AND_REPLICATE_PAUSED) { - pausePushDataAndReplicateCounter.increment(); - logger.info("Trigger action: PAUSE PUSH and REPLICATE"); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE)); - trimAllListeners(); - } else { - logger.info("Trigger action: RESUME PUSH and REPLICATE"); - memoryPressureListeners.forEach( - memoryPressureListener -> memoryPressureListener.onResume("all")); - } - } else { - if (servingState != ServingState.NONE_PAUSED) { - logger.debug("Trigger action: TRIM"); - trimAllListeners(); - } - } + switchServingState(); } catch (Exception e) { logger.error("Memory tracker check error", e); } @@ -274,6 +240,59 @@ public ServingState currentServingState() { return isPaused ? ServingState.PUSH_PAUSED : ServingState.NONE_PAUSED; } + @VisibleForTesting + protected void switchServingState() { + ServingState lastState = servingState; + servingState = currentServingState(); + if (lastState == servingState) { + if (servingState != ServingState.NONE_PAUSED) { + logger.debug("Trigger action: TRIM"); + trimAllListeners(); + } + return; + } + logger.info("Serving state transformed from {} to {}", lastState, servingState); + switch (servingState) { + case PUSH_PAUSED: + pausePushDataCounter.increment(); + logger.info("Trigger action: PAUSE PUSH, RESUME REPLICATE"); + if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) { + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE)); + } else if (lastState == ServingState.NONE_PAUSED) { + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); + } + trimAllListeners(); + break; + case PUSH_AND_REPLICATE_PAUSED: + pausePushDataAndReplicateCounter.increment(); + logger.info("Trigger action: PAUSE PUSH and REPLICATE"); + if (lastState == ServingState.NONE_PAUSED) { + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); + } + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE)); + trimAllListeners(); + break; + case NONE_PAUSED: + logger.info("Trigger action: RESUME PUSH and REPLICATE"); + if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) { + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE)); + } + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE)); + } + } + public void trimAllListeners() { if (trimInProcess.compareAndSet(false, true)) { actionService.submit( @@ -410,7 +429,6 @@ public interface ReadBufferTargetChangeListener { void onChange(long newMemoryTarget); } - @VisibleForTesting public enum ServingState { NONE_PAUSED, PUSH_AND_REPLICATE_PAUSED, diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index f26357ded36..b4b0ddbace2 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -47,8 +47,10 @@ import org.apache.celeborn.common.util.{CelebornExitKind, JavaUtils, ShutdownHoo // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.server.common.{HttpService, Service} +import org.apache.celeborn.service.deploy.worker.WorkerSource.ACTIVE_CONNECTION_COUNT import org.apache.celeborn.service.deploy.worker.congestcontrol.CongestionController import org.apache.celeborn.service.deploy.worker.memory.{ChannelsLimiter, MemoryManager} +import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.ServingState import org.apache.celeborn.service.deploy.worker.storage.{MapPartitionFileWriter, PartitionFilesSorter, StorageManager} private[celeborn] class Worker( @@ -297,6 +299,16 @@ private[celeborn] class Worker( memoryManager.getAllocatedReadBuffers } + private def highWorkload: Boolean = { + (memoryManager.currentServingState, conf.workerActiveConnectionMax) match { + case (ServingState.PUSH_AND_REPLICATE_PAUSED, _) => true + case (ServingState.PUSH_PAUSED, _) => true + case (_, Some(activeConnectionMax)) => + workerSource.getCounterCount(ACTIVE_CONNECTION_COUNT) >= activeConnectionMax + case _ => false + } + } + private def heartbeatToMaster(): Unit = { val activeShuffleKeys = new JHashSet[String]() val estimatedAppDiskUsage = new JHashMap[String, JLong]() @@ -323,7 +335,8 @@ private[celeborn] class Worker( diskInfos, resourceConsumption, activeShuffleKeys, - estimatedAppDiskUsage), + estimatedAppDiskUsage, + highWorkload), classOf[HeartbeatFromWorkerResponse]) response.expiredShuffleKeys.asScala.foreach(shuffleKey => workerInfo.releaseSlots(shuffleKey)) cleanTaskQueue.put(response.expiredShuffleKeys) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/WorkerSource.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/WorkerSource.scala index edcebbb2e08..e1f247a0833 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/WorkerSource.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/WorkerSource.scala @@ -57,6 +57,10 @@ class WorkerSource(conf: CelebornConf) extends AbstractSource(conf, MetricsSyste addTimer(TAKE_BUFFER_TIME) addTimer(SORT_TIME) + def getCounterCount(metricsName: String): Long = { + val metricNameWithLabel = metricNameWithCustomizedLabels(metricsName, Map.empty) + namedCounters.get(metricNameWithLabel).counter.getCount + } // start cleaner thread startCleaner() } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala index f329b6f36cb..864ddffec20 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala @@ -17,10 +17,17 @@ package org.apache.celeborn.service.deploy.memory +import scala.concurrent.duration.DurationInt + +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.{interval, timeout} + import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf.{WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE, WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE} +import org.apache.celeborn.common.protocol.TransportModuleConstants import org.apache.celeborn.service.deploy.worker.memory.MemoryManager +import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.MemoryPressureListener import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.ServingState class MemoryManagerSuite extends CelebornFunSuite { @@ -79,4 +86,84 @@ class MemoryManagerSuite extends CelebornFunSuite { MemoryManager.reset() } } + + test("[CELEBORN-882] Test MemoryManager check memory thread logic") { + val conf = new CelebornConf() + val memoryManager = MemoryManager.initialize(conf) + val maxDirectorMemory = memoryManager.maxDirectorMemory + val pushThreshold = + (conf.workerDirectMemoryRatioToPauseReceive * maxDirectorMemory).longValue() + val replicateThreshold = + (conf.workerDirectMemoryRatioToPauseReplicate * maxDirectorMemory).longValue() + val memoryCounter = memoryManager.getSortMemoryCounter + + val pushListener = new MockMemoryPressureListener(TransportModuleConstants.PUSH_MODULE) + val replicateListener = + new MockMemoryPressureListener(TransportModuleConstants.REPLICATE_MODULE) + memoryManager.registerMemoryListener(pushListener) + memoryManager.registerMemoryListener(replicateListener) + + // NONE PAUSED -> PAUSE PUSH + memoryCounter.set(pushThreshold + 1) + // default check interval is 10ms and we need wait 30ms to make sure the listener is triggered + eventually(timeout(30.second), interval(10.milliseconds)) { + assert(pushListener.isPause) + assert(!replicateListener.isPause) + } + + // PAUSE PUSH -> PAUSE PUSH AND REPLICATE + memoryCounter.set(replicateThreshold + 1); + eventually(timeout(30.second), interval(10.milliseconds)) { + assert(pushListener.isPause) + assert(replicateListener.isPause) + } + + // PAUSE PUSH AND REPLICATE -> PAUSE PUSH + memoryCounter.set(pushThreshold + 1); + eventually(timeout(30.second), interval(10.milliseconds)) { + assert(pushListener.isPause) + assert(!replicateListener.isPause) + } + + // PAUSE PUSH -> NONE PAUSED + memoryCounter.set(0); + eventually(timeout(30.second), interval(10.milliseconds)) { + assert(!pushListener.isPause) + assert(!replicateListener.isPause) + } + + // NONE PAUSED -> PAUSE PUSH AND REPLICATE + memoryCounter.set(replicateThreshold + 1); + eventually(timeout(30.second), interval(10.milliseconds)) { + assert(pushListener.isPause) + assert(replicateListener.isPause) + } + + // PAUSE PUSH AND REPLICATE -> NONE PAUSED + memoryCounter.set(0); + eventually(timeout(30.second), interval(10.milliseconds)) { + assert(!pushListener.isPause) + assert(!replicateListener.isPause) + } + } + + class MockMemoryPressureListener( + val belongModuleName: String, + var isPause: Boolean = false) extends MemoryPressureListener { + override def onPause(moduleName: String): Unit = { + if (belongModuleName == moduleName) { + isPause = true + } + } + + override def onResume(moduleName: String): Unit = { + if (belongModuleName == moduleName) { + isPause = false + } + } + + override def onTrim(): Unit = { + // do nothing + } + } }