Skip to content

Commit

Permalink
GH-44910: [Swift] fix ipc stream reader and writer impl
Browse files Browse the repository at this point in the history
  • Loading branch information
abandy committed Dec 14, 2024
1 parent c3601a9 commit fc1e28e
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 20 deletions.
72 changes: 66 additions & 6 deletions swift/Arrow/Sources/Arrow/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import FlatBuffers
import Foundation

let FILEMARKER = "ARROW1"
let CONTINUATIONMARKER = -1
let CONTINUATIONMARKER = UInt32(0xFFFFFFFF)

public class ArrowReader { // swiftlint:disable:this type_body_length
private class RecordBatchData {
Expand Down Expand Up @@ -216,7 +216,67 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
return .success(RecordBatch(arrowSchema, columns: columns))
}

public func fromStream( // swiftlint:disable:this function_body_length
public func fromMemoryStream( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
let result = ArrowReaderResult()
var offset: Int = 0
var length = getUInt32(fileData, offset: offset)
var streamData = fileData
var schemaMessage: org_apache_arrow_flatbuf_Schema?
while length != 0 {
if length == CONTINUATIONMARKER {
offset += Int(MemoryLayout<Int32>.size)
length = getUInt32(fileData, offset: offset)
if length == 0 {
return .success(result)
}
}

offset += Int(MemoryLayout<Int32>.size)
streamData = fileData[offset...]
let dataBuffer = ByteBuffer(
data: streamData,
allowReadingUnalignedBuffers: true)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
switch message.headerType {
case .recordbatch:
do {
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
offset += Int(message.bodyLength + Int64(length))
let recordBatch = try loadRecordBatch(
rbMessage,
schema: schemaMessage!,
arrowSchema: result.schema!,
data: fileData,
messageEndOffset: (message.bodyLength + Int64(length))).get()
result.batches.append(recordBatch)
length = getUInt32(fileData, offset: offset)
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}
case .schema:
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
let schemaResult = loadSchema(schemaMessage!)
switch schemaResult {
case .success(let schema):
result.schema = schema
case .failure(let error):
return .failure(error)
}
offset += Int(message.bodyLength + Int64(length))
length = getUInt32(fileData, offset: offset)
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}
return .success(result)
}

public func fromFileStream( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
Expand All @@ -242,7 +302,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
for index in 0 ..< footer.recordBatchesCount {
let recordBatch = footer.recordBatches(at: index)!
var messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
}

var messageOffset: Int64 = 1
Expand All @@ -251,7 +311,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
as: Int32.self)
as: UInt32.self)
}
}

Expand Down Expand Up @@ -296,7 +356,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
let markerLength = FILEMARKER.utf8.count
let footerLengthEnd = Int(fileData.count - markerLength)
let data = fileData[..<(footerLengthEnd)]
return fromStream(data)
return fromFileStream(data)
} catch {
return .failure(.unknownError("Error loading file: \(error)"))
}
Expand Down Expand Up @@ -340,10 +400,10 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}

default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}

}
// swiftlint:disable:this file_length
7 changes: 7 additions & 0 deletions swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,10 @@ func validateFileData(_ data: Data) -> Bool {
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
return startString == FILEMARKER && endString == FILEMARKER
}

func getUInt32(_ data: Data, offset: Int) -> UInt32 {
let token = data.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
}
return token
}
41 changes: 37 additions & 4 deletions swift/Arrow/Sources/Arrow/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
let startIndex = writer.count
switch writeRecordBatch(batch: batch) {
case .success(let rbResult):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
withUnsafeBytes(of: rbResult.1.o.littleEndian) {writer.append(Data($0))}
writer.append(rbResult.0)
switch writeRecordBatchData(&writer, batch: batch) {
Expand Down Expand Up @@ -232,7 +233,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(fbb.data)
}

private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
private func writeFileStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
var fbb: FlatBufferBuilder = FlatBufferBuilder()
switch writeSchema(&fbb, schema: info.schema) {
case .success(let schemaOffset):
Expand Down Expand Up @@ -264,9 +265,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(true)
}

public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
public func toMemoryStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
let writer: any DataWriter = InMemDataWriter()
switch toMessage(info.schema) {
case .success(let schemaData):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) {writer.append(Data($0))}
writer.append(schemaData)
case .failure(let error):
return .failure(error)
}

for batch in info.batches {
switch toMessage(batch) {
case .success(let batchData):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) {writer.append(Data($0))}
writer.append(batchData[0])
writer.append(batchData[1])
case .failure(let error):
return .failure(error)
}
}

withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
withUnsafeBytes(of: UInt32(0).littleEndian) {writer.append(Data($0))}
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
} else {
return .failure(.invalid("Unable to cast writer"))
}
}

public func toFileStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeStream(&writer, info: info) {
switch writeFileStream(&writer, info: info) {
case .success:
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
Expand All @@ -293,7 +326,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length

var writer: any DataWriter = FileDataWriter(fileHandle)
writer.append(FILEMARKER.data(using: .utf8)!)
switch writeStream(&writer, info: info) {
switch writeFileStream(&writer, info: info) {
case .success:
writer.append(FILEMARKER.data(using: .utf8)!)
case .failure(let error):
Expand Down
20 changes: 10 additions & 10 deletions swift/Arrow/Tests/ArrowTests/IPCTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ final class IPCFileReaderTests: XCTestCase {
let arrowWriter = ArrowWriter()
// write data from file to a stream
let writerInfo = ArrowWriter.Info(.recordbatch, schema: fileRBs[0].schema, batches: fileRBs)
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
// read stream back into recordbatches
try checkBoolRecordBatch(arrowReader.fromStream(writeData))
try checkBoolRecordBatch(arrowReader.fromFileStream(writeData))
case .failure(let error):
throw error
}
Expand All @@ -190,10 +190,10 @@ final class IPCFileReaderTests: XCTestCase {
let recordBatch = try makeRecordBatch()
let arrowWriter = ArrowWriter()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch])
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
let recordBatches = result.batches
XCTAssertEqual(recordBatches.count, 1)
Expand Down Expand Up @@ -242,10 +242,10 @@ final class IPCFileReaderTests: XCTestCase {
let schema = makeSchema()
let arrowWriter = ArrowWriter()
let writerInfo = ArrowWriter.Info(.schema, schema: schema)
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand Down Expand Up @@ -325,10 +325,10 @@ final class IPCFileReaderTests: XCTestCase {
let dataset = try makeBinaryDataset()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
let arrowWriter = ArrowWriter()
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand All @@ -354,10 +354,10 @@ final class IPCFileReaderTests: XCTestCase {
let dataset = try makeTimeDataset()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
let arrowWriter = ArrowWriter()
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand Down

0 comments on commit fc1e28e

Please sign in to comment.