Skip to content

Commit 01c16af

Browse files
committed
[SPARK-51769][SQL] Add maxRecordsPerOutputBatch to limit the number of record of Arrow output batch
### What changes were proposed in this pull request? This patch adds a new config `maxRecordsPerOutputBatch` to limit the number of output record of Arrow output batch. ### Why are the changes needed? While implementing columnar-based operator for Spark, if the operator takes input from Arrow-based evaluation operator in Spark, the number of records of output batch is unlimited for now. For such columnar-based operator, sometimes we want to limit the maximum number of input batch. If we need to limit the batch size in rows, it seems there is no existing way we can do. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #50301 from viirya/arrow_output_size. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 9bc8cd5 commit 01c16af

File tree

3 files changed

+149
-12
lines changed

3 files changed

+149
-12
lines changed

python/pyspark/sql/tests/arrow/test_arrow_map.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ def setUpClass(cls):
208208
cls.spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "10")
209209

210210

211+
class MapInArrowWithOutputArrowBatchSlicingTests(MapInArrowTests):
212+
@classmethod
213+
def setUpClass(cls):
214+
MapInArrowTests.setUpClass()
215+
cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10")
216+
cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerOutputBatch", "3")
217+
218+
211219
if __name__ == "__main__":
212220
from pyspark.sql.tests.arrow.test_arrow_map import * # noqa: F401
213221

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3551,6 +3551,19 @@ object SQLConf {
35513551
.intConf
35523552
.createWithDefault(10000)
35533553

3554+
val ARROW_EXECUTION_MAX_RECORDS_PER_OUTPUT_BATCH =
3555+
buildConf("spark.sql.execution.arrow.maxRecordsPerOutputBatch")
3556+
.doc("When using Apache Arrow, limit the maximum number of records that can be output " +
3557+
"in a single ArrowRecordBatch to the downstream operator. If set to zero or negative " +
3558+
"there is no limit. Note that the complete ArrowRecordBatch is actually created but " +
3559+
"the number of records is limited when sending it to the downstream operator. This is " +
3560+
"used to avoid large batches being sent to the downstream operator including " +
3561+
"the columnar-based operator implemented by third-party libraries.")
3562+
.version("4.1.0")
3563+
.internal()
3564+
.intConf
3565+
.createWithDefault(-1)
3566+
35543567
val ARROW_EXECUTION_MAX_BYTES_PER_BATCH =
35553568
buildConf("spark.sql.execution.arrow.maxBytesPerBatch")
35563569
.internal()
@@ -6553,6 +6566,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
65536566

65546567
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
65556568

6569+
def arrowMaxRecordsPerOutputBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_OUTPUT_BATCH)
6570+
65566571
def arrowMaxBytesPerBatch: Long = getConf(ARROW_EXECUTION_MAX_BYTES_PER_BATCH)
65576572

65586573
def arrowTransformWithStateInPandasMaxRecordsPerBatch: Int =

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala

Lines changed: 126 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
2727
import org.apache.spark.{SparkEnv, TaskContext}
2828
import org.apache.spark.api.python.{BasePythonRunner, PythonWorker, SpecialLengths}
2929
import org.apache.spark.sql.execution.metric.SQLMetric
30+
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.types.StructType
3132
import org.apache.spark.sql.util.ArrowUtils
3233
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
@@ -43,6 +44,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
4344

4445
protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT
4546

47+
protected def arrowMaxRecordsPerOutputBatch: Int = SQLConf.get.arrowMaxRecordsPerOutputBatch
48+
4649
protected def newReaderIterator(
4750
stream: DataInputStream,
4851
writer: Writer,
@@ -62,7 +65,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
6265
private var reader: ArrowStreamReader = _
6366
private var root: VectorSchemaRoot = _
6467
private var schema: StructType = _
65-
private var vectors: Array[ColumnVector] = _
68+
private var processor: ArrowOutputProcessor = _
6669

6770
context.addTaskCompletionListener[Unit] { _ =>
6871
if (reader != null) {
@@ -84,17 +87,12 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
8487
}
8588
try {
8689
if (reader != null && batchLoaded) {
87-
val bytesReadStart = reader.bytesRead()
88-
batchLoaded = reader.loadNextBatch()
90+
batchLoaded = processor.loadBatch()
8991
if (batchLoaded) {
90-
val batch = new ColumnarBatch(vectors)
91-
val rowCount = root.getRowCount
92-
batch.setNumRows(root.getRowCount)
93-
val bytesReadEnd = reader.bytesRead()
94-
pythonMetrics("pythonNumRowsReceived") += rowCount
95-
pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart
92+
val batch = processor.produceBatch()
9693
deserializeColumnarBatch(batch, schema)
9794
} else {
95+
processor.close()
9896
reader.close(false)
9997
allocator.close()
10098
// Reach end of stream. Call `read()` again to read control data.
@@ -106,9 +104,14 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
106104
reader = new ArrowStreamReader(stream, allocator)
107105
root = reader.getVectorSchemaRoot()
108106
schema = ArrowUtils.fromArrowSchema(root.getSchema())
109-
vectors = root.getFieldVectors().asScala.map { vector =>
110-
new ArrowColumnVector(vector)
111-
}.toArray[ColumnVector]
107+
108+
if (arrowMaxRecordsPerOutputBatch > 0) {
109+
processor = new SliceArrowOutputProcessorImpl(
110+
reader, pythonMetrics, arrowMaxRecordsPerOutputBatch)
111+
} else {
112+
processor = new ArrowOutputProcessorImpl(reader, pythonMetrics)
113+
}
114+
112115
read()
113116
case SpecialLengths.TIMING_DATA =>
114117
handleTimingData()
@@ -133,3 +136,114 @@ private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarB
133136
batch: ColumnarBatch,
134137
schema: StructType): ColumnarBatch = batch
135138
}
139+
140+
trait ArrowOutputProcessor {
141+
def loadBatch(): Boolean
142+
protected def getRoot: VectorSchemaRoot
143+
protected def getVectors(root: VectorSchemaRoot): Array[ColumnVector]
144+
def produceBatch(): ColumnarBatch
145+
def close(): Unit
146+
}
147+
148+
class ArrowOutputProcessorImpl(reader: ArrowStreamReader, pythonMetrics: Map[String, SQLMetric])
149+
extends ArrowOutputProcessor {
150+
protected val root = reader.getVectorSchemaRoot()
151+
protected val schema: StructType = ArrowUtils.fromArrowSchema(root.getSchema())
152+
private val vectors: Array[ColumnVector] = root.getFieldVectors().asScala.map { vector =>
153+
new ArrowColumnVector(vector)
154+
}.toArray[ColumnVector]
155+
156+
protected var rowCount = -1
157+
158+
override def loadBatch(): Boolean = {
159+
val bytesReadStart = reader.bytesRead()
160+
val batchLoaded = reader.loadNextBatch()
161+
if (batchLoaded) {
162+
rowCount = root.getRowCount
163+
val bytesReadEnd = reader.bytesRead()
164+
pythonMetrics("pythonNumRowsReceived") += rowCount
165+
pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart
166+
}
167+
batchLoaded
168+
}
169+
170+
protected override def getRoot: VectorSchemaRoot = root
171+
protected override def getVectors(root: VectorSchemaRoot): Array[ColumnVector] = vectors
172+
override def produceBatch(): ColumnarBatch = {
173+
val batchRoot = getRoot
174+
val vectors = getVectors(batchRoot)
175+
val batch = new ColumnarBatch(vectors)
176+
batch.setNumRows(batchRoot.getRowCount)
177+
batch
178+
}
179+
override def close(): Unit = {
180+
vectors.foreach(_.close())
181+
root.close()
182+
}
183+
}
184+
185+
class SliceArrowOutputProcessorImpl(
186+
reader: ArrowStreamReader,
187+
pythonMetrics: Map[String, SQLMetric],
188+
arrowMaxRecordsPerOutputBatch: Int)
189+
extends ArrowOutputProcessorImpl(reader, pythonMetrics) {
190+
191+
private var currentRowIdx = -1
192+
private var prevRoot: VectorSchemaRoot = null
193+
private var prevVectors: Array[ColumnVector] = _
194+
195+
override def produceBatch(): ColumnarBatch = {
196+
val batchRoot = getRoot
197+
198+
if (batchRoot != prevRoot) {
199+
if (prevRoot != null) {
200+
prevVectors.foreach(_.close())
201+
prevRoot.close()
202+
}
203+
prevRoot = batchRoot
204+
}
205+
206+
val vectors = getVectors(batchRoot)
207+
prevVectors = vectors
208+
209+
val batch = new ColumnarBatch(vectors)
210+
batch.setNumRows(batchRoot.getRowCount)
211+
batch
212+
}
213+
214+
override def loadBatch(): Boolean = {
215+
if (rowCount > 0 && currentRowIdx < rowCount) {
216+
true
217+
} else {
218+
val loaded = super.loadBatch()
219+
currentRowIdx = 0
220+
loaded
221+
}
222+
}
223+
224+
protected override def getRoot: VectorSchemaRoot = {
225+
val remainingRows = rowCount - currentRowIdx
226+
val rootSlice = if (remainingRows > arrowMaxRecordsPerOutputBatch) {
227+
root.slice(currentRowIdx, arrowMaxRecordsPerOutputBatch)
228+
} else {
229+
root
230+
}
231+
232+
currentRowIdx = currentRowIdx + rootSlice.getRowCount
233+
234+
rootSlice
235+
}
236+
237+
protected override def getVectors(root: VectorSchemaRoot): Array[ColumnVector] = {
238+
root.getFieldVectors.asScala.map { vector =>
239+
new ArrowColumnVector(vector)
240+
}.toArray[ColumnVector]
241+
}
242+
243+
override def close(): Unit = {
244+
if (prevRoot != null) {
245+
prevVectors.foreach(_.close())
246+
prevRoot.close()
247+
}
248+
}
249+
}

0 commit comments

Comments
 (0)