Skip to content

Commit

Permalink
Simplify constexpr
Browse files Browse the repository at this point in the history
  • Loading branch information
berkaysynnada committed Dec 25, 2024
1 parent 25e95f4 commit bee8667
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 111 deletions.
72 changes: 40 additions & 32 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ use indexmap::{IndexMap, IndexSet};
// }
// ```
//
// This would provide a more type-safe representation of partition values.
// This would provide more flexible representation of partition values.
// Note: This is a breaking change for the equivalence API and should be
// addressed in a separate issue/PR.
#[derive(Debug, Clone)]
Expand All @@ -73,16 +73,32 @@ pub struct ConstExpr {
expr: Arc<dyn PhysicalExpr>,
/// Does the constant have the same value across all partitions? See
/// struct docs for more details
across_partitions: bool,
/// The value of the constant expression
value: Option<ScalarValue>,
across_partitions: AcrossPartitions,
}

#[derive(PartialEq, Clone, Debug)]
/// Represents whether a constant expression's value is uniform or varies across partitions.
///
/// The `AcrossPartitions` enum is used to describe the nature of a constant expression
/// in a physical execution plan:
///
/// - `Heterogeneous`: The constant expression may have different values for different partitions.
/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same value across all partitions,
/// or is `None` if the value is not specified.
pub enum AcrossPartitions {
Heterogeneous,
Uniform(Option<ScalarValue>),
}

impl Default for AcrossPartitions {
fn default() -> Self {
Self::Heterogeneous
}
}

impl PartialEq for ConstExpr {
fn eq(&self, other: &Self) -> bool {
self.across_partitions == other.across_partitions
&& self.expr.eq(&other.expr)
&& self.value == other.value
self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
}
}

Expand All @@ -95,29 +111,23 @@ impl ConstExpr {
Self {
expr,
// By default, assume constant expressions are not same across partitions.
across_partitions: false,
value: None,
across_partitions: Default::default(),
}
}

pub fn with_value(mut self, value: Option<ScalarValue>) -> Self {
self.value = value;
self
}

/// Set the `across_partitions` flag
///
/// See struct docs for more details
pub fn with_across_partitions(mut self, across_partitions: bool) -> Self {
pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self {
self.across_partitions = across_partitions;
self
}

/// Is the expression the same across all partitions?
///
/// See struct docs for more details
pub fn across_partitions(&self) -> bool {
self.across_partitions
pub fn across_partitions(&self) -> AcrossPartitions {
self.across_partitions.clone()
}

pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
Expand All @@ -128,19 +138,14 @@ impl ConstExpr {
self.expr
}

pub fn value(&self) -> Option<&ScalarValue> {
self.value.as_ref()
}

pub fn map<F>(&self, f: F) -> Option<Self>
where
F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>,
{
let maybe_expr = f(&self.expr);
maybe_expr.map(|expr| Self {
expr,
across_partitions: self.across_partitions,
value: self.value.clone(),
across_partitions: self.across_partitions.clone(),
})
}

Expand Down Expand Up @@ -170,17 +175,20 @@ impl ConstExpr {
}
}

/// Display implementation for `ConstExpr`
///
/// Example `c` or `c(across_partitions)`
impl Display for ConstExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.expr)?;
if self.across_partitions {
write!(f, "(across_partitions)")?;
}
if let Some(value) = self.value.as_ref() {
write!(f, "({})", value)?;
match &self.across_partitions {
AcrossPartitions::Heterogeneous => {
write!(f, "(heterogeneous)")?;
}
AcrossPartitions::Uniform(value) => {
if let Some(val) = value {
write!(f, "(uniform: {})", val)?;
} else {
write!(f, "(uniform: unknown)")?;
}
}
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/equivalence/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mod ordering;
mod projection;
mod properties;

pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup};
pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup};
pub use ordering::OrderingEquivalenceClass;
pub use projection::ProjectionMapping;
pub use properties::{
Expand Down
9 changes: 5 additions & 4 deletions datafusion/physical-expr/src/equivalence/ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ mod tests {
};
use crate::expressions::{col, BinaryExpr, Column};
use crate::utils::tests::TestScalarUDF;
use crate::{ConstExpr, PhysicalExpr, PhysicalSortExpr};
use crate::{AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalSortExpr};

use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::SortOptions;
Expand Down Expand Up @@ -583,9 +583,10 @@ mod tests {
let eq_group = EquivalenceGroup::new(eq_group);
eq_properties.add_equivalence_group(eq_group);

let constants = constants
.into_iter()
.map(|expr| ConstExpr::from(expr).with_across_partitions(true));
let constants = constants.into_iter().map(|expr| {
ConstExpr::from(expr)
.with_across_partitions(AcrossPartitions::Uniform(None))
});
eq_properties = eq_properties.with_constants(constants);

let reqs = convert_to_sort_exprs(&reqs);
Expand Down
Loading

0 comments on commit bee8667

Please sign in to comment.