Skip to content

Commit

Permalink
perf: Cache dtype on ExprIR (#20331)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 17, 2024
1 parent 6142ee5 commit b82b2b2
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 17 deletions.
4 changes: 2 additions & 2 deletions crates/polars-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static ERROR_STRATEGY: LazyLock<ErrorStrategy> = LazyLock::new(|| {
}
});

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ErrString(Cow<'static, str>);

impl ErrString {
Expand Down Expand Up @@ -74,7 +74,7 @@ impl Display for ErrString {
}
}

#[derive(Debug, thiserror::Error)]
#[derive(Debug, Clone, thiserror::Error)]
pub enum PolarsError {
#[error("not found: {0}")]
ColumnNotFound(ErrString),
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-plan/src/plans/builder_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ impl<'a> IRBuilder<'a> {
.to_field(&schema, Context::Default, self.expr_arena)
.unwrap();

expr_irs.push(ExprIR::new(node, OutputName::ColumnLhs(field.name.clone())));
expr_irs.push(
ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))
.with_dtype(field.dtype.clone()),
);
new_schema.with_column(field.name().clone(), field.dtype().clone());
}

Expand Down
5 changes: 1 addition & 4 deletions crates/polars-plan/src/plans/conversion/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,10 +528,7 @@ fn raise_supertype(
) -> PolarsResult<()> {
let dtypes = inputs
.iter()
.map(|e| {
let ae = expr_arena.get(e.node());
ae.to_dtype(input_schema, Context::Default, expr_arena)
})
.map(|e| e.dtype(input_schema, Context::Default, expr_arena).cloned())
.collect::<PolarsResult<Vec<_>>>()?;

let st = dtypes
Expand Down
71 changes: 69 additions & 2 deletions crates/polars-plan/src/plans/expr_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::Borrow;
use std::hash::Hash;
#[cfg(feature = "cse")]
use std::hash::Hasher;
use std::sync::OnceLock;

use polars_utils::format_pl_smallstr;
#[cfg(feature = "ir_serde")]
Expand Down Expand Up @@ -48,14 +49,40 @@ impl OutputName {
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Debug)]
#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
pub struct ExprIR {
/// Output name of this expression.
output_name: OutputName,
/// Output dtype of this expression
/// Reduced expression.
/// This expression is pruned from `alias` and already expanded.
node: Node,
#[cfg_attr(feature = "ir_serde", serde(skip))]
output_dtype: OnceLock<DataType>,
}

impl Eq for ExprIR {}

impl PartialEq for ExprIR {
fn eq(&self, other: &Self) -> bool {
self.node == other.node && self.output_name == other.output_name
}
}

impl Clone for ExprIR {
fn clone(&self) -> Self {
let output_dtype = OnceLock::new();
if let Some(dt) = self.output_dtype.get() {
output_dtype.set(dt.clone()).unwrap()
}

ExprIR {
output_name: self.output_name.clone(),
node: self.node,
output_dtype,
}
}
}

impl Borrow<Node> for ExprIR {
Expand All @@ -67,13 +94,23 @@ impl Borrow<Node> for ExprIR {
impl ExprIR {
pub fn new(node: Node, output_name: OutputName) -> Self {
debug_assert!(!output_name.is_none());
ExprIR { output_name, node }
ExprIR {
output_name,
node,
output_dtype: OnceLock::new(),
}
}

pub fn with_dtype(self, dtype: DataType) -> Self {
let _ = self.output_dtype.set(dtype);
self
}

pub fn from_node(node: Node, arena: &Arena<AExpr>) -> Self {
let mut out = Self {
node,
output_name: OutputName::None,
output_dtype: OnceLock::new(),
};
out.node = node;
for (_, ae) in arena.iter(node) {
Expand Down Expand Up @@ -149,6 +186,7 @@ impl ExprIR {

pub(crate) fn set_node(&mut self, node: Node) {
self.node = node;
self.output_dtype = OnceLock::new();
}

#[cfg(feature = "cse")]
Expand Down Expand Up @@ -206,6 +244,35 @@ impl ExprIR {
pub fn is_scalar(&self, expr_arena: &Arena<AExpr>) -> bool {
is_scalar_ae(self.node, expr_arena)
}

pub fn dtype(
&self,
schema: &Schema,
ctxt: Context,
expr_arena: &Arena<AExpr>,
) -> PolarsResult<&DataType> {
match self.output_dtype.get() {
Some(dtype) => Ok(dtype),
None => {
let dtype = expr_arena
.get(self.node)
.to_dtype(schema, ctxt, expr_arena)?;
let _ = self.output_dtype.set(dtype);
Ok(self.output_dtype.get().unwrap())
},
}
}

pub fn field(
&self,
schema: &Schema,
ctxt: Context,
expr_arena: &Arena<AExpr>,
) -> PolarsResult<Field> {
let dtype = self.dtype(schema, ctxt, expr_arena)?;
let name = self.output_name();
Ok(Field::new(name.clone(), dtype.clone()))
}
}

impl AsRef<ExprIR> for ExprIR {
Expand Down
6 changes: 2 additions & 4 deletions crates/polars-plan/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,9 @@ pub(crate) fn expr_irs_to_schema<I: IntoIterator<Item = K>, K: AsRef<ExprIR>>(
expr.into_iter()
.map(|e| {
let e = e.as_ref();
let mut field = arena
.get(e.node())
.to_field(schema, ctxt, arena)
.expect("should be resolved");
let mut field = e.field(schema, ctxt, arena).expect("should be resolved");

// TODO! (can this be removed?)
if let Some(name) = e.get_alias() {
field.name = name.clone()
}
Expand Down
5 changes: 1 addition & 4 deletions crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,10 +648,7 @@ pub fn compute_output_schema(
.iter()
.map(|e| {
let name = e.output_name().clone();
let dtype =
expr_arena
.get(e.node())
.to_dtype(input_schema, Context::Default, expr_arena)?;
let dtype = e.dtype(input_schema, Context::Default, expr_arena)?.clone();
PolarsResult::Ok(Field::new(name, dtype))
})
.try_collect()?;
Expand Down

0 comments on commit b82b2b2

Please sign in to comment.