diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4286e01862337..b0285cc6557ea 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -75,7 +75,7 @@ private class ShuffleStatus(numPartitions: Int) { * broadcast variable in order to keep it from being garbage collected and to allow for it to be * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. */ - private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + private[this] var cachedSerializedBroadcast: Broadcast[Array[MapStatus]] = _ /** * Counter tracking the number of partitions that have output. This is a performance optimization @@ -323,7 +323,7 @@ private[spark] class MapOutputTrackerMaster( // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = - conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt + conf.get("spark.shuffle.mapOutput.minSizeForBroadcast", "500").toInt /** Whether to compute locality preferences for reduce tasks */ private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true) @@ -361,16 +361,6 @@ private[spark] class MapOutputTrackerMaster( pool } - // Make sure that we aren't going to exceed the max RPC message size by making sure - // we use broadcast to send large map output statuses. - if (minSizeForBroadcast > maxRpcMessageSize) { - val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + - s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " + - "message that is too large." - logError(msg) - throw new IllegalArgumentException(msg) - } - def post(message: GetMapOutputMessage): Unit = { mapOutputRequests.offer(message) } @@ -390,7 +380,7 @@ private[spark] class MapOutputTrackerMaster( val context = data.context val shuffleId = data.shuffleId val hostPort = context.senderAddress.hostPort - logDebug("Handling request to send map output locations for shuffle " + shuffleId + + logInfo("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) val shuffleStatus = shuffleStatuses.get(shuffleId).head context.reply( @@ -425,6 +415,7 @@ private[spark] class MapOutputTrackerMaster( /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + logInfo(s"Unregister MapOutput, shuffleId $shuffleId, mapId $mapId, bmAddress $BlockManagerId.") shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => shuffleStatus.removeMapOutput(mapId, bmAddress) @@ -436,6 +427,7 @@ private[spark] class MapOutputTrackerMaster( /** Unregister all map output information of the given shuffle. */ def unregisterAllMapOutput(shuffleId: Int) { + logInfo(s"Unregister all MapOutput, shuffleId $shuffleId.") shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => shuffleStatus.removeOutputsByFilter(x => true) @@ -448,6 +440,7 @@ private[spark] class MapOutputTrackerMaster( /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int) { + logInfo(s"Unregister Shuffle, shuffleId $shuffleId.") shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => shuffleStatus.invalidateSerializedMapOutputStatusCache() } @@ -458,6 +451,7 @@ private[spark] class MapOutputTrackerMaster( * outputs which are served by an external shuffle server (if one exists). */ def removeOutputsOnHost(host: String): Unit = { + logInfo(s"Remove Outputs on host $host.") shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) } incrementEpoch() } @@ -468,6 +462,7 @@ private[spark] class MapOutputTrackerMaster( * registered with this execId. */ def removeOutputsOnExecutor(execId: String): Unit = { + logInfo(s"Remove Outputs on executor $execId.") shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } incrementEpoch() } @@ -791,23 +786,29 @@ private[spark] object MapOutputTracker extends Logging { // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager, - isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = { - val out = new ByteArrayOutputStream - out.write(DIRECT) - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - Utils.tryWithSafeFinally { - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) + isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[MapStatus]]) = { + + val out = new org.apache.commons.io.output.ByteArrayOutputStream + if (statuses.length < minBroadcastSize) { + out.reset() + out.write(DIRECT) + val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + Utils.tryWithSafeFinally { + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + } { + objOut.close() } - } { - objOut.close() - } - val arr = out.toByteArray - if (arr.length >= minBroadcastSize) { + val arr = out.toByteArray + logInfo("Direct serialize mapstatuses size = " + arr.length + + ", statuses size = " + statuses.length) + (arr, null) + } else { // Use broadcast instead. // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! - val bcast = broadcastManager.newBroadcast(arr, isLocal, null) + val bcast = broadcastManager.newBroadcast(statuses, isLocal, null) // toByteArray creates copy, so we can reuse out out.reset() out.write(BROADCAST) @@ -815,10 +816,9 @@ private[spark] object MapOutputTracker extends Logging { oos.writeObject(bcast) oos.close() val outArr = out.toByteArray - logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) + logInfo("Broadcast mapstatuses size = " + outArr.length + + ", statuses size = " + statuses.length) (outArr, bcast) - } else { - (arr, null) } } @@ -842,11 +842,11 @@ private[spark] object MapOutputTracker extends Logging { case BROADCAST => // deserialize the Broadcast, pull .value array out of it, and then deserialize that val bcast = deserializeObject(bytes, 1, bytes.length - 1). - asInstanceOf[Broadcast[Array[Byte]]] + asInstanceOf[Broadcast[Array[MapStatus]]] logInfo("Broadcast mapstatuses size = " + bytes.length + ", actual size = " + bcast.value.length) // Important - ignore the DIRECT tag ! Start from offset 1 - deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]] + bcast.value case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0)) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index d77fbb2a81475..a1eba98c662d1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -223,27 +223,29 @@ class BlockManagerMasterEndpoint( val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next - val locations = blockLocations.get(blockId) - locations -= blockManagerId - // De-register the block if none of the block managers have it. Otherwise, if pro-active - // replication is enabled, and a block is either an RDD or a test block (the latter is used - // for unit testing), we send a message to a randomly chosen executor location to replicate - // the given block. Note that we ignore other block types (such as broadcast/shuffle blocks - // etc.) as replication doesn't make much sense in that context. - if (locations.size == 0) { - blockLocations.remove(blockId) - logWarning(s"No more replicas available for $blockId !") - } else if (proactivelyReplicate && (blockId.isRDD || blockId.isInstanceOf[TestBlockId])) { - // As a heursitic, assume single executor failure to find out the number of replicas that - // existed before failure - val maxReplicas = locations.size + 1 - val i = (new Random(blockId.hashCode)).nextInt(locations.size) - val blockLocations = locations.toSeq - val candidateBMId = blockLocations(i) - blockManagerInfo.get(candidateBMId).foreach { bm => - val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId) - val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas) - bm.slaveEndpoint.ask[Boolean](replicateMsg) + if (blockLocations.containsKey(blockId)) { + val locations = blockLocations.get(blockId) + locations -= blockManagerId + // De-register the block if none of the block managers have it. Otherwise, if pro-active + // replication is enabled, and a block is either an RDD or a test block (the latter is used + // for unit testing), we send a message to a randomly chosen executor location to replicate + // the given block. Note that we ignore other block types (such as broadcast/shuffle blocks + // etc.) as replication doesn't make much sense in that context. + if (locations.size == 0) { + blockLocations.remove(blockId) + logWarning(s"No more replicas available for $blockId !") + } else if (proactivelyReplicate && (blockId.isRDD || blockId.isInstanceOf[TestBlockId])) { + // As a heursitic, assume single executor failure to find out the number of replicas that + // existed before failure + val maxReplicas = locations.size + 1 + val i = (new Random(blockId.hashCode)).nextInt(locations.size) + val blockLocations = locations.toSeq + val candidateBMId = blockLocations(i) + blockManagerInfo.get(candidateBMId).foreach { bm => + val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId) + val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas) + bm.slaveEndpoint.ask[Boolean](replicateMsg) + } } } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d477242..de36aefa410cd 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -196,15 +196,6 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("min broadcast size exceeds max RPC message size") { - val newConf = new SparkConf - newConf.set("spark.rpc.message.maxSize", "1") - newConf.set("spark.rpc.askTimeout", "1") // Fail fast - newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", Int.MaxValue.toString) - - intercept[IllegalArgumentException] { newTrackerMaster(newConf) } - } - test("getLocationsWithLargestOutputs with multiple outputs in same machine") { val rpcEnv = createRpcEnv("test") val tracker = newTrackerMaster() @@ -244,7 +235,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast - newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "99") // 10 KB << 1MB framesize // needs TorrentBroadcast so need a SparkContext withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc =>