From 8e2b7484a325b6718b0c4e5a4f2ea1f3b04fb416 Mon Sep 17 00:00:00 2001 From: jordanrfrazier <122494242+jordanrfrazier@users.noreply.github.com> Date: Mon, 7 Aug 2023 09:31:54 -0700 Subject: [PATCH] feat: add list_len function (#609) This is the same implementation for all types, so we can move to a generic `length` evaluator when we move to `sparrow-expressions` package. --- .../src/functions/collection.rs | 5 +++ crates/sparrow-instructions/src/evaluators.rs | 1 + .../src/evaluators/list.rs | 2 ++ .../src/evaluators/list/list_len.rs | 36 +++++++++++++++++++ crates/sparrow-main/tests/e2e/list_tests.rs | 18 ++++++++++ crates/sparrow-plan/src/inst.rs | 2 ++ 6 files changed, 64 insertions(+) create mode 100644 crates/sparrow-instructions/src/evaluators/list/list_len.rs diff --git a/crates/sparrow-compiler/src/functions/collection.rs b/crates/sparrow-compiler/src/functions/collection.rs index add4b2ae3..aff252720 100644 --- a/crates/sparrow-compiler/src/functions/collection.rs +++ b/crates/sparrow-compiler/src/functions/collection.rs @@ -20,4 +20,9 @@ pub(super) fn register(registry: &mut Registry) { ) .with_implementation(Implementation::Instruction(InstOp::Collect)) .set_internal(); + + registry + .register("list_len(input: list) -> i32") + .with_implementation(Implementation::Instruction(InstOp::ListLen)) + .set_internal(); } diff --git a/crates/sparrow-instructions/src/evaluators.rs b/crates/sparrow-instructions/src/evaluators.rs index cbd67ed0a..70e1f82d4 100644 --- a/crates/sparrow-instructions/src/evaluators.rs +++ b/crates/sparrow-instructions/src/evaluators.rs @@ -288,6 +288,7 @@ fn create_simple_evaluator( ) } InstOp::Len => LenEvaluator::try_new(info), + InstOp::ListLen => ListLenEvaluator::try_new(info), InstOp::LogicalAnd => LogicalAndKleeneEvaluator::try_new(info), InstOp::LogicalOr => LogicalOrKleeneEvaluator::try_new(info), InstOp::Lower => LowerEvaluator::try_new(info), diff --git a/crates/sparrow-instructions/src/evaluators/list.rs b/crates/sparrow-instructions/src/evaluators/list.rs index 1ee1f8c93..2d0148e46 100644 --- a/crates/sparrow-instructions/src/evaluators/list.rs +++ b/crates/sparrow-instructions/src/evaluators/list.rs @@ -5,6 +5,7 @@ mod collect_primitive; mod collect_string; mod collect_struct; mod index; +mod list_len; pub(super) use collect_boolean::*; pub(super) use collect_list::*; @@ -13,3 +14,4 @@ pub(super) use collect_primitive::*; pub(super) use collect_string::*; pub(super) use collect_struct::*; pub(super) use index::*; +pub(super) use list_len::*; diff --git a/crates/sparrow-instructions/src/evaluators/list/list_len.rs b/crates/sparrow-instructions/src/evaluators/list/list_len.rs new file mode 100644 index 000000000..2110e2d54 --- /dev/null +++ b/crates/sparrow-instructions/src/evaluators/list/list_len.rs @@ -0,0 +1,36 @@ +use arrow::array::ArrayRef; + +use arrow_schema::DataType; +use sparrow_plan::ValueRef; +use std::sync::Arc; + +use crate::{Evaluator, EvaluatorFactory, StaticInfo}; + +/// Evaluator for `len` on lists. +/// +/// Produces the length of the list. +#[derive(Debug)] +pub(in crate::evaluators) struct ListLenEvaluator { + list: ValueRef, +} + +impl EvaluatorFactory for ListLenEvaluator { + fn try_new(info: StaticInfo<'_>) -> anyhow::Result> { + let input_type = info.args[0].data_type.clone(); + match input_type { + DataType::List(_) => (), + other => anyhow::bail!("expected list type, saw {:?}", other), + }; + + let list = info.unpack_argument()?; + Ok(Box::new(Self { list })) + } +} + +impl Evaluator for ListLenEvaluator { + fn evaluate(&mut self, info: &dyn crate::RuntimeInfo) -> anyhow::Result { + let input = info.value(&self.list)?.array_ref()?; + let result = arrow::compute::kernels::length::length(input.as_ref())?; + Ok(Arc::new(result)) + } +} diff --git a/crates/sparrow-main/tests/e2e/list_tests.rs b/crates/sparrow-main/tests/e2e/list_tests.rs index 8766466cf..3e9cfa322 100644 --- a/crates/sparrow-main/tests/e2e/list_tests.rs +++ b/crates/sparrow-main/tests/e2e/list_tests.rs @@ -127,6 +127,24 @@ async fn test_last_list() { "###); } +#[tokio::test] +async fn test_list_len() { + insta::assert_snapshot!(QueryFixture::new("{ + len_struct: { s: Input.string_list } | collect(max=null) | list_len(), + len_num: Input.i64_list | list_len(), + len_str: Input.string_list | list_len(), + len_bool: Input.bool_list | list_len(), + } + ").run_to_csv(&list_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,len_struct,len_num,len_str,len_bool + 1996-12-19T16:39:57.000000000,0,18433805721903975440,1,1,3,2,2 + 1996-12-19T16:40:57.000000000,0,18433805721903975440,1,2,3,3,2 + 1996-12-19T16:40:59.000000000,0,18433805721903975440,1,3,3,0,3 + 1996-12-19T16:41:57.000000000,0,18433805721903975440,1,4,3,2,3 + 1996-12-19T16:42:57.000000000,0,18433805721903975440,1,5,3,1,1 + "###); +} + #[tokio::test] async fn test_list_schemas_are_compatible() { // This query puts a collect() into a record, which diff --git a/crates/sparrow-plan/src/inst.rs b/crates/sparrow-plan/src/inst.rs index 2dc123a5f..609ceed73 100644 --- a/crates/sparrow-plan/src/inst.rs +++ b/crates/sparrow-plan/src/inst.rs @@ -123,6 +123,8 @@ pub enum InstOp { Last, #[strum(props(signature = "len(s: string) -> i32"))] Len, + #[strum(props(signature = "list_len(input: list) -> i32"))] + ListLen, #[strum(props(signature = "logical_and(a: bool, b: bool) -> bool"))] LogicalAnd, #[strum(props(signature = "logical_or(a: bool, b: bool) -> bool"))]