Skip to content

Commit

Permalink
apacheGH-44910: [Swift] fix ipc stream reader and writer impl
Browse files Browse the repository at this point in the history
  • Loading branch information
abandy committed Dec 14, 2024
1 parent c3601a9 commit 83b65c4
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 21 deletions.
71 changes: 64 additions & 7 deletions swift/Arrow/Sources/Arrow/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import FlatBuffers
import Foundation

let FILEMARKER = "ARROW1"
let CONTINUATIONMARKER = -1
let CONTINUATIONMARKER = 0xFFFFFFFF

public class ArrowReader { // swiftlint:disable:this type_body_length
private class RecordBatchData {
Expand Down Expand Up @@ -219,6 +219,64 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
public func fromStream( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
let result = ArrowReaderResult()
var offset: Int = 0
var length = getUInt32(fileData, offset: offset)
var streamData = fileData
var schemaMessage: org_apache_arrow_flatbuf_Schema?
while length != 0 {
if length == CONTINUATIONMARKER {
offset += Int(MemoryLayout<Int32>.size)
length = getUInt32(fileData, offset: offset)
if length == 0 {
return .success(result)
}
}

offset += Int(MemoryLayout<Int32>.size)
streamData = fileData[offset...]
let dataBuffer = ByteBuffer(
data: streamData,
allowReadingUnalignedBuffers: true)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
switch message.headerType {
case .recordbatch:
do {
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
offset += Int(message.bodyLength + Int64(length))
let recordBatch = try loadRecordBatch(
rbMessage,
schema: schemaMessage!,
arrowSchema: result.schema!,
data: fileData,
messageEndOffset: (message.bodyLength + Int64(length))).get()
result.batches.append(recordBatch)
length = getUInt32(fileData, offset: offset)
} catch {
return .failure(error)
}
case .schema:
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
let schemaResult = loadSchema(schemaMessage!)
switch schemaResult {
case .success(let schema):
result.schema = schema
case .failure(let error):
return .failure(error)
}
offset += Int(message.bodyLength + Int64(length))
length = getUInt32(fileData, offset: offset)
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}
return .success(result)
}

public func fromFileStream( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
let footerLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: fileData.count - 4, as: Int32.self)
Expand All @@ -242,7 +300,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
for index in 0 ..< footer.recordBatchesCount {
let recordBatch = footer.recordBatches(at: index)!
var messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
}

var messageOffset: Int64 = 1
Expand All @@ -251,7 +309,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
as: Int32.self)
as: UInt32.self)
}
}

Expand All @@ -273,10 +331,8 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
data: fileData,
messageEndOffset: messageEndOffset).get()
result.batches.append(recordBatch)
} catch let error as ArrowError {
} catch let error {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
Expand All @@ -296,7 +352,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
let markerLength = FILEMARKER.utf8.count
let footerLengthEnd = Int(fileData.count - markerLength)
let data = fileData[..<(footerLengthEnd)]
return fromStream(data)
return fromFileStream(data)
} catch {
return .failure(.unknownError("Error loading file: \(error)"))
}
Expand Down Expand Up @@ -347,3 +403,4 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
}

}
// swiftlint:disable:this file_length
7 changes: 7 additions & 0 deletions swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,10 @@ func validateFileData(_ data: Data) -> Bool {
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
return startString == FILEMARKER && endString == FILEMARKER
}

func getUInt32(_ data: Data, offset: Int) -> UInt32 {
let token = data.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
}
return token
}
29 changes: 26 additions & 3 deletions swift/Arrow/Sources/Arrow/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(fbb.data)
}

private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
private func writeFileStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
var fbb: FlatBufferBuilder = FlatBufferBuilder()
switch writeSchema(&fbb, schema: info.schema) {
case .success(let schemaOffset):
Expand Down Expand Up @@ -266,7 +266,30 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length

public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeStream(&writer, info: info) {
var fbb: FlatBufferBuilder = FlatBufferBuilder()
switch writeSchema(&fbb, schema: info.schema) {
case .success(let schemaOffset):
fbb.finish(offset: schemaOffset)
writer.append(fbb.data)
case .failure(let error):
return .failure(error)
}

switch writeRecordBatches(&writer, batches: info.batches) {
case .success:
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
} else {
return .failure(.invalid("Unable to cast writer"))
}
case .failure(let error):
return .failure(error)
}
}

public func toFileStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeFileStream(&writer, info: info) {
case .success:
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
Expand All @@ -293,7 +316,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length

var writer: any DataWriter = FileDataWriter(fileHandle)
writer.append(FILEMARKER.data(using: .utf8)!)
switch writeStream(&writer, info: info) {
switch writeFileStream(&writer, info: info) {
case .success:
writer.append(FILEMARKER.data(using: .utf8)!)
case .failure(let error):
Expand Down
23 changes: 12 additions & 11 deletions swift/Arrow/Tests/ArrowTests/IPCTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func makeRecordBatch() throws -> RecordBatch {
}
}

final class IPCFileReaderTests: XCTestCase {
final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body_length
func testFileReader_double() throws {
let fileURL = currentDirectory().appendingPathComponent("../../testdata_double.arrow")
let arrowReader = ArrowReader()
Expand Down Expand Up @@ -167,10 +167,10 @@ final class IPCFileReaderTests: XCTestCase {
let arrowWriter = ArrowWriter()
// write data from file to a stream
let writerInfo = ArrowWriter.Info(.recordbatch, schema: fileRBs[0].schema, batches: fileRBs)
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
// read stream back into recordbatches
try checkBoolRecordBatch(arrowReader.fromStream(writeData))
try checkBoolRecordBatch(arrowReader.fromFileStream(writeData))
case .failure(let error):
throw error
}
Expand All @@ -190,10 +190,10 @@ final class IPCFileReaderTests: XCTestCase {
let recordBatch = try makeRecordBatch()
let arrowWriter = ArrowWriter()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch])
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
let recordBatches = result.batches
XCTAssertEqual(recordBatches.count, 1)
Expand Down Expand Up @@ -242,10 +242,10 @@ final class IPCFileReaderTests: XCTestCase {
let schema = makeSchema()
let arrowWriter = ArrowWriter()
let writerInfo = ArrowWriter.Info(.schema, schema: schema)
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand Down Expand Up @@ -325,10 +325,10 @@ final class IPCFileReaderTests: XCTestCase {
let dataset = try makeBinaryDataset()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
let arrowWriter = ArrowWriter()
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand All @@ -354,10 +354,10 @@ final class IPCFileReaderTests: XCTestCase {
let dataset = try makeTimeDataset()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
let arrowWriter = ArrowWriter()
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand All @@ -384,3 +384,4 @@ final class IPCFileReaderTests: XCTestCase {
}
}
}
// swiftlint:disable:this file_length

0 comments on commit 83b65c4

Please sign in to comment.