diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index 748b20a8740c..72e5d30e3541 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -184,17 +184,6 @@ impl ExprIR { } } - /// Gets any name except one deriving from `Column`. - pub(crate) fn get_non_projected_name(&self) -> Option<&PlSmallStr> { - match &self.output_name { - OutputName::Alias(name) => Some(name), - #[cfg(feature = "dtype-struct")] - OutputName::Field(name) => Some(name), - OutputName::LiteralLhs(name) => Some(name), - _ => None, - } - } - // Utility for debugging. #[cfg(debug_assertions)] #[allow(dead_code)] diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs index 79b461cb90aa..97066a306777 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs @@ -5,45 +5,6 @@ pub(super) fn is_count(node: Node, expr_arena: &Arena) -> bool { matches!(expr_arena.get(node), AExpr::Len) } -/// In this function we check a double projection case -/// df -/// .select(col("foo").alias("bar")) -/// .select(col("bar") -/// -/// In this query, bar cannot pass this projection, as it would not exist in DF. -/// THE ORDER IS IMPORTANT HERE! -/// this removes projection names, so any checks to upstream names should -/// be done before this branch. -fn check_double_projection( - expr: &ExprIR, - expr_arena: &mut Arena, - acc_projections: &mut Vec, - projected_names: &mut PlHashSet, -) { - // Factor out the pruning function - fn prune_projections_by_name( - acc_projections: &mut Vec, - name: &str, - expr_arena: &Arena, - ) { - acc_projections.retain(|node| column_node_to_name(*node, expr_arena) != name); - } - if let Some(name) = expr.get_non_projected_name() { - if projected_names.remove(name) { - prune_projections_by_name(acc_projections, name.as_ref(), expr_arena) - } - } - - for (_, ae) in (&*expr_arena).iter(expr.node()) { - if let AExpr::Literal(LiteralValue::Series(s)) = ae { - let name = s.name(); - if projected_names.remove(name) { - prune_projections_by_name(acc_projections, name, expr_arena) - } - } - } -} - #[allow(clippy::too_many_arguments)] pub(super) fn process_projection( proj_pd: &mut ProjectionPushDown, @@ -59,7 +20,7 @@ pub(super) fn process_projection( ) -> PolarsResult { let mut local_projection = Vec::with_capacity(exprs.len()); - // path for `SELECT count(*) FROM` + // Special path for `SELECT count(*) FROM` // as there would be no projections and we would read // the whole file while we only want the count if exprs.len() == 1 && is_count(exprs[0].node(), expr_arena) { @@ -67,102 +28,140 @@ pub(super) fn process_projection( let expr = if input_schema.is_empty() { // If the input schema is empty, we should just project // ourselves - Some(exprs[0].node()) + exprs[0].node() } else { // Select the last column projection. - let mut name = None; - 'outer: for (_, plan) in (&*lp_arena).iter(input) { - match plan { - IR::Select { expr: exprs, .. } | IR::HStack { exprs, .. } => { - for e in exprs { - if !e.is_scalar(expr_arena) { - name = Some(e.output_name()); - break 'outer; - } + let (last_name, _) = input_schema.try_get_at_index(input_schema.len() - 1)?; + + let name = match lp_arena.get(input) { + IR::Select { expr: exprs, .. } | IR::HStack { exprs, .. } => (|| { + for e in exprs { + if !e.is_scalar(expr_arena) { + return e.output_name(); } + } + + last_name + })(), + + IR::Scan { + file_info, + output_schema, + .. + } => { + let schema = output_schema.as_ref().unwrap_or(&file_info.schema); + // NOTE: the first can be the inserted index column, so that might not work + let (last_name, _) = schema.try_get_at_index(schema.len() - 1)?; + last_name + }, + + IR::DataFrameScan { + schema, + output_schema, + .. + } => { + // NOTE: the first can be the inserted index column, so that might not work + let schema = output_schema.as_ref().unwrap_or(schema); + let (last_name, _) = schema.try_get_at_index(schema.len() - 1)?; + last_name + }, + + _ => last_name, + }; + + expr_arena.add(AExpr::Column(name.clone())) + }; + + // Clear all accumulated projections since we only project a single column from this level. + acc_projections.clear(); + projected_names.clear(); + add_expr_to_accumulated(expr, &mut acc_projections, &mut projected_names, expr_arena); + local_projection.push(exprs.pop().unwrap()); + proj_pd.is_count_star = true; + } else { + // `remove_names` tracks projected names that need to be removed as they may be aliased + // names that are created on this level. + let mut remove_names = PlHashSet::new(); + + // If there are non-scalar projections we must project at least one of them to maintain the + // output height. + let mut opt_non_scalar = None; + let mut projection_has_non_scalar = false; + + let projected_exprs: Vec = exprs + .into_iter() + .filter(|e| { + let is_non_scalar = !e.is_scalar(expr_arena); + + if opt_non_scalar.is_none() && is_non_scalar { + opt_non_scalar = Some(e.clone()) + } + + let name = match e.output_name_inner() { + OutputName::LiteralLhs(name) | OutputName::Alias(name) => { + remove_names.insert(name.clone()); + name }, - IR::Scan { - file_info, - output_schema, - .. - } => { - let schema = output_schema.as_ref().unwrap_or(&file_info.schema); - // NOTE: the first can be the inserted index column, so that might not work - let (last_name, _) = schema.try_get_at_index(schema.len() - 1)?; - name = Some(last_name); - break; + #[cfg(feature = "dtype-struct")] + OutputName::Field(name) => { + remove_names.insert(name.clone()); + name }, - IR::DataFrameScan { - schema, - output_schema, - .. - } => { - // NOTE: the first can be the inserted index column, so that might not work - let schema = output_schema.as_ref().unwrap_or(schema); - let (last_name, _) = schema.try_get_at_index(schema.len() - 1)?; - name = Some(last_name); - break; + OutputName::ColumnLhs(name) => name, + OutputName::None => { + if cfg!(debug_assertions) { + panic!() + } else { + return false; + } }, - _ => {}, - } - } + }; - if let Some(name) = name { - let expr = expr_arena.add(AExpr::Column(name.clone())); - if !acc_projections.is_empty() { - check_double_projection( - &exprs[0], - expr_arena, - &mut acc_projections, - &mut projected_names, - ); + let project = acc_projections.is_empty() || projected_names.contains(name); + projection_has_non_scalar |= project & is_non_scalar; + project + }) + .collect(); + + // Remove aliased before adding new ones. + if !remove_names.is_empty() { + if !projected_names.is_empty() { + for name in remove_names.iter() { + projected_names.remove(name); } - Some(expr) - } else { - None } - }; - if let Some(expr) = expr { - add_expr_to_accumulated(expr, &mut acc_projections, &mut projected_names, expr_arena); - local_projection.push(exprs.pop().unwrap()); - proj_pd.is_count_star = true; - } - } else { - // A projection can consist of a chain of expressions followed by an alias. - // We want to do the chain locally because it can have complicated side effects. - // The only thing we push down is the root name of the projection. - // So we: - // - add the root of the projections to accumulation, - // - also do the complete projection locally to keep the schema (column order) and the alias. - - // set this flag outside the loop as we modify within the loop - let has_pushed_down = !acc_projections.is_empty(); - for e in exprs { - if has_pushed_down { - // remove projections that are not used upstream - if !projected_names.contains(e.output_name()) { - continue; - } - check_double_projection(&e, expr_arena, &mut acc_projections, &mut projected_names); - } - // do local as we still need the effect of the projection - // e.g. a projection is more than selecting a column, it can - // also be a function/ complicated expression - local_projection.push(e); + acc_projections.retain(|c| !remove_names.contains(column_node_to_name(*c, expr_arena))); } - // After we have checked double projections, we add the projections to the accumulated state. - // We do this in two passes, otherwise we mutate while checking. - for e in &local_projection { + for e in projected_exprs { add_expr_to_accumulated( e.node(), &mut acc_projections, &mut projected_names, expr_arena, ); + + // do local as we still need the effect of the projection + // e.g. a projection is more than selecting a column, it can + // also be a function/ complicated expression + local_projection.push(e); + } + + if !projection_has_non_scalar { + if let Some(non_scalar) = opt_non_scalar { + add_expr_to_accumulated( + non_scalar.node(), + &mut acc_projections, + &mut projected_names, + expr_arena, + ); + + local_projection.push(non_scalar); + } } } + proj_pd.pushdown_and_assign( input, acc_projections, diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 4aa6763e8fda..307b3e0811d1 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -614,3 +614,8 @@ def test_with_columns_projection_pushdown() -> None: # [dyn int: 1.alias("x"), dyn int: 1.alias("y")] # Csv SCAN [20 in-mem bytes] assert plan.endswith("PROJECT 1/6 COLUMNS") + + +def test_projection_pushdown_height_20221() -> None: + q = pl.LazyFrame({"a": range(5)}).select("a", b=pl.col("a").first()).select("b") + assert_frame_equal(q.collect(), pl.DataFrame({"b": [0, 0, 0, 0, 0]}))