Skip to content

Commit

Permalink
feat: add min parameter to collect (#607)
Browse files Browse the repository at this point in the history
Adds the `min` parameter to collect, which allows a user to specify a
minimum list length before a non-null value is produced from
`collect()`. also reorders parameters to make the python builder arg
pattern a little cleaner.

Unfortunately, we still aren't able to use `collect` as a replacement
for `lag`, given the current behavior. `Lag` does not count `null` as a
valid value, while `collect` does. See the ignored unit test for an
example of differences.
  • Loading branch information
jordanrfrazier authored Aug 7, 2023
1 parent 4ba4936 commit 08f5e9c
Show file tree
Hide file tree
Showing 13 changed files with 551 additions and 67 deletions.
35 changes: 30 additions & 5 deletions crates/sparrow-compiler/src/ast_to_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,11 +557,11 @@ pub fn add_to_dfg(
//
// TODO: Flattening the window arguments is hacky and confusing. We should instead
// incorporate the tick directly into the function containing the window.
let window_arg = original_ast.map(|e| &e.args()[2]);
let window_arg = original_ast.map(|e| &e.args()[3]);
let (condition, duration) = match window_arg {
Some(window) => {
dfg.enter_env();
dfg.bind("$condition_input", args[1].inner().clone());
dfg.bind("$condition_input", args[0].inner().clone());

let result =
flatten_window_args_if_needed(window, dfg, data_context, diagnostics)?;
Expand All @@ -574,12 +574,37 @@ pub fn add_to_dfg(
//
// Note that this won't define the `condition_input` for the
// purposes of ticks.
(args[2].clone(), args[3].clone())
(args[3].clone(), args[4].clone())
}
};

// [max, input, condition, duration]
vec![args[0].clone(), args[1].clone(), condition, duration]
let min = dfg.literal(args[2].value());
let max = dfg.literal(args[1].value());
match (min, max) {
(Some(ScalarValue::Int64(Some(min))), Some(ScalarValue::Int64(Some(max)))) => {
if min > max {
DiagnosticCode::IllegalCast
.builder()
.with_label(args[2].location().primary_label().with_message(
format!(
"min '{min}' must be less than or equal to max '{max}'"
),
))
.emit(diagnostics);
}
}
(Some(_), Some(_)) => (),
(_, _) => panic!("previously verified min and max are scalar types"),
}

// [input, max, min, condition, duration]
vec![
args[0].clone(),
args[1].clone(),
args[2].clone(),
condition,
duration,
]
} else if function.name() == "when" || function.name() == "if" {
dfg.enter_env();
dfg.bind("$condition_input", args[1].inner().clone());
Expand Down
4 changes: 2 additions & 2 deletions crates/sparrow-compiler/src/functions/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ pub(super) fn register(registry: &mut Registry) {
.set_internal();

registry
.register("collect<T: any>(const max: i64, input: T, window: window = null) -> list<T>")
.register("collect<T: any>(input: T, const max: i64, const min: i64 = 0, window: window = null) -> list<T>")
.with_dfg_signature(
"collect<T: any>(const max: i64, input: T, window: bool = null, duration: i64 = null) -> list<T>",
"collect<T: any>(input: T, const max: i64, const min: i64 = 0, window: bool = null, duration: i64 = null) -> list<T>",
)
.with_implementation(Implementation::Instruction(InstOp::Collect))
.set_internal();
Expand Down
2 changes: 1 addition & 1 deletion crates/sparrow-instructions/src/evaluators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ fn create_simple_evaluator(
InstOp::Coalesce => CoalesceEvaluator::try_new(info),
InstOp::Collect => {
create_typed_evaluator!(
&info.args[1].data_type,
&info.args[0].data_type,
CollectPrimitiveEvaluator,
CollectStructEvaluator,
CollectListEvaluator,
Expand Down
30 changes: 26 additions & 4 deletions crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ use std::sync::Arc;
/// If the list is empty, an empty list is returned (rather than `null`).
#[derive(Debug)]
pub struct CollectBooleanEvaluator {
/// The min size of the buffer.
///
/// If the buffer is smaller than this, a null value
/// will be produced.
min: usize,
/// The max size of the buffer.
///
/// Once the max size is reached, the front will be popped and the new
Expand All @@ -28,14 +33,14 @@ pub struct CollectBooleanEvaluator {

impl EvaluatorFactory for CollectBooleanEvaluator {
fn try_new(info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> {
let input_type = info.args[1].data_type();
let input_type = info.args[0].data_type();
let result_type = info.result_type;
match result_type {
DataType::List(t) => anyhow::ensure!(t.data_type() == input_type),
other => anyhow::bail!("expected list result type, saw {:?}", other),
};

let max = match info.args[0].value_ref.literal_value() {
let max = match info.args[1].value_ref.literal_value() {
Some(ScalarValue::Int64(Some(v))) if *v <= 0 => {
anyhow::bail!("unexpected value of `max` -- must be > 0")
}
Expand All @@ -47,8 +52,21 @@ impl EvaluatorFactory for CollectBooleanEvaluator {
None => anyhow::bail!("expected literal value for max parameter"),
};

let (_, input, tick, duration) = info.unpack_arguments()?;
let min = match info.args[2].value_ref.literal_value() {
Some(ScalarValue::Int64(Some(v))) if *v < 0 => {
anyhow::bail!("unexpected value of `min` -- must be >= 0")
}
Some(ScalarValue::Int64(Some(v))) => *v as usize,
// If a user specifies `min = null`, default to 0.
Some(ScalarValue::Int64(None)) => 0,
Some(other) => anyhow::bail!("expected i64 for min parameter, saw {:?}", other),
None => anyhow::bail!("expected literal value for min parameter"),
};
assert!(min < max, "min must be less than max");

let (input, _, _, tick, duration) = info.unpack_arguments()?;
Ok(Box::new(Self {
min,
max,
input,
tick,
Expand Down Expand Up @@ -100,7 +118,11 @@ impl CollectBooleanEvaluator {
self.token.add_value(self.max, entity_index, input);
let cur_list = self.token.state(entity_index);

list_builder.append_value(cur_list.iter().copied());
if cur_list.len() >= self.min {
list_builder.append_value(cur_list.iter().copied());
} else {
list_builder.append_null();
}
});

Ok(Arc::new(list_builder.finish()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ where
T: ArrowPrimitiveType,
T::Native: Serialize + DeserializeOwned + Copy,
{
/// The min size of the buffer.
///
/// If the buffer is smaller than this, a null value
/// will be produced.
min: usize,
/// The max size of the buffer.
///
/// Once the max size is reached, the front will be popped and the new
Expand All @@ -43,14 +48,14 @@ where
T::Native: Serialize + DeserializeOwned + Copy,
{
fn try_new(info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> {
let input_type = info.args[1].data_type();
let input_type = info.args[0].data_type();
let result_type = info.result_type;
match result_type {
DataType::List(t) => anyhow::ensure!(t.data_type() == input_type),
other => anyhow::bail!("expected list result type, saw {:?}", other),
};

let max = match info.args[0].value_ref.literal_value() {
let max = match info.args[1].value_ref.literal_value() {
Some(ScalarValue::Int64(Some(v))) if *v <= 0 => {
anyhow::bail!("unexpected value of `max` -- must be > 0")
}
Expand All @@ -62,8 +67,21 @@ where
None => anyhow::bail!("expected literal value for max parameter"),
};

let (_, input, tick, duration) = info.unpack_arguments()?;
let min = match info.args[2].value_ref.literal_value() {
Some(ScalarValue::Int64(Some(v))) if *v < 0 => {
anyhow::bail!("unexpected value of `min` -- must be >= 0")
}
Some(ScalarValue::Int64(Some(v))) => *v as usize,
// If a user specifies `min = null`, default to 0.
Some(ScalarValue::Int64(None)) => 0,
Some(other) => anyhow::bail!("expected i64 for min parameter, saw {:?}", other),
None => anyhow::bail!("expected literal value for min parameter"),
};
debug_assert!(min <= max, "min must be less than max");

let (input, _, _, tick, duration) = info.unpack_arguments()?;
Ok(Box::new(Self {
min,
max,
input,
tick,
Expand Down Expand Up @@ -123,7 +141,11 @@ where
self.token.add_value(self.max, entity_index, input);
let cur_list = self.token.state(entity_index);

list_builder.append_value(cur_list.iter().copied());
if cur_list.len() >= self.min {
list_builder.append_value(cur_list.iter().copied());
} else {
list_builder.append_null();
}
});

Ok(Arc::new(list_builder.finish()))
Expand Down
30 changes: 26 additions & 4 deletions crates/sparrow-instructions/src/evaluators/list/collect_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ use std::sync::Arc;
/// If the list is empty, an empty list is returned (rather than `null`).
#[derive(Debug)]
pub struct CollectStringEvaluator {
/// The min size of the buffer.
///
/// If the buffer is smaller than this, a null value
/// will be produced.
min: usize,
/// The max size of the buffer.
///
/// Once the max size is reached, the front will be popped and the new
Expand All @@ -28,14 +33,14 @@ pub struct CollectStringEvaluator {

impl EvaluatorFactory for CollectStringEvaluator {
fn try_new(info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> {
let input_type = info.args[1].data_type();
let input_type = info.args[0].data_type();
let result_type = info.result_type;
match result_type {
DataType::List(t) => anyhow::ensure!(t.data_type() == input_type),
other => anyhow::bail!("expected list result type, saw {:?}", other),
};

let max = match info.args[0].value_ref.literal_value() {
let max = match info.args[1].value_ref.literal_value() {
Some(ScalarValue::Int64(Some(v))) if *v <= 0 => {
anyhow::bail!("unexpected value of `max` -- must be > 0")
}
Expand All @@ -47,8 +52,21 @@ impl EvaluatorFactory for CollectStringEvaluator {
None => anyhow::bail!("expected literal value for max parameter"),
};

let (_, input, tick, duration) = info.unpack_arguments()?;
let min = match info.args[2].value_ref.literal_value() {
Some(ScalarValue::Int64(Some(v))) if *v < 0 => {
anyhow::bail!("unexpected value of `min` -- must be >= 0")
}
Some(ScalarValue::Int64(Some(v))) => *v as usize,
// If a user specifies `min = null`, default to 0.
Some(ScalarValue::Int64(None)) => 0,
Some(other) => anyhow::bail!("expected i64 for min parameter, saw {:?}", other),
None => anyhow::bail!("expected literal value for min parameter"),
};
assert!(min < max, "min must be less than max");

let (input, _, _, tick, duration) = info.unpack_arguments()?;
Ok(Box::new(Self {
min,
max,
input,
tick,
Expand Down Expand Up @@ -101,7 +119,11 @@ impl CollectStringEvaluator {
.add_value(self.max, entity_index, input.map(|s| s.to_owned()));
let cur_list = self.token.state(entity_index);

list_builder.append_value(cur_list.clone());
if cur_list.len() >= self.min {
list_builder.append_value(cur_list.clone());
} else {
list_builder.append_null();
}
});

Ok(Arc::new(list_builder.finish()))
Expand Down
Loading

0 comments on commit 08f5e9c

Please sign in to comment.