Skip to content

Commit

Permalink
refactor: Add FunctionCastOptions and conservative IR-level cast ty…
Browse files Browse the repository at this point in the history
…pe-checking (#20286)
  • Loading branch information
coastalwhite authored Dec 14, 2024
1 parent 286eb84 commit bcca075
Show file tree
Hide file tree
Showing 14 changed files with 315 additions and 197 deletions.
59 changes: 59 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,65 @@ impl DataType {
}
}

/// Return whether the cast to `to` makes sense.
///
/// If it `None`, we are not sure.
pub fn can_cast_to(&self, to: &DataType) -> Option<bool> {
if self == to {
return Some(true);
}
if self.is_numeric() && to.is_numeric() {
return Some(true);
}

use DataType as D;
Some(match (self, to) {
#[cfg(feature = "dtype-categorical")]
(D::Categorical(_, _) | D::Enum(_, _), D::Binary)
| (D::Binary, D::Categorical(_, _) | D::Enum(_, _)) => false,

#[cfg(feature = "object")]
(D::Object(_, _), D::Object(_, _)) => true,
#[cfg(feature = "object")]
(D::Object(_, _), _) | (_, D::Object(_, _)) => false,

(D::Boolean, dt) | (dt, D::Boolean) => match dt {
dt if dt.is_numeric() => true,
#[cfg(feature = "dtype-decimal")]
D::Decimal(_, _) => true,
D::String | D::Binary => true,
_ => false,
},

(D::List(from), D::List(to)) => from.can_cast_to(to)?,
#[cfg(feature = "dtype-array")]
(D::Array(from, l_width), D::Array(to, r_width)) => {
l_width == r_width && from.can_cast_to(to)?
},
#[cfg(feature = "dtype-struct")]
(D::Struct(l_fields), D::Struct(r_fields)) => {
if l_fields.is_empty() {
return Some(true);
}

if l_fields.len() != r_fields.len() {
return Some(false);
}

for (l, r) in l_fields.iter().zip(r_fields) {
if !l.dtype().can_cast_to(r.dtype())? {
return Some(false);
}
}

true
},

// @NOTE: we are being conversative
_ => return None,
})
}

pub fn implode(self) -> DataType {
DataType::List(Box::new(self))
}
Expand Down
5 changes: 1 addition & 4 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,7 @@ fn create_physical_expr_inner(
node_to_expr(expression, expr_arena),
FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
fmt_str: "",
cast_to_supertypes: None,
check_lengths: Default::default(),
flags: Default::default(),
..Default::default()
},
state.allow_threading,
schema.clone(),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub fn concat_expr<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
cast_to_supertypes: Some(Default::default()),
cast_options: FunctionCastOptions::cast_to_supertypes(),
..Default::default()
},
})
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/functions/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
cast_options: FunctionCastOptions::cast_to_supertypes(),
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
Expand All @@ -29,7 +29,7 @@ pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
cast_options: FunctionCastOptions::cast_to_supertypes(),
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
Expand All @@ -54,7 +54,7 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr {
function,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
cast_options: FunctionCastOptions::cast_to_supertypes(),
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ pub fn sum_horizontal<E: AsRef<[Expr]>>(exprs: E, ignore_nulls: bool) -> PolarsR
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR,
cast_to_supertypes: None,
..Default::default()
},
})
Expand All @@ -303,7 +302,6 @@ pub fn mean_horizontal<E: AsRef<[Expr]>>(exprs: E, ignore_nulls: bool) -> Polars
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR,
cast_to_supertypes: None,
..Default::default()
},
})
Expand All @@ -319,8 +317,8 @@ pub fn coalesce(exprs: &[Expr]) -> Expr {
function: FunctionExpr::Coalesce,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: Some(Default::default()),
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
cast_options: FunctionCastOptions::cast_to_supertypes(),
..Default::default()
},
}
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/functions/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub fn datetime_range(
}),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
cast_options: FunctionCastOptions::cast_to_supertypes(),
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
Expand Down Expand Up @@ -118,7 +118,7 @@ pub fn datetime_ranges(
}),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
cast_to_supertypes: Some(Default::default()),
cast_options: FunctionCastOptions::cast_to_supertypes(),
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
Expand Down
9 changes: 6 additions & 3 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,12 @@ impl ListNameSpace {
function: FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: Some(SuperTypeOptions {
flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST,
}),
cast_options: FunctionCastOptions {
supertype: Some(SuperTypeOptions {
flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST,
}),
..Default::default()
},
flags: FunctionFlags::default() & !FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
Expand Down
24 changes: 17 additions & 7 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,13 @@ impl Expr {
collect_groups: ApplyOptions::GroupWise,
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
fmt_str: "search_sorted",
cast_to_supertypes: Some(
(SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(),
),
cast_options: FunctionCastOptions {
supertype: Some(
(SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING)
.into(),
),
..Default::default()
},
..Default::default()
},
}
Expand Down Expand Up @@ -709,7 +713,7 @@ impl Expr {
input.push(self);
input.extend_from_slice(arguments);

let cast_to_supertypes = if cast_to_supertypes {
let supertype = if cast_to_supertypes {
Some(Default::default())
} else {
None
Expand All @@ -726,7 +730,10 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
flags,
cast_to_supertypes,
cast_options: FunctionCastOptions {
supertype,
..Default::default()
},
..Default::default()
},
}
Expand Down Expand Up @@ -754,7 +761,10 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
flags,
cast_to_supertypes,
cast_options: FunctionCastOptions {
supertype: cast_to_supertypes,
..Default::default()
},
..Default::default()
},
}
Expand Down Expand Up @@ -1053,7 +1063,7 @@ impl Expr {
function: FunctionExpr::FillNull,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: Some(Default::default()),
cast_options: FunctionCastOptions::cast_to_supertypes(),
..Default::default()
},
}
Expand Down
26 changes: 8 additions & 18 deletions crates/polars-plan/src/plans/conversion/type_coercion/functions.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
use either::Either;

use super::*;

/// Get the datatypes of function arguments.
///
/// If all arguments give the same datatype or a datatype cannot be found, `Ok(None)` is returned.
pub(super) fn get_function_dtypes(
input: &[ExprIR],
expr_arena: &Arena<AExpr>,
input_schema: &Schema,
function: &FunctionExpr,
mut options: FunctionOptions,
) -> PolarsResult<Either<Vec<DataType>, AExpr>> {
let mut early_return = move || {
// Next iteration this will not hit anymore as options is updated.
options.cast_to_supertypes = None;
Ok(Either::Right(AExpr::Function {
function: function.clone(),
input: input.to_vec(),
options,
}))
};

) -> PolarsResult<Option<Vec<DataType>>> {
let mut dtypes = Vec::with_capacity(input.len());
let mut first = true;
for e in input {
let Some((_, dtype)) = get_aexpr_and_type(expr_arena, e.node(), input_schema) else {
return early_return();
return Ok(None);
};

if first {
Expand All @@ -34,16 +24,16 @@ pub(super) fn get_function_dtypes(
// We will raise if we cannot find the supertype later.
match dtype {
DataType::Unknown(UnknownKind::Any) => {
return early_return();
return Ok(None);
},
_ => dtypes.push(dtype),
}
}

if dtypes.iter().all_equal() {
return early_return();
return Ok(None);
}
Ok(Either::Left(dtypes))
Ok(Some(dtypes))
}

// `str` namespace belongs to `String`
Expand Down
Loading

0 comments on commit bcca075

Please sign in to comment.