diff --git a/arrow-array/src/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs index f1ce5fa857d..2b288445c74 100644 --- a/arrow-array/src/builder/struct_builder.rs +++ b/arrow-array/src/builder/struct_builder.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::*; -use crate::types::Int32Type; use crate::StructArray; +use crate::{ + builder::*, + types::{Int16Type, Int32Type, Int64Type, Int8Type}, +}; use arrow_buffer::NullBufferBuilder; use arrow_schema::{DataType, Fields, IntervalUnit, SchemaBuilder, TimeUnit}; use std::sync::Arc; @@ -290,29 +292,42 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box panic!("The field of Map data type {t:?} should has a child Struct field"), }, DataType::Struct(fields) => Box::new(StructBuilder::from_fields(fields.clone(), capacity)), - DataType::Dictionary(key_type, value_type) if **key_type == DataType::Int32 => { - match &**value_type { - DataType::Utf8 => { - let dict_builder: StringDictionaryBuilder = - StringDictionaryBuilder::with_capacity(capacity, 256, 1024); - Box::new(dict_builder) - } - DataType::LargeUtf8 => { - let dict_builder: LargeStringDictionaryBuilder = - LargeStringDictionaryBuilder::with_capacity(capacity, 256, 1024); - Box::new(dict_builder) - } - DataType::Binary => { - let dict_builder: BinaryDictionaryBuilder = - BinaryDictionaryBuilder::with_capacity(capacity, 256, 1024); - Box::new(dict_builder) - } - DataType::LargeBinary => { - let dict_builder: LargeBinaryDictionaryBuilder = - LargeBinaryDictionaryBuilder::with_capacity(capacity, 256, 1024); - Box::new(dict_builder) + t @ DataType::Dictionary(key_type, value_type) => { + macro_rules! dict_builder { + ($key_type:ty) => { + match &**value_type { + DataType::Utf8 => { + let dict_builder: StringDictionaryBuilder<$key_type> = + StringDictionaryBuilder::with_capacity(capacity, 256, 1024); + Box::new(dict_builder) + } + DataType::LargeUtf8 => { + let dict_builder: LargeStringDictionaryBuilder<$key_type> = + LargeStringDictionaryBuilder::with_capacity(capacity, 256, 1024); + Box::new(dict_builder) + } + DataType::Binary => { + let dict_builder: BinaryDictionaryBuilder<$key_type> = + BinaryDictionaryBuilder::with_capacity(capacity, 256, 1024); + Box::new(dict_builder) + } + DataType::LargeBinary => { + let dict_builder: LargeBinaryDictionaryBuilder<$key_type> = + LargeBinaryDictionaryBuilder::with_capacity(capacity, 256, 1024); + Box::new(dict_builder) + } + t => panic!("Dictionary value type {t:?} is not currently supported"), + } + }; + } + match &**key_type { + DataType::Int8 => dict_builder!(Int8Type), + DataType::Int16 => dict_builder!(Int16Type), + DataType::Int32 => dict_builder!(Int32Type), + DataType::Int64 => dict_builder!(Int64Type), + _ => { + panic!("Data type {t:?} with key type {key_type:?} is not currently supported") } - t => panic!("Unsupported dictionary value type {t:?} is not currently supported"), } } t => panic!("Data type {t:?} is not currently supported"), @@ -430,12 +445,14 @@ impl StructBuilder { #[cfg(test)] mod tests { + use std::any::type_name; + use super::*; use arrow_buffer::Buffer; use arrow_data::ArrayData; use arrow_schema::Field; - use crate::array::Array; + use crate::{array::Array, types::ArrowDictionaryKeyType}; #[test] fn test_struct_array_builder() { @@ -690,10 +707,31 @@ mod tests { } #[test] - fn test_struct_array_builder_from_dictionary_type() { + fn test_struct_array_builder_from_dictionary_type_int8_key() { + test_struct_array_builder_from_dictionary_type_inner::(DataType::Int8); + } + + #[test] + fn test_struct_array_builder_from_dictionary_type_int16_key() { + test_struct_array_builder_from_dictionary_type_inner::(DataType::Int16); + } + + #[test] + fn test_struct_array_builder_from_dictionary_type_int32_key() { + test_struct_array_builder_from_dictionary_type_inner::(DataType::Int32); + } + + #[test] + fn test_struct_array_builder_from_dictionary_type_int64_key() { + test_struct_array_builder_from_dictionary_type_inner::(DataType::Int64); + } + + fn test_struct_array_builder_from_dictionary_type_inner( + key_type: DataType, + ) { let dict_field = Field::new( "f1", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary(Box::new(key_type), Box::new(DataType::Utf8)), false, ); let fields = vec![dict_field.clone()]; @@ -701,10 +739,14 @@ mod tests { let cloned_dict_field = dict_field.clone(); let expected_child_dtype = dict_field.data_type(); let mut struct_builder = StructBuilder::from_fields(vec![cloned_dict_field], 5); - struct_builder - .field_builder::>(0) - .expect("Builder should be StringDictionaryBuilder") - .append_value("dict string"); + let Some(dict_builder) = struct_builder.field_builder::>(0) + else { + panic!( + "Builder should be StringDictionaryBuilder<{}>", + type_name::() + ) + }; + dict_builder.append_value("dict string"); struct_builder.append(true); let array = struct_builder.finish(); @@ -714,13 +756,15 @@ mod tests { } #[test] - #[should_panic(expected = "Data type Dictionary(Int16, Utf8) is not currently supported")] + #[should_panic( + expected = "Data type Dictionary(UInt64, Utf8) with key type UInt64 is not currently supported" + )] fn test_struct_array_builder_from_schema_unsupported_type() { let fields = vec![ - Field::new("f1", DataType::Int16, false), + Field::new("f1", DataType::UInt64, false), Field::new( "f2", - DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), + DataType::Dictionary(Box::new(DataType::UInt64), Box::new(DataType::Utf8)), false, ), ]; @@ -729,7 +773,7 @@ mod tests { } #[test] - #[should_panic(expected = "Unsupported dictionary value type Int32 is not currently supported")] + #[should_panic(expected = "Dictionary value type Int32 is not currently supported")] fn test_struct_array_builder_from_dict_with_unsupported_value_type() { let fields = vec![Field::new( "f1",