Skip to content

Commit

Permalink
Make PruningPredicate's rewrite public (apache#12850)
Browse files Browse the repository at this point in the history
* Make PruningPredicate's rewrite public

* feedback

* Improve documentation and add default to ConstantUnhandledPredicatehook

* Update pruning.rs

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
adriangb and alamb authored Oct 13, 2024
1 parent 646f40a commit 1b10c9f
Showing 1 changed file with 188 additions and 24 deletions.
212 changes: 188 additions & 24 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ pub trait PruningStatistics {
/// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741
/// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf
/// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10
///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515
/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515
#[derive(Debug, Clone)]
pub struct PruningPredicate {
/// The input schema against which the predicate will be evaluated
Expand All @@ -478,6 +478,36 @@ pub struct PruningPredicate {
literal_guarantees: Vec<LiteralGuarantee>,
}

/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain
/// complex expressions or predicates that reference columns that are not in the
/// schema.
pub trait UnhandledPredicateHook {
/// Called when a predicate can not be rewritten in terms of statistics or
/// references a column that is not in the schema.
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
}

/// The default handling for unhandled predicates is to return a constant `true`
/// (meaning don't prune the container)
#[derive(Debug, Clone)]
struct ConstantUnhandledPredicateHook {
default: Arc<dyn PhysicalExpr>,
}

impl Default for ConstantUnhandledPredicateHook {
fn default() -> Self {
Self {
default: Arc::new(phys_expr::Literal::new(ScalarValue::from(true))),
}
}
}

impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
self.default.clone()
}
}

impl PruningPredicate {
/// Try to create a new instance of [`PruningPredicate`]
///
Expand All @@ -502,10 +532,16 @@ impl PruningPredicate {
/// See the struct level documentation on [`PruningPredicate`] for more
/// details.
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) -> Result<Self> {
let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _;

// build predicate expression once
let mut required_columns = RequiredColumns::new();
let predicate_expr =
build_predicate_expression(&expr, schema.as_ref(), &mut required_columns);
let predicate_expr = build_predicate_expression(
&expr,
schema.as_ref(),
&mut required_columns,
&unhandled_hook,
);

let literal_guarantees = LiteralGuarantee::analyze(&expr);

Expand Down Expand Up @@ -1312,27 +1348,78 @@ fn build_is_null_column_expr(
/// an OR chain
const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20;

/// Rewrite a predicate expression in terms of statistics (min/max/null_counts)
/// for use as a [`PruningPredicate`].
pub struct PredicateRewriter {
unhandled_hook: Arc<dyn UnhandledPredicateHook>,
}

impl Default for PredicateRewriter {
fn default() -> Self {
Self {
unhandled_hook: Arc::new(ConstantUnhandledPredicateHook::default()),
}
}
}

impl PredicateRewriter {
/// Create a new `PredicateRewriter`
pub fn new() -> Self {
Self::default()
}

/// Set the unhandled hook to be used when a predicate can not be rewritten
pub fn with_unhandled_hook(
self,
unhandled_hook: Arc<dyn UnhandledPredicateHook>,
) -> Self {
Self { unhandled_hook }
}

/// Translate logical filter expression into pruning predicate
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
pub fn rewrite_predicate_to_statistics_predicate(
&self,
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Arc<dyn PhysicalExpr> {
let mut required_columns = RequiredColumns::new();
build_predicate_expression(
expr,
schema,
&mut required_columns,
&self.unhandled_hook,
)
}
}

/// Translate logical filter expression into pruning predicate
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
fn build_predicate_expression(
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
required_columns: &mut RequiredColumns,
unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
) -> Arc<dyn PhysicalExpr> {
// Returned for unsupported expressions. Such expressions are
// converted to TRUE.
let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));

// predicate expression can only be a binary expression
let expr_any = expr.as_any();
if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
return build_is_null_column_expr(is_null.arg(), schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(is_not_null) = expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
return build_is_null_column_expr(
Expand All @@ -1341,19 +1428,19 @@ fn build_predicate_expression(
required_columns,
true,
)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
// match !col (don't do so recursively)
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, true)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
Expand Down Expand Up @@ -1382,9 +1469,14 @@ fn build_predicate_expression(
})
.reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _)
.unwrap();
return build_predicate_expression(&change_expr, schema, required_columns);
return build_predicate_expression(
&change_expr,
schema,
required_columns,
unhandled_hook,
);
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}

Expand All @@ -1396,21 +1488,23 @@ fn build_predicate_expression(
bin_expr.right().clone(),
)
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
};

if op == Operator::And || op == Operator::Or {
let left_expr = build_predicate_expression(&left, schema, required_columns);
let right_expr = build_predicate_expression(&right, schema, required_columns);
let left_expr =
build_predicate_expression(&left, schema, required_columns, unhandled_hook);
let right_expr =
build_predicate_expression(&right, schema, required_columns, unhandled_hook);
// simplify boolean expression if applicable
let expr = match (&left_expr, op, &right_expr) {
(left, Operator::And, _) if is_always_true(left) => right_expr,
(_, Operator::And, right) if is_always_true(right) => left_expr,
(left, Operator::Or, right)
if is_always_true(left) || is_always_true(right) =>
{
unhandled
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
}
_ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)),
};
Expand All @@ -1423,12 +1517,11 @@ fn build_predicate_expression(
Ok(builder) => builder,
// allow partial failure in predicate expression generation
// this can still produce a useful predicate when multiple conditions are joined using AND
Err(_) => {
return unhandled;
}
Err(_) => return unhandled_hook.handle(expr),
};

build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
build_statistics_expr(&mut expr_builder)
.unwrap_or_else(|_| unhandled_hook.handle(expr))
}

fn build_statistics_expr(
Expand Down Expand Up @@ -1582,6 +1675,8 @@ mod tests {
use arrow_array::UInt64Array;
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
use datafusion_functions_nested::expr_fn::{array_has, make_array};
use datafusion_physical_expr::expressions as phys_expr;
use datafusion_physical_expr::planner::logical2physical;

#[derive(Debug, Default)]
Expand Down Expand Up @@ -3397,6 +3492,74 @@ mod tests {
// TODO: add test for other case and op
}

#[test]
fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
struct CustomUnhandledHook;

impl UnhandledPredicateHook for CustomUnhandledHook {
/// This handles an arbitrary case of a column that doesn't exist in the schema
/// by renaming it to yet another column that doesn't exist in the schema
/// (the transformation is arbitrary, the point is that it can do whatever it wants)
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42))))
}
}

let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let schema_with_b = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]);

let rewriter = PredicateRewriter::new()
.with_unhandled_hook(Arc::new(CustomUnhandledHook {}));

let transform_expr = |expr| {
let expr = logical2physical(&expr, &schema_with_b);
rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema)
};

// transform an arbitrary valid expression that we know is handled
let known_expression = col("a").eq(lit(12));
let known_expression_transformed = PredicateRewriter::new()
.rewrite_predicate_to_statistics_predicate(
&logical2physical(&known_expression, &schema),
&schema,
);

// an expression referencing an unknown column (that is not in the schema) gets passed to the hook
let input = col("b").eq(lit(12));
let expected = logical2physical(&lit(42), &schema);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// more complex case with unknown column
let input = known_expression.clone().and(input.clone());
let expected = phys_expr::BinaryExpr::new(
known_expression_transformed.clone(),
Operator::And,
logical2physical(&lit(42), &schema),
);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// an unknown expression gets passed to the hook
let input = array_has(make_array(vec![lit(1)]), col("a"));
let expected = logical2physical(&lit(42), &schema);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// more complex case with unknown expression
let input = known_expression.and(input);
let expected = phys_expr::BinaryExpr::new(
known_expression_transformed.clone(),
Operator::And,
logical2physical(&lit(42), &schema),
);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());
}

#[test]
fn test_rewrite_expr_to_prunable_error() {
// cast string value to numeric value
Expand Down Expand Up @@ -3886,6 +4049,7 @@ mod tests {
required_columns: &mut RequiredColumns,
) -> Arc<dyn PhysicalExpr> {
let expr = logical2physical(expr, schema);
build_predicate_expression(&expr, schema, required_columns)
let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _;
build_predicate_expression(&expr, schema, required_columns, &unhandled_hook)
}
}

0 comments on commit 1b10c9f

Please sign in to comment.