Skip to content

Commit

Permalink
fix: Raise invalid predicate join_where (#19020)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Sep 30, 2024
1 parent ab5200d commit f958881
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
8 changes: 5 additions & 3 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
options,
};

return run_conversion(lp, ctxt, "select");
return run_conversion(lp, ctxt, "select").map_err(|e| e.context(failed_here!(select)));
},
DslPlan::Sort {
input,
Expand Down Expand Up @@ -473,7 +473,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
sort_options,
};

return run_conversion(lp, ctxt, "sort");
return run_conversion(lp, ctxt, "sort").map_err(|e| e.context(failed_here!(sort)));
},
DslPlan::Cache { input, id } => {
let input =
Expand Down Expand Up @@ -527,7 +527,8 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
options,
};

return run_conversion(lp, ctxt, "group_by");
return run_conversion(lp, ctxt, "group_by")
.map_err(|e| e.context(failed_here!(group_by)));
},
DslPlan::Join {
input_left,
Expand All @@ -546,6 +547,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
options,
ctxt,
)
.map_err(|e| e.context(failed_here!(join)))
},
DslPlan::HStack {
input,
Expand Down
36 changes: 26 additions & 10 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,6 @@ fn resolve_join_where(
ctxt: &mut DslConversionContext,
) -> PolarsResult<Node> {
check_join_keys(&predicates)?;
for e in &predicates {
let no_binary_comparisons = e
.into_iter()
.filter(|e| match e {
Expr::BinaryExpr { op, .. } => op.is_comparison(),
_ => false,
})
.count();
polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition");
}
let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)
.map_err(|e| e.context(failed_input!(join left)))?;
let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)
Expand All @@ -174,6 +164,32 @@ fn resolve_join_where(
.schema(ctxt.lp_arena)
.into_owned();

for e in &predicates {
let no_binary_comparisons = e
.into_iter()
.filter(|e| match e {
Expr::BinaryExpr { op, .. } => op.is_comparison(),
_ => false,
})
.count();
polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition");

fn all_in_schema(schema: &Schema, left: &Expr, right: &Expr) -> bool {
let mut iter =
expr_to_leaf_column_names_iter(left).chain(expr_to_leaf_column_names_iter(right));
iter.all(|name| schema.contains(name.as_str()))
}

let valid = e.into_iter().all(|e| match e {
Expr::BinaryExpr { left, op, right } if op.is_comparison() => {
!(all_in_schema(&schema_left, left, right)
|| all_in_schema(&schema_right, left, right))
},
_ => true,
});
polars_ensure!( valid, InvalidOperation: "join predicate in 'join_where' only refers to columns of a single table")
}

let owned = |e: Arc<Expr>| (*e).clone();

// We do a few things
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/operations/test_inequality_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,11 @@ def test_ie_join_projection_pd_19005() -> None:
[("index", pl.get_index_type()), ("index_right", pl.List(pl.get_index_type()))]
)
assert out.shape == (0, 2)


def test_raise_invalid_predicate() -> None:
left = pl.LazyFrame({"a": [1, 2]}).with_row_index()
right = pl.LazyFrame({"b": [1, 2]}).with_row_index()

with pytest.raises(pl.exceptions.InvalidOperationError):
left.join_where(right, pl.col.index >= pl.col.a).collect()

0 comments on commit f958881

Please sign in to comment.