From df69ef57d055453c399fa925ad315d19211d7ab2 Mon Sep 17 00:00:00 2001 From: fan <75058860+fansehep@users.noreply.github.com> Date: Tue, 21 Nov 2023 16:42:51 +0800 Subject: [PATCH] fix: coerce_primitive for serde decoded data (#5101) * fix: fix json decode number Signed-off-by: fan * follow reviews Signed-off-by: fan * follow reviews Signed-off-by: fan * use fixed size space Signed-off-by: fan --------- Signed-off-by: fan --- arrow-json/src/reader/mod.rs | 43 ++++++++++++++++++++++++++- arrow-json/src/reader/string_array.rs | 33 +++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 71a73df9fedb..5afe0dec279a 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -717,7 +717,9 @@ mod tests { use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; - use arrow_array::{make_array, Array, BooleanArray, ListArray, StringArray, StructArray}; + use arrow_array::{ + make_array, Array, BooleanArray, Float64Array, ListArray, StringArray, StructArray, + }; use arrow_buffer::{ArrowNativeType, Buffer}; use arrow_cast::display::{ArrayFormatter, FormatOptions}; use arrow_data::ArrayDataBuilder; @@ -2259,4 +2261,43 @@ mod tests { .values(); assert_eq!(values, &[1699148028689, 2, 3, 4]); } + + #[test] + fn test_coercing_primitive_into_string_decoder() { + let buf = &format!( + r#"[{{"a": 1, "b": "A", "c": "T"}}, {{"a": 2, "b": "BB", "c": "F"}}, {{"a": {}, "b": 123, "c": false}}, {{"a": {}, "b": 789, "c": true}}]"#, + (std::i32::MAX as i64 + 10), + std::i64::MAX - 10 + ); + let schema = Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Utf8, true), + ]); + let json_array: Vec = serde_json::from_str(buf).unwrap(); + let schema_ref = Arc::new(schema); + + // read record batches + let reader = ReaderBuilder::new(schema_ref.clone()).with_coerce_primitive(true); + let mut decoder = reader.build_decoder().unwrap(); + decoder.serialize(json_array.as_slice()).unwrap(); + let batch = decoder.flush().unwrap().unwrap(); + assert_eq!( + batch, + RecordBatch::try_new( + schema_ref, + vec![ + Arc::new(Float64Array::from(vec![ + 1.0, + 2.0, + (std::i32::MAX as i64 + 10) as f64, + (std::i64::MAX - 10) as f64 + ])), + Arc::new(StringArray::from(vec!["A", "BB", "123", "789"])), + Arc::new(StringArray::from(vec!["T", "F", "false", "true"])), + ] + ) + .unwrap() + ); + } } diff --git a/arrow-json/src/reader/string_array.rs b/arrow-json/src/reader/string_array.rs index 63a9bcedb7d1..5ab4d09d5d63 100644 --- a/arrow-json/src/reader/string_array.rs +++ b/arrow-json/src/reader/string_array.rs @@ -61,7 +61,18 @@ impl ArrayDecoder for StringArrayDecoder { TapeElement::Number(idx) if coerce_primitive => { data_capacity += tape.get_string(idx).len(); } - _ => return Err(tape.error(*p, "string")), + TapeElement::I64(_) + | TapeElement::I32(_) + | TapeElement::F64(_) + | TapeElement::F32(_) + if coerce_primitive => + { + // An arbitrary estimate + data_capacity += 10; + } + _ => { + return Err(tape.error(*p, "string")); + } } } @@ -89,6 +100,26 @@ impl ArrayDecoder for StringArrayDecoder { TapeElement::Number(idx) if coerce_primitive => { builder.append_value(tape.get_string(idx)); } + TapeElement::I64(high) if coerce_primitive => match tape.get(p + 1) { + TapeElement::I32(low) => { + let val = (high as i64) << 32 | (low as u32) as i64; + builder.append_value(val.to_string()); + } + _ => unreachable!(), + }, + TapeElement::I32(n) if coerce_primitive => { + builder.append_value(n.to_string()); + } + TapeElement::F32(n) if coerce_primitive => { + builder.append_value(n.to_string()); + } + TapeElement::F64(high) if coerce_primitive => match tape.get(p + 1) { + TapeElement::F32(low) => { + let val = f64::from_bits((high as u64) << 32 | low as u64); + builder.append_value(val.to_string()); + } + _ => unreachable!(), + }, _ => unreachable!(), } }