From f958881047be9a33efeeb903d600c5d8f5b98b4f Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 30 Sep 2024 14:11:00 +0200 Subject: [PATCH] fix: Raise invalid predicate join_where (#19020) --- .../src/plans/conversion/dsl_to_ir.rs | 8 +++-- .../polars-plan/src/plans/conversion/join.rs | 36 +++++++++++++------ .../unit/operations/test_inequality_join.py | 8 +++++ 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 25a605124496..3f17e83eefb5 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -401,7 +401,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult options, }; - return run_conversion(lp, ctxt, "select"); + return run_conversion(lp, ctxt, "select").map_err(|e| e.context(failed_here!(select))); }, DslPlan::Sort { input, @@ -473,7 +473,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult sort_options, }; - return run_conversion(lp, ctxt, "sort"); + return run_conversion(lp, ctxt, "sort").map_err(|e| e.context(failed_here!(sort))); }, DslPlan::Cache { input, id } => { let input = @@ -527,7 +527,8 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult options, }; - return run_conversion(lp, ctxt, "group_by"); + return run_conversion(lp, ctxt, "group_by") + .map_err(|e| e.context(failed_here!(group_by))); }, DslPlan::Join { input_left, @@ -546,6 +547,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult options, ctxt, ) + .map_err(|e| e.context(failed_here!(join))) }, DslPlan::HStack { input, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 21474c992eb4..954714f13b26 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -152,16 +152,6 @@ fn resolve_join_where( ctxt: &mut DslConversionContext, ) -> PolarsResult { check_join_keys(&predicates)?; - for e in &predicates { - let no_binary_comparisons = e - .into_iter() - .filter(|e| match e { - Expr::BinaryExpr { op, .. } => op.is_comparison(), - _ => false, - }) - .count(); - polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition"); - } let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt) .map_err(|e| e.context(failed_input!(join left)))?; let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt) @@ -174,6 +164,32 @@ fn resolve_join_where( .schema(ctxt.lp_arena) .into_owned(); + for e in &predicates { + let no_binary_comparisons = e + .into_iter() + .filter(|e| match e { + Expr::BinaryExpr { op, .. } => op.is_comparison(), + _ => false, + }) + .count(); + polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition"); + + fn all_in_schema(schema: &Schema, left: &Expr, right: &Expr) -> bool { + let mut iter = + expr_to_leaf_column_names_iter(left).chain(expr_to_leaf_column_names_iter(right)); + iter.all(|name| schema.contains(name.as_str())) + } + + let valid = e.into_iter().all(|e| match e { + Expr::BinaryExpr { left, op, right } if op.is_comparison() => { + !(all_in_schema(&schema_left, left, right) + || all_in_schema(&schema_right, left, right)) + }, + _ => true, + }); + polars_ensure!( valid, InvalidOperation: "join predicate in 'join_where' only refers to columns of a single table") + } + let owned = |e: Arc| (*e).clone(); // We do a few things diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 89a81a4c0923..18a24c1d8a7d 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -567,3 +567,11 @@ def test_ie_join_projection_pd_19005() -> None: [("index", pl.get_index_type()), ("index_right", pl.List(pl.get_index_type()))] ) assert out.shape == (0, 2) + + +def test_raise_invalid_predicate() -> None: + left = pl.LazyFrame({"a": [1, 2]}).with_row_index() + right = pl.LazyFrame({"b": [1, 2]}).with_row_index() + + with pytest.raises(pl.exceptions.InvalidOperationError): + left.join_where(right, pl.col.index >= pl.col.a).collect()