Skip to content

Commit

Permalink
feat: Add FirstArgLossless supertype (#20394)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 21, 2024
1 parent 5f791b4 commit dbd2c5f
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 86 deletions.
10 changes: 10 additions & 0 deletions crates/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult<DataType> {
)
}

pub fn try_get_supertype_with_options(
l: &DataType,
r: &DataType,
options: SuperTypeOptions,
) -> PolarsResult<DataType> {
get_supertype_with_options(l, r, options).ok_or_else(
|| polars_err!(SchemaMismatch: "failed to determine supertype of {} and {}", l, r),
)
}

/// Returns a numeric supertype that `l` and `r` can be safely upcasted to if it exists.
pub fn get_numeric_upcast_supertype_lossless(l: &DataType, r: &DataType) -> Option<DataType> {
use DataType::*;
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub fn concat_expr<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
..Default::default()
},
})
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/functions/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
Expand All @@ -29,7 +29,7 @@ pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
Expand All @@ -54,7 +54,7 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr {
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ pub fn coalesce(exprs: &[Expr]) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
..Default::default()
},
}
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/functions/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub fn datetime_range(
}),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
Expand Down Expand Up @@ -118,7 +118,7 @@ pub fn datetime_ranges(
}),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
Expand Down
10 changes: 4 additions & 6 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,10 @@ impl ListNameSpace {
function: FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_options: FunctionCastOptions {
supertype: Some(SuperTypeOptions {
flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST,
}),
..Default::default()
},
cast_options: Some(CastingRules::Supertype(SuperTypeOptions {
flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST,
})),

flags: FunctionFlags::default() & !FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
Expand Down
24 changes: 7 additions & 17 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,9 @@ impl Expr {
collect_groups: ApplyOptions::GroupWise,
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
fmt_str: "search_sorted",
cast_options: FunctionCastOptions {
supertype: Some(
(SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING)
.into(),
),
..Default::default()
},
cast_options: Some(CastingRules::Supertype(
(SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(),
)),
..Default::default()
},
}
Expand Down Expand Up @@ -714,7 +710,7 @@ impl Expr {
input.extend_from_slice(arguments);

let supertype = if cast_to_supertypes {
Some(Default::default())
Some(CastingRules::cast_to_supertypes())
} else {
None
};
Expand All @@ -730,10 +726,7 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
flags,
cast_options: FunctionCastOptions {
supertype,
..Default::default()
},
cast_options: supertype,
..Default::default()
},
}
Expand Down Expand Up @@ -761,10 +754,7 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
flags,
cast_options: FunctionCastOptions {
supertype: cast_to_supertypes,
..Default::default()
},
cast_options: cast_to_supertypes.map(CastingRules::Supertype),
..Default::default()
},
}
Expand Down Expand Up @@ -1063,7 +1053,7 @@ impl Expr {
function: FunctionExpr::FillNull,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
..Default::default()
},
}
Expand Down
72 changes: 47 additions & 25 deletions crates/polars-plan/src/plans/conversion/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ impl OptimizationRule for TypeCoercionRule {
ref function,
ref input,
mut options,
} if options.cast_options.supertype.is_some() => {
} if options.cast_options.is_some() => {
let casting_rules = options.cast_options.unwrap();
let input_schema = get_schema(lp_arena, lp_node);

let function = function.clone();
Expand All @@ -280,33 +281,54 @@ impl OptimizationRule for TypeCoercionRule {
if let Some(dtypes) =
functions::get_function_dtypes(&input, expr_arena, &input_schema, &function)?
{
// TODO! use args_to_supertype.
let self_e = input[0].clone();
let (self_ae, type_self) =
unpack!(get_aexpr_and_type(expr_arena, self_e.node(), &input_schema));

let mut super_type = type_self.clone();
for other in &input[1..] {
let (other, type_other) =
unpack!(get_aexpr_and_type(expr_arena, other.node(), &input_schema));

let Some(new_st) = get_supertype_with_options(
&super_type,
&type_other,
options.cast_options.supertype.unwrap(),
) else {
raise_supertype(&function, &input, &input_schema, expr_arena)?;
unreachable!()
};
if input.len() == 2 {
// modify_supertype is a bit more conservative of casting columns
// to literals
super_type =
modify_supertype(new_st, self_ae, other, &type_self, &type_other)
} else {
// when dealing with more than 1 argument, we simply find the supertypes
super_type = new_st
}
match casting_rules {
CastingRules::Supertype(super_type_opts) => {
for other in &input[1..] {
let (other, type_other) = unpack!(get_aexpr_and_type(
expr_arena,
other.node(),
&input_schema
));

let Some(new_st) = get_supertype_with_options(
&super_type,
&type_other,
super_type_opts,
) else {
raise_supertype(&function, &input, &input_schema, expr_arena)?;
unreachable!()
};
if input.len() == 2 {
// modify_supertype is a bit more conservative of casting columns
// to literals
super_type = modify_supertype(
new_st,
self_ae,
other,
&type_self,
&type_other,
)
} else {
// when dealing with more than 1 argument, we simply find the supertypes
super_type = new_st
}
}
},
CastingRules::FirstArgLossless => {
if super_type.is_integer() {
for other in &input[1..] {
let other =
other.dtype(&input_schema, Context::Default, expr_arena)?;
if other.is_float() {
polars_bail!(InvalidOperation: "cannot cast lossless between {} and {}", super_type, other)
}
}
}
},
}

if matches!(super_type, DataType::Unknown(UnknownKind::Any)) {
Expand Down Expand Up @@ -334,7 +356,7 @@ impl OptimizationRule for TypeCoercionRule {
}

// Ensure we don't go through this on next iteration.
options.cast_options.supertype = None;
options.cast_options = None;
Some(AExpr::Function {
function,
input,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/optimizer/fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena<AExpr>) -> AEx
.collect();
let mut options = FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_options: FunctionCastOptions::cast_to_supertypes(),
cast_options: Some(CastingRules::cast_to_supertypes()),
..Default::default()
};
// order of operations change because of FMA
Expand Down
38 changes: 13 additions & 25 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,30 +179,18 @@ impl Default for FunctionFlags {
}
}

bitflags::bitflags! {
#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct FunctionCastFlags: u8 {}
}

#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct FunctionCastOptions {
pub flags: FunctionCastFlags,

// if the expression and its inputs should be cast to supertypes
// `None` -> Don't cast.
// `Some` -> cast with given options.
#[cfg_attr(feature = "serde", serde(skip))]
pub supertype: Option<SuperTypeOptions>,
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub enum CastingRules {
/// Whether information may be lost during cast. E.g. a float to int is considered lossy,
/// whereas int to int is considered lossless.
/// Overflowing is not considered in this flag, that's handled in `strict` casting
FirstArgLossless,
Supertype(SuperTypeOptions),
}

impl FunctionCastOptions {
pub fn cast_to_supertypes() -> FunctionCastOptions {
Self {
supertype: Some(Default::default()),
..Default::default()
}
impl CastingRules {
pub fn cast_to_supertypes() -> CastingRules {
Self::Supertype(Default::default())
}
}

Expand All @@ -213,9 +201,6 @@ pub struct FunctionOptions {
/// This can be important in aggregation context.
pub collect_groups: ApplyOptions,

/// Options used when deciding how to cast the arguments of the function.
pub cast_options: FunctionCastOptions,

// Validate the output of a `map`.
// this should always be true or we could OOB
pub check_lengths: UnsafeBool,
Expand All @@ -224,6 +209,9 @@ pub struct FunctionOptions {
// used for formatting, (only for anonymous functions)
#[cfg_attr(feature = "serde", serde(skip))]
pub fmt_str: &'static str,
/// Options used when deciding how to cast the arguments of the function.
#[cfg_attr(feature = "serde", serde(skip))]
pub cast_options: Option<CastingRules>,
}

impl FunctionOptions {
Expand Down
7 changes: 2 additions & 5 deletions crates/polars-python/src/functions/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub fn register_plugin_function(
};

let cast_to_supertypes = if cast_to_supertype {
Some(Default::default())
Some(CastingRules::cast_to_supertypes())
} else {
None
};
Expand All @@ -56,10 +56,7 @@ pub fn register_plugin_function(
},
options: FunctionOptions {
collect_groups,
cast_options: FunctionCastOptions {
supertype: cast_to_supertypes,
..Default::default()
},
cast_options: cast_to_supertypes,
flags,
..Default::default()
},
Expand Down

0 comments on commit dbd2c5f

Please sign in to comment.