Skip to content

Commit

Permalink
fix(rust,python): streamline is_in handling of mismatched dtypes an…
Browse files Browse the repository at this point in the history
…d fix a minor regression (#11533)
  • Loading branch information
alexander-beedie authored Oct 5, 2023
1 parent 29c1269 commit 205355c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 51 deletions.
85 changes: 35 additions & 50 deletions crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,77 +350,62 @@ impl OptimizationRule for TypeCoercionRule {
let casted_expr = match (&type_left, &type_other) {
// types are equal, do nothing
(a, b) if a == b => return Ok(None),
// all-null can represent anything (and/or empty list), so cast to target dtype
(_, DataType::Null) => AExpr::Cast {
expr: other_node,
data_type: type_left,
strict: false,
},
// cast both local and global string cache
// note that there might not yet be a rev
#[cfg(feature = "dtype-categorical")]
(DataType::Categorical(_), DataType::Utf8) => {
AExpr::Cast {
expr: other_node,
data_type: DataType::Categorical(None),
strict: false,
}
(DataType::Categorical(_), DataType::Utf8) => AExpr::Cast {
expr: other_node,
data_type: DataType::Categorical(None),
strict: false,
},
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => {
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left)
},
// can't check for more granular time_unit in less-granular time_unit data,
// can't check for more granular time_unit in less-granular time_unit data,
// or we'll cast away valid/necessary precision (eg: nanosecs to millisecs)
(DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => {
if lhs_unit <= rhs_unit { return Ok(None) }
else {
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit)
if lhs_unit <= rhs_unit {
return Ok(None);
} else {
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit)
}
},
(DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => {
if lhs_unit <= rhs_unit { return Ok(None) }
else {
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit)
if lhs_unit <= rhs_unit {
return Ok(None);
} else {
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit)
}
},
// don't attempt to cast between obviously mismatched types;
// we should error early/explicitly on invalid comparisons
(
_,
| DataType::Datetime(_, _)
| DataType::Duration(_)
| DataType::Date
| DataType::Time
| DataType::Boolean
| DataType::Binary
| DataType::Utf8,
)
| (
| DataType::Datetime(_, _)
| DataType::Duration(_)
| DataType::Date
| DataType::Time
| DataType::Boolean
| DataType::Binary
| DataType::Utf8,
_,
) => {
match type_other {
// all-null can represent anything (and/or empty list), so cast to target dtype
DataType::Null => AExpr::Cast {expr: other_node, data_type: type_left, strict: false},
_ => polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left)
(_, DataType::List(other_inner)) => {
if other_inner.as_ref() == &type_left
|| (type_left == DataType::Null)
|| (other_inner.as_ref() == &DataType::Null)
|| (other_inner.as_ref().is_numeric() && type_left.is_numeric())
{
return Ok(None);
}
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_left, &type_other)
},
#[cfg(feature = "dtype-struct")]
(DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None),
(DataType::List(_), _) | (_, DataType::List(_)) => return Ok(None),
// if rhs is another type, we cast it to lhs (we do not use supertype
// as `is_in` operation should not implicitly cast the whole column)
(a, b)
// for integer/float comparison we let them use supertypes.
if !(a.is_integer() && b.is_float()) =>
{
AExpr::Cast {expr: other_node, data_type: type_left, strict: false }

// don't attempt to cast between obviously mismatched types, but
// allow integer/float comparison (will use their supertypes).
(a, b) => {
if (a.is_numeric() && b.is_numeric()) || (a == &DataType::Null) {
return Ok(None);
}
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left)
},
// do nothing
_ => return Ok(None),
};

let mut input = input.clone();
let other_input = expr_arena.add(casted_expr);
input[1] = other_input;
Expand Down
45 changes: 44 additions & 1 deletion py-polars/tests/unit/operations/test_is_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def test_is_in_float_list_10764() -> None:
"n": [3.0, 2.0],
}
)

assert df.select(pl.col("n").is_in("lst").alias("is_in")).to_dict(False) == {
"is_in": [True, False]
}
Expand Down Expand Up @@ -165,3 +164,47 @@ def test_is_in_null() -> None:
def test_is_in_invalid_shape() -> None:
with pytest.raises(pl.ComputeError):
pl.Series("a", [1, 2, 3]).is_in([[]])


@pytest.mark.parametrize(
("df", "matches", "expected_error"),
[
(
pl.DataFrame({"a": [1, 2], "b": [[1.0, 2.5], [3.0, 4.0]]}),
[True, False],
None,
),
(
pl.DataFrame({"a": [2.5, 3.0], "b": [[1, 2], [3, 4]]}),
[False, True],
None,
),
(
pl.DataFrame(
{"a": [None, None], "b": [[1, 2], [3, 4]]},
schema_overrides={"a": pl.Null},
),
[None, None],
None,
),
(
pl.DataFrame({"a": ["1", "2"], "b": [[1, 2], [3, 4]]}),
None,
r"`is_in` cannot check for Utf8 values in List\(Int64\) data",
),
(
pl.DataFrame({"a": [date.today(), None], "b": [[1, 2], [3, 4]]}),
None,
r"`is_in` cannot check for Date values in List\(Int64\) data",
),
],
)
def test_is_in_expr_list_series(
df: pl.DataFrame, matches: list[bool] | None, expected_error: str | None
) -> None:
expr_is_in = pl.col("a").is_in(pl.col("b"))
if matches:
assert df.select(expr_is_in).to_series().to_list() == matches
else:
with pytest.raises(pl.InvalidOperationError, match=expected_error):
df.select(expr_is_in)

0 comments on commit 205355c

Please sign in to comment.