diff --git a/crates/sparrow-main/tests/e2e/map_tests.rs b/crates/sparrow-main/tests/e2e/map_tests.rs index 0469a3f1a..51b516ea1 100644 --- a/crates/sparrow-main/tests/e2e/map_tests.rs +++ b/crates/sparrow-main/tests/e2e/map_tests.rs @@ -1,4 +1,5 @@ //! e2e tests for map types. + use sparrow_api::kaskada::v1alpha::TableConfig; use uuid::Uuid; @@ -108,20 +109,6 @@ async fn test_bool_to_s_get_static_key() { "###); } -#[tokio::test] -async fn test_first_last_map() { - // The csv writer does not support map types currently, so the output has been verified - // manually and now just compared as the hash of the parquet output. - let hash = - QueryFixture::new("{ first: Input.s_to_i64 | first(), last: Input.s_to_i64 | last() }") - .run_to_parquet_hash(&map_data_fixture().await) - .await - .unwrap(); - - let expected = "AB719CF6634779A5285D699A178AC69354696872E3733AA9388C9A6A"; - assert_eq!(hash, expected); -} - #[tokio::test] async fn test_s_to_i64_get_with_first_last_agg() { // Note that the last_f2 is empty. This is expected because the last() aggregation @@ -136,6 +123,18 @@ async fn test_s_to_i64_get_with_first_last_agg() { "###); } +#[tokio::test] +async fn test_map_output_into_sum_aggregation() { + insta::assert_snapshot!(QueryFixture::new("{ sum: Input.s_to_i64 | get(\"f1\") | sum(), value: Input.s_to_i64 | get(Input.s_to_i64_key) } | with_key(Input.s_to_i64_key)").run_to_csv(&map_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,sum,value + 1996-12-19T16:39:57.000000000,0,18146622110643880433,f1,0,0 + 1996-12-19T16:40:57.000000000,0,7541589802123724450,f2,1,10 + 1996-12-19T16:40:59.000000000,0,5533153676183607778,f3,6, + 1996-12-19T16:41:57.000000000,0,7541589802123724450,f2,6,13 + 1996-12-19T16:42:57.000000000,0,5533153676183607778,f3,21,11 + "###); +} + #[tokio::test] #[ignore = "https://docs.rs/arrow-ord/44.0.0/src/arrow_ord/comparison.rs.html#1746"] async fn test_map_equality() { @@ -144,6 +143,40 @@ async fn test_map_equality() { "###); } +#[tokio::test] +async fn test_query_with_merge_and_map_output() { + // This query produces a `merge` operations with `map` inputs, verifying + // we support maps within the _unlatched_ `spread` operation as well. + // Note that _latched_ spread is a separate implementation. + // + // It also produces a `map` as an output, verifying we can write maps to parquet. + let hash = QueryFixture::new( + "{ map: Input.s_to_i64, value: Input.s_to_i64 | get(Input.s_to_i64_key), lookup: lookup(Input.s_to_i64_key as u64, Input) }", + ) + .run_to_parquet_hash(&map_data_fixture().await) + .await + .unwrap(); + + assert_eq!( + "92C3C8B7E6AE6AF41266B63F3FBE11958DB5BFD23B58E891963F6287", + hash + ); +} + +#[tokio::test] +async fn test_first_last_map() { + // The csv writer does not support map types currently, so the output has been verified + // manually and now just compared as the hash of the parquet output. + let hash = + QueryFixture::new("{ first: Input.s_to_i64 | first(), last: Input.s_to_i64 | last() }") + .run_to_parquet_hash(&map_data_fixture().await) + .await + .unwrap(); + + let expected = "AB719CF6634779A5285D699A178AC69354696872E3733AA9388C9A6A"; + assert_eq!(hash, expected); +} + #[tokio::test] async fn test_swapped_args_for_get_map() { insta::assert_yaml_snapshot!(QueryFixture::new("{ f1: get(Input.s_to_i64, \"f1\") }") diff --git a/crates/sparrow-runtime/src/execute/operation/spread.rs b/crates/sparrow-runtime/src/execute/operation/spread.rs index 7dfc6233c..ce6467ef9 100644 --- a/crates/sparrow-runtime/src/execute/operation/spread.rs +++ b/crates/sparrow-runtime/src/execute/operation/spread.rs @@ -4,15 +4,15 @@ use std::sync::Arc; use anyhow::Context; use arrow::array::{ new_null_array, Array, ArrayData, ArrayRef, BooleanArray, BooleanBufferBuilder, - Int32BufferBuilder, ListArray, PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder, - StructArray, + GenericStringArray, GenericStringBuilder, Int32BufferBuilder, ListArray, MapArray, + OffsetSizeTrait, PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder, StructArray, }; use arrow::datatypes::{self, ArrowPrimitiveType, DataType, Fields}; use bitvec::vec::BitVec; use itertools::{izip, Itertools}; use sparrow_arrow::downcast::{ - downcast_boolean_array, downcast_list_array, downcast_primitive_array, downcast_string_array, - downcast_struct_array, + downcast_boolean_array, downcast_list_array, downcast_map_array, downcast_primitive_array, + downcast_string_array, downcast_struct_array, }; use sparrow_arrow::utils::make_null_array; use sparrow_instructions::GroupingIndices; @@ -163,9 +163,12 @@ enum SerializedSpread<'a> { ), LatchedIntervalYearMonth(Boo<'a, LatchedPrimitiveSpread>), UnlatchedIntervalYearMonth(Boo<'a, UnlatchedPrimitiveSpread>), - LatchedString(Boo<'a, LatchedStringSpread>), - UnlatchedString(Boo<'a, UnlatchedStringSpread>), + LatchedString(Boo<'a, LatchedStringSpread>), + UnlatchedString(Boo<'a, UnlatchedStringSpread>), + LatchedLargeString(Boo<'a, LatchedStringSpread>), + UnlatchedLargeString(Boo<'a, UnlatchedStringSpread>), UnlatchedUInt64List(Boo<'a, UnlatchedUInt64ListSpread>), + UnlatchedMap(Boo<'a, UnlatchedMapSpread>), LatchedStruct(Boo<'a, StructSpread>), UnlatchedStruct(Boo<'a, StructSpread>), } @@ -241,7 +244,10 @@ impl<'a> SerializedSpread<'a> { SerializedSpread::UnlatchedIntervalYearMonth(spread) => into_spread_impl(spread), SerializedSpread::LatchedString(spread) => into_spread_impl(spread), SerializedSpread::UnlatchedString(spread) => into_spread_impl(spread), + SerializedSpread::LatchedLargeString(spread) => into_spread_impl(spread), + SerializedSpread::UnlatchedLargeString(spread) => into_spread_impl(spread), SerializedSpread::UnlatchedUInt64List(spread) => into_spread_impl(spread), + SerializedSpread::UnlatchedMap(spread) => into_spread_impl(spread), SerializedSpread::LatchedStruct(spread) => into_spread_impl(spread), SerializedSpread::UnlatchedStruct(spread) => into_spread_impl(spread), } @@ -322,9 +328,16 @@ impl Spread { } DataType::Utf8 => { if latched { - Box::::default() + Box::>::default() } else { - Box::new(UnlatchedStringSpread) + Box::>::default() + } + } + DataType::LargeUtf8 => { + if latched { + Box::>::default() + } else { + Box::>::default() } } DataType::Struct(fields) => { @@ -334,6 +347,10 @@ impl Spread { Box::new(StructSpread::try_new_unlatched(fields)?) } } + DataType::Map(_, _) => { + anyhow::ensure!(!latched, "Latched map spread not supported"); + Box::new(UnlatchedMapSpread) + } DataType::List(field) => { anyhow::ensure!(!latched, "Latched list spread not supported"); anyhow::ensure!( @@ -1214,16 +1231,31 @@ impl StructSpreadState for LatchedStructSpreadState { } } -#[derive(serde::Serialize, serde::Deserialize, Debug)] -struct UnlatchedStringSpread; +#[derive(Default, serde::Serialize, serde::Deserialize, Debug)] +struct UnlatchedStringSpread +where + O: OffsetSizeTrait, +{ + _phantom: PhantomData, +} -impl ToSerializedSpread for UnlatchedStringSpread { +impl ToSerializedSpread for UnlatchedStringSpread { fn to_serialized_spread(&self) -> SerializedSpread<'_> { SerializedSpread::UnlatchedString(Boo::Borrowed(self)) } } -impl SpreadImpl for UnlatchedStringSpread { +impl ToSerializedSpread for UnlatchedStringSpread { + fn to_serialized_spread(&self) -> SerializedSpread<'_> { + SerializedSpread::UnlatchedLargeString(Boo::Borrowed(self)) + } +} + +impl SpreadImpl for UnlatchedStringSpread +where + O: OffsetSizeTrait, + UnlatchedStringSpread: ToSerializedSpread, +{ fn spread_signaled( &mut self, grouping: &GroupingIndices, @@ -1263,18 +1295,29 @@ impl SpreadImpl for UnlatchedStringSpread { } #[derive(Default, Debug, serde::Serialize, serde::Deserialize)] -struct LatchedStringSpread { +struct LatchedStringSpread { values: Vec, valid: BitVec, + _phantom: PhantomData, } -impl ToSerializedSpread for LatchedStringSpread { +impl ToSerializedSpread for LatchedStringSpread { fn to_serialized_spread(&self) -> SerializedSpread<'_> { SerializedSpread::LatchedString(Boo::Borrowed(self)) } } -impl SpreadImpl for LatchedStringSpread { +impl ToSerializedSpread for LatchedStringSpread { + fn to_serialized_spread(&self) -> SerializedSpread<'_> { + SerializedSpread::LatchedLargeString(Boo::Borrowed(self)) + } +} + +impl SpreadImpl for LatchedStringSpread +where + O: OffsetSizeTrait, + LatchedStringSpread: ToSerializedSpread, +{ fn spread_signaled( &mut self, grouping: &GroupingIndices, @@ -1289,10 +1332,10 @@ impl SpreadImpl for LatchedStringSpread { self.valid.resize(grouping.num_groups(), false); } - let values: &StringArray = downcast_string_array(values.as_ref())?; + let values: &GenericStringArray = downcast_string_array(values.as_ref())?; let mut values = values.iter(); - let mut builder = StringBuilder::with_capacity(grouping.len(), 1024); + let mut builder = GenericStringBuilder::::with_capacity(grouping.len(), 1024); // TODO: Could use "next set bit" operations to more quickly handle // signal arrays. @@ -1339,7 +1382,7 @@ impl SpreadImpl for LatchedStringSpread { self.valid.resize(grouping.num_groups(), false); } - let values_array: &StringArray = downcast_string_array(values.as_ref())?; + let values_array: &GenericStringArray = downcast_string_array(values.as_ref())?; for (group, value) in grouping.group_iter().zip(values_array.iter()) { // SAFETY: Resized to contain groups above. @@ -1367,7 +1410,7 @@ impl SpreadImpl for LatchedStringSpread { self.valid.resize(grouping.num_groups(), false); } - let mut builder = StringBuilder::with_capacity(grouping.len(), 1024); + let mut builder = GenericStringBuilder::::with_capacity(grouping.len(), 1024); // TODO: Could use "next set bit" operations to more quickly handle // signal arrays. @@ -1708,13 +1751,84 @@ impl SpreadImpl for UnlatchedUInt64ListSpread { } } +#[derive(serde::Serialize, serde::Deserialize, Debug)] +struct UnlatchedMapSpread; + +impl ToSerializedSpread for UnlatchedMapSpread { + fn to_serialized_spread(&self) -> SerializedSpread<'_> { + SerializedSpread::UnlatchedMap(Boo::Borrowed(self)) + } +} + +impl SpreadImpl for UnlatchedMapSpread { + fn spread_signaled( + &mut self, + grouping: &GroupingIndices, + values: &ArrayRef, + signal: &BooleanArray, + ) -> anyhow::Result { + let map_values = downcast_map_array(values.as_ref())?; + + let mut offset_builder = Int32BufferBuilder::new(grouping.len() + 1); + let mut null_builder = BooleanBufferBuilder::new(grouping.len()); + + // Ensure the buffers are aligned to the offset. + offset_builder.append_n_zeroed(values.offset()); + null_builder.append_n(values.offset(), false); + + let mut offset_iter = map_values.value_offsets().iter(); + let mut offset = *offset_iter.next().context("missing offset")?; + offset_builder.append(offset); + + let mut index = 0; + for signal in signal.iter() { + if matches!(signal, Some(true)) { + offset = *offset_iter.next().context("missing offset")?; + null_builder.append(values.is_valid(index)); + index += 1; + } else { + null_builder.append(false); + } + offset_builder.append(offset); + } + + let data_builder = values.to_data().into_builder(); + let offset = offset_builder.finish(); + let array_data = data_builder + .len(grouping.len()) + .null_bit_buffer(Some(null_builder.finish().into_inner())) + .buffers(vec![offset]) + .build()?; + let result = MapArray::from(array_data); + + Ok(Arc::new(result)) + } + + fn spread_true( + &mut self, + grouping: &GroupingIndices, + values: &ArrayRef, + ) -> anyhow::Result { + anyhow::ensure!(grouping.len() == values.len()); + Ok(values.clone()) + } + + fn spread_false( + &mut self, + grouping: &GroupingIndices, + value_type: &DataType, + ) -> anyhow::Result { + Ok(new_null_array(value_type, grouping.len())) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; use arrow::array::{ - Array, ArrayRef, BooleanArray, Float64Array, Int64Array, ListArray, StringArray, - StructArray, UInt32Array, + Array, ArrayRef, BooleanArray, Float64Array, Int32Builder, Int64Array, LargeStringArray, + ListArray, MapBuilder, StringArray, StringBuilder, StructArray, UInt32Array, }; use arrow::datatypes::{DataType, Field, UInt64Type}; use sparrow_arrow::downcast::{ @@ -2057,6 +2171,44 @@ mod tests { ); } + #[test] + fn test_large_string_latched() { + let nums = LargeStringArray::from(vec![ + Some("5"), + Some("8"), + None, + Some("10"), + None, + Some("12"), + ]); + let result = run_spread( + Arc::new(nums), + vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 0], + vec![ + false, true, false, true, false, true, false, true, false, true, false, + ], + true, + ); + let result: &LargeStringArray = downcast_string_array(result.as_ref()).unwrap(); + + assert_eq!( + result, + &LargeStringArray::from(vec![ + None, + Some("5"), + None, + Some("8"), + Some("5"), // signal false, remember last value for key 1=5 + None, + Some("8"), // signal false, remember last value for key 0=8 + Some("10"), + None, + None, + None + ]) + ); + } + #[test] fn test_unlatched_uint64_list_spread() { let data = vec![ @@ -2126,6 +2278,74 @@ mod tests { assert_eq!(&result, &expected) } + #[test] + fn test_unlatched_map_spread() { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(8); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + + builder.append(false).unwrap(); + + builder.keys().append_value("joe"); + builder.values().append_value(10); + builder.keys().append_value("foo"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("alice"); + builder.values().append_value(2); + builder.append(true).unwrap(); + + let map_array = builder.finish(); + + let result = run_spread( + Arc::new(map_array), + vec![0, 1, 2, 3, 4, 5, 6, 7], + vec![true, false, false, true, false, true, false, true], + false, + ); + + let string_builder2 = StringBuilder::new(); + let int_builder2 = Int32Builder::with_capacity(8); + let mut builder = MapBuilder::new(None, string_builder2, int_builder2); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.append(false).unwrap(); + builder.append(false).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + + builder.append(false).unwrap(); + builder.append(false).unwrap(); + builder.append(false).unwrap(); + + builder.keys().append_value("joe"); + builder.values().append_value(10); + builder.keys().append_value("foo"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + let expected = builder.finish(); + let expected: ArrayRef = Arc::new(expected); + assert_eq!(&result, &expected) + } + fn run_spread( values: ArrayRef, indices: Vec,