Skip to content

Commit

Permalink
fix: Fix projection pushdown bug in IEJOINS
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 30, 2024
1 parent e402e70 commit 0932842
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 12 deletions.
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2091,7 +2091,7 @@ impl JoinBuilder {
/// Finish builder
pub fn finish(self) -> LazyFrame {
let mut opt_state = self.lf.opt_state;
let other = self.other.expect("with not set");
let other = self.other.expect("'with' not set in join builder");

// If any of the nodes reads from files we must activate this plan as well.
if other.opt_state.contains(OptFlags::FILE_CACHING) {
Expand Down
23 changes: 15 additions & 8 deletions crates/polars-lazy/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,24 @@ pub(crate) fn get_df() -> DataFrame {
#[test]
fn test_foo() -> PolarsResult<()> {
let df = df![
"A" => [1],
"B" => [1],
"a" => [1],
"b" => [1],
]?;

let q = df.lazy();
let q = q.with_row_index("index", None);
let q = q
.clone()
.join_builder()
.with(q)
.join_where(vec![
col("index").lt(col("index_right")),
(col("index") + col("a")).gt(col("a_right")),
])
.group_by([col("index")])
.agg([col("index_right")])
.collect()
.unwrap();

let out = q
.group_by([col("A")])
.agg([cols(["A", "B"]).name().prefix("_agg")])
.explain(false)?;

println!("{out}");
Ok(())
}
11 changes: 11 additions & 0 deletions crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,17 @@ impl JoinType {
false
}
}

pub fn is_ie(&self) -> bool {
#[cfg(feature = "iejoin")]
{
matches!(self, JoinType::IEJoin(_))
}
#[cfg(not(feature = "iejoin"))]
{
false
}
}
}

#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-plan/src/plans/ir/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,12 @@ impl<'a, T: AsExpr> Display for ExprIRSliceDisplay<'a, T> {
}
}

impl<'a, T: AsExpr> fmt::Debug for ExprIRSliceDisplay<'a, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(self, f)
}
}

impl<'a> Display for ExprIRDisplay<'a> {
#[recursive]
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -657,6 +663,12 @@ impl<'a> Display for ExprIRDisplay<'a> {
}
}

impl<'a> fmt::Debug for ExprIRDisplay<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(self, f)
}
}

pub(crate) struct ColumnsDisplay<'a>(pub(crate) &'a Schema);

impl fmt::Display for ColumnsDisplay<'_> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ fn add_keys_to_accumulated_state(
local_projection: &mut Vec<ColumnNode>,
projected_names: &mut PlHashSet<PlSmallStr>,
expr_arena: &mut Arena<AExpr>,
// only for left hand side table we add local names
// Only for left hand side table we add local names.
add_local: bool,
) -> Option<PlSmallStr> {
add_expr_to_accumulated(expr, acc_projections, projected_names, expr_arena);
// the projections may do more than simply project.
// The projections may do more than simply project.
// e.g. col("foo").truncate() * col("bar")
// that means we don't want to execute the projection as that is already done by
// the JOIN executor
Expand Down Expand Up @@ -234,7 +234,7 @@ pub(super) fn process_join(
let mut names_right = PlHashSet::with_capacity(n);
let mut local_projection = Vec::with_capacity(n);

// if there are no projections we don't have to do anything (all columns are projected)
// If there are no projections we don't have to do anything (all columns are projected)
// otherwise we build local projections to sort out proper column names due to the
// join operation
//
Expand All @@ -253,6 +253,16 @@ pub(super) fn process_join(
// We need the join columns so we push the projection downwards
for e in &left_on {
if !local_projected_names.insert(e.output_name().clone()) {
// A join can have multiple leaf names, so we must still ensure all leaf names are projected.
if options.args.how.is_ie() {
add_expr_to_accumulated(
e.node(),
&mut pushdown_left,
&mut names_left,
expr_arena,
);
}

continue;
}

Expand Down Expand Up @@ -300,9 +310,12 @@ pub(super) fn process_join(
// In both columns remain. So `add_local=true` also for the right table
let add_local = !options.args.should_coalesce();
for e in &right_on {
dbg!(e.display(expr_arena));

// In case of full outer joins we also add the columns.
// But before we do that we must check if the column wasn't already added by the lhs.
let add_local = if add_local {
dbg!("CONTINUE");
!local_projected_names.contains(e.output_name())
} else {
false
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/operations/test_inequality_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,22 @@ def test_single_inequality_with_slice(offset: int, length: int) -> None:
expected_rows = set(expected_full.iter_rows())
for row in actual.iter_rows():
assert row in expected_rows, f"{row} not in expected rows"


def test_ie_join_projection_pd_19005() -> None:
lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).with_row_index()
q = (
lf.join_where(
lf,
pl.col.index < pl.col.index_right,
pl.col.index.cast(pl.Int64) + pl.col.a > pl.col.a_right,
)
.group_by(pl.col.index)
.agg(pl.col.index_right)
)

out = q.collect()
assert out.schema == pl.Schema(
[("index", pl.get_index_type()), ("index_right", pl.List(pl.get_index_type()))]
)
assert out.shape == (0, 2)

0 comments on commit 0932842

Please sign in to comment.