Skip to content

Commit

Permalink
apacheGH-42245: [Swift] Ensure map behavior is the same for all key t…
Browse files Browse the repository at this point in the history
…ypes
  • Loading branch information
abandy committed Jun 21, 2024
1 parent cb1f9b7 commit f00d5aa
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 40 deletions.
68 changes: 46 additions & 22 deletions swift/Arrow/Sources/Arrow/ArrowDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import Foundation

public class ArrowDecoder: Decoder {
var rbIndex: UInt = 0
var singleRBCol: Int = 0
public var codingPath: [CodingKey] = []
public var userInfo: [CodingUserInfoKey: Any] = [:]
public let rb: RecordBatch
Expand Down Expand Up @@ -47,6 +48,25 @@ public class ArrowDecoder: Decoder {
self.nameToCol = colMapping
}

public func decode<T: Decodable, U: Decodable>(_ type: [T: U].Type) throws -> [T: U] {
var output = [T: U]()
if rb.columnCount != 2 {
throw ArrowError.invalid("RecordBatch column count of 2 is required to decode to map")
}

for index in 0..<rb.length {
self.rbIndex = index
self.singleRBCol = 0
let key = try T.init(from: self)
self.singleRBCol = 1
let value = try U.init(from: self)
output[key] = value
}

self.singleRBCol = 0
return output
}

public func decode<T: Decodable>(_ type: T.Type) throws -> [T] {
var output = [T]()
for index in 0..<rb.length {
Expand Down Expand Up @@ -105,6 +125,11 @@ public class ArrowDecoder: Decoder {
return array.asAny(self.rbIndex) as? T
}

func doDecodeSingleValue<T>() throws -> T {
let array: AnyArray = try self.getCol(self.singleRBCol)
return array.asAny(self.rbIndex) as! T // swiftlint:disable:this force_cast
}

func isNull(_ key: CodingKey) throws -> Bool {
let array: AnyArray = try self.getCol(key.stringValue)
return array.asAny(self.rbIndex) == nil
Expand All @@ -114,6 +139,11 @@ public class ArrowDecoder: Decoder {
let array: AnyArray = try self.getCol(col)
return array.asAny(self.rbIndex) == nil
}

func isNullSingleValue() throws -> Bool {
let array: AnyArray = try self.getCol(self.singleRBCol)
return array.asAny(self.rbIndex) == nil
}
}

private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer {
Expand Down Expand Up @@ -252,7 +282,7 @@ private struct ArrowKeyedDecoding<Key: CodingKey>: KeyedDecodingContainerProtoco
}

func decode<T>(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable {
if type == Date.self {
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self {
return try self.decoder.doDecode(key)!
} else {
throw ArrowError.invalid("Type \(type) is currently not supported")
Expand Down Expand Up @@ -290,26 +320,26 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {

func decodeNil() -> Bool {
do {
return try self.decoder.isNull(0)
return try self.decoder.isNullSingleValue()
} catch {
return false
}
}

func decode(_ type: Bool.Type) throws -> Bool {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: String.Type) throws -> String {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: Double.Type) throws -> Double {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: Float.Type) throws -> Float {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: Int.Type) throws -> Int {
Expand All @@ -318,19 +348,19 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {
}

func decode(_ type: Int8.Type) throws -> Int8 {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: Int16.Type) throws -> Int16 {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: Int32.Type) throws -> Int32 {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: Int64.Type) throws -> Int64 {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()!
}

func decode(_ type: UInt.Type) throws -> UInt {
Expand All @@ -339,30 +369,24 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {
}

func decode(_ type: UInt8.Type) throws -> UInt8 {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: UInt16.Type) throws -> UInt16 {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: UInt32.Type) throws -> UInt32 {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode(_ type: UInt64.Type) throws -> UInt64 {
return try self.decoder.doDecode(0)!
return try self.decoder.doDecodeSingleValue()
}

func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
if type == Int8.self || type == Int16.self ||
type == Int32.self || type == Int64.self ||
type == UInt8.self || type == UInt16.self ||
type == UInt32.self || type == UInt64.self ||
type == String.self || type == Double.self ||
type == Float.self || type == Date.self ||
type == Bool.self {
return try self.decoder.doDecode(0)!
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self {
return try self.decoder.doDecodeSingleValue()
} else {
throw ArrowError.invalid("Type \(type) is currently not supported")
}
Expand Down
138 changes: 120 additions & 18 deletions swift/Arrow/Tests/ArrowTests/CodableTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import XCTest
@testable import Arrow

final class CodableTests: XCTestCase {
final class CodableTests: XCTestCase { // swiftlint:disable:this type_body_length
public class TestClass: Codable {
public var propBool: Bool
public var propInt8: Int8
Expand Down Expand Up @@ -166,35 +166,45 @@ final class CodableTests: XCTestCase {
}
}

func testArrowUnkeyedDecoderWithoutNull() throws {
func testArrowMapDecoderWithoutNull() throws {
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
int8Builder.append(10, 11, 12, 13)
stringBuilder.append("test0", "test1", "test2", "test3")
let result = RecordBatch.Builder()
stringBuilder.append("test10", "test11", "test12", "test13")
switch RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.addColumn("propString", arrowArray: try stringBuilder.toHolder())
.finish()
switch result {
.finish() {
case .success(let rb):
let decoder = ArrowDecoder(rb)
let testData = try decoder.decode([Int8: String].self)
var index: Int8 = 0
for data in testData {
let str = data[10 + index]
XCTAssertEqual(str, "test\(index)")
index += 1
XCTAssertEqual("test\(data.key)", data.value)
}
case .failure(let err):
throw err
}

switch RecordBatch.Builder()
.addColumn("propString", arrowArray: try stringBuilder.toHolder())
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.finish() {
case .success(let rb):
let decoder = ArrowDecoder(rb)
let testData = try decoder.decode([String: Int8].self)
for data in testData {
XCTAssertEqual("test\(data.value)", data.key)
}
case .failure(let err):
throw err
}
}

func testArrowUnkeyedDecoderWithNull() throws {
func testArrowMapDecoderWithNull() throws {
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
let stringWNilBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
int8Builder.append(10, 11, 12, 13)
stringWNilBuilder.append(nil, "test1", nil, "test3")
stringWNilBuilder.append(nil, "test11", nil, "test13")
let resultWNil = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.addColumn("propString", arrowArray: try stringWNilBuilder.toHolder())
Expand All @@ -203,19 +213,111 @@ final class CodableTests: XCTestCase {
case .success(let rb):
let decoder = ArrowDecoder(rb)
let testData = try decoder.decode([Int8: String?].self)
var index: Int8 = 0
for data in testData {
let str = data[10 + index]
if index % 2 == 0 {
XCTAssertNil(str!)
let str = data.value
if data.key % 2 == 0 {
XCTAssertNil(str)
} else {
XCTAssertEqual(str, "test\(index)")
XCTAssertEqual(str, "test\(data.key)")
}
index += 1
}
case .failure(let err):
throw err
}
}

func getArrayValue<T>(_ rb: RecordBatch, colIndex: Int, rowIndex: UInt) -> T {
let anyArray = rb.columns[colIndex].array as! AnyArray // swiftlint:disable:this force_cast
return anyArray.asAny(UInt(rowIndex)) as! T // swiftlint:disable:this force_cast
}

func testArrowKeyedEncoder() throws { // swiftlint:disable:this function_body_length
var infos = [TestClass]()
for index in 0..<10 {
let tClass = TestClass()
let offset = index * 12
tClass.propBool = index % 2 == 0
tClass.propInt8 = Int8(offset + 1)
tClass.propInt16 = Int16(offset + 2)
tClass.propInt32 = Int32(offset + 3)
tClass.propInt64 = Int64(offset + 4)
tClass.propUInt8 = UInt8(offset + 5)
tClass.propUInt16 = UInt16(offset + 6)
tClass.propUInt32 = UInt32(offset + 7)
tClass.propUInt64 = UInt64(offset + 8)
tClass.propFloat = Float(offset + 9)
tClass.propDouble = index % 2 == 0 ? Double(offset + 10) : nil
tClass.propString = "\(offset + 11)"
tClass.propDate = Date.now
infos.append(tClass)
}

let rb = try ArrowEncoder.encode(infos)!
XCTAssertEqual(Int(rb.length), infos.count)
XCTAssertEqual(rb.columns.count, 13)
XCTAssertEqual(rb.columns[0].type.id, ArrowTypeId.boolean)
XCTAssertEqual(rb.columns[1].type.id, ArrowTypeId.int8)
XCTAssertEqual(rb.columns[2].type.id, ArrowTypeId.int16)
XCTAssertEqual(rb.columns[3].type.id, ArrowTypeId.int32)
XCTAssertEqual(rb.columns[4].type.id, ArrowTypeId.int64)
XCTAssertEqual(rb.columns[5].type.id, ArrowTypeId.uint8)
XCTAssertEqual(rb.columns[6].type.id, ArrowTypeId.uint16)
XCTAssertEqual(rb.columns[7].type.id, ArrowTypeId.uint32)
XCTAssertEqual(rb.columns[8].type.id, ArrowTypeId.uint64)
XCTAssertEqual(rb.columns[9].type.id, ArrowTypeId.float)
XCTAssertEqual(rb.columns[10].type.id, ArrowTypeId.double)
XCTAssertEqual(rb.columns[11].type.id, ArrowTypeId.string)
XCTAssertEqual(rb.columns[12].type.id, ArrowTypeId.date64)
for index in 0..<10 {
let offset = index * 12
XCTAssertEqual(getArrayValue(rb, colIndex: 0, rowIndex: UInt(index)), index % 2 == 0)
XCTAssertEqual(getArrayValue(rb, colIndex: 1, rowIndex: UInt(index)), Int8(offset + 1))
XCTAssertEqual(getArrayValue(rb, colIndex: 2, rowIndex: UInt(index)), Int16(offset + 2))
XCTAssertEqual(getArrayValue(rb, colIndex: 3, rowIndex: UInt(index)), Int32(offset + 3))
XCTAssertEqual(getArrayValue(rb, colIndex: 4, rowIndex: UInt(index)), Int64(offset + 4))
XCTAssertEqual(getArrayValue(rb, colIndex: 5, rowIndex: UInt(index)), UInt8(offset + 5))
XCTAssertEqual(getArrayValue(rb, colIndex: 6, rowIndex: UInt(index)), UInt16(offset + 6))
XCTAssertEqual(getArrayValue(rb, colIndex: 7, rowIndex: UInt(index)), UInt32(offset + 7))
XCTAssertEqual(getArrayValue(rb, colIndex: 8, rowIndex: UInt(index)), UInt64(offset + 8))
XCTAssertEqual(getArrayValue(rb, colIndex: 9, rowIndex: UInt(index)), Float(offset + 9))
if index % 2 == 0 {
XCTAssertEqual(getArrayValue(rb, colIndex: 10, rowIndex: UInt(index)), Double(offset + 10))
} else {
XCTAssertEqual(getArrayValue(rb, colIndex: 10, rowIndex: UInt(index)), Double?(nil))
}

XCTAssertEqual(getArrayValue(rb, colIndex: 11, rowIndex: UInt(index)), String(offset + 11))
}
}

func testArrowUnkeyedEncoder() throws {
var testMap = [Int8: String?]()
for index in 0..<10 {
testMap[Int8(index)] = "test\(index)"
}

let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
let stringWNilBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
int8Builder.append(10, 11, 12, 13)
stringWNilBuilder.append(nil, "test11", nil, "test13")
let resultWNil = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.addColumn("propString", arrowArray: try stringWNilBuilder.toHolder())
.finish()
switch resultWNil {
case .success(let rb):
let decoder = ArrowDecoder(rb)
let testData = try decoder.decode([Int8: String?].self)
for data in testData {
let str = data.value
if data.key % 2 == 0 {
XCTAssertNil(str)
} else {
XCTAssertEqual(str, "test\(data.key)")
}
}
case .failure(let err):
throw err
}
}
}

0 comments on commit f00d5aa

Please sign in to comment.