diff --git a/Cargo.lock b/Cargo.lock index a8ffdd804c23..e344e7f0f078 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3306,6 +3306,7 @@ name = "polars-plan" version = "0.41.3" dependencies = [ "ahash", + "bitflags", "bytemuck", "chrono", "chrono-tz", diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 5f34b91167c1..4d13d784540e 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -39,7 +39,9 @@ impl ApplyExpr { output_dtype: Option, ) -> Self { #[cfg(debug_assertions)] - if matches!(options.collect_groups, ApplyOptions::ElementWise) && options.returns_scalar { + if matches!(options.collect_groups, ApplyOptions::ElementWise) + && options.flags.contains(FunctionFlags::RETURNS_SCALAR) + { panic!("expr {:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr) } @@ -48,13 +50,13 @@ impl ApplyExpr { function, expr, collect_groups: options.collect_groups, - returns_scalar: options.returns_scalar, - allow_rename: options.allow_rename, - pass_name_to_apply: options.pass_name_to_apply, + returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR), + allow_rename: options.flags.contains(FunctionFlags::ALLOW_RENAME), + pass_name_to_apply: options.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY), input_schema, allow_threading, check_lengths: options.check_lengths(), - allow_group_aware: options.allow_group_aware, + allow_group_aware: options.flags.contains(FunctionFlags::ALLOW_GROUP_AWARE), output_dtype, } } diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index 5a71230a9d09..c2ccf7028b03 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -299,7 +299,7 @@ impl WindowExpr { }, Expr::Function { options, .. } | Expr::AnonymousFunction { options, .. } => { - if options.returns_scalar + if options.flags.contains(FunctionFlags::RETURNS_SCALAR) && matches!(options.collect_groups, ApplyOptions::GroupWise) { agg_col = true; diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 85968c74e77c..4ba657df9fa7 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -496,8 +496,8 @@ fn create_physical_expr_inner( .ok() }); - let is_reducing_aggregation = - options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise); + let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR) + && matches!(options.collect_groups, ApplyOptions::GroupWise); // Will be reset in the function so get that here. let has_window = state.local.has_window; let input = create_physical_expressions_check_state( @@ -534,8 +534,8 @@ fn create_physical_expr_inner( .to_dtype(schema, Context::Default, expr_arena) .ok() }); - let is_reducing_aggregation = - options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise); + let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR) + && matches!(options.collect_groups, ApplyOptions::GroupWise); // Will be reset in the function so get that here. let has_window = state.local.has_window; let input = create_physical_expressions_check_state( diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 55de822d1a4a..ebdae1d33b4a 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -24,6 +24,7 @@ polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } +bitflags = { workspace = true } bytemuck = { workspace = true } chrono = { workspace = true, optional = true } chrono-tz = { workspace = true, optional = true } diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index 8e6efd50b84c..e781e65201fe 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -152,7 +152,7 @@ impl ArrayNameSpace { false, ) .with_function_options(|mut options| { - options.input_wildcard_expansion = true; + options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; options }) } diff --git a/crates/polars-plan/src/dsl/functions/business.rs b/crates/polars-plan/src/dsl/functions/business.rs index 2aa21727e8f1..180e27b61063 100644 --- a/crates/polars-plan/src/dsl/functions/business.rs +++ b/crates/polars-plan/src/dsl/functions/business.rs @@ -16,7 +16,7 @@ pub fn business_day_count( holidays, }), options: FunctionOptions { - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/coerce.rs b/crates/polars-plan/src/dsl/functions/coerce.rs index 12e1f930f5f3..b821ae247d70 100644 --- a/crates/polars-plan/src/dsl/functions/coerce.rs +++ b/crates/polars-plan/src/dsl/functions/coerce.rs @@ -6,9 +6,10 @@ pub fn as_struct(exprs: Vec) -> Expr { input: exprs, function: FunctionExpr::AsStruct, options: FunctionOptions { - input_wildcard_expansion: true, - pass_name_to_apply: true, collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default() + | FunctionFlags::PASS_NAME_TO_APPLY + | FunctionFlags::INPUT_WILDCARD_EXPANSION, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/concat.rs b/crates/polars-plan/src/dsl/functions/concat.rs index f2dda1cfbab7..6f420c72f768 100644 --- a/crates/polars-plan/src/dsl/functions/concat.rs +++ b/crates/polars-plan/src/dsl/functions/concat.rs @@ -15,8 +15,8 @@ pub fn concat_str>(s: E, separator: &str, ignore_nulls: bool) - .into(), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, - returns_scalar: false, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR, ..Default::default() }, } @@ -63,7 +63,7 @@ pub fn concat_list, IE: Into + Clone>(s: E) -> PolarsResult function: FunctionExpr::ListExpr(ListFunction::Concat), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, ..Default::default() }, }) @@ -81,7 +81,7 @@ pub fn concat_expr, IE: Into + Clone>( function: FunctionExpr::ConcatExpr(rechunk), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, cast_to_supertypes: Some(Default::default()), ..Default::default() }, diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index a7483c448851..bb0fc5aa3cf1 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -13,7 +13,7 @@ pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: Some(Default::default()), - returns_scalar: true, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, ..Default::default() }, } @@ -36,7 +36,7 @@ pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: Some(Default::default()), - returns_scalar: true, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, ..Default::default() }, } @@ -64,7 +64,7 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: Some(Default::default()), - returns_scalar: true, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index f341e7ec6c65..f120570a9dc3 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -46,8 +46,9 @@ where output_type: GetOutput::super_type(), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - input_wildcard_expansion: true, - returns_scalar: true, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::RETURNS_SCALAR, fmt_str: "fold", ..Default::default() }, @@ -90,8 +91,9 @@ where output_type: GetOutput::super_type(), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - input_wildcard_expansion: true, - returns_scalar: true, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::RETURNS_SCALAR, fmt_str: "reduce", ..Default::default() }, @@ -136,8 +138,9 @@ where output_type: cum_fold_dtype(), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - input_wildcard_expansion: true, - returns_scalar: true, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::RETURNS_SCALAR, fmt_str: "cum_reduce", ..Default::default() }, @@ -181,8 +184,9 @@ where output_type: cum_fold_dtype(), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - input_wildcard_expansion: true, - returns_scalar: true, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::RETURNS_SCALAR, fmt_str: "cum_fold", ..Default::default() }, @@ -200,7 +204,7 @@ pub fn all_horizontal>(exprs: E) -> PolarsResult { input: exprs, function: FunctionExpr::Boolean(BooleanFunction::AllHorizontal), options: FunctionOptions { - input_wildcard_expansion: true, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, ..Default::default() }, }) @@ -217,7 +221,7 @@ pub fn any_horizontal>(exprs: E) -> PolarsResult { input: exprs, function: FunctionExpr::Boolean(BooleanFunction::AnyHorizontal), options: FunctionOptions { - input_wildcard_expansion: true, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, ..Default::default() }, }) @@ -235,9 +239,9 @@ pub fn max_horizontal>(exprs: E) -> PolarsResult { function: FunctionExpr::MaxHorizontal, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, - returns_scalar: false, - allow_rename: true, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR + | FunctionFlags::ALLOW_RENAME, ..Default::default() }, }) @@ -255,9 +259,9 @@ pub fn min_horizontal>(exprs: E) -> PolarsResult { function: FunctionExpr::MinHorizontal, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, - returns_scalar: false, - allow_rename: true, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR + | FunctionFlags::ALLOW_RENAME, ..Default::default() }, }) @@ -273,8 +277,8 @@ pub fn sum_horizontal>(exprs: E) -> PolarsResult { function: FunctionExpr::SumHorizontal, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, - returns_scalar: false, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR, cast_to_supertypes: None, ..Default::default() }, @@ -291,8 +295,8 @@ pub fn mean_horizontal>(exprs: E) -> PolarsResult { function: FunctionExpr::MeanHorizontal, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, - returns_scalar: false, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR, cast_to_supertypes: None, ..Default::default() }, @@ -310,7 +314,7 @@ pub fn coalesce(exprs: &[Expr]) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, cast_to_supertypes: Some(Default::default()), - input_wildcard_expansion: true, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/range.rs b/crates/polars-plan/src/dsl/functions/range.rs index a0c070560969..0fb81aeb0da0 100644 --- a/crates/polars-plan/src/dsl/functions/range.rs +++ b/crates/polars-plan/src/dsl/functions/range.rs @@ -15,7 +15,7 @@ pub fn int_range(start: Expr, end: Expr, step: i64, dtype: DataType) -> Expr { input, function: FunctionExpr::Range(RangeFunction::IntRange { step, dtype }), options: FunctionOptions { - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } @@ -29,7 +29,7 @@ pub fn int_ranges(start: Expr, end: Expr, step: Expr) -> Expr { input, function: FunctionExpr::Range(RangeFunction::IntRanges), options: FunctionOptions { - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } @@ -45,7 +45,7 @@ pub fn date_range(start: Expr, end: Expr, interval: Duration, closed: ClosedWind function: FunctionExpr::Range(RangeFunction::DateRange { interval, closed }), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } @@ -61,7 +61,7 @@ pub fn date_ranges(start: Expr, end: Expr, interval: Duration, closed: ClosedWin function: FunctionExpr::Range(RangeFunction::DateRanges { interval, closed }), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } @@ -90,7 +90,7 @@ pub fn datetime_range( options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: Some(Default::default()), - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } @@ -119,7 +119,7 @@ pub fn datetime_ranges( options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: Some(Default::default()), - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } @@ -135,7 +135,7 @@ pub fn time_range(start: Expr, end: Expr, interval: Duration, closed: ClosedWind function: FunctionExpr::Range(RangeFunction::TimeRange { interval, closed }), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } @@ -151,7 +151,7 @@ pub fn time_ranges(start: Expr, end: Expr, interval: Duration, closed: ClosedWin function: FunctionExpr::Range(RangeFunction::TimeRanges { interval, closed }), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - allow_rename: true, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index ec6ba9ac6668..0289e69c6514 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -142,8 +142,9 @@ pub fn datetime(args: DatetimeArgs) -> Expr { }), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - allow_rename: true, - input_wildcard_expansion: true, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::ALLOW_RENAME, fmt_str: "datetime", ..Default::default() }, @@ -271,7 +272,7 @@ pub fn duration(args: DurationArgs) -> Expr { function: FunctionExpr::TemporalExpr(TemporalFunction::Duration(args.time_unit)), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index d3dec58ee6f2..87691f263757 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -337,7 +337,7 @@ impl ListNameSpace { false, ) .with_function_options(|mut options| { - options.input_wildcard_expansion = true; + options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; options }) } @@ -354,7 +354,7 @@ impl ListNameSpace { false, ) .with_function_options(|mut options| { - options.input_wildcard_expansion = true; + options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; options }) } @@ -366,9 +366,9 @@ impl ListNameSpace { function: FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - returns_scalar: false, cast_to_supertypes: Some(SuperTypeOptions { implode_list: true }), - input_wildcard_expansion: true, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index cd10ce9c8503..b9df4a21412d 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -304,7 +304,7 @@ impl Expr { pub fn arg_min(self) -> Self { let options = FunctionOptions { collect_groups: ApplyOptions::GroupWise, - returns_scalar: true, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, fmt_str: "arg_min", ..Default::default() }; @@ -325,7 +325,7 @@ impl Expr { pub fn arg_max(self) -> Self { let options = FunctionOptions { collect_groups: ApplyOptions::GroupWise, - returns_scalar: true, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, fmt_str: "arg_max", ..Default::default() }; @@ -366,7 +366,7 @@ impl Expr { function: FunctionExpr::SearchSorted(side), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - returns_scalar: true, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, fmt_str: "search_sorted", cast_to_supertypes: Some(Default::default()), ..Default::default() @@ -691,12 +691,17 @@ impl Expr { None }; + let mut flags = FunctionFlags::default(); + if returns_scalar { + flags |= FunctionFlags::RETURNS_SCALAR; + } + Expr::Function { input, function: function_expr, options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - returns_scalar, + flags, cast_to_supertypes, ..Default::default() }, @@ -719,13 +724,17 @@ impl Expr { } else { None }; + let mut flags = FunctionFlags::default(); + if returns_scalar { + flags |= FunctionFlags::RETURNS_SCALAR; + } Expr::Function { input, function: function_expr, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - returns_scalar, + flags, cast_to_supertypes, ..Default::default() }, @@ -803,7 +812,7 @@ impl Expr { pub fn product(self) -> Self { let options = FunctionOptions { collect_groups: ApplyOptions::GroupWise, - returns_scalar: true, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, fmt_str: "product", ..Default::default() }; @@ -1084,7 +1093,7 @@ impl Expr { pub fn approx_n_unique(self) -> Self { self.apply_private(FunctionExpr::ApproxNUnique) .with_function_options(|mut options| { - options.returns_scalar = true; + options.flags |= FunctionFlags::RETURNS_SCALAR; options }) } @@ -1578,7 +1587,7 @@ impl Expr { include_breaks, }) .with_function_options(|mut opt| { - opt.pass_name_to_apply = true; + opt.flags |= FunctionFlags::PASS_NAME_TO_APPLY; opt }) } @@ -1601,7 +1610,7 @@ impl Expr { include_breaks, }) .with_function_options(|mut opt| { - opt.pass_name_to_apply = true; + opt.flags |= FunctionFlags::PASS_NAME_TO_APPLY; opt }) } @@ -1625,7 +1634,7 @@ impl Expr { include_breaks, }) .with_function_options(|mut opt| { - opt.pass_name_to_apply = true; + opt.flags |= FunctionFlags::PASS_NAME_TO_APPLY; opt }) } @@ -1667,7 +1676,7 @@ impl Expr { pub fn skew(self, bias: bool) -> Expr { self.apply_private(FunctionExpr::Skew(bias)) .with_function_options(|mut options| { - options.returns_scalar = true; + options.flags |= FunctionFlags::RETURNS_SCALAR; options }) } @@ -1683,7 +1692,7 @@ impl Expr { pub fn kurtosis(self, fisher: bool, bias: bool) -> Expr { self.apply_private(FunctionExpr::Kurtosis(fisher, bias)) .with_function_options(|mut options| { - options.returns_scalar = true; + options.flags |= FunctionFlags::RETURNS_SCALAR; options }) } @@ -1742,7 +1751,7 @@ impl Expr { pub fn any(self, ignore_nulls: bool) -> Self { self.apply_private(BooleanFunction::Any { ignore_nulls }.into()) .with_function_options(|mut opt| { - opt.returns_scalar = true; + opt.flags |= FunctionFlags::RETURNS_SCALAR; opt }) } @@ -1757,7 +1766,7 @@ impl Expr { pub fn all(self, ignore_nulls: bool) -> Self { self.apply_private(BooleanFunction::All { ignore_nulls }.into()) .with_function_options(|mut opt| { - opt.returns_scalar = true; + opt.flags |= FunctionFlags::RETURNS_SCALAR; opt }) } @@ -1780,7 +1789,7 @@ impl Expr { normalize, }) .with_function_options(|mut opts| { - opts.pass_name_to_apply = true; + opts.flags |= FunctionFlags::PASS_NAME_TO_APPLY; opts }) } @@ -1817,7 +1826,7 @@ impl Expr { pub fn entropy(self, base: f64, normalize: bool) -> Self { self.apply_private(FunctionExpr::Entropy { base, normalize }) .with_function_options(|mut options| { - options.returns_scalar = true; + options.flags |= FunctionFlags::RETURNS_SCALAR; options }) } @@ -1825,7 +1834,7 @@ impl Expr { pub fn null_count(self) -> Expr { self.apply_private(FunctionExpr::NullCount) .with_function_options(|mut options| { - options.returns_scalar = true; + options.flags |= FunctionFlags::RETURNS_SCALAR; options }) } @@ -1963,8 +1972,8 @@ where output_type, options: FunctionOptions { collect_groups: ApplyOptions::ApplyList, - returns_scalar: true, fmt_str: "", + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, ..Default::default() }, } @@ -1990,6 +1999,10 @@ where E: AsRef<[Expr]>, { let input = expr.as_ref().to_vec(); + let mut flags = FunctionFlags::default(); + if returns_scalar { + flags |= FunctionFlags::RETURNS_SCALAR; + } Expr::AnonymousFunction { input, @@ -1999,8 +2012,8 @@ where collect_groups: ApplyOptions::GroupWise, // don't set this to true // this is for the caller to decide - returns_scalar, fmt_str: "", + flags, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index 038b6c33bb1f..71828bccc5ac 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -251,6 +251,10 @@ impl Expr { }, }) }); + let mut flags = FunctionFlags::default(); + if returns_scalar { + flags |= FunctionFlags::RETURNS_SCALAR; + } Expr::AnonymousFunction { input: vec![self], @@ -259,7 +263,7 @@ impl Expr { options: FunctionOptions { collect_groups, fmt_str: name, - returns_scalar, + flags, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index aec6952cb057..e03859326662 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -339,7 +339,7 @@ impl StringNameSpace { .into(), ) .with_function_options(|mut options| { - options.returns_scalar = true; + options.flags |= FunctionFlags::RETURNS_SCALAR; options.collect_groups = ApplyOptions::GroupWise; options }) diff --git a/crates/polars-plan/src/dsl/struct_.rs b/crates/polars-plan/src/dsl/struct_.rs index d609547c1ff7..b5a1afafa698 100644 --- a/crates/polars-plan/src/dsl/struct_.rs +++ b/crates/polars-plan/src/dsl/struct_.rs @@ -11,7 +11,7 @@ impl StructNameSpace { index, ))) .with_function_options(|mut options| { - options.allow_rename = true; + options.flags |= FunctionFlags::ALLOW_RENAME; options }) } @@ -33,7 +33,7 @@ impl StructNameSpace { names, ))) .with_function_options(|mut options| { - options.allow_rename = true; + options.flags |= FunctionFlags::ALLOW_RENAME; options }) } @@ -49,7 +49,7 @@ impl StructNameSpace { ColumnName::from(name), ))) .with_function_options(|mut options| { - options.allow_rename = true; + options.flags |= FunctionFlags::ALLOW_RENAME; options }) } @@ -98,9 +98,9 @@ impl StructNameSpace { function: FunctionExpr::StructExpr(StructFunction::WithFields), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - pass_name_to_apply: true, - allow_group_aware: false, - input_wildcard_expansion: true, + flags: FunctionFlags::default() & !FunctionFlags::ALLOW_GROUP_AWARE + | FunctionFlags::PASS_NAME_TO_APPLY + | FunctionFlags::INPUT_WILDCARD_EXPANSION, ..Default::default() }, }) diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 5c613e05f249..07f2a94f6b34 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -249,7 +249,8 @@ impl AExpr { options, .. } => { - *nested = nested.saturating_sub(options.returns_scalar as _); + *nested = nested + .saturating_sub(options.flags.contains(FunctionFlags::RETURNS_SCALAR) as _); let tmp = function.get_output(); let output_type = tmp.as_ref().unwrap_or(output_type); let fields = func_args_to_fields(input, schema, arena, nested)?; @@ -261,7 +262,8 @@ impl AExpr { input, options, } => { - *nested = nested.saturating_sub(options.returns_scalar as _); + *nested = nested + .saturating_sub(options.flags.contains(FunctionFlags::RETURNS_SCALAR) as _); let fields = func_args_to_fields(input, schema, arena, nested)?; polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function); function.get_field(schema, Context::Default, &fields) diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index 4f168d1080d1..d14e51139eb6 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -544,7 +544,9 @@ fn prepare_excluded( fn expand_function_inputs(expr: Expr, schema: &Schema) -> Expr { expr.map_expr(|mut e| match &mut e { Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. } - if options.input_wildcard_expansion => + if options + .flags + .contains(FunctionFlags::INPUT_WILDCARD_EXPANSION) => { *input = rewrite_projections(core::mem::take(input), schema, &[]).unwrap(); e diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs index c5fcadc5ce80..d7480c463b7c 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs @@ -99,7 +99,7 @@ pub(super) fn predicate_is_sort_boundary(node: Node, expr_arena: &Arena) // group sensitive and doesn't auto-explode (e.g. is a reduction/aggregation // like sum, min, etc). // function that match this are `cum_sum`, `shift`, `sort`, etc. - options.is_groups_sensitive() && !options.returns_scalar + options.is_groups_sensitive() && !options.flags.contains(FunctionFlags::RETURNS_SCALAR) }, _ => false, }; diff --git a/crates/polars-plan/src/plans/optimizer/simplify_expr.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr.rs index c8d5e26e09b0..86cbcc0e5e82 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr.rs @@ -413,8 +413,9 @@ fn string_addition_to_linear_concat( .into(), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - input_wildcard_expansion: true, - returns_scalar: false, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + & !FunctionFlags::RETURNS_SCALAR, ..Default::default() }, }), diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index f2f53240e96e..e6a3733b2d79 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -2,6 +2,7 @@ use std::num::NonZeroUsize; use std::path::PathBuf; +use bitflags::bitflags; use polars_core::prelude::*; use polars_core::utils::SuperTypeOptions; #[cfg(feature = "csv")] @@ -108,6 +109,53 @@ impl Default for UnsafeBool { } } +bitflags!( + #[repr(transparent)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] + pub struct FunctionFlags: u8 { + // Raise if use in group by + const ALLOW_GROUP_AWARE = 1 << 0; + // For example a `unique` or a `slice` + const CHANGES_LENGTH = 1 << 1; + // The physical expression may rename the output of this function. + // If set to `false` the physical engine will ensure the left input + // expression is the output name. + const ALLOW_RENAME = 1 << 2; + // if set, then the `Series` passed to the function in the group_by operation + // will ensure the name is set. This is an extra heap allocation per group. + const PASS_NAME_TO_APPLY = 1 << 3; + /// There can be two ways of expanding wildcards: + /// + /// Say the schema is 'a', 'b' and there is a function `f`. In this case, `f('*')` can expand + /// to: + /// 1. `f('a', 'b')` + /// 2. `f('a'), f('b')` + /// + /// Setting this to true, will lead to behavior 1. + /// + /// This also accounts for regex expansion. + const INPUT_WILDCARD_EXPANSION = 1 << 4; + /// Automatically explode on unit length if it ran as final aggregation. + /// + /// this is the case for aggregations like sum, min, covariance etc. + /// We need to know this because we cannot see the difference between + /// the following functions based on the output type and number of elements: + /// + /// x: {1, 2, 3} + /// + /// head_1(x) -> {1} + /// sum(x) -> {4} + const RETURNS_SCALAR = 1 << 5; + } +); + +impl Default for FunctionFlags { + fn default() -> Self { + Self::from_bits_truncate(0) | Self::ALLOW_GROUP_AWARE + } +} + #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct FunctionOptions { @@ -117,47 +165,15 @@ pub struct FunctionOptions { // used for formatting, (only for anonymous functions) #[cfg_attr(feature = "serde", serde(skip_deserializing))] pub fmt_str: &'static str, - /// There can be two ways of expanding wildcards: - /// - /// Say the schema is 'a', 'b' and there is a function `f`. In this case, `f('*')` can expand - /// to: - /// 1. `f('a', 'b')` - /// 2. `f('a'), f('b')` - /// - /// Setting this to true, will lead to behavior 1. - /// - /// This also accounts for regex expansion. - pub input_wildcard_expansion: bool, - /// Automatically explode on unit length if it ran as final aggregation. - /// - /// this is the case for aggregations like sum, min, covariance etc. - /// We need to know this because we cannot see the difference between - /// the following functions based on the output type and number of elements: - /// - /// x: {1, 2, 3} - /// - /// head_1(x) -> {1} - /// sum(x) -> {4} - pub returns_scalar: bool, // if the expression and its inputs should be cast to supertypes // `None` -> Don't cast. // `Some` -> cast with given options. #[cfg_attr(feature = "serde", serde(skip))] pub cast_to_supertypes: Option, - // The physical expression may rename the output of this function. - // If set to `false` the physical engine will ensure the left input - // expression is the output name. - pub allow_rename: bool, - // if set, then the `Series` passed to the function in the group_by operation - // will ensure the name is set. This is an extra heap allocation per group. - pub pass_name_to_apply: bool, - // For example a `unique` or a `slice` - pub changes_length: bool, // Validate the output of a `map`. // this should always be true or we could OOB pub check_lengths: UnsafeBool, - // Raise if use in group by - pub allow_group_aware: bool, + pub flags: FunctionFlags, } impl FunctionOptions { @@ -182,15 +198,10 @@ impl Default for FunctionOptions { fn default() -> Self { FunctionOptions { collect_groups: ApplyOptions::GroupWise, - input_wildcard_expansion: false, - returns_scalar: false, fmt_str: "", cast_to_supertypes: None, - allow_rename: false, - pass_name_to_apply: false, - changes_length: false, check_lengths: UnsafeBool(true), - allow_group_aware: true, + flags: Default::default(), } } } diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index e1d857bce9e3..cd7e6c3e0c7e 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -113,7 +113,7 @@ pub(crate) fn has_leaf_literal(e: &Expr) -> bool { pub(crate) fn all_return_scalar(e: &Expr) -> bool { match e { Expr::Literal(lv) => lv.projects_as_scalar(), - Expr::Function { options: opt, .. } => opt.returns_scalar, + Expr::Function { options: opt, .. } => opt.flags.contains(FunctionFlags::RETURNS_SCALAR), Expr::Agg(_) => true, Expr::Column(_) | Expr::Wildcard => false, _ => { diff --git a/py-polars/src/functions/misc.rs b/py-polars/src/functions/misc.rs index d9c878989764..2b02abb5f3f2 100644 --- a/py-polars/src/functions/misc.rs +++ b/py-polars/src/functions/misc.rs @@ -40,6 +40,20 @@ pub fn register_plugin_function( None }; + let mut flags = FunctionFlags::from_bits_truncate(0); + if changes_length { + flags |= FunctionFlags::CHANGES_LENGTH; + } + if pass_name_to_apply { + flags |= FunctionFlags::PASS_NAME_TO_APPLY; + } + if returns_scalar { + flags |= FunctionFlags::RETURNS_SCALAR; + } + if input_wildcard_expansion { + flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; + } + Ok(Expr::Function { input: args.to_exprs(), function: FunctionExpr::FfiPlugin { @@ -49,11 +63,8 @@ pub fn register_plugin_function( }, options: FunctionOptions { collect_groups, - input_wildcard_expansion, - returns_scalar, cast_to_supertypes, - pass_name_to_apply, - changes_length, + flags, ..Default::default() }, }