Skip to content

Commit

Permalink
GH-43169: [Swift] Add StructArray to ArrowReader (#43335)
Browse files Browse the repository at this point in the history
### Rationale for this change
Structs have been added for Swift but currently the ArrowReader does not support them.  This PR adds the ArrowReader support

### What changes are included in this PR?
Adding StructArray to ArrowReader

### Are these changes tested?
The next PR for the ArrowWriter will include a test for reading and writing Structs.

* GitHub Issue: #43169

Authored-by: Alva Bandy <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
  • Loading branch information
abandy authored Jul 25, 2024
1 parent 0fbea66 commit 85684fe
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 69 deletions.
3 changes: 2 additions & 1 deletion swift/Arrow/Sources/Arrow/ArrowCImporter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ public class ArrowCImporter {
}
}

switch makeArrayHolder(arrowField, buffers: arrowBuffers, nullCount: nullCount) {
switch makeArrayHolder(arrowField, buffers: arrowBuffers,
nullCount: nullCount, children: nil, rbLength: 0) {
case .success(let holder):
return .success(ImportArrayHolder(holder, cArrayPtr: cArrayPtr))
case .failure(let err):
Expand Down
199 changes: 140 additions & 59 deletions swift/Arrow/Sources/Arrow/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,46 @@ import Foundation
let FILEMARKER = "ARROW1"
let CONTINUATIONMARKER = -1

public class ArrowReader {
private struct DataLoadInfo {
public class ArrowReader { // swiftlint:disable:this type_body_length
private class RecordBatchData {
let schema: org_apache_arrow_flatbuf_Schema
let recordBatch: org_apache_arrow_flatbuf_RecordBatch
let field: org_apache_arrow_flatbuf_Field
let nodeIndex: Int32
let bufferIndex: Int32
private var fieldIndex: Int32 = 0
private var nodeIndex: Int32 = 0
private var bufferIndex: Int32 = 0
init(_ recordBatch: org_apache_arrow_flatbuf_RecordBatch,
schema: org_apache_arrow_flatbuf_Schema) {
self.recordBatch = recordBatch
self.schema = schema
}

func nextNode() -> org_apache_arrow_flatbuf_FieldNode? {
if nodeIndex >= self.recordBatch.nodesCount {return nil}
defer {nodeIndex += 1}
return self.recordBatch.nodes(at: nodeIndex)
}

func nextBuffer() -> org_apache_arrow_flatbuf_Buffer? {
if bufferIndex >= self.recordBatch.buffersCount {return nil}
defer {bufferIndex += 1}
return self.recordBatch.buffers(at: bufferIndex)
}

func nextField() -> org_apache_arrow_flatbuf_Field? {
if fieldIndex >= self.schema.fieldsCount {return nil}
defer {fieldIndex += 1}
return self.schema.fields(at: fieldIndex)
}

func isDone() -> Bool {
return nodeIndex >= self.recordBatch.nodesCount
}
}

private struct DataLoadInfo {
let fileData: Data
let messageOffset: Int64
var batchData: RecordBatchData
}

public class ArrowReaderResult {
Expand All @@ -54,49 +86,104 @@ public class ArrowReader {
return .success(builder.finish())
}

private func loadPrimitiveData(_ loadInfo: DataLoadInfo) -> Result<ArrowArrayHolder, ArrowError> {
do {
let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
let nullLength = UInt(ceil(Double(node.length) / 8))
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex)
let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)!
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
length: nullLength, messageOffset: loadInfo.messageOffset)
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1)
let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)!
let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowValueBuffer],
nullCount: UInt(node.nullCount))
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("\(error)"))
private func loadStructData(_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field)
-> Result<ArrowArrayHolder, ArrowError> {
guard let node = loadInfo.batchData.nextNode() else {
return .failure(.invalid("Node not found"))
}

guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Null buffer not found"))
}

let nullLength = UInt(ceil(Double(node.length) / 8))
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
length: nullLength, messageOffset: loadInfo.messageOffset)
var children = [ArrowData]()
for index in 0..<field.childrenCount {
let childField = field.children(at: index)!
switch loadField(loadInfo, field: childField) {
case .success(let holder):
children.append(holder.array.arrowData)
case .failure(let error):
return .failure(error)
}
}

return makeArrayHolder(field, buffers: [arrowNullBuffer],
nullCount: UInt(node.nullCount), children: children,
rbLength: UInt(loadInfo.batchData.recordBatch.length))
}

private func loadVariableData(_ loadInfo: DataLoadInfo) -> Result<ArrowArrayHolder, ArrowError> {
let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
do {
let nullLength = UInt(ceil(Double(node.length) / 8))
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex)
let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)!
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
length: nullLength, messageOffset: loadInfo.messageOffset)
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1)
let offsetBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)!
let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 2)
let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 2)!
let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer],
nullCount: UInt(node.nullCount))
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("\(error)"))
private func loadPrimitiveData(
_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field)
-> Result<ArrowArrayHolder, ArrowError> {
guard let node = loadInfo.batchData.nextNode() else {
return .failure(.invalid("Node not found"))
}

guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Null buffer not found"))
}

guard let valueBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Value buffer not found"))
}

let nullLength = UInt(ceil(Double(node.length) / 8))
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
length: nullLength, messageOffset: loadInfo.messageOffset)
let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
return makeArrayHolder(field, buffers: [arrowNullBuffer, arrowValueBuffer],
nullCount: UInt(node.nullCount), children: nil,
rbLength: UInt(loadInfo.batchData.recordBatch.length))
}

private func loadVariableData(
_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field)
-> Result<ArrowArrayHolder, ArrowError> {
guard let node = loadInfo.batchData.nextNode() else {
return .failure(.invalid("Node not found"))
}

guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Null buffer not found"))
}

guard let offsetBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Offset buffer not found"))
}

guard let valueBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Value buffer not found"))
}

let nullLength = UInt(ceil(Double(node.length) / 8))
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
length: nullLength, messageOffset: loadInfo.messageOffset)
let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
return makeArrayHolder(field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer],
nullCount: UInt(node.nullCount), children: nil,
rbLength: UInt(loadInfo.batchData.recordBatch.length))
}

private func loadField(
_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field)
-> Result<ArrowArrayHolder, ArrowError> {
if isNestedType(field.typeType) {
return loadStructData(loadInfo, field: field)
} else if isFixedPrimitive(field.typeType) {
return loadPrimitiveData(loadInfo, field: field)
} else {
return loadVariableData(loadInfo, field: field)
}
}

Expand All @@ -107,23 +194,17 @@ public class ArrowReader {
data: Data,
messageEndOffset: Int64
) -> Result<RecordBatch, ArrowError> {
let nodesCount = recordBatch.nodesCount
var bufferIndex: Int32 = 0
var columns: [ArrowArrayHolder] = []
for nodeIndex in 0 ..< nodesCount {
let field = schema.fields(at: nodeIndex)!
let loadInfo = DataLoadInfo(recordBatch: recordBatch, field: field,
nodeIndex: nodeIndex, bufferIndex: bufferIndex,
fileData: data, messageOffset: messageEndOffset)
var result: Result<ArrowArrayHolder, ArrowError>
if isFixedPrimitive(field.typeType) {
result = loadPrimitiveData(loadInfo)
bufferIndex += 2
} else {
result = loadVariableData(loadInfo)
bufferIndex += 3
let batchData = RecordBatchData(recordBatch, schema: schema)
let loadInfo = DataLoadInfo(fileData: data,
messageOffset: messageEndOffset,
batchData: batchData)
while !batchData.isDone() {
guard let field = batchData.nextField() else {
return .failure(.invalid("Field not found"))
}

let result = loadField(loadInfo, field: field)
switch result {
case .success(let holder):
columns.append(holder)
Expand Down
59 changes: 51 additions & 8 deletions swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,42 @@ private func makeFixedHolder<T>(
}
}

func makeStructHolder(
_ field: ArrowField,
buffers: [ArrowBuffer],
nullCount: UInt,
children: [ArrowData],
rbLength: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(field.type,
buffers: buffers, children: children,
nullCount: nullCount, length: rbLength)
return .success(ArrowArrayHolderImpl(try StructArray(arrowData)))
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("\(error)"))
}
}

func makeArrayHolder(
_ field: org_apache_arrow_flatbuf_Field,
buffers: [ArrowBuffer],
nullCount: UInt
nullCount: UInt,
children: [ArrowData]?,
rbLength: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
let arrowField = fromProto(field: field)
return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount)
return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount, children: children, rbLength: rbLength)
}

func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
_ field: ArrowField,
buffers: [ArrowBuffer],
nullCount: UInt
nullCount: UInt,
children: [ArrowData]?,
rbLength: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
let typeId = field.type.id
switch typeId {
Expand Down Expand Up @@ -159,12 +182,12 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
return makeStringHolder(buffers, nullCount: nullCount)
case .binary:
return makeBinaryHolder(buffers, nullCount: nullCount)
case .date32:
case .date32, .date64:
return makeDateHolder(field, buffers: buffers, nullCount: nullCount)
case .time32:
return makeTimeHolder(field, buffers: buffers, nullCount: nullCount)
case .time64:
case .time32, .time64:
return makeTimeHolder(field, buffers: buffers, nullCount: nullCount)
case .strct:
return makeStructHolder(field, buffers: buffers, nullCount: nullCount, children: children!, rbLength: rbLength)
default:
return .failure(.unknownType("Type \(typeId) currently not supported"))
}
Expand All @@ -187,7 +210,16 @@ func isFixedPrimitive(_ type: org_apache_arrow_flatbuf_Type_) -> Bool {
}
}

func findArrowType( // swiftlint:disable:this cyclomatic_complexity
func isNestedType(_ type: org_apache_arrow_flatbuf_Type_) -> Bool {
switch type {
case .struct_:
return true
default:
return false
}
}

func findArrowType( // swiftlint:disable:this cyclomatic_complexity function_body_length
_ field: org_apache_arrow_flatbuf_Field) -> ArrowType {
let type = field.typeType
switch type {
Expand Down Expand Up @@ -229,6 +261,17 @@ func findArrowType( // swiftlint:disable:this cyclomatic_complexity
}

return ArrowTypeTime64(timeType.unit == .microsecond ? .microseconds : .nanoseconds)
case .struct_:
_ = field.type(type: org_apache_arrow_flatbuf_Struct_.self)!
var fields = [ArrowField]()
for index in 0..<field.childrenCount {
let childField = field.children(at: index)!
let childType = findArrowType(childField)
fields.append(
ArrowField(childField.name ?? "", type: childType, isNullable: childField.nullable))
}

return ArrowNestedType(ArrowType.ArrowStruct, fields: fields)
default:
return ArrowType(ArrowType.ArrowUnknown)
}
Expand Down
2 changes: 1 addition & 1 deletion swift/Arrow/Tests/ArrowTests/ArrayTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ final class ArrayTests: XCTestCase { // swiftlint:disable:this type_body_length
ArrowBuffer(length: 0, capacity: 0,
rawPointer: UnsafeMutableRawPointer.allocate(byteCount: 0, alignment: .zero))]
let field = ArrowField("", type: checkType, isNullable: true)
switch makeArrayHolder(field, buffers: buffers, nullCount: 0) {
switch makeArrayHolder(field, buffers: buffers, nullCount: 0, children: nil, rbLength: 0) {
case .success(let holder):
XCTAssertEqual(holder.type.id, checkType.id)
case .failure(let err):
Expand Down

0 comments on commit 85684fe

Please sign in to comment.