diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs index c19e90ec20b3..5e06799b997a 100644 --- a/crates/polars-error/src/lib.rs +++ b/crates/polars-error/src/lib.rs @@ -28,7 +28,7 @@ static ERROR_STRATEGY: LazyLock = LazyLock::new(|| { } }); -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ErrString(Cow<'static, str>); impl ErrString { @@ -74,7 +74,7 @@ impl Display for ErrString { } } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, Clone, thiserror::Error)] pub enum PolarsError { #[error("not found: {0}")] ColumnNotFound(ErrString), diff --git a/crates/polars-plan/src/plans/builder_ir.rs b/crates/polars-plan/src/plans/builder_ir.rs index 7eddfdfea5da..f60c1d5ec95a 100644 --- a/crates/polars-plan/src/plans/builder_ir.rs +++ b/crates/polars-plan/src/plans/builder_ir.rs @@ -182,7 +182,10 @@ impl<'a> IRBuilder<'a> { .to_field(&schema, Context::Default, self.expr_arena) .unwrap(); - expr_irs.push(ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))); + expr_irs.push( + ExprIR::new(node, OutputName::ColumnLhs(field.name.clone())) + .with_dtype(field.dtype.clone()), + ); new_schema.with_column(field.name().clone(), field.dtype().clone()); } diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index 7dac7ed0dc60..b3ba2861a115 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -528,10 +528,7 @@ fn raise_supertype( ) -> PolarsResult<()> { let dtypes = inputs .iter() - .map(|e| { - let ae = expr_arena.get(e.node()); - ae.to_dtype(input_schema, Context::Default, expr_arena) - }) + .map(|e| e.dtype(input_schema, Context::Default, expr_arena).cloned()) .collect::>>()?; let st = dtypes diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index 72e5d30e3541..c61aa21f03e6 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -2,6 +2,7 @@ use std::borrow::Borrow; use std::hash::Hash; #[cfg(feature = "cse")] use std::hash::Hasher; +use std::sync::OnceLock; use polars_utils::format_pl_smallstr; #[cfg(feature = "ir_serde")] @@ -48,14 +49,40 @@ impl OutputName { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug)] #[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub struct ExprIR { /// Output name of this expression. output_name: OutputName, + /// Output dtype of this expression /// Reduced expression. /// This expression is pruned from `alias` and already expanded. node: Node, + #[cfg_attr(feature = "ir_serde", serde(skip))] + output_dtype: OnceLock, +} + +impl Eq for ExprIR {} + +impl PartialEq for ExprIR { + fn eq(&self, other: &Self) -> bool { + self.node == other.node && self.output_name == other.output_name + } +} + +impl Clone for ExprIR { + fn clone(&self) -> Self { + let output_dtype = OnceLock::new(); + if let Some(dt) = self.output_dtype.get() { + output_dtype.set(dt.clone()).unwrap() + } + + ExprIR { + output_name: self.output_name.clone(), + node: self.node, + output_dtype, + } + } } impl Borrow for ExprIR { @@ -67,13 +94,23 @@ impl Borrow for ExprIR { impl ExprIR { pub fn new(node: Node, output_name: OutputName) -> Self { debug_assert!(!output_name.is_none()); - ExprIR { output_name, node } + ExprIR { + output_name, + node, + output_dtype: OnceLock::new(), + } + } + + pub fn with_dtype(self, dtype: DataType) -> Self { + let _ = self.output_dtype.set(dtype); + self } pub fn from_node(node: Node, arena: &Arena) -> Self { let mut out = Self { node, output_name: OutputName::None, + output_dtype: OnceLock::new(), }; out.node = node; for (_, ae) in arena.iter(node) { @@ -149,6 +186,7 @@ impl ExprIR { pub(crate) fn set_node(&mut self, node: Node) { self.node = node; + self.output_dtype = OnceLock::new(); } #[cfg(feature = "cse")] @@ -206,6 +244,35 @@ impl ExprIR { pub fn is_scalar(&self, expr_arena: &Arena) -> bool { is_scalar_ae(self.node, expr_arena) } + + pub fn dtype( + &self, + schema: &Schema, + ctxt: Context, + expr_arena: &Arena, + ) -> PolarsResult<&DataType> { + match self.output_dtype.get() { + Some(dtype) => Ok(dtype), + None => { + let dtype = expr_arena + .get(self.node) + .to_dtype(schema, ctxt, expr_arena)?; + let _ = self.output_dtype.set(dtype); + Ok(self.output_dtype.get().unwrap()) + }, + } + } + + pub fn field( + &self, + schema: &Schema, + ctxt: Context, + expr_arena: &Arena, + ) -> PolarsResult { + let dtype = self.dtype(schema, ctxt, expr_arena)?; + let name = self.output_name(); + Ok(Field::new(name.clone(), dtype.clone())) + } } impl AsRef for ExprIR { diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index 0d433e9e74d7..ddd3e1ef42ee 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -330,11 +330,9 @@ pub(crate) fn expr_irs_to_schema, K: AsRef>( expr.into_iter() .map(|e| { let e = e.as_ref(); - let mut field = arena - .get(e.node()) - .to_field(schema, ctxt, arena) - .expect("should be resolved"); + let mut field = e.field(schema, ctxt, arena).expect("should be resolved"); + // TODO! (can this be removed?) if let Some(name) = e.get_alias() { field.name = name.clone() } diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index b3bef0cf43e5..fcd93d975fa9 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -648,10 +648,7 @@ pub fn compute_output_schema( .iter() .map(|e| { let name = e.output_name().clone(); - let dtype = - expr_arena - .get(e.node()) - .to_dtype(input_schema, Context::Default, expr_arena)?; + let dtype = e.dtype(input_schema, Context::Default, expr_arena)?.clone(); PolarsResult::Ok(Field::new(name, dtype)) }) .try_collect()?;