From 64652f43dd200adc13b797489fab6ab3ea36d467 Mon Sep 17 00:00:00 2001 From: Deegue Date: Tue, 28 Dec 2021 09:58:01 +0000 Subject: [PATCH] arrow zerocopy for read and write in object store --- .../apache/spark/raydp/RayExecutorUtils.java | 5 +- .../org/apache/spark/rdd/RayDatasetRDD.scala | 5 +- .../spark/sql/raydp/ObjectStoreReader.scala | 63 ++++++++++++++++--- .../spark/sql/raydp/ObjectStoreWriter.scala | 47 +++++--------- python/raydp/spark/dataset.py | 11 +--- 5 files changed, 77 insertions(+), 54 deletions(-) diff --git a/core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java b/core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java index 0af7c8e5..2b8e03d0 100644 --- a/core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java +++ b/core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java @@ -21,11 +21,12 @@ import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.api.call.ActorCreator; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.runtime.object.ObjectRefImpl; + import java.util.Map; import java.util.List; -import io.ray.api.placementgroup.PlacementGroup; -import io.ray.runtime.object.ObjectRefImpl; import org.apache.spark.executor.RayDPExecutor; public class RayExecutorUtils { diff --git a/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala b/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala index 925d88bc..b41db9c9 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala @@ -22,6 +22,7 @@ import java.util.List; import scala.collection.JavaConverters._ import io.ray.runtime.generated.Common.Address +import org.apache.arrow.vector.VectorSchemaRoot import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.api.java.JavaSparkContext @@ -37,7 +38,7 @@ class RayDatasetRDD( jsc: JavaSparkContext, @transient val objectIds: List[Array[Byte]], locations: List[Array[Byte]]) - extends RDD[Array[Byte]](jsc.sc, Nil) { + extends RDD[VectorSchemaRoot](jsc.sc, Nil) { override def getPartitions: Array[Partition] = { objectIds.asScala.zipWithIndex.map { case (k, i) => @@ -45,7 +46,7 @@ class RayDatasetRDD( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + override def compute(split: Partition, context: TaskContext): Iterator[VectorSchemaRoot] = { val ref = split.asInstanceOf[RayDatasetRDDPartition].ref ObjectStoreReader.getBatchesFromStream(ref) } diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala index 31fbc366..345c2f77 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala @@ -18,18 +18,26 @@ package org.apache.spark.sql.raydp import java.io.ByteArrayInputStream +import java.nio.ByteBuffer import java.nio.channels.{Channels, ReadableByteChannel} import java.util.List +import scala.collection.JavaConverters._ + import com.intel.raydp.shims.SparkShimLoader +import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.TaskContext +import org.apache.spark.api.java.JavaRDD import org.apache.spark.raydp.RayDPUtils import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD} import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.execution.arrow.ArrowConverters -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} object ObjectStoreReader { def createRayObjectRefDF( @@ -40,17 +48,56 @@ object ObjectStoreReader { spark.createDataFrame(rdd, schema) } + def fromRootIterator( + arrowRootIter: Iterator[VectorSchemaRoot], + schema: StructType, + timeZoneId: String): Iterator[InternalRow] = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + + new Iterator[InternalRow] { + private var rowIter = if (arrowRootIter.hasNext) nextBatch() else Iterator.empty + + override def hasNext: Boolean = rowIter.hasNext || { + if (arrowRootIter.hasNext) { + rowIter = nextBatch() + true + } else { + false + } + } + + override def next(): InternalRow = rowIter.next() + + private def nextBatch(): Iterator[InternalRow] = { + val root = arrowRootIter.next() + val columns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector).asInstanceOf[ColumnVector] + }.toArray + + val batch = new ColumnarBatch(columns) + batch.setNumRows(root.getRowCount) + root.close() + batch.rowIterator().asScala + } + } + } + def RayDatasetToDataFrame( sparkSession: SparkSession, rdd: RayDatasetRDD, - schema: String): DataFrame = { - SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession) + schemaString: String): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val sqlContext = new SQLContext(sparkSession) + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val resultRDD = JavaRDD.fromRDD(rdd).rdd.mapPartitions { it => + fromRootIterator(it, schema, timeZoneId) + } + sqlContext.internalCreateDataFrame(resultRDD.setName("arrow"), schema) } def getBatchesFromStream( - ref: Array[Byte]): Iterator[Array[Byte]] = { - val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]]) - ArrowConverters.getBatchesFromStream( - Channels.newChannel(new ByteArrayInputStream(objectRef.get))) + ref: Array[Byte]): Iterator[VectorSchemaRoot] = { + val objectRef = RayDPUtils.readBinary(ref, classOf[VectorSchemaRoot]) + Iterator[VectorSchemaRoot](objectRef.get) } } diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 0afcb204..1359643d 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.raydp - import java.io.ByteArrayOutputStream import java.util.{List, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} @@ -61,17 +60,16 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { val uuid: UUID = ObjectStoreWriter.dfToId.getOrElseUpdate(df, UUID.randomUUID()) def writeToRay( - data: Array[Byte], + root: VectorSchemaRoot, numRecords: Int, queue: ObjectRefHolder.Queue, ownerName: String): RecordBatch = { - - var objectRef: ObjectRef[Array[Byte]] = null + var objectRef: ObjectRef[VectorSchemaRoot] = null if (ownerName == "") { - objectRef = Ray.put(data) + objectRef = Ray.put(root) } else { var dataOwner: PyActorHandle = Ray.getActor(ownerName).get() - objectRef = Ray.put(data, dataOwner) + objectRef = Ray.put(root, dataOwner) } // add the objectRef to the objectRefHolder to avoid reference GC @@ -111,7 +109,6 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { val root = VectorSchemaRoot.create(arrowSchema, allocator) val results = new ArrayBuffer[RecordBatch]() - val byteOut = new ByteArrayOutputStream() val arrowWriter = ArrowWriter.create(root) var numRecords: Int = 0 @@ -119,13 +116,8 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { while (batchIter.hasNext) { // reset the state numRecords = 0 - byteOut.reset() arrowWriter.reset() - // write out the schema meta data - val writer = new ArrowStreamWriter(root, null, byteOut) - writer.start() - // get the next record batch val nextBatch = batchIter.next() @@ -136,19 +128,11 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { // set the write record count arrowWriter.finish() - // write out the record batch to the underlying out - writer.writeBatch() - - // get the wrote ByteArray and save to Ray ObjectStore - val byteArray = byteOut.toByteArray - results += writeToRay(byteArray, numRecords, queue, ownerName) - // end writes footer to the output stream and doesn't clean any resources. - // It could throw exception if the output stream is closed, so it should be - // in the try block. - writer.end() + + // write and schema root directly and save to Ray ObjectStore + results += writeToRay(root, numRecords, queue, ownerName) } arrowWriter.reset() - byteOut.close() } { // If we close root and allocator in TaskCompletionListener, there could be a race // condition where the writer thread keeps writing to the VectorSchemaRoot while @@ -173,7 +157,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { /** * For test. */ - def getRandomRef(): List[Array[Byte]] = { + def getRandomRef(): List[VectorSchemaRoot] = { df.queryExecution.toRdd.mapPartitions { _ => Iterator(ObjectRefHolder.getRandom(uuid)) @@ -233,7 +217,7 @@ object ObjectStoreWriter { var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray val numExecutors = executorIds.length val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME) - .get.asInstanceOf[ActorHandle[RayAppMaster]] + .get.asInstanceOf[ActorHandle[RayAppMaster]] val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle) // Check if there is any restarted executors if (!restartedExecutors.isEmpty) { @@ -251,8 +235,8 @@ object ObjectStoreWriter { val refs = new Array[ObjectRef[Array[Byte]]](numPartitions) val handles = executorIds.map {id => Ray.getActor("raydp-executor-" + id) - .get - .asInstanceOf[ActorHandle[RayDPExecutor]] + .get + .asInstanceOf[ActorHandle[RayDPExecutor]] } val handlesMap = (executorIds zip handles).toMap val locations = RayExecutorUtils.getBlockLocations( @@ -261,18 +245,15 @@ object ObjectStoreWriter { // TODO use getPreferredLocs, but we don't have a host ip to actor table now refs(i) = RayExecutorUtils.getRDDPartition( handlesMap(locations(i)), rdd.id, i, schema, driverAgentUrl) - queue.add(refs(i)) - } - for (i <- 0 until numPartitions) { + queue.add(RayDPUtils.readBinary(refs(i).get(), classOf[VectorSchemaRoot])) results(i) = RayDPUtils.convert(refs(i)).getId.getBytes } results } - } object ObjectRefHolder { - type Queue = ConcurrentLinkedQueue[ObjectRef[Array[Byte]]] + type Queue = ConcurrentLinkedQueue[ObjectRef[VectorSchemaRoot]] private val dfToQueue = new ConcurrentHashMap[UUID, Queue]() def getQueue(df: UUID): Queue = { @@ -297,7 +278,7 @@ object ObjectRefHolder { queue.size() } - def getRandom(df: UUID): Array[Byte] = { + def getRandom(df: UUID): VectorSchemaRoot = { val queue = checkQueueExists(df) val ref = RayDPUtils.convert(queue.peek()) ref.get() diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index d6a64764..ec1414c4 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -237,7 +237,7 @@ def _convert_blocks_to_dataframe(blocks): return df def _convert_by_rdd(spark: sql.SparkSession, - blocks: Dataset, + blocks: List[ObjectRef], locations: List[bytes], schema: StructType) -> DataFrame: object_ids = [block.binary() for block in blocks] @@ -269,14 +269,7 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession, schema = StructType() for field in arrow_schema: schema.add(field.name, from_arrow_type(field.type), nullable=field.nullable) - #TODO how to branch on type of block? - sample = ray.get(blocks[0]) - if isinstance(sample, bytes): - return _convert_by_rdd(spark, blocks, locations, schema) - elif isinstance(sample, pa.Table): - return _convert_by_udf(spark, blocks, locations, schema) - else: - raise RuntimeError("ray.to_spark only supports arrow type blocks") + return _convert_by_rdd(spark, blocks, locations, schema) if HAS_MLDATASET: class RecordBatch(_SourceShard):