Skip to content

Commit 1096461

Browse files
authored
Add support for Map writing to Arrow vtab (#439)
1 parent f594f39 commit 1096461

File tree

2 files changed

+122
-9
lines changed

2 files changed

+122
-9
lines changed

crates/duckdb/src/core/vector.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ impl ListVector {
192192
}
193193

194194
/// Take the child as [StructVector].
195-
pub fn struct_child(&self) -> StructVector {
195+
pub fn struct_child(&self, capacity: usize) -> StructVector {
196+
self.reserve(capacity);
196197
StructVector::from(unsafe { duckdb_list_vector_get_child(self.entries.ptr) })
197198
}
198199

@@ -300,8 +301,11 @@ impl From<duckdb_vector> for StructVector {
300301

301302
impl StructVector {
302303
/// Returns the child by idx in the list vector.
303-
pub fn child(&self, idx: usize) -> FlatVector {
304-
FlatVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) })
304+
pub fn child(&self, idx: usize, capacity: usize) -> FlatVector {
305+
FlatVector::with_capacity(
306+
unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) },
307+
capacity,
308+
)
305309
}
306310

307311
/// Take the child as [StructVector].

crates/duckdb/src/vtab/arrow.rs

+115-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::{
66
core::{ArrayVector, FlatVector, Inserter, ListVector, StructVector, Vector},
77
types::DuckString,
88
};
9+
use arrow::array::as_map_array;
910
use arrow::{
1011
array::{
1112
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array,
@@ -204,6 +205,7 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalTypeHandle,
204205
// DuckDB does not support negative decimal scales
205206
Ok(LogicalTypeHandle::decimal(*width, (*scale).try_into().unwrap()))
206207
}
208+
DataType::Map(field, _) => arrow_map_to_duckdb_logical_type(field),
207209
DataType::Boolean
208210
| DataType::Utf8
209211
| DataType::LargeUtf8
@@ -220,6 +222,35 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalTypeHandle,
220222
}
221223
}
222224

225+
fn arrow_map_to_duckdb_logical_type(field: &FieldRef) -> Result<LogicalTypeHandle, Box<dyn std::error::Error>> {
226+
// Map is a logical nested type that is represented as `List<entries: Struct<key: K, value: V>>`
227+
let DataType::Struct(ref fields) = field.data_type() else {
228+
return Err(format!(
229+
"The inner field of a Map must be a Struct, got: {:?}",
230+
field.data_type()
231+
)
232+
.into());
233+
};
234+
235+
if fields.len() != 2 {
236+
return Err(format!(
237+
"The inner Struct field of a Map must have 2 fields, got {} fields",
238+
fields.len()
239+
)
240+
.into());
241+
}
242+
243+
let (Some(key_field), Some(value_field)) = (fields.first(), fields.get(1)) else {
244+
// number of fields is verified above
245+
unreachable!()
246+
};
247+
248+
Ok(LogicalTypeHandle::map(
249+
&LogicalTypeHandle::from(to_duckdb_type_id(key_field.data_type())?),
250+
&LogicalTypeHandle::from(to_duckdb_type_id(value_field.data_type())?),
251+
))
252+
}
253+
223254
// FIXME: flat vectors don't have all of thsese types. I think they only
224255
/// Converts flat vector to an arrow array
225256
pub fn flat_vector_to_arrow_array(
@@ -586,6 +617,19 @@ pub fn write_arrow_array_to_vector(
586617
let mut struct_vector = chunk.struct_vector();
587618
struct_array_to_vector(struct_array, &mut struct_vector)?;
588619
}
620+
DataType::Map(_, _) => {
621+
// [`MapArray`] is physically a [`ListArray`] of key values pairs stored as an `entries` [`StructArray`] with 2 child fields.
622+
let map_array = as_map_array(col.as_ref());
623+
let out = &mut chunk.list_vector();
624+
struct_array_to_vector(map_array.entries(), &mut out.struct_child(map_array.entries().len()))?;
625+
626+
for i in 0..map_array.len() {
627+
let offset = map_array.value_offsets()[i];
628+
let length = map_array.value_length(i);
629+
out.set_entry(i, offset.as_(), length.as_());
630+
}
631+
set_nulls_in_list_vector(map_array, out);
632+
}
589633
dt => {
590634
return Err(format!(
591635
"column with data_type {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs",
@@ -935,7 +979,10 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
935979
fixed_size_list_array_to_vector(as_fixed_size_list_array(value_array.as_ref()), &mut out.array_child())?;
936980
}
937981
DataType::Struct(_) => {
938-
struct_array_to_vector(as_struct_array(value_array.as_ref()), &mut out.struct_child())?;
982+
struct_array_to_vector(
983+
as_struct_array(value_array.as_ref()),
984+
&mut out.struct_child(value_array.len()),
985+
)?;
939986
}
940987
_ => {
941988
return Err(format!(
@@ -993,13 +1040,13 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result
9931040
let column = array.column(i);
9941041
match column.data_type() {
9951042
dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
996-
primitive_array_to_vector(column, &mut out.child(i))?;
1043+
primitive_array_to_vector(column, &mut out.child(i, array.len()))?;
9971044
}
9981045
DataType::Utf8 => {
999-
string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i));
1046+
string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i, array.len()));
10001047
}
10011048
DataType::Binary => {
1002-
binary_array_to_vector(as_generic_binary_array(column.as_ref()), &mut out.child(i));
1049+
binary_array_to_vector(as_generic_binary_array(column.as_ref()), &mut out.child(i, array.len()));
10031050
}
10041051
DataType::List(_) => {
10051052
list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i))?;
@@ -1112,10 +1159,10 @@ mod test {
11121159
Array, ArrayRef, AsArray, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
11131160
Decimal128Array, Decimal256Array, DurationSecondArray, FixedSizeListArray, FixedSizeListBuilder,
11141161
GenericByteArray, GenericListArray, Int32Array, Int32Builder, IntervalDayTimeArray,
1115-
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray, ListBuilder,
1162+
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray, ListBuilder, MapArray,
11161163
OffsetSizeTrait, PrimitiveArray, StringArray, StringViewArray, StructArray, Time32SecondArray,
11171164
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
1118-
TimestampSecondArray,
1165+
TimestampSecondArray, UInt32Array,
11191166
},
11201167
buffer::{OffsetBuffer, ScalarBuffer},
11211168
datatypes::{
@@ -1894,4 +1941,66 @@ mod test {
18941941

18951942
Ok(())
18961943
}
1944+
1945+
fn check_map_array_roundtrip(array: MapArray) -> Result<(), Box<dyn Error>> {
1946+
let expected = array.clone();
1947+
1948+
let db = Connection::open_in_memory()?;
1949+
db.register_table_function::<ArrowVTab>("arrow")?;
1950+
1951+
// Roundtrip a record batch from Rust to DuckDB and back to Rust
1952+
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]);
1953+
1954+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?;
1955+
let param = arrow_recordbatch_to_query_params(rb.clone());
1956+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
1957+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
1958+
let output_array = rb
1959+
.column(0)
1960+
.as_any()
1961+
.downcast_ref::<MapArray>()
1962+
.expect("Expected MapArray");
1963+
1964+
assert_eq!(output_array.keys(), expected.keys());
1965+
assert_eq!(output_array.values(), expected.values());
1966+
1967+
Ok(())
1968+
}
1969+
1970+
#[test]
1971+
fn test_map_roundtrip() -> Result<(), Box<dyn Error>> {
1972+
// Test 1 - simple MapArray
1973+
let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"];
1974+
let values_data = UInt32Array::from(vec![
1975+
Some(0u32),
1976+
None,
1977+
Some(20),
1978+
Some(30),
1979+
None,
1980+
Some(50),
1981+
Some(60),
1982+
Some(70),
1983+
]);
1984+
// Construct a buffer for value offsets, for the nested array:
1985+
// [[a, b, c], [d, e, f], [g, h]]
1986+
let entry_offsets = [0, 3, 6, 8];
1987+
let map_array = MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets).unwrap();
1988+
check_map_array_roundtrip(map_array)?;
1989+
1990+
// Test 2 - large MapArray of 4000 elements to test buffers capacity adjustment
1991+
let keys: Vec<String> = (0..4000).map(|i| format!("key-{}", i)).collect();
1992+
let values_data = UInt32Array::from(
1993+
(0..4000)
1994+
.map(|i| if i % 5 == 0 { None } else { Some(i as u32) })
1995+
.collect::<Vec<_>>(),
1996+
);
1997+
let mut entry_offsets: Vec<u32> = (0..=4000).step_by(3).collect();
1998+
entry_offsets.push(4000);
1999+
let map_array =
2000+
MapArray::new_from_strings(keys.iter().map(String::as_str), &values_data, entry_offsets.as_slice())
2001+
.unwrap();
2002+
check_map_array_roundtrip(map_array)?;
2003+
2004+
Ok(())
2005+
}
18972006
}

0 commit comments

Comments
 (0)