From dbd2c5fdb1e3fc1a7ec29c41fa2c6c81c1218b05 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 21 Dec 2024 11:16:01 +0100 Subject: [PATCH] feat: Add FirstArgLossless supertype (#20394) --- crates/polars-core/src/utils/supertype.rs | 10 +++ .../polars-plan/src/dsl/functions/concat.rs | 2 +- .../src/dsl/functions/correlation.rs | 6 +- .../src/dsl/functions/horizontal.rs | 2 +- crates/polars-plan/src/dsl/functions/range.rs | 4 +- crates/polars-plan/src/dsl/list.rs | 10 ++- crates/polars-plan/src/dsl/mod.rs | 24 ++----- .../src/plans/conversion/type_coercion/mod.rs | 72 ++++++++++++------- .../polars-plan/src/plans/optimizer/fused.rs | 2 +- crates/polars-plan/src/plans/options.rs | 38 ++++------ crates/polars-python/src/functions/misc.rs | 7 +- 11 files changed, 91 insertions(+), 86 deletions(-) diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 8938ec244b9b..48555a7eb520 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -12,6 +12,16 @@ pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult { ) } +pub fn try_get_supertype_with_options( + l: &DataType, + r: &DataType, + options: SuperTypeOptions, +) -> PolarsResult { + get_supertype_with_options(l, r, options).ok_or_else( + || polars_err!(SchemaMismatch: "failed to determine supertype of {} and {}", l, r), + ) +} + /// Returns a numeric supertype that `l` and `r` can be safely upcasted to if it exists. pub fn get_numeric_upcast_supertype_lossless(l: &DataType, r: &DataType) -> Option { use DataType::*; diff --git a/crates/polars-plan/src/dsl/functions/concat.rs b/crates/polars-plan/src/dsl/functions/concat.rs index 44d596684eea..c07e7242dce1 100644 --- a/crates/polars-plan/src/dsl/functions/concat.rs +++ b/crates/polars-plan/src/dsl/functions/concat.rs @@ -99,7 +99,7 @@ pub fn concat_expr, IE: Into + Clone>( options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), ..Default::default() }, }) diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index 730076049be3..99ad08f4b804 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -11,7 +11,7 @@ pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr { function, options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, ..Default::default() }, @@ -29,7 +29,7 @@ pub fn pearson_corr(a: Expr, b: Expr) -> Expr { function, options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, ..Default::default() }, @@ -54,7 +54,7 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr { function, options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, ..Default::default() }, diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 4ea8c68c81aa..8d01d3696086 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -318,7 +318,7 @@ pub fn coalesce(exprs: &[Expr]) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/range.rs b/crates/polars-plan/src/dsl/functions/range.rs index 74d8ddb909cf..16d806c0ce74 100644 --- a/crates/polars-plan/src/dsl/functions/range.rs +++ b/crates/polars-plan/src/dsl/functions/range.rs @@ -89,7 +89,7 @@ pub fn datetime_range( }), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, @@ -118,7 +118,7 @@ pub fn datetime_ranges( }), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 8265f1a39e8f..d5c2622b5afb 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -329,12 +329,10 @@ impl ListNameSpace { function: FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - cast_options: FunctionCastOptions { - supertype: Some(SuperTypeOptions { - flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST, - }), - ..Default::default() - }, + cast_options: Some(CastingRules::Supertype(SuperTypeOptions { + flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST, + })), + flags: FunctionFlags::default() & !FunctionFlags::RETURNS_SCALAR, ..Default::default() }, diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 9f796232db13..b095587d4ef0 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -388,13 +388,9 @@ impl Expr { collect_groups: ApplyOptions::GroupWise, flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, fmt_str: "search_sorted", - cast_options: FunctionCastOptions { - supertype: Some( - (SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING) - .into(), - ), - ..Default::default() - }, + cast_options: Some(CastingRules::Supertype( + (SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(), + )), ..Default::default() }, } @@ -714,7 +710,7 @@ impl Expr { input.extend_from_slice(arguments); let supertype = if cast_to_supertypes { - Some(Default::default()) + Some(CastingRules::cast_to_supertypes()) } else { None }; @@ -730,10 +726,7 @@ impl Expr { options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, flags, - cast_options: FunctionCastOptions { - supertype, - ..Default::default() - }, + cast_options: supertype, ..Default::default() }, } @@ -761,10 +754,7 @@ impl Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, flags, - cast_options: FunctionCastOptions { - supertype: cast_to_supertypes, - ..Default::default() - }, + cast_options: cast_to_supertypes.map(CastingRules::Supertype), ..Default::default() }, } @@ -1063,7 +1053,7 @@ impl Expr { function: FunctionExpr::FillNull, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), ..Default::default() }, } diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index b3ba2861a115..c37f792ce797 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -271,7 +271,8 @@ impl OptimizationRule for TypeCoercionRule { ref function, ref input, mut options, - } if options.cast_options.supertype.is_some() => { + } if options.cast_options.is_some() => { + let casting_rules = options.cast_options.unwrap(); let input_schema = get_schema(lp_arena, lp_node); let function = function.clone(); @@ -280,33 +281,54 @@ impl OptimizationRule for TypeCoercionRule { if let Some(dtypes) = functions::get_function_dtypes(&input, expr_arena, &input_schema, &function)? { - // TODO! use args_to_supertype. let self_e = input[0].clone(); let (self_ae, type_self) = unpack!(get_aexpr_and_type(expr_arena, self_e.node(), &input_schema)); - let mut super_type = type_self.clone(); - for other in &input[1..] { - let (other, type_other) = - unpack!(get_aexpr_and_type(expr_arena, other.node(), &input_schema)); - - let Some(new_st) = get_supertype_with_options( - &super_type, - &type_other, - options.cast_options.supertype.unwrap(), - ) else { - raise_supertype(&function, &input, &input_schema, expr_arena)?; - unreachable!() - }; - if input.len() == 2 { - // modify_supertype is a bit more conservative of casting columns - // to literals - super_type = - modify_supertype(new_st, self_ae, other, &type_self, &type_other) - } else { - // when dealing with more than 1 argument, we simply find the supertypes - super_type = new_st - } + match casting_rules { + CastingRules::Supertype(super_type_opts) => { + for other in &input[1..] { + let (other, type_other) = unpack!(get_aexpr_and_type( + expr_arena, + other.node(), + &input_schema + )); + + let Some(new_st) = get_supertype_with_options( + &super_type, + &type_other, + super_type_opts, + ) else { + raise_supertype(&function, &input, &input_schema, expr_arena)?; + unreachable!() + }; + if input.len() == 2 { + // modify_supertype is a bit more conservative of casting columns + // to literals + super_type = modify_supertype( + new_st, + self_ae, + other, + &type_self, + &type_other, + ) + } else { + // when dealing with more than 1 argument, we simply find the supertypes + super_type = new_st + } + } + }, + CastingRules::FirstArgLossless => { + if super_type.is_integer() { + for other in &input[1..] { + let other = + other.dtype(&input_schema, Context::Default, expr_arena)?; + if other.is_float() { + polars_bail!(InvalidOperation: "cannot cast lossless between {} and {}", super_type, other) + } + } + } + }, } if matches!(super_type, DataType::Unknown(UnknownKind::Any)) { @@ -334,7 +356,7 @@ impl OptimizationRule for TypeCoercionRule { } // Ensure we don't go through this on next iteration. - options.cast_options.supertype = None; + options.cast_options = None; Some(AExpr::Function { function, input, diff --git a/crates/polars-plan/src/plans/optimizer/fused.rs b/crates/polars-plan/src/plans/optimizer/fused.rs index 185b8312da3e..5009a9fa061d 100644 --- a/crates/polars-plan/src/plans/optimizer/fused.rs +++ b/crates/polars-plan/src/plans/optimizer/fused.rs @@ -10,7 +10,7 @@ fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena) -> AEx .collect(); let mut options = FunctionOptions { collect_groups: ApplyOptions::ElementWise, - cast_options: FunctionCastOptions::cast_to_supertypes(), + cast_options: Some(CastingRules::cast_to_supertypes()), ..Default::default() }; // order of operations change because of FMA diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index 4cb3ae064b75..a194a0f5434d 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -179,30 +179,18 @@ impl Default for FunctionFlags { } } -bitflags::bitflags! { - #[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash)] - #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] - pub struct FunctionCastFlags: u8 {} -} - -#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FunctionCastOptions { - pub flags: FunctionCastFlags, - - // if the expression and its inputs should be cast to supertypes - // `None` -> Don't cast. - // `Some` -> cast with given options. - #[cfg_attr(feature = "serde", serde(skip))] - pub supertype: Option, +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum CastingRules { + /// Whether information may be lost during cast. E.g. a float to int is considered lossy, + /// whereas int to int is considered lossless. + /// Overflowing is not considered in this flag, that's handled in `strict` casting + FirstArgLossless, + Supertype(SuperTypeOptions), } -impl FunctionCastOptions { - pub fn cast_to_supertypes() -> FunctionCastOptions { - Self { - supertype: Some(Default::default()), - ..Default::default() - } +impl CastingRules { + pub fn cast_to_supertypes() -> CastingRules { + Self::Supertype(Default::default()) } } @@ -213,9 +201,6 @@ pub struct FunctionOptions { /// This can be important in aggregation context. pub collect_groups: ApplyOptions, - /// Options used when deciding how to cast the arguments of the function. - pub cast_options: FunctionCastOptions, - // Validate the output of a `map`. // this should always be true or we could OOB pub check_lengths: UnsafeBool, @@ -224,6 +209,9 @@ pub struct FunctionOptions { // used for formatting, (only for anonymous functions) #[cfg_attr(feature = "serde", serde(skip))] pub fmt_str: &'static str, + /// Options used when deciding how to cast the arguments of the function. + #[cfg_attr(feature = "serde", serde(skip))] + pub cast_options: Option, } impl FunctionOptions { diff --git a/crates/polars-python/src/functions/misc.rs b/crates/polars-python/src/functions/misc.rs index ae65273ef700..f6f9db254a54 100644 --- a/crates/polars-python/src/functions/misc.rs +++ b/crates/polars-python/src/functions/misc.rs @@ -33,7 +33,7 @@ pub fn register_plugin_function( }; let cast_to_supertypes = if cast_to_supertype { - Some(Default::default()) + Some(CastingRules::cast_to_supertypes()) } else { None }; @@ -56,10 +56,7 @@ pub fn register_plugin_function( }, options: FunctionOptions { collect_groups, - cast_options: FunctionCastOptions { - supertype: cast_to_supertypes, - ..Default::default() - }, + cast_options: cast_to_supertypes, flags, ..Default::default() },