Skip to content

Commit

Permalink
Make with_value() accept optional value
Browse files Browse the repository at this point in the history
  • Loading branch information
gokselk committed Dec 23, 2024
1 parent 57913f8 commit 1917c0e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 74 deletions.
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ impl ConstExpr {
}
}

pub fn with_value(mut self, value: ScalarValue) -> Self {
self.value = Some(value);
pub fn with_value(mut self, value: Option<ScalarValue>) -> Self {
self.value = value;
self
}

Expand Down
77 changes: 24 additions & 53 deletions datafusion/physical-expr/src/equivalence/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,31 +259,17 @@ impl EquivalenceProperties {
if self.is_expr_constant(left) {
// Left expression is constant, add right as constant
if !const_exprs_contains(&self.constants, right) {
// Try to get value from left constant expression
let value = left
.as_any()
.downcast_ref::<Literal>()
.map(|lit| lit.value().clone());

let mut const_expr = ConstExpr::from(right).with_across_partitions(true);
if let Some(val) = value {
const_expr = const_expr.with_value(val);
}
let const_expr = ConstExpr::from(right)
.with_across_partitions(true)
.with_value(self.get_expr_constant_value(left));
self.constants.push(const_expr);
}
} else if self.is_expr_constant(right) {
// Right expression is constant, add left as constant
if !const_exprs_contains(&self.constants, left) {
// Try to get value from right constant expression
let value = right
.as_any()
.downcast_ref::<Literal>()
.map(|lit| lit.value().clone());

let mut const_expr = ConstExpr::from(left).with_across_partitions(true);
if let Some(val) = value {
const_expr = const_expr.with_value(val);
}
let const_expr = ConstExpr::from(left)
.with_across_partitions(true)
.with_value(self.get_expr_constant_value(right));
self.constants.push(const_expr);
}
}
Expand Down Expand Up @@ -325,12 +311,9 @@ impl EquivalenceProperties {
return None;
}

let mut const_expr = ConstExpr::from(normalized_expr)
.with_across_partitions(across_partitions);

if let Some(value) = value {
const_expr = const_expr.with_value(value);
}
let const_expr = ConstExpr::from(normalized_expr)
.with_across_partitions(across_partitions)
.with_value(value);

Some(const_expr)
})
Expand Down Expand Up @@ -901,12 +884,9 @@ impl EquivalenceProperties {
const_expr
.map(|expr| self.eq_group.project_expr(mapping, expr))
.map(|projected_expr| {
let mut new_const_expr = projected_expr
.with_across_partitions(const_expr.across_partitions());
if let Some(value) = const_expr.value() {
new_const_expr = new_const_expr.with_value(value.clone());
}
new_const_expr
projected_expr
.with_across_partitions(const_expr.across_partitions())
.with_value(const_expr.value().cloned())
})
})
.collect::<Vec<_>>();
Expand All @@ -917,20 +897,14 @@ impl EquivalenceProperties {
&& !const_exprs_contains(&projected_constants, target)
{
let across_partitions = self.is_expr_constant_accross_partitions(source);
// Try to get value from source constant expression
let value = self
.constants
.iter()
.find(|c| c.expr().eq(source))
.and_then(|c| c.value().cloned());
let value = self.get_expr_constant_value(source);

// Expression evaluates to single value
let mut const_expr =
ConstExpr::from(target).with_across_partitions(across_partitions);
if let Some(val) = value {
const_expr = const_expr.with_value(val);
}
projected_constants.push(const_expr);
projected_constants.push(
ConstExpr::from(target)
.with_across_partitions(across_partitions)
.with_value(value),
);
}
}
projected_constants
Expand Down Expand Up @@ -1178,12 +1152,9 @@ impl EquivalenceProperties {
let across_partitions = const_expr.across_partitions();
let value = const_expr.value().cloned();
let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?;
let mut new_const_expr = ConstExpr::new(new_const_expr)
.with_across_partitions(across_partitions);
if let Some(value) = value {
new_const_expr = new_const_expr.with_value(value.clone());
}
Ok(new_const_expr)
Ok(ConstExpr::new(new_const_expr)
.with_across_partitions(across_partitions)
.with_value(value))
})
.collect::<Result<Vec<_>>>()?;

Expand Down Expand Up @@ -1962,7 +1933,7 @@ fn calculate_union_binary(
if lhs_val == rhs_val {
const_expr = const_expr
.with_across_partitions(true)
.with_value(lhs_val.clone());
.with_value(Some(lhs_val.clone()));
}
}
const_expr
Expand Down Expand Up @@ -3786,13 +3757,13 @@ mod tests {

// Create first input with a=10
let const_expr1 =
ConstExpr::new(Arc::clone(&col_a)).with_value(literal_10.clone());
ConstExpr::new(Arc::clone(&col_a)).with_value(Some(literal_10.clone()));
let input1 = EquivalenceProperties::new(Arc::clone(&schema))
.with_constants(vec![const_expr1]);

// Create second input with a=10
let const_expr2 =
ConstExpr::new(Arc::clone(&col_a)).with_value(literal_10.clone());
ConstExpr::new(Arc::clone(&col_a)).with_value(Some(literal_10.clone()));
let input2 = EquivalenceProperties::new(Arc::clone(&schema))
.with_constants(vec![const_expr2]);

Expand Down
32 changes: 13 additions & 19 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,29 +218,25 @@ impl FilterExec {
if binary.op() == &Operator::Eq {
// Filter evaluates to single value for all partitions
if input_eqs.is_expr_constant(binary.left()) {
// When left side is constant, extract value from right side if it's a literal
let (expr, value) = (
binary.right(),
input_eqs.get_expr_constant_value(binary.right()),
);
let mut const_expr =
ConstExpr::from(expr).with_across_partitions(true);
if let Some(value) = value {
const_expr = const_expr.with_value(value.clone());
}
res_constants.push(const_expr);
res_constants.push(
ConstExpr::new(Arc::clone(expr))
.with_across_partitions(true)
.with_value(value),
);
} else if input_eqs.is_expr_constant(binary.right()) {
// When right side is constant, extract value from left side if it's a literal
let (expr, value) = (
binary.left(),
input_eqs.get_expr_constant_value(binary.left()),
);
let mut const_expr =
ConstExpr::from(expr).with_across_partitions(true);
if let Some(value) = value {
const_expr = const_expr.with_value(value.clone());
}
res_constants.push(const_expr);
res_constants.push(
ConstExpr::new(Arc::clone(expr))
.with_across_partitions(true)
.with_value(value),
);
}
}
}
Expand Down Expand Up @@ -272,11 +268,9 @@ impl FilterExec {
.min_value
.get_value();
let expr = Arc::new(column) as _;
let mut const_expr = ConstExpr::new(expr).with_across_partitions(true);
if let Some(value) = value {
const_expr = const_expr.with_value(value.clone());
}
const_expr
ConstExpr::new(expr)
.with_across_partitions(true)
.with_value(value.cloned())
});
// This is for statistics
eq_properties = eq_properties.with_constants(constants);
Expand Down

0 comments on commit 1917c0e

Please sign in to comment.