diff --git a/crates/sparrow-compiler/src/ast_to_dfg.rs b/crates/sparrow-compiler/src/ast_to_dfg.rs index 7c32c69e6..a6224228e 100644 --- a/crates/sparrow-compiler/src/ast_to_dfg.rs +++ b/crates/sparrow-compiler/src/ast_to_dfg.rs @@ -557,11 +557,11 @@ pub fn add_to_dfg( // // TODO: Flattening the window arguments is hacky and confusing. We should instead // incorporate the tick directly into the function containing the window. - let window_arg = original_ast.map(|e| &e.args()[2]); + let window_arg = original_ast.map(|e| &e.args()[3]); let (condition, duration) = match window_arg { Some(window) => { dfg.enter_env(); - dfg.bind("$condition_input", args[1].inner().clone()); + dfg.bind("$condition_input", args[0].inner().clone()); let result = flatten_window_args_if_needed(window, dfg, data_context, diagnostics)?; @@ -574,12 +574,37 @@ pub fn add_to_dfg( // // Note that this won't define the `condition_input` for the // purposes of ticks. - (args[2].clone(), args[3].clone()) + (args[3].clone(), args[4].clone()) } }; - // [max, input, condition, duration] - vec![args[0].clone(), args[1].clone(), condition, duration] + let min = dfg.literal(args[2].value()); + let max = dfg.literal(args[1].value()); + match (min, max) { + (Some(ScalarValue::Int64(Some(min))), Some(ScalarValue::Int64(Some(max)))) => { + if min > max { + DiagnosticCode::IllegalCast + .builder() + .with_label(args[2].location().primary_label().with_message( + format!( + "min '{min}' must be less than or equal to max '{max}'" + ), + )) + .emit(diagnostics); + } + } + (Some(_), Some(_)) => (), + (_, _) => panic!("previously verified min and max are scalar types"), + } + + // [input, max, min, condition, duration] + vec![ + args[0].clone(), + args[1].clone(), + args[2].clone(), + condition, + duration, + ] } else if function.name() == "when" || function.name() == "if" { dfg.enter_env(); dfg.bind("$condition_input", args[1].inner().clone()); diff --git a/crates/sparrow-compiler/src/functions/collection.rs b/crates/sparrow-compiler/src/functions/collection.rs index aff252720..aec627445 100644 --- a/crates/sparrow-compiler/src/functions/collection.rs +++ b/crates/sparrow-compiler/src/functions/collection.rs @@ -14,9 +14,9 @@ pub(super) fn register(registry: &mut Registry) { .set_internal(); registry - .register("collect(const max: i64, input: T, window: window = null) -> list") + .register("collect(input: T, const max: i64, const min: i64 = 0, window: window = null) -> list") .with_dfg_signature( - "collect(const max: i64, input: T, window: bool = null, duration: i64 = null) -> list", + "collect(input: T, const max: i64, const min: i64 = 0, window: bool = null, duration: i64 = null) -> list", ) .with_implementation(Implementation::Instruction(InstOp::Collect)) .set_internal(); diff --git a/crates/sparrow-instructions/src/evaluators.rs b/crates/sparrow-instructions/src/evaluators.rs index 70e1f82d4..5241fb314 100644 --- a/crates/sparrow-instructions/src/evaluators.rs +++ b/crates/sparrow-instructions/src/evaluators.rs @@ -196,7 +196,7 @@ fn create_simple_evaluator( InstOp::Coalesce => CoalesceEvaluator::try_new(info), InstOp::Collect => { create_typed_evaluator!( - &info.args[1].data_type, + &info.args[0].data_type, CollectPrimitiveEvaluator, CollectStructEvaluator, CollectListEvaluator, diff --git a/crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs b/crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs index 84c46d5d3..87beca43a 100644 --- a/crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs +++ b/crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs @@ -14,6 +14,11 @@ use std::sync::Arc; /// If the list is empty, an empty list is returned (rather than `null`). #[derive(Debug)] pub struct CollectBooleanEvaluator { + /// The min size of the buffer. + /// + /// If the buffer is smaller than this, a null value + /// will be produced. + min: usize, /// The max size of the buffer. /// /// Once the max size is reached, the front will be popped and the new @@ -28,14 +33,14 @@ pub struct CollectBooleanEvaluator { impl EvaluatorFactory for CollectBooleanEvaluator { fn try_new(info: StaticInfo<'_>) -> anyhow::Result> { - let input_type = info.args[1].data_type(); + let input_type = info.args[0].data_type(); let result_type = info.result_type; match result_type { DataType::List(t) => anyhow::ensure!(t.data_type() == input_type), other => anyhow::bail!("expected list result type, saw {:?}", other), }; - let max = match info.args[0].value_ref.literal_value() { + let max = match info.args[1].value_ref.literal_value() { Some(ScalarValue::Int64(Some(v))) if *v <= 0 => { anyhow::bail!("unexpected value of `max` -- must be > 0") } @@ -47,8 +52,21 @@ impl EvaluatorFactory for CollectBooleanEvaluator { None => anyhow::bail!("expected literal value for max parameter"), }; - let (_, input, tick, duration) = info.unpack_arguments()?; + let min = match info.args[2].value_ref.literal_value() { + Some(ScalarValue::Int64(Some(v))) if *v < 0 => { + anyhow::bail!("unexpected value of `min` -- must be >= 0") + } + Some(ScalarValue::Int64(Some(v))) => *v as usize, + // If a user specifies `min = null`, default to 0. + Some(ScalarValue::Int64(None)) => 0, + Some(other) => anyhow::bail!("expected i64 for min parameter, saw {:?}", other), + None => anyhow::bail!("expected literal value for min parameter"), + }; + assert!(min < max, "min must be less than max"); + + let (input, _, _, tick, duration) = info.unpack_arguments()?; Ok(Box::new(Self { + min, max, input, tick, @@ -100,7 +118,11 @@ impl CollectBooleanEvaluator { self.token.add_value(self.max, entity_index, input); let cur_list = self.token.state(entity_index); - list_builder.append_value(cur_list.iter().copied()); + if cur_list.len() >= self.min { + list_builder.append_value(cur_list.iter().copied()); + } else { + list_builder.append_null(); + } }); Ok(Arc::new(list_builder.finish())) diff --git a/crates/sparrow-instructions/src/evaluators/list/collect_primitive.rs b/crates/sparrow-instructions/src/evaluators/list/collect_primitive.rs index 176cb6a9f..b23c06adb 100644 --- a/crates/sparrow-instructions/src/evaluators/list/collect_primitive.rs +++ b/crates/sparrow-instructions/src/evaluators/list/collect_primitive.rs @@ -25,6 +25,11 @@ where T: ArrowPrimitiveType, T::Native: Serialize + DeserializeOwned + Copy, { + /// The min size of the buffer. + /// + /// If the buffer is smaller than this, a null value + /// will be produced. + min: usize, /// The max size of the buffer. /// /// Once the max size is reached, the front will be popped and the new @@ -43,14 +48,14 @@ where T::Native: Serialize + DeserializeOwned + Copy, { fn try_new(info: StaticInfo<'_>) -> anyhow::Result> { - let input_type = info.args[1].data_type(); + let input_type = info.args[0].data_type(); let result_type = info.result_type; match result_type { DataType::List(t) => anyhow::ensure!(t.data_type() == input_type), other => anyhow::bail!("expected list result type, saw {:?}", other), }; - let max = match info.args[0].value_ref.literal_value() { + let max = match info.args[1].value_ref.literal_value() { Some(ScalarValue::Int64(Some(v))) if *v <= 0 => { anyhow::bail!("unexpected value of `max` -- must be > 0") } @@ -62,8 +67,21 @@ where None => anyhow::bail!("expected literal value for max parameter"), }; - let (_, input, tick, duration) = info.unpack_arguments()?; + let min = match info.args[2].value_ref.literal_value() { + Some(ScalarValue::Int64(Some(v))) if *v < 0 => { + anyhow::bail!("unexpected value of `min` -- must be >= 0") + } + Some(ScalarValue::Int64(Some(v))) => *v as usize, + // If a user specifies `min = null`, default to 0. + Some(ScalarValue::Int64(None)) => 0, + Some(other) => anyhow::bail!("expected i64 for min parameter, saw {:?}", other), + None => anyhow::bail!("expected literal value for min parameter"), + }; + debug_assert!(min <= max, "min must be less than max"); + + let (input, _, _, tick, duration) = info.unpack_arguments()?; Ok(Box::new(Self { + min, max, input, tick, @@ -123,7 +141,11 @@ where self.token.add_value(self.max, entity_index, input); let cur_list = self.token.state(entity_index); - list_builder.append_value(cur_list.iter().copied()); + if cur_list.len() >= self.min { + list_builder.append_value(cur_list.iter().copied()); + } else { + list_builder.append_null(); + } }); Ok(Arc::new(list_builder.finish())) diff --git a/crates/sparrow-instructions/src/evaluators/list/collect_string.rs b/crates/sparrow-instructions/src/evaluators/list/collect_string.rs index 3886929b0..98633053e 100644 --- a/crates/sparrow-instructions/src/evaluators/list/collect_string.rs +++ b/crates/sparrow-instructions/src/evaluators/list/collect_string.rs @@ -14,6 +14,11 @@ use std::sync::Arc; /// If the list is empty, an empty list is returned (rather than `null`). #[derive(Debug)] pub struct CollectStringEvaluator { + /// The min size of the buffer. + /// + /// If the buffer is smaller than this, a null value + /// will be produced. + min: usize, /// The max size of the buffer. /// /// Once the max size is reached, the front will be popped and the new @@ -28,14 +33,14 @@ pub struct CollectStringEvaluator { impl EvaluatorFactory for CollectStringEvaluator { fn try_new(info: StaticInfo<'_>) -> anyhow::Result> { - let input_type = info.args[1].data_type(); + let input_type = info.args[0].data_type(); let result_type = info.result_type; match result_type { DataType::List(t) => anyhow::ensure!(t.data_type() == input_type), other => anyhow::bail!("expected list result type, saw {:?}", other), }; - let max = match info.args[0].value_ref.literal_value() { + let max = match info.args[1].value_ref.literal_value() { Some(ScalarValue::Int64(Some(v))) if *v <= 0 => { anyhow::bail!("unexpected value of `max` -- must be > 0") } @@ -47,8 +52,21 @@ impl EvaluatorFactory for CollectStringEvaluator { None => anyhow::bail!("expected literal value for max parameter"), }; - let (_, input, tick, duration) = info.unpack_arguments()?; + let min = match info.args[2].value_ref.literal_value() { + Some(ScalarValue::Int64(Some(v))) if *v < 0 => { + anyhow::bail!("unexpected value of `min` -- must be >= 0") + } + Some(ScalarValue::Int64(Some(v))) => *v as usize, + // If a user specifies `min = null`, default to 0. + Some(ScalarValue::Int64(None)) => 0, + Some(other) => anyhow::bail!("expected i64 for min parameter, saw {:?}", other), + None => anyhow::bail!("expected literal value for min parameter"), + }; + assert!(min < max, "min must be less than max"); + + let (input, _, _, tick, duration) = info.unpack_arguments()?; Ok(Box::new(Self { + min, max, input, tick, @@ -101,7 +119,11 @@ impl CollectStringEvaluator { .add_value(self.max, entity_index, input.map(|s| s.to_owned())); let cur_list = self.token.state(entity_index); - list_builder.append_value(cur_list.clone()); + if cur_list.len() >= self.min { + list_builder.append_value(cur_list.clone()); + } else { + list_builder.append_null(); + } }); Ok(Arc::new(list_builder.finish())) diff --git a/crates/sparrow-instructions/src/evaluators/list/collect_struct.rs b/crates/sparrow-instructions/src/evaluators/list/collect_struct.rs index 3a5b6c59d..c8ee1cd41 100644 --- a/crates/sparrow-instructions/src/evaluators/list/collect_struct.rs +++ b/crates/sparrow-instructions/src/evaluators/list/collect_struct.rs @@ -2,7 +2,7 @@ use crate::{CollectStructToken, Evaluator, EvaluatorFactory, RuntimeInfo, StateT use arrow::array::{ new_empty_array, Array, ArrayRef, AsArray, ListArray, UInt32Array, UInt32Builder, }; -use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; use arrow_schema::Field; use itertools::Itertools; @@ -19,6 +19,11 @@ use std::sync::Arc; /// If the list is empty, an empty list is returned (rather than `null`). #[derive(Debug)] pub struct CollectStructEvaluator { + /// The min size of the buffer. + /// + /// If the buffer is smaller than this, a null value + /// will be produced. + min: usize, /// The max size of the buffer. /// /// Once the max size is reached, the front will be popped and the new @@ -33,7 +38,7 @@ pub struct CollectStructEvaluator { impl EvaluatorFactory for CollectStructEvaluator { fn try_new(info: StaticInfo<'_>) -> anyhow::Result> { - let input_type = info.args[1].data_type(); + let input_type = info.args[0].data_type(); let result_type = info.result_type; match result_type { DataType::List(t) => { @@ -43,7 +48,7 @@ impl EvaluatorFactory for CollectStructEvaluator { other => anyhow::bail!("expected list result type, saw {:?}", other), }; - let max = match info.args[0].value_ref.literal_value() { + let max = match info.args[1].value_ref.literal_value() { Some(ScalarValue::Int64(Some(v))) if *v <= 0 => { anyhow::bail!("unexpected value of `max` -- must be > 0") } @@ -54,12 +59,23 @@ impl EvaluatorFactory for CollectStructEvaluator { Some(other) => anyhow::bail!("expected i64 for max parameter, saw {:?}", other), None => anyhow::bail!("expected literal value for max parameter"), }; + let min = match info.args[2].value_ref.literal_value() { + Some(ScalarValue::Int64(Some(v))) if *v < 0 => { + anyhow::bail!("unexpected value of `min` -- must be >= 0") + } + Some(ScalarValue::Int64(Some(v))) => *v as usize, + // If a user specifies `min = null`, default to 0. + Some(ScalarValue::Int64(None)) => 0, + Some(other) => anyhow::bail!("expected i64 for min parameter, saw {:?}", other), + None => anyhow::bail!("expected literal value for min parameter"), + }; + debug_assert!(min <= max, "min must be less than max"); let accum = new_empty_array(result_type).as_list::().to_owned(); let token = CollectStructToken::new(Arc::new(accum)); - - let (_, input, tick, duration) = info.unpack_arguments()?; + let (input, _, _, tick, duration) = info.unpack_arguments()?; Ok(Box::new(Self { + min, max, input, tick, @@ -77,7 +93,14 @@ impl Evaluator for CollectStructEvaluator { let input = info.value(&self.input)?.array_ref()?; let key_capacity = info.grouping().num_groups(); let entity_indices = info.grouping().group_indices(); - Self::evaluate_non_windowed(token, key_capacity, entity_indices, input, self.max) + Self::evaluate_non_windowed( + token, + key_capacity, + entity_indices, + input, + self.min, + self.max, + ) } (false, true) => unimplemented!("since window aggregation unsupported"), (false, false) => panic!("sliding window aggregation should use other evaluator"), @@ -179,6 +202,7 @@ impl CollectStructEvaluator { key_capacity: usize, entity_indices: &UInt32Array, input: ArrayRef, + min: usize, max: usize, ) -> anyhow::Result { let input_structs = input.as_struct(); @@ -195,6 +219,9 @@ impl CollectStructEvaluator { let mut take_output_builder = UInt32Builder::new(); let mut output_offset_builder = vec![0]; + // Tracks the result's null values + let mut null_buffer = vec![]; + let mut cur_offset = 0; // For each entity, append the take indices for the new input to the existing // entity take indices @@ -213,14 +240,26 @@ impl CollectStructEvaluator { // already verified key exists, or created entry if not, in previous step let entity_take = entity_take_indices.get(entity_index).unwrap(); - // Append this entity's take indices to the take output builder - entity_take.iter().for_each(|i| { - take_output_builder.append_value(*i); - }); - - // Append this entity's current number of take indices to the output offset builder - cur_offset += entity_take.len(); - output_offset_builder.push(cur_offset as i32); + if entity_take.len() >= min { + // Append this entity's take indices to the take output builder + entity_take.iter().for_each(|i| { + take_output_builder.append_value(*i); + }); + + // Append this entity's current number of take indices to the output offset builder + cur_offset += entity_take.len(); + + output_offset_builder.push(cur_offset as i32); + null_buffer.push(true); + } else { + // Append null if there are not enough values + take_output_builder.append_null(); + null_buffer.push(false); + + // Cur offset increases by 1 to account for the null value + cur_offset += 1; + output_offset_builder.push(cur_offset as i32); + } } let output_values = sparrow_arrow::concat_take(old_state_flat, &input, &take_output_builder.finish())?; @@ -232,7 +271,7 @@ impl CollectStructEvaluator { field, OffsetBuffer::new(ScalarBuffer::from(output_offset_builder)), output_values, - None, + Some(NullBuffer::from(BooleanBuffer::from(null_buffer))), ); // Construct the new state offset and values using the current entity take indices @@ -264,7 +303,10 @@ impl CollectStructEvaluator { mod tests { use super::*; use arrow::{ - array::{AsArray, Int64Array, StringArray, StructArray}, + array::{ + ArrayBuilder, AsArray, Int64Array, Int64Builder, ListBuilder, StringArray, + StringBuilder, StructArray, StructBuilder, + }, buffer::ScalarBuffer, }; use arrow_schema::{DataType, Field, Fields}; @@ -308,6 +350,7 @@ mod tests { key_capacity, &key_indices, input, + 0, usize::MAX, ) .unwrap(); @@ -363,6 +406,7 @@ mod tests { key_capacity, &key_indices, input, + 0, usize::MAX, ) .unwrap(); @@ -431,6 +475,7 @@ mod tests { key_capacity, &key_indices, input, + 0, max, ) .unwrap(); @@ -480,6 +525,7 @@ mod tests { key_capacity, &key_indices, input, + 0, max, ) .unwrap(); @@ -508,4 +554,195 @@ mod tests { assert_eq!(expected.as_ref(), result); } + + #[test] + fn test_min() { + let min = 3; + let mut token = default_token(); + let fields = Fields::from(vec![ + Field::new("n", DataType::Int64, true), + Field::new("s", DataType::Utf8, true), + ]); + let field_builders: Vec> = vec![ + Box::new(Int64Builder::new()), + Box::new(StringBuilder::new()), + ]; + + // Batch 1 + let mut builder = StructBuilder::new(fields.clone(), field_builders); + builder + .field_builder::(0) + .unwrap() + .append_value(0); + builder + .field_builder::(1) + .unwrap() + .append_value("a"); + builder.append(true); + + builder + .field_builder::(0) + .unwrap() + .append_null(); + builder + .field_builder::(1) + .unwrap() + .append_null(); + builder.append(false); + + builder + .field_builder::(0) + .unwrap() + .append_value(1); + builder + .field_builder::(1) + .unwrap() + .append_value("b"); + builder.append(true); + + builder + .field_builder::(0) + .unwrap() + .append_value(2); + builder + .field_builder::(1) + .unwrap() + .append_value("c"); + builder.append(true); + + let input = builder.finish(); + let input = Arc::new(input); + + let key_indices = UInt32Array::from(vec![0, 0, 0, 0]); + let key_capacity = 1; + + let result = CollectStructEvaluator::evaluate_non_windowed( + &mut token, + key_capacity, + &key_indices, + input, + min, + usize::MAX, + ) + .unwrap(); + let result = result.as_list::(); + + // build expected result 1 + let field_builders: Vec> = vec![ + Box::new(Int64Builder::new()), + Box::new(StringBuilder::new()), + ]; + + let mut builder = ListBuilder::new(StructBuilder::new(fields, field_builders)); + builder + .values() + .field_builder::(0) + .unwrap() + .append_null(); + builder + .values() + .field_builder::(1) + .unwrap() + .append_null(); + builder.values().append(false); + builder.append(false); + + builder + .values() + .field_builder::(0) + .unwrap() + .append_null(); + builder + .values() + .field_builder::(1) + .unwrap() + .append_null(); + builder.values().append(false); + builder.append(false); + + builder + .values() + .field_builder::(0) + .unwrap() + .append_value(0); + builder + .values() + .field_builder::(1) + .unwrap() + .append_value("a"); + builder.values().append(true); + builder + .values() + .field_builder::(0) + .unwrap() + .append_null(); + builder + .values() + .field_builder::(1) + .unwrap() + .append_null(); + builder.values().append(false); + builder + .values() + .field_builder::(0) + .unwrap() + .append_value(1); + builder + .values() + .field_builder::(1) + .unwrap() + .append_value("b"); + builder.values().append(true); + builder.append(true); + + builder + .values() + .field_builder::(0) + .unwrap() + .append_value(0); + builder + .values() + .field_builder::(1) + .unwrap() + .append_value("a"); + builder.values().append(true); + builder + .values() + .field_builder::(0) + .unwrap() + .append_null(); + builder + .values() + .field_builder::(1) + .unwrap() + .append_null(); + builder.values().append(false); + builder + .values() + .field_builder::(0) + .unwrap() + .append_value(1); + builder + .values() + .field_builder::(1) + .unwrap() + .append_value("b"); + builder.values().append(true); + builder + .values() + .field_builder::(0) + .unwrap() + .append_value(2); + builder + .values() + .field_builder::(1) + .unwrap() + .append_value("c"); + builder.values().append(true); + builder.append(true); + let expected = builder.finish(); + let expected = Arc::new(expected); + + assert_eq!(expected.as_ref(), result); + } } diff --git a/crates/sparrow-main/tests/e2e/collect_tests.rs b/crates/sparrow-main/tests/e2e/collect_tests.rs index 024b63e1c..ab73d468f 100644 --- a/crates/sparrow-main/tests/e2e/collect_tests.rs +++ b/crates/sparrow-main/tests/e2e/collect_tests.rs @@ -64,7 +64,7 @@ async fn test_collect_with_null_max() { #[tokio::test] async fn test_collect_to_list_i64() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(max=10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,0 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,0 @@ -86,7 +86,7 @@ async fn test_collect_to_list_i64() { #[tokio::test] async fn test_collect_to_list_i64_dynamic() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(max=10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,0 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,2 @@ -108,7 +108,7 @@ async fn test_collect_to_list_i64_dynamic() { #[tokio::test] async fn test_collect_to_small_list_i64() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(max=2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,0 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,2 @@ -130,7 +130,7 @@ async fn test_collect_to_small_list_i64() { #[tokio::test] async fn test_collect_to_list_string() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(max=10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,hEllo 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,hEllo @@ -152,7 +152,7 @@ async fn test_collect_to_list_string() { #[tokio::test] async fn test_collect_to_list_string_dynamic() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(max=10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,hEllo 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,hi @@ -174,7 +174,7 @@ async fn test_collect_to_list_string_dynamic() { #[tokio::test] async fn test_collect_to_small_list_string() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(max=2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,hEllo 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,hi @@ -196,7 +196,7 @@ async fn test_collect_to_small_list_string() { #[tokio::test] async fn test_collect_to_list_boolean() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(max=10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,true 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,true @@ -218,7 +218,7 @@ async fn test_collect_to_list_boolean() { #[tokio::test] async fn test_collect_to_list_boolean_dynamic() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(max=10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,true 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,false @@ -240,7 +240,7 @@ async fn test_collect_to_list_boolean_dynamic() { #[tokio::test] async fn test_collect_to_small_list_boolean() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(max=2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,true 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,false @@ -289,9 +289,61 @@ async fn test_collect_structs() { "###); } +#[tokio::test] +async fn test_collect_with_minimum() { + insta::assert_snapshot!(QueryFixture::new("{ + min0: Collect.s | collect(max=10) | index(0), + min1: Collect.s | collect(min=2, max=10) | index(0), + min2: Collect.s | collect(min=3, max=10) | index(0) + }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,min0,min1,min2 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,hEllo,, + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,hEllo,hEllo, + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,hEllo,hEllo,hEllo + 1996-12-20T00:42:00.000000000,9223372036854775808,12960666915911099378,A,hEllo,hEllo,hEllo + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,hEllo,hEllo,hEllo + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A,hEllo,hEllo,hEllo + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,h,, + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,h,h, + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B,h,h,h + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,h,h,h + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,h,h,h + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C,g,, + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,g,g, + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C,g,g,g + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,g,g,g + "###); +} + +#[tokio::test] +#[ignore = "lag ignores nulls, so results are different"] +async fn test_collect_lag_equality() { + insta::assert_snapshot!(QueryFixture::new("{ + collect: Collect.n | collect(min=3, max=3) | index(0), + lag: Collect.n | lag(2) + }").with_dump_dot("asdf").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,collect,lag + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,, + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,, + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,0,0 + 1996-12-20T00:42:00.000000000,9223372036854775808,12960666915911099378,A,2,2 + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,9,9 + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A,-7,-7 + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,, + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,, + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B,5,5 + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,-2,5 + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,,-2 + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C,, + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,, + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C,1,1 + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,2,2 + "###); +} + #[tokio::test] async fn test_collect_primitive_since_minutely() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(10, window=since(minutely())) | index(0) | when(is_valid($input)) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(max=10, window=since(minutely())) | index(0) | when(is_valid($input)) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,0 1996-12-20T00:40:00.000000000,18446744073709551615,12960666915911099378,A,0 @@ -324,7 +376,7 @@ async fn test_collect_primitive_since_minutely() { async fn test_collect_primitive_since_minutely_1() { // Only two rows in this set exist within the same minute, hence these results when // getting the second item. - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(10, window=since(minutely())) | index(1) | when(is_valid($input)) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(max=10, window=since(minutely())) | index(1) | when(is_valid($input)) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:42:00.000000000,9223372036854775808,12960666915911099378,A,-7 1996-12-20T00:42:00.000000000,18446744073709551615,12960666915911099378,A,-7 @@ -334,7 +386,7 @@ async fn test_collect_primitive_since_minutely_1() { #[tokio::test] async fn test_collect_string_since_hourly() { // note that `B` is empty because we collect `null` as a valid value in a list currently - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(10, window=since(hourly())) | index(2) | when(is_valid($input)) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(max=10, window=since(hourly())) | index(2) | when(is_valid($input)) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,hey 1996-12-20T00:42:00.000000000,9223372036854775808,12960666915911099378,A,hey @@ -352,7 +404,7 @@ async fn test_collect_string_since_hourly() { #[tokio::test] async fn test_collect_boolean_since_hourly() { - insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(10, window=since(hourly())) | index(3) | when(is_valid($input)) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(max=10, window=since(hourly())) | index(3) | when(is_valid($input)) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" _time,_subsort,_key_hash,_key,f1 1996-12-20T00:42:00.000000000,9223372036854775808,12960666915911099378,A,false 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,false @@ -369,7 +421,50 @@ async fn test_collect_boolean_since_hourly() { async fn test_require_literal_max() { // TODO: We should figure out how to not report the second error -- type variables with // error propagation needs some fixing. - insta::assert_yaml_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(Collect.index) | index(1) }") + insta::assert_yaml_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(max=Collect.index) | index(1) }") + .run_to_csv(&collect_data_fixture().await).await.unwrap_err(), @r###" + --- + code: Client specified an invalid argument + message: 2 errors in Fenl statements; see diagnostics + fenl_diagnostics: + - severity: error + code: E0014 + message: Invalid non-constant argument + formatted: + - "error[E0014]: Invalid non-constant argument" + - " --> Query:1:31" + - " |" + - "1 | { f1: Collect.s | collect(max=Collect.index) | index(1) }" + - " | ^^^^^^^^^^^^^ Argument 'max' to 'collect' must be constant, but was not" + - "" + - "" + - severity: error + code: E0010 + message: Invalid argument type(s) + formatted: + - "error[E0010]: Invalid argument type(s)" + - " --> Query:1:48" + - " |" + - "1 | { f1: Collect.s | collect(max=Collect.index) | index(1) }" + - " | ^^^^^ Invalid types for parameter 'list' in call to 'index'" + - " |" + - " --> internal:1:1" + - " |" + - 1 | $input + - " | ------ Actual type: error" + - " |" + - " --> built-in signature 'index(i: i64, list: list) -> T':1:29" + - " |" + - "1 | index(i: i64, list: list) -> T" + - " | ------- Expected type: list" + - "" + - "" + "###); +} + +#[tokio::test] +async fn test_require_literal_min() { + insta::assert_yaml_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(min=Collect.index, max=10) | index(1) }") .run_to_csv(&collect_data_fixture().await).await.unwrap_err(), @r###" --- code: Client specified an invalid argument @@ -380,10 +475,10 @@ async fn test_require_literal_max() { message: Invalid non-constant argument formatted: - "error[E0014]: Invalid non-constant argument" - - " --> Query:1:27" + - " --> Query:1:31" - " |" - - "1 | { f1: Collect.s | collect(Collect.index) | index(1) }" - - " | ^^^^^^^^^^^^^ Argument 'max' to 'collect' must be constant, but was not" + - "1 | { f1: Collect.s | collect(min=Collect.index, max=10) | index(1) }" + - " | ^^^^^^^^^^^^^ Argument 'min' to 'collect' must be constant, but was not" - "" - "" - severity: error @@ -391,10 +486,10 @@ async fn test_require_literal_max() { message: Invalid argument type(s) formatted: - "error[E0010]: Invalid argument type(s)" - - " --> Query:1:44" + - " --> Query:1:56" - " |" - - "1 | { f1: Collect.s | collect(Collect.index) | index(1) }" - - " | ^^^^^ Invalid types for parameter 'list' in call to 'index'" + - "1 | { f1: Collect.s | collect(min=Collect.index, max=10) | index(1) }" + - " | ^^^^^ Invalid types for parameter 'list' in call to 'index'" - " |" - " --> internal:1:1" - " |" @@ -409,3 +504,25 @@ async fn test_require_literal_max() { - "" "###); } + +#[tokio::test] +async fn test_min_must_be_lte_max() { + insta::assert_yaml_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(min=10, max=0) | index(1) }") + .run_to_csv(&collect_data_fixture().await).await.unwrap_err(), @r###" + --- + code: Client specified an invalid argument + message: 1 errors in Fenl statements; see diagnostics + fenl_diagnostics: + - severity: error + code: E0002 + message: Illegal cast + formatted: + - "error[E0002]: Illegal cast" + - " --> Query:1:31" + - " |" + - "1 | { f1: Collect.s | collect(min=10, max=0) | index(1) }" + - " | ^^ min '10' must be less than or equal to max '0'" + - "" + - "" + "###); +} diff --git a/crates/sparrow-plan/src/inst.rs b/crates/sparrow-plan/src/inst.rs index 609ceed73..9ff1aca11 100644 --- a/crates/sparrow-plan/src/inst.rs +++ b/crates/sparrow-plan/src/inst.rs @@ -61,7 +61,7 @@ pub enum InstOp { #[strum(props(signature = "coalesce(values+: T) -> T"))] Coalesce, #[strum(props( - signature = "collect(max: i64, input: T, ticks: bool = null, slide_duration: i64 = null) -> list" + signature = "collect(input: T, const max: i64, const min: i64 = 0, ticks: bool = null, slide_duration: i64 = null) -> list" ))] Collect, #[strum(props( diff --git a/sparrow-py/pysrc/sparrow_py/_timestream.py b/sparrow-py/pysrc/sparrow_py/_timestream.py index f39e79114..01a875451 100644 --- a/sparrow-py/pysrc/sparrow_py/_timestream.py +++ b/sparrow-py/pysrc/sparrow_py/_timestream.py @@ -553,7 +553,7 @@ def is_not_null(self) -> Timestream: return Timestream._call("is_valid", self) def collect( - self, max: Optional[int], window: Optional["kt.Window"] = None + self, max: Optional[int], min: Optional[int] = 0, window: Optional["kt.Window"] = None ) -> Timestream: """ Create a Timestream collecting up to the last `max` values in the `window`. @@ -565,6 +565,9 @@ def collect( max : Optional[int] The maximum number of values to collect. If `None` all values are collected. + min: Optional[int] + The minimum number of values to collect before + producing a value. Defaults to 0. window : Optional[Window] The window to use for the aggregation. If not specified, the entire Timestream is used. @@ -574,7 +577,7 @@ def collect( Timestream Timestream containing the collected list at each point. """ - return _aggregation("collect", self, window, max) + return _aggregation("collect", self, window, max, min) def sum(self, window: Optional["kt.Window"] = None) -> Timestream: """ @@ -698,7 +701,7 @@ def _aggregation( window : Optional[Window] The window to use for the aggregation. *args : Union[Timestream, Literal] - Additional arguments to provide before `input` and the flattened window. + Additional arguments to provide after `input` and before the flattened window. Returns ------- @@ -710,16 +713,12 @@ def _aggregation( NotImplementedError If the window is not a known type. """ - # Note: things would be easier if we had a more normal order, which - # we could do as part of "aligning" Sparrow signatures to the new direction. - # However, `collect` currently has `collect(max, input, window)`, requiring - # us to add the *args like so. if window is None: - return Timestream._call(op, *args, input, None, None) + return Timestream._call(op, input, *args, None, None) elif isinstance(window, kt.SinceWindow): - return Timestream._call(op, *args, input, window.predicate, None) + return Timestream._call(op, input, *args, window.predicate, None) elif isinstance(window, kt.SlidingWindow): - return Timestream._call(op, *args, input, window.predicate, window.duration) + return Timestream._call(op, input, *args, window.predicate, window.duration) else: raise NotImplementedError(f"Unknown window type {window!r}") diff --git a/sparrow-py/pytests/collect_test.py b/sparrow-py/pytests/collect_test.py index 8a8fa50c0..ba222abac 100644 --- a/sparrow-py/pytests/collect_test.py +++ b/sparrow-py/pytests/collect_test.py @@ -47,6 +47,34 @@ def test_collect_with_max(source, golden) -> None: ) ) +def test_collect_with_min(source, golden) -> None: + m = source["m"] + n = source["n"] + golden( + kt.record( + { + "m": m, + "collect_m_min_2": m.collect(min=2, max=None), + "n": n, + "collect_n_min_2": n.collect(min=2, max=None), + } + ) + ) + +def test_collect_with_min_and_max(source, golden) -> None: + m = source["m"] + n = source["n"] + golden( + kt.record( + { + "m": m, + "collect_m_min_2_max_2": m.collect(min=2, max=2), + "n": n, + "collect_n_min_2_max_2": n.collect(min=2, max=2), + } + ) + ) + def test_collect_since_window(source, golden) -> None: m = source["m"] diff --git a/sparrow-py/pytests/golden/collect_test/test_collect_with_min.json b/sparrow-py/pytests/golden/collect_test/test_collect_with_min.json new file mode 100644 index 000000000..cd8bb71ff --- /dev/null +++ b/sparrow-py/pytests/golden/collect_test/test_collect_with_min.json @@ -0,0 +1,6 @@ +{"_time":851042397000000000,"_subsort":0,"_key_hash":12960666915911099378,"_key":"A","m":5.0,"collect_m_min_2":null,"n":10.0,"collect_n_min_2":null} +{"_time":851042398000000000,"_subsort":1,"_key_hash":2867199309159137213,"_key":"B","m":24.0,"collect_m_min_2":null,"n":3.0,"collect_n_min_2":null} +{"_time":851042399000000000,"_subsort":2,"_key_hash":12960666915911099378,"_key":"A","m":17.0,"collect_m_min_2":[5.0,17.0],"n":6.0,"collect_n_min_2":[10.0,6.0]} +{"_time":851042400000000000,"_subsort":3,"_key_hash":12960666915911099378,"_key":"A","m":null,"collect_m_min_2":[5.0,17.0,null],"n":9.0,"collect_n_min_2":[10.0,6.0,9.0]} +{"_time":851042401000000000,"_subsort":4,"_key_hash":12960666915911099378,"_key":"A","m":12.0,"collect_m_min_2":[5.0,17.0,null,12.0],"n":null,"collect_n_min_2":[10.0,6.0,9.0,null]} +{"_time":851042402000000000,"_subsort":5,"_key_hash":12960666915911099378,"_key":"A","m":null,"collect_m_min_2":[5.0,17.0,null,12.0,null],"n":null,"collect_n_min_2":[10.0,6.0,9.0,null,null]} diff --git a/sparrow-py/pytests/golden/collect_test/test_collect_with_min_and_max.json b/sparrow-py/pytests/golden/collect_test/test_collect_with_min_and_max.json new file mode 100644 index 000000000..4ae49d627 --- /dev/null +++ b/sparrow-py/pytests/golden/collect_test/test_collect_with_min_and_max.json @@ -0,0 +1,6 @@ +{"_time":851042397000000000,"_subsort":0,"_key_hash":12960666915911099378,"_key":"A","m":5.0,"collect_m_min_2_max_2":null,"n":10.0,"collect_n_min_2_max_2":null} +{"_time":851042398000000000,"_subsort":1,"_key_hash":2867199309159137213,"_key":"B","m":24.0,"collect_m_min_2_max_2":null,"n":3.0,"collect_n_min_2_max_2":null} +{"_time":851042399000000000,"_subsort":2,"_key_hash":12960666915911099378,"_key":"A","m":17.0,"collect_m_min_2_max_2":[5.0,17.0],"n":6.0,"collect_n_min_2_max_2":[10.0,6.0]} +{"_time":851042400000000000,"_subsort":3,"_key_hash":12960666915911099378,"_key":"A","m":null,"collect_m_min_2_max_2":[17.0,null],"n":9.0,"collect_n_min_2_max_2":[6.0,9.0]} +{"_time":851042401000000000,"_subsort":4,"_key_hash":12960666915911099378,"_key":"A","m":12.0,"collect_m_min_2_max_2":[null,12.0],"n":null,"collect_n_min_2_max_2":[9.0,null]} +{"_time":851042402000000000,"_subsort":5,"_key_hash":12960666915911099378,"_key":"A","m":null,"collect_m_min_2_max_2":[12.0,null],"n":null,"collect_n_min_2_max_2":[null,null]}