diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index a03094ce50ac..3913eb8aec2c 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -38,6 +38,50 @@ fn build_slice_node( } } +fn build_filter_node( + input: PhysNodeKey, + predicate: ExprIR, + expr_arena: &mut Arena, + phys_sm: &mut SlotMap, + expr_cache: &mut ExprCache, +) -> PolarsResult { + let predicate = predicate.clone(); + let cols_and_predicate = phys_sm[input].output_schema + .iter_names() + .cloned() + .map(|name| { + ExprIR::new( + expr_arena.add(AExpr::Column(name.clone())), + OutputName::ColumnLhs(name), + ) + }) + .chain([predicate]) + .collect_vec(); + let (trans_input, mut trans_cols_and_predicate) = lower_exprs( + input, + &cols_and_predicate, + expr_arena, + phys_sm, + expr_cache, + )?; + + let filter_schema = phys_sm[trans_input].output_schema.clone(); + let filter = PhysNodeKind::Filter { + input: trans_input, + predicate: trans_cols_and_predicate.last().unwrap().clone(), + }; + + let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter)); + trans_cols_and_predicate.pop(); // Remove predicate. + build_select_node( + post_filter, + &trans_cols_and_predicate, + expr_arena, + phys_sm, + expr_cache, + ) +} + #[recursive::recursive] pub fn lower_ir( node: Node, @@ -128,40 +172,7 @@ pub fn lower_ir( IR::Filter { input, predicate } => { let predicate = predicate.clone(); let phys_input = lower_ir!(*input)?; - let cols_and_predicate = output_schema - .iter_names() - .cloned() - .map(|name| { - ExprIR::new( - expr_arena.add(AExpr::Column(name.clone())), - OutputName::ColumnLhs(name), - ) - }) - .chain([predicate]) - .collect_vec(); - let (trans_input, mut trans_cols_and_predicate) = lower_exprs( - phys_input, - &cols_and_predicate, - expr_arena, - phys_sm, - expr_cache, - )?; - - let filter_schema = phys_sm[trans_input].output_schema.clone(); - let filter = PhysNodeKind::Filter { - input: trans_input, - predicate: trans_cols_and_predicate.last().unwrap().clone(), - }; - - let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter)); - trans_cols_and_predicate.pop(); // Remove predicate. - return build_select_node( - post_filter, - &trans_cols_and_predicate, - expr_arena, - phys_sm, - expr_cache, - ); + return build_filter_node(phys_input, predicate, expr_arena, phys_sm, expr_cache); }, IR::DataFrameScan { @@ -192,15 +203,8 @@ pub fn lower_ir( } if let Some(predicate) = filter.clone() { - if !is_elementwise_rec_cached(predicate.node(), expr_arena, expr_cache) { - todo!() - } - let phys_input = phys_sm.insert(PhysNode::new(schema, node_kind)); - node_kind = PhysNodeKind::Filter { - input: phys_input, - predicate, - }; + return build_filter_node(phys_input, predicate, expr_arena, phys_sm, expr_cache); } node_kind @@ -287,16 +291,21 @@ pub fn lower_ir( }, IR::Union { inputs, options } => { - if options.slice.is_some() { - todo!() - } - + let options = options.clone(); let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() .map(|input| lower_ir!(input)) .collect::>()?; - PhysNodeKind::OrderedUnion { inputs } + + let mut node = phys_sm.insert(PhysNode { + output_schema, + kind: PhysNodeKind::OrderedUnion { inputs }, + }); + if let Some((offset, length)) = options.slice { + node = build_slice_node(node, offset, length, phys_sm); + } + return Ok(node); }, IR::HConcat { @@ -379,7 +388,7 @@ pub fn lower_ir( _ => todo!(), }; - let phys_node = PhysNodeKind::FileScan { + let node_kind = PhysNodeKind::FileScan { scan_sources, file_info, hive_parts, @@ -391,12 +400,12 @@ pub fn lower_ir( let (row_index, slice, predicate) = opt_rewrite_to_nodes; - let phys_node = if let Some(ri) = row_index { + let node_kind = if let Some(ri) = row_index { let mut schema = Arc::unwrap_or_clone(output_schema.clone()); let v = schema.shift_remove_index(0).unwrap().0; assert_eq!(v, ri.name); - let input = phys_sm.insert(PhysNode::new(Arc::new(schema), phys_node)); + let input = phys_sm.insert(PhysNode::new(Arc::new(schema), node_kind)); PhysNodeKind::WithRowIndex { input, @@ -404,32 +413,23 @@ pub fn lower_ir( offset: Some(ri.offset), } } else { - phys_node + node_kind }; + + let mut node = phys_sm.insert(PhysNode { + output_schema, + kind: node_kind + }); - let phys_node = if let Some((offset, length)) = slice { - let input = phys_sm.insert(PhysNode::new(output_schema.clone(), phys_node)); - - if offset < 0 { - todo!() - } - - PhysNodeKind::StreamingSlice { - input, - offset: offset as usize, - length, - } - } else { - phys_node - }; + if let Some((offset, length)) = slice { + node = build_slice_node(node, offset, length, phys_sm); + } if let Some(predicate) = predicate { - let input = phys_sm.insert(PhysNode::new(output_schema.clone(), phys_node)); - - PhysNodeKind::Filter { input, predicate } - } else { - phys_node + node = build_filter_node(node, predicate, expr_arena, phys_sm, expr_cache)?; } + + return Ok(node); } },