Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Cache dtype on ExprIR #20331

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/polars-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static ERROR_STRATEGY: LazyLock<ErrorStrategy> = LazyLock::new(|| {
}
});

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ErrString(Cow<'static, str>);

impl ErrString {
Expand Down Expand Up @@ -74,7 +74,7 @@ impl Display for ErrString {
}
}

#[derive(Debug, thiserror::Error)]
#[derive(Debug, Clone, thiserror::Error)]
pub enum PolarsError {
#[error("not found: {0}")]
ColumnNotFound(ErrString),
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-plan/src/plans/builder_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ impl<'a> IRBuilder<'a> {
.to_field(&schema, Context::Default, self.expr_arena)
.unwrap();

expr_irs.push(ExprIR::new(node, OutputName::ColumnLhs(field.name.clone())));
expr_irs.push(
ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))
.with_dtype(field.dtype.clone()),
);
new_schema.with_column(field.name().clone(), field.dtype().clone());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,10 +528,7 @@ fn raise_supertype(
) -> PolarsResult<()> {
let dtypes = inputs
.iter()
.map(|e| {
let ae = expr_arena.get(e.node());
ae.to_dtype(input_schema, Context::Default, expr_arena)
})
.map(|e| e.dtype(input_schema, Context::Default, expr_arena).cloned())
.collect::<PolarsResult<Vec<_>>>()?;

let st = dtypes
Expand Down
71 changes: 69 additions & 2 deletions crates/polars-plan/src/plans/expr_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::Borrow;
use std::hash::Hash;
#[cfg(feature = "cse")]
use std::hash::Hasher;
use std::sync::OnceLock;

use polars_utils::format_pl_smallstr;
#[cfg(feature = "ir_serde")]
Expand Down Expand Up @@ -48,14 +49,40 @@ impl OutputName {
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Debug)]
#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
pub struct ExprIR {
/// Output name of this expression.
output_name: OutputName,
/// Output dtype of this expression
/// Reduced expression.
/// This expression is pruned from `alias` and already expanded.
node: Node,
#[cfg_attr(feature = "ir_serde", serde(skip))]
output_dtype: OnceLock<DataType>,
}

impl Eq for ExprIR {}

impl PartialEq for ExprIR {
fn eq(&self, other: &Self) -> bool {
self.node == other.node && self.output_name == other.output_name
}
}

impl Clone for ExprIR {
fn clone(&self) -> Self {
let output_dtype = OnceLock::new();
if let Some(dt) = self.output_dtype.get() {
output_dtype.set(dt.clone()).unwrap()
}

ExprIR {
output_name: self.output_name.clone(),
node: self.node,
output_dtype,
}
}
}

impl Borrow<Node> for ExprIR {
Expand All @@ -67,13 +94,23 @@ impl Borrow<Node> for ExprIR {
impl ExprIR {
pub fn new(node: Node, output_name: OutputName) -> Self {
debug_assert!(!output_name.is_none());
ExprIR { output_name, node }
ExprIR {
output_name,
node,
output_dtype: OnceLock::new(),
}
}

pub fn with_dtype(self, dtype: DataType) -> Self {
let _ = self.output_dtype.set(dtype);
self
}

pub fn from_node(node: Node, arena: &Arena<AExpr>) -> Self {
let mut out = Self {
node,
output_name: OutputName::None,
output_dtype: OnceLock::new(),
};
out.node = node;
for (_, ae) in arena.iter(node) {
Expand Down Expand Up @@ -149,6 +186,7 @@ impl ExprIR {

pub(crate) fn set_node(&mut self, node: Node) {
self.node = node;
self.output_dtype = OnceLock::new();
}

#[cfg(feature = "cse")]
Expand Down Expand Up @@ -206,6 +244,35 @@ impl ExprIR {
pub fn is_scalar(&self, expr_arena: &Arena<AExpr>) -> bool {
is_scalar_ae(self.node, expr_arena)
}

pub fn dtype(
&self,
schema: &Schema,
ctxt: Context,
expr_arena: &Arena<AExpr>,
) -> PolarsResult<&DataType> {
match self.output_dtype.get() {
Some(dtype) => Ok(dtype),
None => {
let dtype = expr_arena
.get(self.node)
.to_dtype(schema, ctxt, expr_arena)?;
let _ = self.output_dtype.set(dtype);
Ok(self.output_dtype.get().unwrap())
},
}
}

pub fn field(
&self,
schema: &Schema,
ctxt: Context,
expr_arena: &Arena<AExpr>,
) -> PolarsResult<Field> {
let dtype = self.dtype(schema, ctxt, expr_arena)?;
let name = self.output_name();
Ok(Field::new(name.clone(), dtype.clone()))
}
}

impl AsRef<ExprIR> for ExprIR {
Expand Down
6 changes: 2 additions & 4 deletions crates/polars-plan/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,9 @@ pub(crate) fn expr_irs_to_schema<I: IntoIterator<Item = K>, K: AsRef<ExprIR>>(
expr.into_iter()
.map(|e| {
let e = e.as_ref();
let mut field = arena
.get(e.node())
.to_field(schema, ctxt, arena)
.expect("should be resolved");
let mut field = e.field(schema, ctxt, arena).expect("should be resolved");

// TODO! (can this be removed?)
if let Some(name) = e.get_alias() {
field.name = name.clone()
}
Expand Down
5 changes: 1 addition & 4 deletions crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,10 +648,7 @@ pub fn compute_output_schema(
.iter()
.map(|e| {
let name = e.output_name().clone();
let dtype =
expr_arena
.get(e.node())
.to_dtype(input_schema, Context::Default, expr_arena)?;
let dtype = e.dtype(input_schema, Context::Default, expr_arena)?.clone();
PolarsResult::Ok(Field::new(name, dtype))
})
.try_collect()?;
Expand Down
Loading