diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index a5b291e35a98..d67813ea5329 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -88,8 +88,8 @@ impl ConstExpr { } } - pub fn with_value(mut self, value: ScalarValue) -> Self { - self.value = Some(value); + pub fn with_value(mut self, value: Option) -> Self { + self.value = value; self } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 751edb1f72ed..759b2c17ceb4 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -259,31 +259,17 @@ impl EquivalenceProperties { if self.is_expr_constant(left) { // Left expression is constant, add right as constant if !const_exprs_contains(&self.constants, right) { - // Try to get value from left constant expression - let value = left - .as_any() - .downcast_ref::() - .map(|lit| lit.value().clone()); - - let mut const_expr = ConstExpr::from(right).with_across_partitions(true); - if let Some(val) = value { - const_expr = const_expr.with_value(val); - } + let const_expr = ConstExpr::from(right) + .with_across_partitions(true) + .with_value(self.get_expr_constant_value(left)); self.constants.push(const_expr); } } else if self.is_expr_constant(right) { // Right expression is constant, add left as constant if !const_exprs_contains(&self.constants, left) { - // Try to get value from right constant expression - let value = right - .as_any() - .downcast_ref::() - .map(|lit| lit.value().clone()); - - let mut const_expr = ConstExpr::from(left).with_across_partitions(true); - if let Some(val) = value { - const_expr = const_expr.with_value(val); - } + let const_expr = ConstExpr::from(left) + .with_across_partitions(true) + .with_value(self.get_expr_constant_value(right)); self.constants.push(const_expr); } } @@ -325,12 +311,9 @@ impl EquivalenceProperties { return None; } - let mut const_expr = ConstExpr::from(normalized_expr) - .with_across_partitions(across_partitions); - - if let Some(value) = value { - const_expr = const_expr.with_value(value); - } + let const_expr = ConstExpr::from(normalized_expr) + .with_across_partitions(across_partitions) + .with_value(value); Some(const_expr) }) @@ -901,12 +884,9 @@ impl EquivalenceProperties { const_expr .map(|expr| self.eq_group.project_expr(mapping, expr)) .map(|projected_expr| { - let mut new_const_expr = projected_expr - .with_across_partitions(const_expr.across_partitions()); - if let Some(value) = const_expr.value() { - new_const_expr = new_const_expr.with_value(value.clone()); - } - new_const_expr + projected_expr + .with_across_partitions(const_expr.across_partitions()) + .with_value(const_expr.value().cloned()) }) }) .collect::>(); @@ -917,20 +897,14 @@ impl EquivalenceProperties { && !const_exprs_contains(&projected_constants, target) { let across_partitions = self.is_expr_constant_accross_partitions(source); - // Try to get value from source constant expression - let value = self - .constants - .iter() - .find(|c| c.expr().eq(source)) - .and_then(|c| c.value().cloned()); + let value = self.get_expr_constant_value(source); // Expression evaluates to single value - let mut const_expr = - ConstExpr::from(target).with_across_partitions(across_partitions); - if let Some(val) = value { - const_expr = const_expr.with_value(val); - } - projected_constants.push(const_expr); + projected_constants.push( + ConstExpr::from(target) + .with_across_partitions(across_partitions) + .with_value(value), + ); } } projected_constants @@ -1178,12 +1152,9 @@ impl EquivalenceProperties { let across_partitions = const_expr.across_partitions(); let value = const_expr.value().cloned(); let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?; - let mut new_const_expr = ConstExpr::new(new_const_expr) - .with_across_partitions(across_partitions); - if let Some(value) = value { - new_const_expr = new_const_expr.with_value(value.clone()); - } - Ok(new_const_expr) + Ok(ConstExpr::new(new_const_expr) + .with_across_partitions(across_partitions) + .with_value(value)) }) .collect::>>()?; @@ -1962,7 +1933,7 @@ fn calculate_union_binary( if lhs_val == rhs_val { const_expr = const_expr .with_across_partitions(true) - .with_value(lhs_val.clone()); + .with_value(Some(lhs_val.clone())); } } const_expr @@ -3786,13 +3757,13 @@ mod tests { // Create first input with a=10 let const_expr1 = - ConstExpr::new(Arc::clone(&col_a)).with_value(literal_10.clone()); + ConstExpr::new(Arc::clone(&col_a)).with_value(Some(literal_10.clone())); let input1 = EquivalenceProperties::new(Arc::clone(&schema)) .with_constants(vec![const_expr1]); // Create second input with a=10 let const_expr2 = - ConstExpr::new(Arc::clone(&col_a)).with_value(literal_10.clone()); + ConstExpr::new(Arc::clone(&col_a)).with_value(Some(literal_10.clone())); let input2 = EquivalenceProperties::new(Arc::clone(&schema)) .with_constants(vec![const_expr2]); diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 8539cbaac8b6..3b21bb0b8f6f 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -218,29 +218,25 @@ impl FilterExec { if binary.op() == &Operator::Eq { // Filter evaluates to single value for all partitions if input_eqs.is_expr_constant(binary.left()) { - // When left side is constant, extract value from right side if it's a literal let (expr, value) = ( binary.right(), input_eqs.get_expr_constant_value(binary.right()), ); - let mut const_expr = - ConstExpr::from(expr).with_across_partitions(true); - if let Some(value) = value { - const_expr = const_expr.with_value(value.clone()); - } - res_constants.push(const_expr); + res_constants.push( + ConstExpr::new(Arc::clone(expr)) + .with_across_partitions(true) + .with_value(value), + ); } else if input_eqs.is_expr_constant(binary.right()) { - // When right side is constant, extract value from left side if it's a literal let (expr, value) = ( binary.left(), input_eqs.get_expr_constant_value(binary.left()), ); - let mut const_expr = - ConstExpr::from(expr).with_across_partitions(true); - if let Some(value) = value { - const_expr = const_expr.with_value(value.clone()); - } - res_constants.push(const_expr); + res_constants.push( + ConstExpr::new(Arc::clone(expr)) + .with_across_partitions(true) + .with_value(value), + ); } } } @@ -272,11 +268,9 @@ impl FilterExec { .min_value .get_value(); let expr = Arc::new(column) as _; - let mut const_expr = ConstExpr::new(expr).with_across_partitions(true); - if let Some(value) = value { - const_expr = const_expr.with_value(value.clone()); - } - const_expr + ConstExpr::new(expr) + .with_across_partitions(true) + .with_value(value.cloned()) }); // This is for statistics eq_properties = eq_properties.with_constants(constants);