diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs index 57ee5da352b5..5810c28160de 100644 --- a/crates/polars-error/src/lib.rs +++ b/crates/polars-error/src/lib.rs @@ -397,17 +397,16 @@ macro_rules! polars_ensure { pub fn to_compute_err(err: impl Display) -> PolarsError { PolarsError::ComputeError(err.to_string().into()) } - #[macro_export] macro_rules! feature_gated { - ($feature:expr, $content:expr) => {{ - #[cfg(feature = $feature)] + ($($feature:literal);*, $content:expr) => {{ + #[cfg(all($(feature = $feature),*))] { $content } - #[cfg(not(feature = $feature))] + #[cfg(not(all($(feature = $feature),*)))] { - panic!("activate '{}' feature", $feature) + panic!("activate '{}' feature", concat!($($feature, ", "),*)) } }}; } diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index b771a717050d..8878b420af02 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -500,7 +500,7 @@ fn create_physical_expr_inner( Ok(Arc::new(ApplyExpr::new( input, - function.clone(), + function.clone().materialize()?, node_to_expr(expression, expr_arena), *options, state.allow_threading, diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 2b9dff7cb5f2..33eb20e86da6 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -1,7 +1,9 @@ use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; +use bytes::Bytes; use polars_core::chunked_array::cast::CastOptions; +use polars_core::error::feature_gated; use polars_core::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -157,7 +159,7 @@ pub enum Expr { /// function arguments input: Vec, /// function to apply - function: SpecialEq>, + function: OpaqueColumnUdf, /// output dtype of the function output_type: GetOutput, options: FunctionOptions, @@ -172,6 +174,50 @@ pub enum Expr { Selector(super::selector::Selector), } +pub type OpaqueColumnUdf = LazySerde>>; +pub(crate) fn new_column_udf(func: F) -> OpaqueColumnUdf { + LazySerde::Deserialized(SpecialEq::new(Arc::new(func))) +} + +#[derive(Clone)] +pub enum LazySerde { + Deserialized(T), + Bytes(Bytes), +} + +impl PartialEq for LazySerde { + fn eq(&self, other: &Self) -> bool { + use LazySerde as L; + match (self, other) { + (L::Deserialized(a), L::Deserialized(b)) => a == b, + (L::Bytes(a), L::Bytes(b)) => a.as_ptr() == b.as_ptr() && a.len() == b.len(), + _ => false, + } + } +} + +impl Debug for LazySerde { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Bytes(_) => write!(f, "lazy-serde"), + Self::Deserialized(_) => write!(f, "lazy-serde"), + } + } +} + +impl OpaqueColumnUdf { + pub fn materialize(self) -> PolarsResult>> { + match self { + Self::Deserialized(t) => Ok(t), + Self::Bytes(b) => { + feature_gated!("serde";"python", { + python_udf::PythonUdfExpression::try_deserialize(b.as_ref()).map(SpecialEq::new) + }) + }, + } + } +} + #[allow(clippy::derived_hash_with_manual_eq)] impl Hash for Expr { fn hash(&self, state: &mut H) { diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index e134e8b556ef..2c0acfeeb13b 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -3,9 +3,7 @@ use std::ops::Deref; use std::sync::Arc; #[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; -#[cfg(feature = "serde")] -use serde::{Deserializer, Serializer}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::*; @@ -20,14 +18,6 @@ pub trait ColumnsUdf: Send + Sync { fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { polars_bail!(ComputeError: "serialization not supported for this 'opaque' function") } - - // Needed for python functions. After they are deserialized we first check if they - // have a function that generates an output - // This will be slower during optimization, so it is up to us to move - // all expression to the known function architecture. - fn get_output(&self) -> Option { - None - } } #[cfg(feature = "serde")] @@ -46,6 +36,31 @@ impl Serialize for SpecialEq> { } #[cfg(feature = "serde")] +impl Serialize for LazySerde { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Self::Deserialized(t) => t.serialize(serializer), + Self::Bytes(b) => serializer.serialize_bytes(b), + } + } +} + +#[cfg(feature = "serde")] +impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'a>, + { + let buf = Vec::::deserialize(deserializer)?; + Ok(Self::Bytes(bytes::Bytes::from(buf))) + } +} + +#[cfg(feature = "serde")] +// impl Deserialize for crate::dsl::expr::LazySerde { impl<'a> Deserialize<'a> for SpecialEq> { fn deserialize(deserializer: D) -> std::result::Result where diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index b792c98956b3..542212f8de82 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -28,7 +28,7 @@ where let mut exprs = exprs.as_ref().to_vec(); exprs.push(acc); - let function = SpecialEq::new(Arc::new(move |columns: &mut [Column]| { + let function = new_column_udf(move |columns: &mut [Column]| { let mut columns = columns.to_vec(); let mut acc = columns.pop().unwrap(); @@ -38,7 +38,7 @@ where } } Ok(Some(acc)) - }) as Arc); + }); Expr::AnonymousFunction { input: exprs, @@ -67,7 +67,7 @@ where { let exprs = exprs.as_ref().to_vec(); - let function = SpecialEq::new(Arc::new(move |columns: &mut [Column]| { + let function = new_column_udf(move |columns: &mut [Column]| { let mut c_iter = columns.iter(); match c_iter.next() { @@ -83,7 +83,7 @@ where }, None => Err(polars_err!(ComputeError: "`reduce` did not have any expressions to fold")), } - }) as Arc); + }); Expr::AnonymousFunction { input: exprs, @@ -109,7 +109,7 @@ where { let exprs = exprs.as_ref().to_vec(); - let function = SpecialEq::new(Arc::new(move |columns: &mut [Column]| { + let function = new_column_udf(move |columns: &mut [Column]| { let mut c_iter = columns.iter(); match c_iter.next() { @@ -131,7 +131,7 @@ where }, None => Err(polars_err!(ComputeError: "`reduce` did not have any expressions to fold")), } - }) as Arc); + }); Expr::AnonymousFunction { input: exprs, @@ -158,7 +158,7 @@ where let mut exprs = exprs.as_ref().to_vec(); exprs.push(acc); - let function = SpecialEq::new(Arc::new(move |columns: &mut [Column]| { + let function = new_column_udf(move |columns: &mut [Column]| { let mut columns = columns.to_vec(); let mut acc = columns.pop().unwrap(); @@ -177,7 +177,7 @@ where } StructChunked::from_columns(acc.name().clone(), &result).map(|ca| Some(ca.into_column())) - }) as Arc); + }); Expr::AnonymousFunction { input: exprs, diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 477d97b9c299..9dd20bc813f5 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -548,7 +548,7 @@ impl Expr { Expr::AnonymousFunction { input: vec![self], - function: SpecialEq::new(Arc::new(f)), + function: new_column_udf(f), output_type, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, @@ -582,7 +582,7 @@ impl Expr { Expr::AnonymousFunction { input, - function: SpecialEq::new(Arc::new(function)), + function: new_column_udf(function), output_type, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, @@ -607,7 +607,7 @@ impl Expr { Expr::AnonymousFunction { input: vec![self], - function: SpecialEq::new(Arc::new(f)), + function: new_column_udf(f), output_type, options: FunctionOptions { collect_groups: ApplyOptions::ApplyList, @@ -631,7 +631,7 @@ impl Expr { Expr::AnonymousFunction { input: vec![self], - function: SpecialEq::new(Arc::new(f)), + function: new_column_udf(f), output_type, options, } @@ -654,7 +654,7 @@ impl Expr { Expr::AnonymousFunction { input: vec![self], - function: SpecialEq::new(Arc::new(f)), + function: new_column_udf(f), output_type, options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, @@ -687,7 +687,7 @@ impl Expr { Expr::AnonymousFunction { input, - function: SpecialEq::new(Arc::new(function)), + function: new_column_udf(function), output_type, options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, @@ -1983,7 +1983,7 @@ where Expr::AnonymousFunction { input, - function: SpecialEq::new(Arc::new(function)), + function: new_column_udf(function), output_type, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, @@ -2009,7 +2009,7 @@ where Expr::AnonymousFunction { input, - function: SpecialEq::new(Arc::new(function)), + function: new_column_udf(function), output_type, options: FunctionOptions { collect_groups: ApplyOptions::ApplyList, @@ -2047,7 +2047,7 @@ where Expr::AnonymousFunction { input, - function: SpecialEq::new(Arc::new(function)), + function: new_column_udf(function), output_type, options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index 0f9ac4a3dc9a..a813dbf64e87 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -214,20 +214,6 @@ impl ColumnsUdf for PythonUdfExpression { Ok(()) }) } - - fn get_output(&self) -> Option { - let output_type = self.output_type.clone(); - Some(GetOutput::map_field(move |fld| { - Ok(match output_type { - Some(ref dt) => Field::new(fld.name().clone(), dt.clone()), - None => { - let mut fld = fld.clone(); - fld.coerce(DataType::Unknown(Default::default())); - fld - }, - }) - })) - } } /// Serializable version of [`GetOutput`] for Python UDFs. @@ -301,7 +287,7 @@ impl Expr { Expr::AnonymousFunction { input: vec![self], - function: SpecialEq::new(Arc::new(func)), + function: new_column_udf(func), output_type, options: FunctionOptions { collect_groups, diff --git a/crates/polars-plan/src/dsl/udf.rs b/crates/polars-plan/src/dsl/udf.rs index b09ef6f556a2..cc889db7975a 100644 --- a/crates/polars-plan/src/dsl/udf.rs +++ b/crates/polars-plan/src/dsl/udf.rs @@ -1,12 +1,10 @@ -use std::sync::Arc; - use arrow::legacy::error::{polars_bail, PolarsResult}; use polars_core::prelude::Field; use polars_core::schema::Schema; use polars_utils::pl_str::PlSmallStr; -use super::{ColumnsUdf, Expr, GetOutput, SpecialEq}; -use crate::prelude::{Context, FunctionOptions}; +use super::{ColumnsUdf, Expr, GetOutput, OpaqueColumnUdf}; +use crate::prelude::{new_column_udf, Context, FunctionOptions}; /// Represents a user-defined function #[derive(Clone)] @@ -18,7 +16,7 @@ pub struct UserDefinedFunction { /// The function output type. pub return_type: GetOutput, /// The function implementation. - pub fun: SpecialEq>, + pub fun: OpaqueColumnUdf, /// Options for the function. pub options: FunctionOptions, } @@ -46,7 +44,7 @@ impl UserDefinedFunction { name, input_fields, return_type, - fun: SpecialEq::new(Arc::new(fun)), + fun: new_column_udf(fun), options: FunctionOptions::default(), } } diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 53bf24ff838e..565710c0dbaf 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -182,7 +182,7 @@ pub enum AExpr { }, AnonymousFunction { input: Vec, - function: SpecialEq>, + function: OpaqueColumnUdf, output_type: GetOutput, options: FunctionOptions, }, diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index af37357502ee..88c44233175a 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -260,14 +260,11 @@ impl AExpr { AnonymousFunction { output_type, input, - function, options, .. } => { *nested = nested .saturating_sub(options.flags.contains(FunctionFlags::RETURNS_SCALAR) as _); - let tmp = function.get_output(); - let output_type = tmp.as_ref().unwrap_or(output_type); let fields = func_args_to_fields(input, schema, arena, nested)?; polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str); output_type.get_field(schema, Context::Default, &fields)