Skip to content

Commit

Permalink
#51, opt MapStatus ser/dser.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaaaaaron committed Sep 20, 2019
1 parent cc4e1e1 commit 7b73196
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 63 deletions.
64 changes: 32 additions & 32 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private class ShuffleStatus(numPartitions: Int) {
* broadcast variable in order to keep it from being garbage collected and to allow for it to be
* explicitly destroyed later on when the ShuffleMapStage is garbage-collected.
*/
private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
private[this] var cachedSerializedBroadcast: Broadcast[Array[MapStatus]] = _

/**
* Counter tracking the number of partitions that have output. This is a performance optimization
Expand Down Expand Up @@ -323,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(

// The size at which we use Broadcast to send the map output statuses to the executors
private val minSizeForBroadcast =
conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt
conf.get("spark.shuffle.mapOutput.minSizeForBroadcast", "500").toInt

/** Whether to compute locality preferences for reduce tasks */
private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true)
Expand Down Expand Up @@ -361,16 +361,6 @@ private[spark] class MapOutputTrackerMaster(
pool
}

// Make sure that we aren't going to exceed the max RPC message size by making sure
// we use broadcast to send large map output statuses.
if (minSizeForBroadcast > maxRpcMessageSize) {
val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " +
s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " +
"message that is too large."
logError(msg)
throw new IllegalArgumentException(msg)
}

def post(message: GetMapOutputMessage): Unit = {
mapOutputRequests.offer(message)
}
Expand All @@ -390,7 +380,7 @@ private[spark] class MapOutputTrackerMaster(
val context = data.context
val shuffleId = data.shuffleId
val hostPort = context.senderAddress.hostPort
logDebug("Handling request to send map output locations for shuffle " + shuffleId +
logInfo("Handling request to send map output locations for shuffle " + shuffleId +
" to " + hostPort)
val shuffleStatus = shuffleStatuses.get(shuffleId).head
context.reply(
Expand Down Expand Up @@ -425,6 +415,7 @@ private[spark] class MapOutputTrackerMaster(

/** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
logInfo(s"Unregister MapOutput, shuffleId $shuffleId, mapId $mapId, bmAddress $BlockManagerId.")
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeMapOutput(mapId, bmAddress)
Expand All @@ -436,6 +427,7 @@ private[spark] class MapOutputTrackerMaster(

/** Unregister all map output information of the given shuffle. */
def unregisterAllMapOutput(shuffleId: Int) {
logInfo(s"Unregister all MapOutput, shuffleId $shuffleId.")
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(x => true)
Expand All @@ -448,6 +440,7 @@ private[spark] class MapOutputTrackerMaster(

/** Unregister shuffle data */
def unregisterShuffle(shuffleId: Int) {
logInfo(s"Unregister Shuffle, shuffleId $shuffleId.")
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
shuffleStatus.invalidateSerializedMapOutputStatusCache()
}
Expand All @@ -458,6 +451,7 @@ private[spark] class MapOutputTrackerMaster(
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
logInfo(s"Remove Outputs on host $host.")
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) }
incrementEpoch()
}
Expand All @@ -468,6 +462,7 @@ private[spark] class MapOutputTrackerMaster(
* registered with this execId.
*/
def removeOutputsOnExecutor(execId: String): Unit = {
logInfo(s"Remove Outputs on executor $execId.")
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) }
incrementEpoch()
}
Expand Down Expand Up @@ -791,34 +786,39 @@ private[spark] object MapOutputTracker extends Logging {
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager,
isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = {
val out = new ByteArrayOutputStream
out.write(DIRECT)
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
Utils.tryWithSafeFinally {
// Since statuses can be modified in parallel, sync on it
statuses.synchronized {
objOut.writeObject(statuses)
isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[MapStatus]]) = {

val out = new org.apache.commons.io.output.ByteArrayOutputStream
if (statuses.length < minBroadcastSize) {
out.reset()
out.write(DIRECT)
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
Utils.tryWithSafeFinally {
// Since statuses can be modified in parallel, sync on it
statuses.synchronized {
objOut.writeObject(statuses)
}
} {
objOut.close()
}
} {
objOut.close()
}
val arr = out.toByteArray
if (arr.length >= minBroadcastSize) {
val arr = out.toByteArray
logInfo("Direct serialize mapstatuses size = " + arr.length +
", statuses size = " + statuses.length)
(arr, null)
} else {
// Use broadcast instead.
// Important arr(0) is the tag == DIRECT, ignore that while deserializing !
val bcast = broadcastManager.newBroadcast(arr, isLocal, null)
val bcast = broadcastManager.newBroadcast(statuses, isLocal, null)
// toByteArray creates copy, so we can reuse out
out.reset()
out.write(BROADCAST)
val oos = new ObjectOutputStream(new GZIPOutputStream(out))
oos.writeObject(bcast)
oos.close()
val outArr = out.toByteArray
logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length)
logInfo("Broadcast mapstatuses size = " + outArr.length +
", statuses size = " + statuses.length)
(outArr, bcast)
} else {
(arr, null)
}
}

Expand All @@ -842,11 +842,11 @@ private[spark] object MapOutputTracker extends Logging {
case BROADCAST =>
// deserialize the Broadcast, pull .value array out of it, and then deserialize that
val bcast = deserializeObject(bytes, 1, bytes.length - 1).
asInstanceOf[Broadcast[Array[Byte]]]
asInstanceOf[Broadcast[Array[MapStatus]]]
logInfo("Broadcast mapstatuses size = " + bytes.length +
", actual size = " + bcast.value.length)
// Important - ignore the DIRECT tag ! Start from offset 1
deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]]
bcast.value
case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,29 @@ class BlockManagerMasterEndpoint(
val iterator = info.blocks.keySet.iterator
while (iterator.hasNext) {
val blockId = iterator.next
val locations = blockLocations.get(blockId)
locations -= blockManagerId
// De-register the block if none of the block managers have it. Otherwise, if pro-active
// replication is enabled, and a block is either an RDD or a test block (the latter is used
// for unit testing), we send a message to a randomly chosen executor location to replicate
// the given block. Note that we ignore other block types (such as broadcast/shuffle blocks
// etc.) as replication doesn't make much sense in that context.
if (locations.size == 0) {
blockLocations.remove(blockId)
logWarning(s"No more replicas available for $blockId !")
} else if (proactivelyReplicate && (blockId.isRDD || blockId.isInstanceOf[TestBlockId])) {
// As a heursitic, assume single executor failure to find out the number of replicas that
// existed before failure
val maxReplicas = locations.size + 1
val i = (new Random(blockId.hashCode)).nextInt(locations.size)
val blockLocations = locations.toSeq
val candidateBMId = blockLocations(i)
blockManagerInfo.get(candidateBMId).foreach { bm =>
val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId)
val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas)
bm.slaveEndpoint.ask[Boolean](replicateMsg)
if (blockLocations.containsKey(blockId)) {
val locations = blockLocations.get(blockId)
locations -= blockManagerId
// De-register the block if none of the block managers have it. Otherwise, if pro-active
// replication is enabled, and a block is either an RDD or a test block (the latter is used
// for unit testing), we send a message to a randomly chosen executor location to replicate
// the given block. Note that we ignore other block types (such as broadcast/shuffle blocks
// etc.) as replication doesn't make much sense in that context.
if (locations.size == 0) {
blockLocations.remove(blockId)
logWarning(s"No more replicas available for $blockId !")
} else if (proactivelyReplicate && (blockId.isRDD || blockId.isInstanceOf[TestBlockId])) {
// As a heursitic, assume single executor failure to find out the number of replicas that
// existed before failure
val maxReplicas = locations.size + 1
val i = (new Random(blockId.hashCode)).nextInt(locations.size)
val blockLocations = locations.toSeq
val candidateBMId = blockLocations(i)
blockManagerInfo.get(candidateBMId).foreach { bm =>
val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId)
val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas)
bm.slaveEndpoint.ask[Boolean](replicateMsg)
}
}
}
}
Expand Down
11 changes: 1 addition & 10 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,6 @@ class MapOutputTrackerSuite extends SparkFunSuite {
rpcEnv.shutdown()
}

test("min broadcast size exceeds max RPC message size") {
val newConf = new SparkConf
newConf.set("spark.rpc.message.maxSize", "1")
newConf.set("spark.rpc.askTimeout", "1") // Fail fast
newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", Int.MaxValue.toString)

intercept[IllegalArgumentException] { newTrackerMaster(newConf) }
}

test("getLocationsWithLargestOutputs with multiple outputs in same machine") {
val rpcEnv = createRpcEnv("test")
val tracker = newTrackerMaster()
Expand Down Expand Up @@ -244,7 +235,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val newConf = new SparkConf
newConf.set("spark.rpc.message.maxSize", "1")
newConf.set("spark.rpc.askTimeout", "1") // Fail fast
newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize
newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "99") // 10 KB << 1MB framesize

// needs TorrentBroadcast so need a SparkContext
withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc =>
Expand Down

0 comments on commit 7b73196

Please sign in to comment.