diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 5a88604716d2..b80b8b3b3435 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -16,30 +16,29 @@ // under the License. use arrow::array::builder::{Int32Builder, StringBuilder}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::{Array, ArrayRef, Int32Array}; +use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_common::ScalarValue; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column, Literal}; +use datafusion_physical_expr::expressions::{case, col, lit, BinaryExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; -fn make_col(name: &str, index: usize) -> Arc { - Arc::new(Column::new(name, index)) +fn make_x_cmp_y( + x: &Arc, + op: Operator, + y: i32, +) -> Arc { + Arc::new(BinaryExpr::new(Arc::clone(x), op, lit(y))) } -fn make_lit_i32(n: i32) -> Arc { - Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) -} - -fn criterion_benchmark(c: &mut Criterion) { - // create input data +fn make_batch(row_count: usize, column_count: usize) -> RecordBatch { let mut c1 = Int32Builder::new(); let mut c2 = StringBuilder::new(); let mut c3 = StringBuilder::new(); - for i in 0..1000 { - c1.append_value(i); + for i in 0..row_count { + c1.append_value(i as i32); if i % 7 == 0 { c2.append_null(); } else { @@ -54,69 +53,134 @@ fn criterion_benchmark(c: &mut Criterion) { let c1 = Arc::new(c1.finish()); let c2 = Arc::new(c2.finish()); let c3 = Arc::new(c3.finish()); - let schema = Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - Field::new("c3", DataType::Utf8, true), - ]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); - - // use same predicate for all benchmarks - let predicate = Arc::new(BinaryExpr::new( - make_col("c1", 0), - Operator::LtEq, - make_lit_i32(500), - )); + let mut columns: Vec = vec![c1, c2, c3]; + for _ in 3..column_count { + columns.push(Arc::new(Int32Array::from_value(0, row_count))); + } - // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END - c.bench_function("case_when: scalar or scalar", |b| { - let expr = Arc::new( - CaseExpr::try_new( - None, - vec![(predicate.clone(), make_lit_i32(1))], - Some(make_lit_i32(0)), + let fields = columns + .iter() + .enumerate() + .map(|(i, c)| { + Field::new( + format!("c{}", i + 1), + c.data_type().clone(), + c.is_nullable(), ) - .unwrap(), - ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) - }); + }) + .collect::>(); - // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END - c.bench_function("case_when: column or null", |b| { - let expr = Arc::new( - CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None) + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(Arc::clone(&schema), columns).unwrap() +} + +fn criterion_benchmark(c: &mut Criterion) { + run_benchmarks(c, &make_batch(8192, 3)); + run_benchmarks(c, &make_batch(8192, 50)); + run_benchmarks(c, &make_batch(8192, 100)); +} + +fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { + let c1 = col("c1", &batch.schema()).unwrap(); + let c2 = col("c2", &batch.schema()).unwrap(); + let c3 = col("c3", &batch.schema()).unwrap(); + + c.bench_function( + format!( + "case_when {}x{}: CASE WHEN c1 <= 500 THEN 1 ELSE 0 END", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + let expr = Arc::new( + case( + None, + vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), lit(1))], + Some(lit(0)), + &batch.schema(), + ) .unwrap(), - ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) - }); + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); - // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END - c.bench_function("case_when: expr or expr", |b| { - let expr = Arc::new( - CaseExpr::try_new( - None, - vec![(predicate.clone(), make_col("c2", 1))], - Some(make_col("c3", 2)), - ) - .unwrap(), - ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) - }); + c.bench_function( + format!( + "case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + let expr = Arc::new( + case( + None, + vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), Arc::clone(&c2))], + None, + &batch.schema(), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); + + c.bench_function( + format!( + "case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 ELSE c3 END", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + let expr = Arc::new( + case( + None, + vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), Arc::clone(&c2))], + Some(Arc::clone(&c3)), + &batch.schema(), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); - // CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END - c.bench_function("case_when: CASE expr", |b| { + c.bench_function( + format!( + "case_when {}x{}: CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + let expr = Arc::new( + case( + Some(Arc::clone(&c1)), + vec![(lit(1), Arc::clone(&c2)), (lit(2), Arc::clone(&c3))], + None, + &batch.schema(), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); + + c.bench_function(format!("case_when {}x{}: CASE WHEN c1 == 0 THEN 0 WHEN c1 == 1 THEN 1 ... WHEN c1 == n THEN n ELSE n + 1 END", batch.num_rows(), batch.num_columns()).as_str(), |b| { + let when_thens = (0..batch.num_rows() as i32).map(|i| (make_x_cmp_y(&c1, Operator::Eq, i), lit(i))).collect(); let expr = Arc::new( - CaseExpr::try_new( - Some(make_col("c1", 0)), - vec![ - (make_lit_i32(1), make_col("c2", 1)), - (make_lit_i32(2), make_col("c3", 2)), - ], + case( None, + when_thens, + Some(lit(batch.num_rows() as i32)), + &batch.schema(), ) - .unwrap(), + .unwrap(), ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) }); } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index d14146a20d8b..10450755183f 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -16,7 +16,7 @@ // under the License. use crate::expressions::try_cast; -use crate::PhysicalExpr; +use crate::{expressions, PhysicalExpr}; use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; @@ -600,8 +600,20 @@ pub fn case( expr: Option>, when_thens: Vec, else_expr: Option>, + input_schema: &Schema, ) -> Result> { - Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) + let case_expr = CaseExpr::try_new(expr, when_thens, else_expr)?; + + match case_expr.eval_method { + EvalMethod::NoExpression + | EvalMethod::WithExpression + | EvalMethod::ExpressionOrExpression => { + expressions::projected(Arc::new(case_expr), input_schema) + } + EvalMethod::InfallibleExprOrNull | EvalMethod::ScalarOrScalar => { + Ok(Arc::new(case_expr)) + } + } } #[cfg(test)] @@ -1381,7 +1393,7 @@ mod tests { Ok((left, right)) } }?; - case(expr, when_thens, else_expr) + case(expr, when_thens, else_expr, input_schema) } fn get_case_common_type( diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 59d675753d98..78671eb422fa 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -32,6 +32,7 @@ mod literal; mod negative; mod no_op; mod not; +mod projected; mod try_cast; mod unknown_column; @@ -54,5 +55,6 @@ pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; +pub use projected::{projected, ProjectedExpr}; pub use try_cast::{try_cast, TryCastExpr}; pub use unknown_column::UnKnownColumn; diff --git a/datafusion/physical-expr/src/expressions/projected.rs b/datafusion/physical-expr/src/expressions/projected.rs new file mode 100644 index 000000000000..bc3a05812c46 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/projected.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::Column; +use crate::utils::collect_column_indices; +use arrow::array::RecordBatch; +use arrow::datatypes::{FieldRef, Schema}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{HashMap, HashSet, Result}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::Hash; +use std::sync::Arc; + +/// A wrapper around a physical expression that optimizes evaluation by projecting only +/// the columns that are actually needed from the input batch. +/// +/// This expression analyzes the wrapped expression to determine which columns it references +/// and creates a "projected" version that operates on a subset of the input record batch containing +/// only those columns. +/// +/// For most expressions this projection step is not useful. For conditional expressions like `case` +/// which make extensive use of [PhysicalExpr::evaluate_selection] though this can make a +/// significant difference due to the reduction in the number of arrays that need to be filtered. +#[derive(Debug, Clone, Eq)] +pub struct ProjectedExpr { + /// The original expression before projection + original: Arc, + /// The rewritten expression that operates on projected column indices + projected: Arc, + projection: Vec, +} + +impl ProjectedExpr { + /// Create a new projected expression wrapper if needed. + /// + /// # Arguments + /// * `inner` - The expression to potentially wrap + /// * `input_schema` - The schema of the input data + /// + /// # Returns + /// Returns the original expression if it uses all columns from the input schema, + /// otherwise returns a new [ProjectedExpr] that only references the required columns. + pub fn maybe_wrap( + inner: Arc, + input_schema: &Schema, + ) -> Result> { + let column_indices = collect_column_indices(&inner); + if column_indices.len() == input_schema.fields().len() { + Ok(inner) + } else { + Ok(Arc::new(Self::try_new_with_column_indices( + inner, + &column_indices, + )?)) + } + } + + pub fn try_new(inner: Arc) -> Result { + let column_indices = collect_column_indices(&inner); + Self::try_new_with_column_indices(inner, &column_indices) + } + + fn try_new_with_column_indices( + inner: Arc, + column_indices: &HashSet, + ) -> Result { + let mut column_index_map = HashMap::::new(); + column_indices.iter().enumerate().for_each(|(i, c)| { + column_index_map.insert(*c, i); + }); + + let projected = Arc::clone(&inner) + .transform_down(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + let projected_index = *column_index_map.get(&column.index()).unwrap(); + if projected_index != column.index() { + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + projected_index, + )))); + } + } + Ok(Transformed::no(expr)) + }) + .data()?; + + let projection = column_index_map + .iter() + .sorted_by_key(|(_, v)| **v) + .map(|(k, _)| *k) + .collect::>(); + + Ok(Self { + original: Arc::clone(&inner), + projected, + projection, + }) + } +} + +/// Create a projected expression +pub fn projected( + expr: Arc, + input_schema: &Schema, +) -> Result> { + ProjectedExpr::maybe_wrap(expr, input_schema) +} + +impl PhysicalExpr for ProjectedExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn return_field(&self, input_schema: &Schema) -> Result { + self.original.return_field(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let projected_batch = batch.project(&self.projection)?; + self.projected.evaluate(&projected_batch) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.original] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::try_new( + children.into_iter().next().unwrap(), + )?)) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.original.fmt_sql(f) + } + + fn is_volatile_node(&self) -> bool { + self.projected.is_volatile_node() + } +} + +// Manually derive `PartialEq`/`Hash` as `Arc` does not +// implement these traits by default for the trait object. +impl PartialEq for ProjectedExpr { + fn eq(&self, other: &Self) -> bool { + self.original.eq(&other.original) && self.projection.eq(&other.projection) + } +} + +impl Hash for ProjectedExpr { + fn hash(&self, state: &mut H) { + self.original.hash(state); + self.projection.hash(state); + } +} + +impl Display for ProjectedExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.original.fmt(f) + } +} diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 73df60c42e96..16d1cb6488a1 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -286,7 +286,8 @@ pub fn create_physical_expr( } else { None }; - Ok(expressions::case(expr, when_then_expr, else_expr)?) + + expressions::case(expr, when_then_expr, else_expr, input_schema) } Expr::Cast(Cast { expr, data_type }) => expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 745ae855efee..9de6ed258e68 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -238,6 +238,19 @@ pub fn collect_columns(expr: &Arc) -> HashSet { columns } +pub fn collect_column_indices(expr: &Arc) -> HashSet { + let mut columns = HashSet::::new(); + expr.apply(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + columns.insert(column.index()); + } + Ok(TreeNodeRecursion::Continue) + }) + // pre_visit always returns OK, so this will always too + .expect("no way to return error during recursion"); + columns +} + /// Re-assign indices of [`Column`]s within the given [`PhysicalExpr`] according to /// the provided [`Schema`]. ///