Skip to content

Commit

Permalink
Add zipWithOther & repartition
Browse files Browse the repository at this point in the history
Adds support for zipping an arbitary RDD with an IndexedRDD's entire index.
This is a useful primitive for implementing a distributed broadcast join.
  • Loading branch information
Nick White committed May 25, 2016
1 parent ac4ac78 commit c73d3ee
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package edu.berkeley.cs.amplab.spark.indexedrdd

import scala.reflect.ClassTag
import scala.util.Random

import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.storage.StorageLevel

import edu.berkeley.cs.amplab.spark.indexedrdd.impl._
Expand All @@ -39,7 +41,11 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
private val partitionsRDD: RDD[IndexedRDDPartition[K, V]])
extends RDD[(K, V)](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {

require(partitionsRDD.partitioner.isDefined)
require(
partitionsRDD.partitioner.isDefined,
s"$partitionsRDD needs a partitioner")

type PartitionType = IndexedRDDPartition[K, V]

override val partitioner = partitionsRDD.partitioner

Expand Down Expand Up @@ -67,9 +73,13 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
partitionsRDD.map(_.size).reduce(_ + _)
}

private[indexedrdd] def indexedPartition(part: Partition, context: TaskContext): PartitionType = {
firstParent[PartitionType].iterator(part, context).next
}

/** Provides the `RDD[(K, V)]` equivalent output. */
override def compute(part: Partition, context: TaskContext): Iterator[(K, V)] = {
firstParent[IndexedRDDPartition[K, V]].iterator(part, context).next.iterator
indexedPartition(part, context).iterator
}

/** Gets the value corresponding to the specified key, if any. */
Expand All @@ -81,7 +91,7 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
val partitions = ksByPartition.keys.toSeq
// TODO: avoid sending all keys to all partitions by creating and zipping an RDD of keys
val results: Array[Array[(K, V)]] = context.runJob(partitionsRDD,
(context: TaskContext, partIter: Iterator[IndexedRDDPartition[K, V]]) => {
(context: TaskContext, partIter: Iterator[PartitionType]) => {
if (partIter.hasNext && ksByPartition.contains(context.partitionId)) {
val part = partIter.next()
val ksForPartition = ksByPartition.get(context.partitionId).get
Expand Down Expand Up @@ -146,7 +156,7 @@ class IndexedRDD[K: ClassTag, V: ClassTag](

/** Applies a function to each partition of this IndexedRDD. */
private def mapIndexedRDDPartitions[K2: ClassTag, V2: ClassTag](
f: IndexedRDDPartition[K, V] => IndexedRDDPartition[K2, V2]): IndexedRDD[K2, V2] = {
f: PartitionType => IndexedRDDPartition[K2, V2]): IndexedRDD[K2, V2] = {
val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true)
new IndexedRDD(newPartitionsRDD)
}
Expand Down Expand Up @@ -196,6 +206,54 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
this.zipPartitionsWithOther(other)(new OtherDiffZipper)
}

/**
* Repartition the RDD, maintaining the index.
*/
override def repartition
(numPartitions: Int)
(implicit ord: Ordering[(K, V)] = null): IndexedRDD[K ,V] = {

// we don't expect collisions here (as the key universes in each partition
// are disjoint), so this doesn't lose data:
def reduce(key: K, lhs: Option[V], rhs: Option[V]): V = lhs.getOrElse(rhs.get)

// from RDD#coalesce
def distributePartition(index: Int, items: Iterator[PartitionType]) = {
val position = new Random(index).nextInt(numPartitions)
items.zipWithIndex.map { case (t, i) =>
(position + i, t)
}
}

val partitions = new ShuffledRDD[Int, PartitionType, PartitionType](
partitionsRDD.mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)).
mapPartitions({ it =>
val parts = it.map(_._2).toSeq
if (parts.size <= 1)
parts.iterator
else
Iterator[PartitionType](new LazyPartition(parts, reduce))
}, true)

new IndexedRDD(partitions)
}



/**
* Zip this entire RDD with the partitions of the given RDD, giving each zipped
* function access to our index via a reference to an `ReadableIndex`. This RDD
* is reduced to a single partition if it isn't one already. This method is a useful
* primitive for implementing a broadcast join.
*/
def zipWithOther[V2: ClassTag, W: ClassTag]
(other: RDD[V2], preservesPartitioning: Boolean = false)
(f: (V2, ReadableIndex[K, V]) => W): RDD[W] = {
val singlePartition = if (partitions.size <= 1) this else repartition(1)
new ZippedIndexRDD(sparkContext, singlePartition, other, f, preservesPartitioning)
}

/**
* Joins `this` with `other`, running `f` on the values of all keys in both sets. Note that for
* efficiency `other` must be an IndexedRDD, not just a pair RDD. Use [[aggregateUsingIndex]] to
Expand Down Expand Up @@ -277,40 +335,40 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
// compiler bug related to specialization.

private type ZipPartitionsFunction[V2, V3] =
Function2[Iterator[IndexedRDDPartition[K, V]], Iterator[IndexedRDDPartition[K, V2]],
Function2[Iterator[PartitionType], Iterator[IndexedRDDPartition[K, V2]],
Iterator[IndexedRDDPartition[K, V3]]]

private type OtherZipPartitionsFunction[V2, V3] =
Function2[Iterator[IndexedRDDPartition[K, V]], Iterator[(K, V2)],
Function2[Iterator[PartitionType], Iterator[(K, V2)],
Iterator[IndexedRDDPartition[K, V3]]]

private class MultiputZipper[U](z: (K, U) => V, f: (K, V, U) => V)
extends OtherZipPartitionsFunction[U, V] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, U)])
: Iterator[IndexedRDDPartition[K, V]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[(K, U)])
: Iterator[PartitionType] = {
val thisPart = thisIter.next()
Iterator(thisPart.multiput(otherIter, z, f))
}
}

private class DeleteZipper extends OtherZipPartitionsFunction[Unit, V] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, Unit)])
: Iterator[IndexedRDDPartition[K, V]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[(K, Unit)])
: Iterator[PartitionType] = {
val thisPart = thisIter.next()
Iterator(thisPart.delete(otherIter.map(_._1)))
}
}

private class DiffZipper extends ZipPartitionsFunction[V, V] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[IndexedRDDPartition[K, V]]): Iterator[IndexedRDDPartition[K, V]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[PartitionType]): Iterator[PartitionType] = {
val thisPart = thisIter.next()
val otherPart = otherIter.next()
Iterator(thisPart.diff(otherPart))
}
}

private class OtherDiffZipper extends OtherZipPartitionsFunction[V, V] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, V)]): Iterator[IndexedRDDPartition[K, V]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[(K, V)]): Iterator[PartitionType] = {
val thisPart = thisIter.next()
Iterator(thisPart.diff(otherIter))
}
Expand All @@ -319,7 +377,7 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
private class FullOuterJoinZipper[V2: ClassTag, W: ClassTag](f: (K, Option[V], Option[V2]) => W)
extends ZipPartitionsFunction[V2, W] with Serializable {
def apply(
thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[IndexedRDDPartition[K, V2]])
thisIter: Iterator[PartitionType], otherIter: Iterator[IndexedRDDPartition[K, V2]])
: Iterator[IndexedRDDPartition[K, W]] = {
val thisPart = thisIter.next()
val otherPart = otherIter.next()
Expand All @@ -330,8 +388,8 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
private class LazyFullOuterJoinZipper(f: (K, Option[V], Option[V]) => V)
extends ZipPartitionsFunction[V, V] with Serializable {
def apply(
thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[IndexedRDDPartition[K, V]])
: Iterator[IndexedRDDPartition[K, V]] = {
thisIter: Iterator[PartitionType], otherIter: Iterator[PartitionType])
: Iterator[PartitionType] = {
val thisPart = thisIter.next()
val otherPart = otherIter.next()
(thisPart, otherPart) match {
Expand All @@ -350,7 +408,7 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
private class OtherFullOuterJoinZipper[V2: ClassTag, W: ClassTag](f: (K, Option[V], Option[V2]) => W)
extends OtherZipPartitionsFunction[V2, W] with Serializable {
def apply(
thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, V2)])
thisIter: Iterator[PartitionType], otherIter: Iterator[(K, V2)])
: Iterator[IndexedRDDPartition[K, W]] = {
val thisPart = thisIter.next()
Iterator(thisPart.fullOuterJoin(otherIter)(f))
Expand All @@ -359,7 +417,7 @@ class IndexedRDD[K: ClassTag, V: ClassTag](

private class JoinZipper[U: ClassTag](f: (K, V, U) => V)
extends ZipPartitionsFunction[U, V] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[IndexedRDDPartition[K, U]]): Iterator[IndexedRDDPartition[K, V]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[IndexedRDDPartition[K, U]]): Iterator[PartitionType] = {
val thisPart = thisIter.next()
val otherPart = otherIter.next()
Iterator(thisPart.join(otherPart)(f))
Expand All @@ -368,15 +426,15 @@ class IndexedRDD[K: ClassTag, V: ClassTag](

private class OtherJoinZipper[U: ClassTag](f: (K, V, U) => V)
extends OtherZipPartitionsFunction[U, V] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, U)]): Iterator[IndexedRDDPartition[K, V]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[(K, U)]): Iterator[PartitionType] = {
val thisPart = thisIter.next()
Iterator(thisPart.join(otherIter)(f))
}
}

private class LeftJoinZipper[V2: ClassTag, V3: ClassTag](f: (K, V, Option[V2]) => V3)
extends ZipPartitionsFunction[V2, V3] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[IndexedRDDPartition[K, V2]]): Iterator[IndexedRDDPartition[K, V3]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[IndexedRDDPartition[K, V2]]): Iterator[IndexedRDDPartition[K, V3]] = {
val thisPart = thisIter.next()
val otherPart = otherIter.next()
Iterator(thisPart.leftJoin(otherPart)(f))
Expand All @@ -385,7 +443,7 @@ class IndexedRDD[K: ClassTag, V: ClassTag](

private class OtherLeftJoinZipper[V2: ClassTag, V3: ClassTag](f: (K, V, Option[V2]) => V3)
extends OtherZipPartitionsFunction[V2, V3] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, V2)]): Iterator[IndexedRDDPartition[K, V3]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[(K, V2)]): Iterator[IndexedRDDPartition[K, V3]] = {
val thisPart = thisIter.next()
Iterator(thisPart.leftJoin(otherIter)(f))
}
Expand All @@ -394,7 +452,7 @@ class IndexedRDD[K: ClassTag, V: ClassTag](
private class InnerJoinZipper[V2: ClassTag, V3: ClassTag](f: (K, V, V2) => V3)
extends ZipPartitionsFunction[V2, V3] with Serializable {
def apply(
thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[IndexedRDDPartition[K, V2]])
thisIter: Iterator[PartitionType], otherIter: Iterator[IndexedRDDPartition[K, V2]])
: Iterator[IndexedRDDPartition[K, V3]] = {
val thisPart = thisIter.next()
val otherPart = otherIter.next()
Expand All @@ -404,7 +462,7 @@ class IndexedRDD[K: ClassTag, V: ClassTag](

private class OtherInnerJoinZipper[V2: ClassTag, V3: ClassTag](f: (K, V, V2) => V3)
extends OtherZipPartitionsFunction[V2, V3] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, V2)])
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[(K, V2)])
: Iterator[IndexedRDDPartition[K, V3]] = {
val thisPart = thisIter.next()
Iterator(thisPart.innerJoin(otherIter)(f))
Expand All @@ -413,21 +471,35 @@ class IndexedRDD[K: ClassTag, V: ClassTag](

private class CreateUsingIndexZipper[V2: ClassTag]
extends OtherZipPartitionsFunction[V2, V2] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, V2)]): Iterator[IndexedRDDPartition[K, V2]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[(K, V2)]): Iterator[IndexedRDDPartition[K, V2]] = {
val thisPart = thisIter.next()
Iterator(thisPart.createUsingIndex(otherIter))
}
}

private class AggregateUsingIndexZipper[V2: ClassTag](reduceFunc: (V2, V2) => V2)
extends OtherZipPartitionsFunction[V2, V2] with Serializable {
def apply(thisIter: Iterator[IndexedRDDPartition[K, V]], otherIter: Iterator[(K, V2)]): Iterator[IndexedRDDPartition[K, V2]] = {
def apply(thisIter: Iterator[PartitionType], otherIter: Iterator[(K, V2)]): Iterator[IndexedRDDPartition[K, V2]] = {
val thisPart = thisIter.next()
Iterator(thisPart.aggregateUsingIndex(otherIter, reduceFunc))
}
}
}

trait ReadableIndex[K, V] {

/** Return the value for the given key. */
def apply(k: K): Option[V]

def isDefined(k: K): Boolean =
apply(k).isDefined

/**
* Gets the values corresponding to the specified keys, if any.
*/
def multiget(ks: Array[K]): Iterator[(K, V)]
}

object IndexedRDD {
/**
* Constructs an updatable IndexedRDD from an RDD of pairs, merging duplicate keys arbitrarily.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,16 @@ import scala.reflect.ClassTag
* @tparam K the key associated with each entry in the set.
* @tparam V the value associated with each entry in the set.
*/
private[indexedrdd] abstract class IndexedRDDPartition[K, V] extends Serializable {
private[indexedrdd] abstract class IndexedRDDPartition[K, V]
extends Serializable with ReadableIndex[K, V] {

protected implicit def kTag: ClassTag[K]
protected implicit def vTag: ClassTag[V]

def size: Long

/** Return the value for the given key. */
def apply(k: K): Option[V]

def isDefined(k: K): Boolean =
apply(k).isDefined

def iterator: Iterator[(K, V)]

/**
* Gets the values corresponding to the specified keys, if any.
*/
def multiget(ks: Array[K]): Iterator[(K, V)]

/**
* Updates the keys in `kvs` to their corresponding values generated by running `f` on old and new
* values, if an old value exists, or `z` otherwise. Returns a new IndexedRDDPartition that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import scala.collection.Traversable
* A wrapper around several IndexedRDDPartition that avoids rebuilding
* the index for the combined partitions. Instead, each operation probes
* the nested partitions and merges the results.
*
* @param reducer a reduction function; at least one of the given V
* Options will not be None.
*/

private[indexedrdd] class LazyPartition[K, V]
Expand All @@ -36,14 +39,20 @@ private[indexedrdd] class LazyPartition[K, V]
@transient private lazy val cached: IndexedRDDPartition[K, V] =
partitions.reduce((a, b) => a.fullOuterJoin(b)(reducer))

def size: Long =
cached.size
/**
* We need to index the combined partitions in case any have duplicates
* we need to reduce.
*/
def size: Long = cached.size

/** Return the value for the given key. */
def apply(k: K): Option[V] =
partitions.
map(_(k)).
reduce((a, b) => Option(reducer(k, a, b)))
reduce((a, b) => (a, b) match {
case (None, None) => None
case _ => Option(reducer(k, a, b))
})

override def isDefined(k: K): Boolean =
partitions.find(_.isDefined(k)).isDefined
Expand Down Expand Up @@ -133,4 +142,9 @@ private[indexedrdd] class LazyPartition[K, V]
*/
def reindex(): IndexedRDDPartition[K, V] =
partitions.map(_.reindex).reduce((a, b) => a.fullOuterJoin(b)(reducer))

/**
* Re-index before serialization.
*/
private def writeReplace(): Object = cached
}
Loading

0 comments on commit c73d3ee

Please sign in to comment.