Skip to content

Add custom decoder in arrow-json #7442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arrow-json/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
pub mod reader;
pub mod writer;

pub use self::reader::{Reader, ReaderBuilder};
pub use self::reader::{ArrayDecoder, DecoderFactory, Reader, ReaderBuilder, Tape, TapeElement};
pub use self::writer::{
ArrayWriter, Encoder, EncoderFactory, EncoderOptions, LineDelimitedWriter, Writer,
WriterBuilder,
Expand Down
5 changes: 5 additions & 0 deletions arrow-json/src/reader/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ use arrow_buffer::buffer::NullBuffer;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType};
use std::marker::PhantomData;
use std::sync::Arc;

use super::DecoderFactory;

pub struct ListArrayDecoder<O> {
data_type: DataType,
Expand All @@ -39,6 +42,7 @@ impl<O: OffsetSizeTrait> ListArrayDecoder<O> {
strict_mode: bool,
is_nullable: bool,
struct_mode: StructMode,
decoder_factory: Option<Arc<dyn DecoderFactory>>,
) -> Result<Self, ArrowError> {
let field = match &data_type {
DataType::List(f) if !O::IS_LARGE => f,
Expand All @@ -51,6 +55,7 @@ impl<O: OffsetSizeTrait> ListArrayDecoder<O> {
strict_mode,
field.is_nullable(),
struct_mode,
decoder_factory,
)?;

Ok(Self {
Expand Down
7 changes: 7 additions & 0 deletions arrow-json/src/reader/map_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::reader::tape::{Tape, TapeElement};
use crate::reader::{make_decoder, ArrayDecoder};
use crate::StructMode;
Expand All @@ -24,6 +26,8 @@ use arrow_buffer::ArrowNativeType;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType};

use super::DecoderFactory;

pub struct MapArrayDecoder {
data_type: DataType,
keys: Box<dyn ArrayDecoder>,
Expand All @@ -38,6 +42,7 @@ impl MapArrayDecoder {
strict_mode: bool,
is_nullable: bool,
struct_mode: StructMode,
decoder_factory: Option<Arc<dyn DecoderFactory>>,
) -> Result<Self, ArrowError> {
let fields = match &data_type {
DataType::Map(_, true) => {
Expand All @@ -62,13 +67,15 @@ impl MapArrayDecoder {
strict_mode,
fields[0].is_nullable(),
struct_mode,
decoder_factory.clone(),
)?;
let values = make_decoder(
fields[1].data_type().clone(),
coerce_primitive,
strict_mode,
fields[1].is_nullable(),
struct_mode,
decoder_factory,
)?;

Ok(Self {
Expand Down
191 changes: 185 additions & 6 deletions arrow-json/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ use arrow_array::{downcast_integer, make_array, RecordBatch, RecordBatchReader,
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit};
pub use schema::*;
pub use tape::*;

use crate::reader::boolean_array::BooleanArrayDecoder;
use crate::reader::decimal_array::DecimalArrayDecoder;
Expand All @@ -156,7 +157,6 @@ use crate::reader::primitive_array::PrimitiveArrayDecoder;
use crate::reader::string_array::StringArrayDecoder;
use crate::reader::string_view_array::StringViewArrayDecoder;
use crate::reader::struct_array::StructArrayDecoder;
use crate::reader::tape::{Tape, TapeDecoder};
use crate::reader::timestamp_array::TimestampArrayDecoder;

mod boolean_array;
Expand All @@ -180,6 +180,7 @@ pub struct ReaderBuilder {
strict_mode: bool,
is_field: bool,
struct_mode: StructMode,
decoder_factory: Option<Arc<dyn DecoderFactory>>,

schema: SchemaRef,
}
Expand All @@ -201,6 +202,7 @@ impl ReaderBuilder {
is_field: false,
struct_mode: Default::default(),
schema,
decoder_factory: None,
}
}

Expand Down Expand Up @@ -242,6 +244,7 @@ impl ReaderBuilder {
is_field: true,
struct_mode: Default::default(),
schema: Arc::new(Schema::new([field.into()])),
decoder_factory: None,
}
}

Expand Down Expand Up @@ -281,6 +284,14 @@ impl ReaderBuilder {
}
}

/// Set an optional hook for customizing decoding behavior.
pub fn with_decoder_factory(self, decoder_factory: Arc<dyn DecoderFactory>) -> Self {
Self {
decoder_factory: Some(decoder_factory),
..self
}
}

/// Create a [`Reader`] with the provided [`BufRead`]
pub fn build<R: BufRead>(self, reader: R) -> Result<Reader<R>, ArrowError> {
Ok(Reader {
Expand All @@ -305,6 +316,7 @@ impl ReaderBuilder {
self.strict_mode,
nullable,
self.struct_mode,
self.decoder_factory,
)?;

let num_fields = self.schema.flattened_fields().len();
Expand Down Expand Up @@ -369,6 +381,95 @@ impl<R: BufRead> RecordBatchReader for Reader<R> {
}
}

/// A trait to create custom decoders for specific data types.
///
/// This allows overriding the default decoders for specific data types,
/// or adding new decoders for custom data types.
///
/// # Examples
///
/// ```
/// use arrow_json::{ArrayDecoder, DecoderFactory, TapeElement, Tape, ReaderBuilder, StructMode};
/// use arrow_schema::ArrowError;
/// use arrow_schema::{DataType, Field, Fields, Schema};
/// use arrow_array::cast::AsArray;
/// use arrow_array::Array;
/// use arrow_array::builder::StringBuilder;
/// use arrow_data::ArrayData;
/// use std::sync::Arc;
///
/// struct IncorrectStringAsNullDecoder {}
///
/// impl ArrayDecoder for IncorrectStringAsNullDecoder {
/// fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
/// let mut builder = StringBuilder::new();
/// for p in pos {
/// match tape.get(*p) {
/// TapeElement::String(idx) => {
/// builder.append_value(tape.get_string(idx));
/// }
/// _ => builder.append_null(),
/// }
/// }
/// Ok(builder.finish().into_data())
/// }
/// }
///
/// #[derive(Debug)]
/// struct IncorrectStringAsNullDecoderFactory;
///
/// impl DecoderFactory for IncorrectStringAsNullDecoderFactory {
/// fn make_default_decoder<'a>(
/// &self,
/// data_type: DataType,
/// _coerce_primitive: bool,
/// _strict_mode: bool,
/// _is_nullable: bool,
/// _struct_mode: StructMode,
/// ) -> Result<Option<Box<dyn ArrayDecoder>>, ArrowError> {
/// match data_type {
/// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder {}))),
/// _ => Ok(None),
/// }
/// }
/// }
///
/// let json = r#"
/// {"a": "a"}
/// {"a": 12}
/// "#;
/// let batch = ReaderBuilder::new(Arc::new(Schema::new(Fields::from(vec![Field::new(
/// "a",
/// DataType::Utf8,
/// true,
/// )]))))
/// .with_decoder_factory(Arc::new(IncorrectStringAsNullDecoderFactory))
/// .build(json.as_bytes())
/// .unwrap()
/// .next()
/// .unwrap()
/// .unwrap();
///
/// let values = batch.column(0).as_string::<i32>();
/// assert_eq!(values.len(), 2);
/// assert_eq!(values.value(0), "a");
/// assert!(values.is_null(1));
/// ```
pub trait DecoderFactory: std::fmt::Debug + Send + Sync {
/// Make a decoder that overrides the default decoder for a specific data type.
/// This can be used to override how e.g. error in decoding are handled.
fn make_default_decoder(
&self,
_data_type: DataType,
_coerce_primitive: bool,
_strict_mode: bool,
_is_nullable: bool,
_struct_mode: StructMode,
) -> Result<Option<Box<dyn ArrayDecoder>>, ArrowError> {
Ok(None)
}
}

/// A low-level interface for reading JSON data from a byte stream
///
/// See [`Reader`] for a higher-level interface for interface with [`BufRead`]
Expand Down Expand Up @@ -668,7 +769,8 @@ impl Decoder {
}
}

trait ArrayDecoder: Send {
/// A trait to decode JSON values into arrow arrays
pub trait ArrayDecoder: Send {
/// Decode elements from `tape` starting at the indexes contained in `pos`
fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError>;
}
Expand All @@ -685,7 +787,20 @@ fn make_decoder(
strict_mode: bool,
is_nullable: bool,
struct_mode: StructMode,
decoder_factory: Option<Arc<dyn DecoderFactory>>,
) -> Result<Box<dyn ArrayDecoder>, ArrowError> {
if let Some(ref factory) = decoder_factory {
if let Some(decoder) = factory.make_default_decoder(
data_type.clone(),
coerce_primitive,
strict_mode,
is_nullable,
struct_mode,
)? {
return Ok(decoder);
}
}

downcast_integer! {
data_type => (primitive_decoder, data_type),
DataType::Null => Ok(Box::<NullArrayDecoder>::default()),
Expand Down Expand Up @@ -736,13 +851,13 @@ fn make_decoder(
DataType::Utf8 => Ok(Box::new(StringArrayDecoder::<i32>::new(coerce_primitive))),
DataType::Utf8View => Ok(Box::new(StringViewArrayDecoder::new(coerce_primitive))),
DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::<i64>::new(coerce_primitive))),
DataType::List(_) => Ok(Box::new(ListArrayDecoder::<i32>::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)),
DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::<i64>::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)),
DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)),
DataType::List(_) => Ok(Box::new(ListArrayDecoder::<i32>::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)),
DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::<i64>::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)),
DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)),
DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => {
Err(ArrowError::JsonError(format!("{data_type} is not supported by JSON")))
}
DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)),
DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)),
d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in JSON reader")))
}
}
Expand Down Expand Up @@ -2808,4 +2923,68 @@ mod tests {
"Json error: whilst decoding field 'a': failed to parse \"a\" as Int32".to_owned()
);
}

#[test]
fn test_decoder_factory() {
use arrow_array::builder;

struct AlwaysNullStringArrayDecoder;

impl ArrayDecoder for AlwaysNullStringArrayDecoder {
fn decode(&mut self, _tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
let mut builder = builder::StringBuilder::new();
for _ in pos {
builder.append_null();
}
Ok(builder.finish().into_data())
}
}

#[derive(Debug)]
struct AlwaysNullStringArrayDecoderFactory;

impl DecoderFactory for AlwaysNullStringArrayDecoderFactory {
fn make_default_decoder<'a>(
&self,
data_type: DataType,
_coerce_primitive: bool,
_strict_mode: bool,
_is_nullable: bool,
_struct_mode: StructMode,
) -> Result<Option<Box<dyn ArrayDecoder>>, ArrowError> {
match data_type {
DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder {}))),
_ => Ok(None),
}
}
}

let buf = r#"
{"a": "1", "b": 2}
{"a": "hello", "b": 23}
"#;
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, true),
Field::new("b", DataType::Int32, true),
]));

let batches = ReaderBuilder::new(schema.clone())
.with_batch_size(2)
.with_decoder_factory(Arc::new(AlwaysNullStringArrayDecoderFactory))
.build(Cursor::new(buf.as_bytes()))
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();

assert_eq!(batches.len(), 1);

let col1 = batches[0].column(0).as_string::<i32>();
assert_eq!(col1.null_count(), 2);
assert!(col1.is_null(0));
assert!(col1.is_null(1));

let col2 = batches[0].column(1).as_primitive::<Int32Type>();
assert_eq!(col2.value(0), 2);
assert_eq!(col2.value(1), 23);
}
}
6 changes: 6 additions & 0 deletions arrow-json/src/reader/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::reader::tape::{Tape, TapeElement};
use crate::reader::{make_decoder, ArrayDecoder, StructMode};
use arrow_array::builder::BooleanBufferBuilder;
use arrow_buffer::buffer::NullBuffer;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Fields};

use super::DecoderFactory;

pub struct StructArrayDecoder {
data_type: DataType,
decoders: Vec<Box<dyn ArrayDecoder>>,
Expand All @@ -37,6 +41,7 @@ impl StructArrayDecoder {
strict_mode: bool,
is_nullable: bool,
struct_mode: StructMode,
decoder_factory: Option<Arc<dyn DecoderFactory>>,
) -> Result<Self, ArrowError> {
let decoders = struct_fields(&data_type)
.iter()
Expand All @@ -51,6 +56,7 @@ impl StructArrayDecoder {
strict_mode,
nullable,
struct_mode,
decoder_factory.clone(),
)
})
.collect::<Result<Vec<_>, ArrowError>>()?;
Expand Down
1 change: 1 addition & 0 deletions arrow-json/src/reader/tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ impl TapeDecoder {
}
}

/// Decodes JSON data from the provided buffer, returning the number of bytes consumed
pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
let mut iter = BufIter::new(buf);

Expand Down