Skip to content

Commit

Permalink
refactor Kryo serializer support to use chill/chill-java
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanlecompte committed Jul 25, 2013
1 parent c258718 commit 8e0939f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 143 deletions.
143 changes: 28 additions & 115 deletions core/src/main/scala/spark/KryoSerializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,15 @@ package spark

import java.io._
import java.nio.ByteBuffer
import java.nio.channels.Channels

import scala.collection.immutable
import scala.collection.mutable

import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
import com.esotericsoftware.kryo.{Kryo, KryoException}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport

import com.twitter.chill.ScalaKryoInstantiator
import serializer.{SerializerInstance, DeserializationStream, SerializationStream}
import spark.broadcast._
import spark.storage._

private[spark]
class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {

private[spark] class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
val output = new KryoOutput(outStream)

def writeObject[T](t: T): SerializationStream = {
Expand All @@ -48,17 +39,15 @@ class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends Seria
def close() { output.close() }
}

private[spark]
class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {

private[spark] class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
val input = new KryoInput(inStream)

def readObject[T](): T = {
try {
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
// DeserializationStream uses the EOF exception to indicate stopping condition.
case e: com.esotericsoftware.kryo.KryoException => throw new java.io.EOFException
case _: KryoException => throw new EOFException
}
}

Expand All @@ -69,10 +58,9 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
}

private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {

val kryo = ks.kryo.get()
val output = ks.output.get()
val input = ks.input.get()
val kryo = ks.newKryo()
val output = ks.newKryoOutput()
val input = ks.newKryoInput()

def serialize[T](t: T): ByteBuffer = {
output.clear()
Expand Down Expand Up @@ -108,125 +96,51 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
* serialization.
*/
trait KryoRegistrator {
def registerClasses(kryo: Kryo): Unit
def registerClasses(kryo: Kryo)
}

/**
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
*/
class KryoSerializer extends spark.serializer.Serializer with Logging {
private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024

val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
def newKryoOutput() = new KryoOutput(bufferSize)

val kryo = new ThreadLocal[Kryo] {
override def initialValue = createKryo()
}

val output = new ThreadLocal[KryoOutput] {
override def initialValue = new KryoOutput(bufferSize)
}

val input = new ThreadLocal[KryoInput] {
override def initialValue = new KryoInput(bufferSize)
}
def newKryoInput() = new KryoInput(bufferSize)

def createKryo(): Kryo = {
val kryo = new KryoReflectionFactorySupport()
def newKryo(): Kryo = {
val instantiator = new ScalaKryoInstantiator
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader

// Register some commonly used classes
val toRegister: Seq[AnyRef] = Seq(
// Arrays
Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")),
Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'),
// Specialized Tuple2s
("", ""), ("", 1), (1, 1), (1.0, 1.0), (1L, 1L),
(1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1),
// Scala collections
List(1), mutable.ArrayBuffer(1),
// Options and Either
Some(1), Left(1), Right(1),
// Higher-dimensional tuples
(1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
None,
ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY,
PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
GotBlock("1", ByteBuffer.allocate(1)),
GetBlock("1")
)
for (obj <- toRegister) {
kryo.register(obj.getClass)
}

for (obj <- toRegister) kryo.register(obj.getClass)

// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())

// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.
class SingletonSerializer[T](obj: T) extends KSerializer[T] {
override def write(kryo: Kryo, output: KryoOutput, obj: T) {}
override def read(kryo: Kryo, input: KryoInput, cls: java.lang.Class[T]): T = obj
}
kryo.register(None.getClass, new SingletonSerializer[AnyRef](None))
kryo.register(Nil.getClass, new SingletonSerializer[AnyRef](Nil))

// Register maps with a special serializer since they have complex internal structure
class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any])
extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {

//hack, look at https://groups.google.com/forum/#!msg/kryo-users/Eu5V4bxCfws/k-8UQ22y59AJ
private final val FAKE_REFERENCE = new Object()
override def write(
kryo: Kryo,
output: KryoOutput,
obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
val map = obj.asInstanceOf[scala.collection.Map[Any, Any]]
output.writeInt(map.size)
for ((k, v) <- map) {
kryo.writeClassAndObject(output, k)
kryo.writeClassAndObject(output, v)
}
}
override def read (
kryo: Kryo,
input: KryoInput,
cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
: Array[(Any, Any)] => scala.collection.Map[Any, Any] = {
kryo.reference(FAKE_REFERENCE)
val size = input.readInt()
val elems = new Array[(Any, Any)](size)
for (i <- 0 until size) {
val k = kryo.readClassAndObject(input)
val v = kryo.readClassAndObject(input)
elems(i)=(k,v)
}
buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]]
// Allow the user to register their own classes by setting spark.kryo.registrator
try {
Option(System.getProperty("spark.kryo.registrator")).foreach { regCls =>
logDebug("Running user registrator: " + regCls)
val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]
reg.registerClasses(kryo)
}
} catch {
case _: Exception => println("Failed to register spark.kryo.registrator")
}
kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _))
// TODO: add support for immutable maps too; this is more annoying because there are many
// subclasses of immutable.Map for small maps (with <= 4 entries)
val map1 = Map[Any, Any](1 -> 1)
val map2 = Map[Any, Any](1 -> 1, 2 -> 2)
val map3 = Map[Any, Any](1 -> 1, 2 -> 2, 3 -> 3)
val map4 = Map[Any, Any](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4)
val map5 = Map[Any, Any](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5)
kryo.register(map1.getClass, new ScalaMapSerializer(mutable.HashMap() ++ _ toMap))
kryo.register(map2.getClass, new ScalaMapSerializer(mutable.HashMap() ++ _ toMap))
kryo.register(map3.getClass, new ScalaMapSerializer(mutable.HashMap() ++ _ toMap))
kryo.register(map4.getClass, new ScalaMapSerializer(mutable.HashMap() ++ _ toMap))
kryo.register(map5.getClass, new ScalaMapSerializer(mutable.HashMap() ++ _ toMap))

// Allow the user to register their own classes by setting spark.kryo.registrator
val regCls = System.getProperty("spark.kryo.registrator")
if (regCls != null) {
logInfo("Running user registrator: " + regCls)
val classLoader = Thread.currentThread.getContextClassLoader
val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]
reg.registerClasses(kryo)
}
kryo.setClassLoader(classLoader)

// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops
kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
Expand All @@ -235,7 +149,6 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
}

def newInstance(): SerializerInstance = {
this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader)
new KryoSerializerInstance(this)
}
}
}
32 changes: 6 additions & 26 deletions core/src/test/scala/spark/KryoSerializerSuite.scala
Original file line number Diff line number Diff line change
@@ -1,30 +1,10 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package spark

import scala.collection.mutable
import scala.collection.immutable

import org.scalatest.FunSuite
import com.esotericsoftware.kryo._

import SparkContext._

class KryoSerializerSuite extends FunSuite {
test("basic types") {
val ser = (new KryoSerializer).newInstance()
Expand Down Expand Up @@ -53,6 +33,7 @@ class KryoSerializerSuite extends FunSuite {
check(Array(true, false, true))
check(Array('a', 'b', 'c'))
check(Array[Int]())
check(Array(Array("1", "2"), Array("1", "2", "3", "4")))
}

test("pairs") {
Expand Down Expand Up @@ -99,11 +80,10 @@ class KryoSerializerSuite extends FunSuite {
check(mutable.HashMap(1 -> "one", 2 -> "two"))
check(mutable.HashMap("one" -> 1, "two" -> 2))
check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4))))
check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three")))
}

test("custom registrator") {
import spark.test._
import KryoTest._
System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName)

val ser = (new KryoSerializer).newInstance()
Expand All @@ -123,14 +103,14 @@ class KryoSerializerSuite extends FunSuite {
val hashMap = new java.util.HashMap[String, String]
hashMap.put("foo", "bar")
check(hashMap)

System.clearProperty("spark.kryo.registrator")
}
}

package test {
object KryoTest {
case class CaseClass(i: Int, s: String) {}

class ClassWithNoArgConstructor {
var x: Int = 0
override def equals(other: Any) = other match {
Expand All @@ -154,4 +134,4 @@ package test {
k.register(classOf[java.util.HashMap[_, _]])
}
}
}
}
5 changes: 3 additions & 2 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ object SparkBuild extends Build {
"com.ning" % "compress-lzf" % "0.8.4",
"org.ow2.asm" % "asm" % "4.0",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
"de.javakaffee" % "kryo-serializers" % "0.22",
"com.typesafe.akka" % "akka-actor" % "2.0.5" excludeAll(excludeNetty),
"com.typesafe.akka" % "akka-remote" % "2.0.5" excludeAll(excludeNetty),
"com.typesafe.akka" % "akka-slf4j" % "2.0.5" excludeAll(excludeNetty),
Expand All @@ -181,7 +180,9 @@ object SparkBuild extends Build {
"io.netty" % "netty-all" % "4.0.0.Beta2",
"org.apache.derby" % "derby" % "10.4.2.0" % "test",
"com.codahale.metrics" % "metrics-core" % "3.0.0",
"com.codahale.metrics" % "metrics-jvm" % "3.0.0"
"com.codahale.metrics" % "metrics-jvm" % "3.0.0",
"com.twitter" % "chill_2.9.3" % "0.3.0",
"com.twitter" % "chill-java" % "0.3.0"
) ++ (
if (HADOOP_MAJOR_VERSION == "2") {
if (HADOOP_YARN) {
Expand Down

0 comments on commit 8e0939f

Please sign in to comment.