@@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
27
27
import org .apache .spark .{SparkEnv , TaskContext }
28
28
import org .apache .spark .api .python .{BasePythonRunner , PythonWorker , SpecialLengths }
29
29
import org .apache .spark .sql .execution .metric .SQLMetric
30
+ import org .apache .spark .sql .internal .SQLConf
30
31
import org .apache .spark .sql .types .StructType
31
32
import org .apache .spark .sql .util .ArrowUtils
32
33
import org .apache .spark .sql .vectorized .{ArrowColumnVector , ColumnarBatch , ColumnVector }
@@ -43,6 +44,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
43
44
44
45
protected def deserializeColumnarBatch (batch : ColumnarBatch , schema : StructType ): OUT
45
46
47
+ protected def arrowMaxRecordsPerOutputBatch : Int = SQLConf .get.arrowMaxRecordsPerOutputBatch
48
+
46
49
protected def newReaderIterator (
47
50
stream : DataInputStream ,
48
51
writer : Writer ,
@@ -62,7 +65,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
62
65
private var reader : ArrowStreamReader = _
63
66
private var root : VectorSchemaRoot = _
64
67
private var schema : StructType = _
65
- private var vectors : Array [ ColumnVector ] = _
68
+ private var processor : ArrowOutputProcessor = _
66
69
67
70
context.addTaskCompletionListener[Unit ] { _ =>
68
71
if (reader != null ) {
@@ -84,17 +87,12 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
84
87
}
85
88
try {
86
89
if (reader != null && batchLoaded) {
87
- val bytesReadStart = reader.bytesRead()
88
- batchLoaded = reader.loadNextBatch()
90
+ batchLoaded = processor.loadBatch()
89
91
if (batchLoaded) {
90
- val batch = new ColumnarBatch (vectors)
91
- val rowCount = root.getRowCount
92
- batch.setNumRows(root.getRowCount)
93
- val bytesReadEnd = reader.bytesRead()
94
- pythonMetrics(" pythonNumRowsReceived" ) += rowCount
95
- pythonMetrics(" pythonDataReceived" ) += bytesReadEnd - bytesReadStart
92
+ val batch = processor.produceBatch()
96
93
deserializeColumnarBatch(batch, schema)
97
94
} else {
95
+ processor.close()
98
96
reader.close(false )
99
97
allocator.close()
100
98
// Reach end of stream. Call `read()` again to read control data.
@@ -106,9 +104,14 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
106
104
reader = new ArrowStreamReader (stream, allocator)
107
105
root = reader.getVectorSchemaRoot()
108
106
schema = ArrowUtils .fromArrowSchema(root.getSchema())
109
- vectors = root.getFieldVectors().asScala.map { vector =>
110
- new ArrowColumnVector (vector)
111
- }.toArray[ColumnVector ]
107
+
108
+ if (arrowMaxRecordsPerOutputBatch > 0 ) {
109
+ processor = new SliceArrowOutputProcessorImpl (
110
+ reader, pythonMetrics, arrowMaxRecordsPerOutputBatch)
111
+ } else {
112
+ processor = new ArrowOutputProcessorImpl (reader, pythonMetrics)
113
+ }
114
+
112
115
read()
113
116
case SpecialLengths .TIMING_DATA =>
114
117
handleTimingData()
@@ -133,3 +136,114 @@ private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarB
133
136
batch : ColumnarBatch ,
134
137
schema : StructType ): ColumnarBatch = batch
135
138
}
139
+
140
+ trait ArrowOutputProcessor {
141
+ def loadBatch (): Boolean
142
+ protected def getRoot : VectorSchemaRoot
143
+ protected def getVectors (root : VectorSchemaRoot ): Array [ColumnVector ]
144
+ def produceBatch (): ColumnarBatch
145
+ def close (): Unit
146
+ }
147
+
148
+ class ArrowOutputProcessorImpl (reader : ArrowStreamReader , pythonMetrics : Map [String , SQLMetric ])
149
+ extends ArrowOutputProcessor {
150
+ protected val root = reader.getVectorSchemaRoot()
151
+ protected val schema : StructType = ArrowUtils .fromArrowSchema(root.getSchema())
152
+ private val vectors : Array [ColumnVector ] = root.getFieldVectors().asScala.map { vector =>
153
+ new ArrowColumnVector (vector)
154
+ }.toArray[ColumnVector ]
155
+
156
+ protected var rowCount = - 1
157
+
158
+ override def loadBatch (): Boolean = {
159
+ val bytesReadStart = reader.bytesRead()
160
+ val batchLoaded = reader.loadNextBatch()
161
+ if (batchLoaded) {
162
+ rowCount = root.getRowCount
163
+ val bytesReadEnd = reader.bytesRead()
164
+ pythonMetrics(" pythonNumRowsReceived" ) += rowCount
165
+ pythonMetrics(" pythonDataReceived" ) += bytesReadEnd - bytesReadStart
166
+ }
167
+ batchLoaded
168
+ }
169
+
170
+ protected override def getRoot : VectorSchemaRoot = root
171
+ protected override def getVectors (root : VectorSchemaRoot ): Array [ColumnVector ] = vectors
172
+ override def produceBatch (): ColumnarBatch = {
173
+ val batchRoot = getRoot
174
+ val vectors = getVectors(batchRoot)
175
+ val batch = new ColumnarBatch (vectors)
176
+ batch.setNumRows(batchRoot.getRowCount)
177
+ batch
178
+ }
179
+ override def close (): Unit = {
180
+ vectors.foreach(_.close())
181
+ root.close()
182
+ }
183
+ }
184
+
185
+ class SliceArrowOutputProcessorImpl (
186
+ reader : ArrowStreamReader ,
187
+ pythonMetrics : Map [String , SQLMetric ],
188
+ arrowMaxRecordsPerOutputBatch : Int )
189
+ extends ArrowOutputProcessorImpl (reader, pythonMetrics) {
190
+
191
+ private var currentRowIdx = - 1
192
+ private var prevRoot : VectorSchemaRoot = null
193
+ private var prevVectors : Array [ColumnVector ] = _
194
+
195
+ override def produceBatch (): ColumnarBatch = {
196
+ val batchRoot = getRoot
197
+
198
+ if (batchRoot != prevRoot) {
199
+ if (prevRoot != null ) {
200
+ prevVectors.foreach(_.close())
201
+ prevRoot.close()
202
+ }
203
+ prevRoot = batchRoot
204
+ }
205
+
206
+ val vectors = getVectors(batchRoot)
207
+ prevVectors = vectors
208
+
209
+ val batch = new ColumnarBatch (vectors)
210
+ batch.setNumRows(batchRoot.getRowCount)
211
+ batch
212
+ }
213
+
214
+ override def loadBatch (): Boolean = {
215
+ if (rowCount > 0 && currentRowIdx < rowCount) {
216
+ true
217
+ } else {
218
+ val loaded = super .loadBatch()
219
+ currentRowIdx = 0
220
+ loaded
221
+ }
222
+ }
223
+
224
+ protected override def getRoot : VectorSchemaRoot = {
225
+ val remainingRows = rowCount - currentRowIdx
226
+ val rootSlice = if (remainingRows > arrowMaxRecordsPerOutputBatch) {
227
+ root.slice(currentRowIdx, arrowMaxRecordsPerOutputBatch)
228
+ } else {
229
+ root
230
+ }
231
+
232
+ currentRowIdx = currentRowIdx + rootSlice.getRowCount
233
+
234
+ rootSlice
235
+ }
236
+
237
+ protected override def getVectors (root : VectorSchemaRoot ): Array [ColumnVector ] = {
238
+ root.getFieldVectors.asScala.map { vector =>
239
+ new ArrowColumnVector (vector)
240
+ }.toArray[ColumnVector ]
241
+ }
242
+
243
+ override def close (): Unit = {
244
+ if (prevRoot != null ) {
245
+ prevVectors.foreach(_.close())
246
+ prevRoot.close()
247
+ }
248
+ }
249
+ }
0 commit comments