Skip to content

Commit

Permalink
refactor(rust): Dispatch slice/filter lowering properly (#20390)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Dec 20, 2024
1 parent 530e70a commit d309fd0
Showing 1 changed file with 66 additions and 70 deletions.
136 changes: 66 additions & 70 deletions crates/polars-stream/src/physical_plan/lower_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,46 @@ fn build_slice_node(
}
}

fn build_filter_node(
input: PhysNodeKey,
predicate: ExprIR,
expr_arena: &mut Arena<AExpr>,
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
expr_cache: &mut ExprCache,
) -> PolarsResult<PhysNodeKey> {
let predicate = predicate.clone();
let cols_and_predicate = phys_sm[input]
.output_schema
.iter_names()
.cloned()
.map(|name| {
ExprIR::new(
expr_arena.add(AExpr::Column(name.clone())),
OutputName::ColumnLhs(name),
)
})
.chain([predicate])
.collect_vec();
let (trans_input, mut trans_cols_and_predicate) =
lower_exprs(input, &cols_and_predicate, expr_arena, phys_sm, expr_cache)?;

let filter_schema = phys_sm[trans_input].output_schema.clone();
let filter = PhysNodeKind::Filter {
input: trans_input,
predicate: trans_cols_and_predicate.last().unwrap().clone(),
};

let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter));
trans_cols_and_predicate.pop(); // Remove predicate.
build_select_node(
post_filter,
&trans_cols_and_predicate,
expr_arena,
phys_sm,
expr_cache,
)
}

#[recursive::recursive]
pub fn lower_ir(
node: Node,
Expand Down Expand Up @@ -128,40 +168,7 @@ pub fn lower_ir(
IR::Filter { input, predicate } => {
let predicate = predicate.clone();
let phys_input = lower_ir!(*input)?;
let cols_and_predicate = output_schema
.iter_names()
.cloned()
.map(|name| {
ExprIR::new(
expr_arena.add(AExpr::Column(name.clone())),
OutputName::ColumnLhs(name),
)
})
.chain([predicate])
.collect_vec();
let (trans_input, mut trans_cols_and_predicate) = lower_exprs(
phys_input,
&cols_and_predicate,
expr_arena,
phys_sm,
expr_cache,
)?;

let filter_schema = phys_sm[trans_input].output_schema.clone();
let filter = PhysNodeKind::Filter {
input: trans_input,
predicate: trans_cols_and_predicate.last().unwrap().clone(),
};

let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter));
trans_cols_and_predicate.pop(); // Remove predicate.
return build_select_node(
post_filter,
&trans_cols_and_predicate,
expr_arena,
phys_sm,
expr_cache,
);
return build_filter_node(phys_input, predicate, expr_arena, phys_sm, expr_cache);
},

IR::DataFrameScan {
Expand Down Expand Up @@ -192,15 +199,8 @@ pub fn lower_ir(
}

if let Some(predicate) = filter.clone() {
if !is_elementwise_rec_cached(predicate.node(), expr_arena, expr_cache) {
todo!()
}

let phys_input = phys_sm.insert(PhysNode::new(schema, node_kind));
node_kind = PhysNodeKind::Filter {
input: phys_input,
predicate,
};
return build_filter_node(phys_input, predicate, expr_arena, phys_sm, expr_cache);
}

node_kind
Expand Down Expand Up @@ -287,16 +287,21 @@ pub fn lower_ir(
},

IR::Union { inputs, options } => {
if options.slice.is_some() {
todo!()
}

let options = *options;
let inputs = inputs
.clone() // Needed to borrow ir_arena mutably.
.into_iter()
.map(|input| lower_ir!(input))
.collect::<Result<_, _>>()?;
PhysNodeKind::OrderedUnion { inputs }

let mut node = phys_sm.insert(PhysNode {
output_schema,
kind: PhysNodeKind::OrderedUnion { inputs },
});
if let Some((offset, length)) = options.slice {
node = build_slice_node(node, offset, length, phys_sm);
}
return Ok(node);
},

IR::HConcat {
Expand Down Expand Up @@ -379,7 +384,7 @@ pub fn lower_ir(
_ => todo!(),
};

let phys_node = PhysNodeKind::FileScan {
let node_kind = PhysNodeKind::FileScan {
scan_sources,
file_info,
hive_parts,
Expand All @@ -391,45 +396,36 @@ pub fn lower_ir(

let (row_index, slice, predicate) = opt_rewrite_to_nodes;

let phys_node = if let Some(ri) = row_index {
let node_kind = if let Some(ri) = row_index {
let mut schema = Arc::unwrap_or_clone(output_schema.clone());

let v = schema.shift_remove_index(0).unwrap().0;
assert_eq!(v, ri.name);
let input = phys_sm.insert(PhysNode::new(Arc::new(schema), phys_node));
let input = phys_sm.insert(PhysNode::new(Arc::new(schema), node_kind));

PhysNodeKind::WithRowIndex {
input,
name: ri.name,
offset: Some(ri.offset),
}
} else {
phys_node
node_kind
};

let phys_node = if let Some((offset, length)) = slice {
let input = phys_sm.insert(PhysNode::new(output_schema.clone(), phys_node));

if offset < 0 {
todo!()
}
let mut node = phys_sm.insert(PhysNode {
output_schema,
kind: node_kind,
});

PhysNodeKind::StreamingSlice {
input,
offset: offset as usize,
length,
}
} else {
phys_node
};
if let Some((offset, length)) = slice {
node = build_slice_node(node, offset, length, phys_sm);
}

if let Some(predicate) = predicate {
let input = phys_sm.insert(PhysNode::new(output_schema.clone(), phys_node));

PhysNodeKind::Filter { input, predicate }
} else {
phys_node
node = build_filter_node(node, predicate, expr_arena, phys_sm, expr_cache)?;
}

return Ok(node);
}
},

Expand Down

0 comments on commit d309fd0

Please sign in to comment.