Skip to content

Commit

Permalink
perf: Use bitflags for function options (#17723)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 19, 2024
1 parent a4442f4 commit 90cdf95
Show file tree
Hide file tree
Showing 25 changed files with 188 additions and 134 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ impl ApplyExpr {
output_dtype: Option<DataType>,
) -> Self {
#[cfg(debug_assertions)]
if matches!(options.collect_groups, ApplyOptions::ElementWise) && options.returns_scalar {
if matches!(options.collect_groups, ApplyOptions::ElementWise)
&& options.flags.contains(FunctionFlags::RETURNS_SCALAR)
{
panic!("expr {:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr)
}

Expand All @@ -48,13 +50,13 @@ impl ApplyExpr {
function,
expr,
collect_groups: options.collect_groups,
returns_scalar: options.returns_scalar,
allow_rename: options.allow_rename,
pass_name_to_apply: options.pass_name_to_apply,
returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR),
allow_rename: options.flags.contains(FunctionFlags::ALLOW_RENAME),
pass_name_to_apply: options.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY),
input_schema,
allow_threading,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
allow_group_aware: options.flags.contains(FunctionFlags::ALLOW_GROUP_AWARE),
output_dtype,
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-expr/src/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl WindowExpr {
},
Expr::Function { options, .. }
| Expr::AnonymousFunction { options, .. } => {
if options.returns_scalar
if options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise)
{
agg_col = true;
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ fn create_physical_expr_inner(
.ok()
});

let is_reducing_aggregation =
options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise);
let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise);
// Will be reset in the function so get that here.
let has_window = state.local.has_window;
let input = create_physical_expressions_check_state(
Expand Down Expand Up @@ -534,8 +534,8 @@ fn create_physical_expr_inner(
.to_dtype(schema, Context::Default, expr_arena)
.ok()
});
let is_reducing_aggregation =
options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise);
let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise);
// Will be reset in the function so get that here.
let has_window = state.local.has_window;
let input = create_physical_expressions_check_state(
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ polars-utils = { workspace = true }

ahash = { workspace = true }
arrow = { workspace = true }
bitflags = { workspace = true }
bytemuck = { workspace = true }
chrono = { workspace = true, optional = true }
chrono-tz = { workspace = true, optional = true }
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl ArrayNameSpace {
false,
)
.with_function_options(|mut options| {
options.input_wildcard_expansion = true;
options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION;
options
})
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/business.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn business_day_count(
holidays,
}),
options: FunctionOptions {
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-plan/src/dsl/functions/coerce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ pub fn as_struct(exprs: Vec<Expr>) -> Expr {
input: exprs,
function: FunctionExpr::AsStruct,
options: FunctionOptions {
input_wildcard_expansion: true,
pass_name_to_apply: true,
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default()
| FunctionFlags::PASS_NAME_TO_APPLY
| FunctionFlags::INPUT_WILDCARD_EXPANSION,
..Default::default()
},
}
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-plan/src/dsl/functions/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ pub fn concat_str<E: AsRef<[Expr]>>(s: E, separator: &str, ignore_nulls: bool) -
.into(),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
}
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn concat_list<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> PolarsResult
function: FunctionExpr::ListExpr(ListFunction::Concat),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
..Default::default()
},
})
Expand All @@ -81,7 +81,7 @@ pub fn concat_expr<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(
function: FunctionExpr::ConcatExpr(rechunk),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
cast_to_supertypes: Some(Default::default()),
..Default::default()
},
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/functions/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
returns_scalar: true,
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
}
Expand All @@ -36,7 +36,7 @@ pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
returns_scalar: true,
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
}
Expand Down Expand Up @@ -64,7 +64,7 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
returns_scalar: true,
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
}
Expand Down
46 changes: 25 additions & 21 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ where
output_type: GetOutput::super_type(),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
input_wildcard_expansion: true,
returns_scalar: true,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION
| FunctionFlags::RETURNS_SCALAR,
fmt_str: "fold",
..Default::default()
},
Expand Down Expand Up @@ -90,8 +91,9 @@ where
output_type: GetOutput::super_type(),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
input_wildcard_expansion: true,
returns_scalar: true,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION
| FunctionFlags::RETURNS_SCALAR,
fmt_str: "reduce",
..Default::default()
},
Expand Down Expand Up @@ -136,8 +138,9 @@ where
output_type: cum_fold_dtype(),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
input_wildcard_expansion: true,
returns_scalar: true,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION
| FunctionFlags::RETURNS_SCALAR,
fmt_str: "cum_reduce",
..Default::default()
},
Expand Down Expand Up @@ -181,8 +184,9 @@ where
output_type: cum_fold_dtype(),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
input_wildcard_expansion: true,
returns_scalar: true,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION
| FunctionFlags::RETURNS_SCALAR,
fmt_str: "cum_fold",
..Default::default()
},
Expand All @@ -200,7 +204,7 @@ pub fn all_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
input: exprs,
function: FunctionExpr::Boolean(BooleanFunction::AllHorizontal),
options: FunctionOptions {
input_wildcard_expansion: true,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
..Default::default()
},
})
Expand All @@ -217,7 +221,7 @@ pub fn any_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
input: exprs,
function: FunctionExpr::Boolean(BooleanFunction::AnyHorizontal),
options: FunctionOptions {
input_wildcard_expansion: true,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
..Default::default()
},
})
Expand All @@ -235,9 +239,9 @@ pub fn max_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
function: FunctionExpr::MaxHorizontal,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
allow_rename: true,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR
| FunctionFlags::ALLOW_RENAME,
..Default::default()
},
})
Expand All @@ -255,9 +259,9 @@ pub fn min_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
function: FunctionExpr::MinHorizontal,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
allow_rename: true,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR
| FunctionFlags::ALLOW_RENAME,
..Default::default()
},
})
Expand All @@ -273,8 +277,8 @@ pub fn sum_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
function: FunctionExpr::SumHorizontal,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR,
cast_to_supertypes: None,
..Default::default()
},
Expand All @@ -291,8 +295,8 @@ pub fn mean_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
function: FunctionExpr::MeanHorizontal,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR,
cast_to_supertypes: None,
..Default::default()
},
Expand All @@ -310,7 +314,7 @@ pub fn coalesce(exprs: &[Expr]) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: Some(Default::default()),
input_wildcard_expansion: true,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
..Default::default()
},
}
Expand Down
16 changes: 8 additions & 8 deletions crates/polars-plan/src/dsl/functions/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub fn int_range(start: Expr, end: Expr, step: i64, dtype: DataType) -> Expr {
input,
function: FunctionExpr::Range(RangeFunction::IntRange { step, dtype }),
options: FunctionOptions {
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand All @@ -29,7 +29,7 @@ pub fn int_ranges(start: Expr, end: Expr, step: Expr) -> Expr {
input,
function: FunctionExpr::Range(RangeFunction::IntRanges),
options: FunctionOptions {
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand All @@ -45,7 +45,7 @@ pub fn date_range(start: Expr, end: Expr, interval: Duration, closed: ClosedWind
function: FunctionExpr::Range(RangeFunction::DateRange { interval, closed }),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand All @@ -61,7 +61,7 @@ pub fn date_ranges(start: Expr, end: Expr, interval: Duration, closed: ClosedWin
function: FunctionExpr::Range(RangeFunction::DateRanges { interval, closed }),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand Down Expand Up @@ -90,7 +90,7 @@ pub fn datetime_range(
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand Down Expand Up @@ -119,7 +119,7 @@ pub fn datetime_ranges(
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand All @@ -135,7 +135,7 @@ pub fn time_range(start: Expr, end: Expr, interval: Duration, closed: ClosedWind
function: FunctionExpr::Range(RangeFunction::TimeRange { interval, closed }),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand All @@ -151,7 +151,7 @@ pub fn time_ranges(start: Expr, end: Expr, interval: Duration, closed: ClosedWin
function: FunctionExpr::Range(RangeFunction::TimeRanges { interval, closed }),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
allow_rename: true,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
Expand Down
7 changes: 4 additions & 3 deletions crates/polars-plan/src/dsl/functions/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ pub fn datetime(args: DatetimeArgs) -> Expr {
}),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
allow_rename: true,
input_wildcard_expansion: true,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION
| FunctionFlags::ALLOW_RENAME,
fmt_str: "datetime",
..Default::default()
},
Expand Down Expand Up @@ -271,7 +272,7 @@ pub fn duration(args: DurationArgs) -> Expr {
function: FunctionExpr::TemporalExpr(TemporalFunction::Duration(args.time_unit)),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
..Default::default()
},
}
Expand Down
Loading

0 comments on commit 90cdf95

Please sign in to comment.