Skip to content

Shuffle consolidation #669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,29 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockManager = SparkEnv.get.blockManager

val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
val (mapLocations, blockSizes) = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)

logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))

val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}

val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = blockSizes.map { case (bm, groups) =>
val blockIds = groups.flatMap { case (groupId, segments) =>
segments.zipWithIndex.map(x=>(("shuffle_%d_%d_%d_%d").format(shuffleId, groupId, reduceId, x._2), x._1))
}
(bm, blockIds)
}

def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[(K, V)] = {
val blockId = blockPair._1
val blockOption = blockPair._2
def unpackBlock(blockTuple: (BlockManagerId, String, Option[Iterator[Any]])) : Iterator[(K, V)] = {
val (address, blockId, blockOption) = blockTuple
blockOption match {
case Some(block) => {
block.asInstanceOf[Iterator[(K, V)]]
}
case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
case regex(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case regex(shufId, _, _, _) =>
throw new FetchFailedException(address, shufId.toInt, -1, reduceId, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
Expand All @@ -53,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}

val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val blockFetcherItr = blockManager.getMultiple(
blocksByAddress, shuffleId, reduceId, mapLocations, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)

CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
Expand Down
192 changes: 160 additions & 32 deletions core/src/main/scala/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _

private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
private val mapStatuses = new TimeStampedHashMap[Int, Array[MapOutputLocation]]
private val shuffleBlockSizes = new TimeStampedHashMap[Int, HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]]

// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
Expand Down Expand Up @@ -76,36 +77,36 @@ private[spark] class MapOutputTracker extends Logging {
}

def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapOutputLocation](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}

def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
shuffleBlockSizes.put(shuffleId, HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]())
}

def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
statuses: Array[MapOutputLocation],
sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray],
changeGeneration: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
mapStatuses.put(shuffleId, Array[MapOutputLocation]() ++ statuses)
shuffleBlockSizes.put(shuffleId, HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]() ++ sizes)

if (changeGeneration) {
incrementGeneration()
}
}

def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
var array = arrayOpt.get
val array = mapStatuses.get(shuffleId).orNull
val sizes = shuffleBlockSizes.get(shuffleId).orNull
if (array != null) {
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
if (sizes!= null) {
sizes.remove(bmAddress)
}
}
incrementGeneration()
} else {
Expand All @@ -116,12 +117,21 @@ private[spark] class MapOutputTracker extends Logging {
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]

// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
/**
* Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
* Return an array of MapOutputLocation of the specific reduceId, one for each ShuffleMapTask,
* and sizes of all segments in for the shuffle (bucket) in the form of
* Seq(BlockManagerId, Seq(groupId, size array for all the segments in the bucket))
*/
def getServerStatuses(shuffleId: Int, reduceId: Int):
(Array[MapOutputLocation], Seq[(BlockManagerId, Seq[(Int, Seq[Long])])]) = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
val sizes = shuffleBlockSizes.get(shuffleId).orNull

if (statuses == null || sizes == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
var fetchedStatuses: Array[MapOutputLocation] = null
var fetchedSizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
Expand Down Expand Up @@ -151,19 +161,26 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
val tuple = deserializeStatuses(fetchedBytes)
fetchedStatuses = tuple._1
fetchedSizes = tuple._2
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
shuffleBlockSizes.put(shuffleId, fetchedSizes)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
if (fetchedStatuses != null) {
if (fetchedStatuses != null && fetchedSizes != null) {
logDebug("ShufCon - getServerStatuses for shuffle " + shuffleId + ": " +
fetachedResultStr(fetchedStatuses, fetchedSizes))

fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
return MapOutputTracker.convertShuffleBlockSizes(shuffleId, reduceId,
fetchedStatuses, fetchedSizes)
}
}
else{
Expand All @@ -172,19 +189,33 @@ private[spark] class MapOutputTracker extends Logging {
}
} else {
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
return MapOutputTracker.convertShuffleBlockSizes(shuffleId, reduceId, statuses, sizes)
}
}
}

private def fetachedResultStr (fetchedStatuses: Array[MapOutputLocation],
fetchedSizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]) = {
var str = "(fetchedStatuses="
fetchedStatuses.zipWithIndex.foreach { s =>
str += (if (s._2 != 0) ", " else "") + "map[" + s._2 + "]=" + s._1.debugString
}
str += "), fetchedSizes=("
fetchedSizes.foreach { s => str += "(" + s._1 + ", " + s._2.debugString + ") "}
str += ")"
str
}

private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
shuffleBlockSizes.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}

def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
shuffleBlockSizes.clear()
metadataCleaner.cancel()
trackerActor = null
}
Expand Down Expand Up @@ -219,7 +250,8 @@ private[spark] class MapOutputTracker extends Logging {
}

def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var statuses: Array[MapOutputLocation] = null
var sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray] = null
var generationGotten: Long = -1
generationLock.synchronized {
if (generation > cacheGeneration) {
Expand All @@ -231,12 +263,13 @@ private[spark] class MapOutputTracker extends Logging {
return bytes
case None =>
statuses = mapStatuses(shuffleId)
sizes = shuffleBlockSizes.get(shuffleId).orNull
generationGotten = generation
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeStatuses(statuses)
val bytes = serializeStatuses((statuses, sizes))
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the generation hasn't changed while we were working
generationLock.synchronized {
Expand All @@ -250,47 +283,64 @@ private[spark] class MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
private def serializeStatuses(tuple: (Array[MapOutputLocation], HashMap[BlockManagerId, ShuffleBlockGroupSizeArray])): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
val statuses = tuple._1
statuses.synchronized {
objOut.writeObject(statuses)
objOut.writeObject(tuple)
}
objOut.close()
out.toByteArray
}

// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
def deserializeStatuses(bytes: Array[Byte]) = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
// comment this out - nulls could be due to missing location ?
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
asInstanceOf[Tuple2[Array[MapOutputLocation], HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]]] // .filter( _ != null )
}
}

private[spark] object MapOutputTracker {
private[spark] object MapOutputTracker extends Logging{
private val LOG_BASE = 1.1

// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
private def convertMapStatuses(
private def convertShuffleBlockSizes(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
statuses: Array[MapOutputLocation],
sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]):
(Array[MapOutputLocation], Seq[(BlockManagerId, Seq[(Int, Seq[Long])])]) = {
assert (statuses != null)
assert (sizes != null)

statuses.map {
status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
}
}

val segments = sizes.toSeq.map { case (bmId, groups) =>
if (groups == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing output blocks for shuffle " + shuffleId + " on " + bmId))
} else {
(status.location, decompressSize(status.compressedSizes(reduceId)))
val seq =
for (i <- 0 until groups.groupNum if groups(i) != null)
yield (i, groups(i).bucketSizes(reduceId).map(decompressSize(_)).toSeq)
(bmId, seq)
}
}

(statuses, segments)
}

/**
Expand Down Expand Up @@ -319,3 +369,81 @@ private[spark] object MapOutputTracker {
}
}
}

private[spark] class MapOutputLocation(val location: BlockManagerId, val sequence: Int)
extends Serializable {
def this (status: MapStatus) = this (status.location, status.sequence)
def debugString = "MapOutputLocation(location=" + location + ", sequence=" + sequence +")"

override def equals(that: Any) = that match {
case loc: MapOutputLocation =>
location == loc.location && sequence == loc.sequence
case _ =>
false
}

}

private[spark] class GroupBucketSizes(var sequence: Int, var bucketSizes: Array[Array[Byte]])
extends Serializable {
def this(status: MapStatus) = this(status.sequence, status.compressedSizes)
def debugString = {
var str = "GroupBucketSizes(sequence=" + sequence + ", "
bucketSizes.zipWithIndex.foreach { s =>
str += (if (s._2 != 0) ", " else "") + "bucket[" + s._2 + "]=("
s._1.zipWithIndex.foreach{ x =>
str += (if (x._2 != 0) ", " else "") + x._1
}
str += ")"
}
str += ")"
str
}
}

private[spark] class ShuffleBlockGroupSizeArray extends Serializable {
var groupNum = 0
private var groupSizeArray = Array.fill[GroupBucketSizes](32)(null)

def apply(idx: Int) = if (idx >= groupSizeArray.length) null else groupSizeArray(idx)

def update(idx: Int, elem: GroupBucketSizes) {
if (idx >= groupSizeArray.length){
var newLen = groupSizeArray.length * 2
while (idx >= newLen)
newLen = newLen * 2

val newArray = Array.fill[GroupBucketSizes](newLen)(null)
scala.compat.Platform.arraycopy(groupSizeArray, 0, newArray, 0, groupNum)
groupSizeArray = newArray
}

if (idx >= groupNum)
groupNum = idx + 1

groupSizeArray(idx) = elem
}

def +=(elem: GroupBucketSizes) {
this(groupNum) = elem
}

def debugString = {
var str = "ShuffleBlockGroupSizeArray("
for (i <- 0 until groupNum) {
str += (if (i != 0) str += ", " else "") + "group_" + i + "=" + (if (groupSizeArray(i) == null) "null" else groupSizeArray(i).debugString)
}
str + ")"
}
}

private[spark] object ShuffleBlockGroupSizeArray {
def apply(xs: GroupBucketSizes*) = {
val sizes = new ShuffleBlockGroupSizeArray()
xs.foreach { x =>
sizes += x
}
sizes
}
}

Loading