From 2e0438bd8ac675bfbaaf0b027ec3002e461c4776 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Wed, 9 Oct 2024 12:32:36 +0200 Subject: [PATCH] refactor: Make `pl.repeat` part of the IR (#19152) --- crates/polars-core/src/frame/column/mod.rs | 17 +++++++----- .../polars-plan/src/dsl/function_expr/mod.rs | 5 ++++ .../src/dsl/function_expr/repeat.rs | 18 +++++++++++++ .../src/dsl/function_expr/schema.rs | 1 + .../polars-plan/src/dsl/functions/repeat.rs | 27 ++++++++++--------- crates/polars-python/src/lazyframe/visit.rs | 2 +- .../src/lazyframe/visitor/expr_nodes.rs | 1 + 7 files changed, 51 insertions(+), 20 deletions(-) create mode 100644 crates/polars-plan/src/dsl/function_expr/repeat.rs diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 74b6f0e7da0a..2eb2bbc14f4f 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -372,15 +372,18 @@ impl Column { #[inline] pub fn new_from_index(&self, index: usize, length: usize) -> Self { + if index >= self.len() { + return Self::full_null(self.name().clone(), length, self.dtype()); + } + match self { - Column::Series(s) => s.new_from_index(index, length).into(), - Column::Scalar(s) => { - if index >= s.len() { - Self::full_null(s.name().clone(), length, s.dtype()) - } else { - s.resize(length).into() - } + Column::Series(s) => { + // SAFETY: Bounds check done before. + let av = unsafe { s.get_unchecked(index) }; + let scalar = Scalar::new(self.dtype().clone(), av.into_static()); + Self::new_scalar(self.name().clone(), scalar, length) }, + Column::Scalar(s) => s.resize(length).into(), } } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 0458b2b4a1d0..6cebaa301b85 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -47,6 +47,7 @@ pub mod pow; mod random; #[cfg(feature = "range")] mod range; +mod repeat; #[cfg(feature = "rolling_window")] pub mod rolling; #[cfg(feature = "rolling_window_by")] @@ -189,6 +190,7 @@ pub enum FunctionExpr { options: RankOptions, seed: Option, }, + Repeat, #[cfg(feature = "round_series")] Clip { has_min: bool, @@ -452,6 +454,7 @@ impl Hash for FunctionExpr { a.hash(state); b.hash(state); }, + Repeat => {}, #[cfg(feature = "rank")] Rank { options, seed } => { options.hash(state); @@ -651,6 +654,7 @@ impl Display for FunctionExpr { #[cfg(feature = "moment")] Kurtosis(..) => "kurtosis", ArgUnique => "arg_unique", + Repeat => "repeat", #[cfg(feature = "rank")] Rank { .. } => "rank", #[cfg(feature = "round_series")] @@ -996,6 +1000,7 @@ impl From for SpecialEq> { #[cfg(feature = "moment")] Kurtosis(fisher, bias) => map!(dispatch::kurtosis, fisher, bias), ArgUnique => map!(dispatch::arg_unique), + Repeat => map_as_slice!(repeat::repeat), #[cfg(feature = "rank")] Rank { options, seed } => map!(dispatch::rank, options, seed), #[cfg(feature = "dtype-struct")] diff --git a/crates/polars-plan/src/dsl/function_expr/repeat.rs b/crates/polars-plan/src/dsl/function_expr/repeat.rs new file mode 100644 index 000000000000..cebc2ce792c4 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/repeat.rs @@ -0,0 +1,18 @@ +use polars_core::prelude::{polars_ensure, polars_err, Column, PolarsResult}; + +pub fn repeat(args: &[Column]) -> PolarsResult { + let c = &args[0]; + let n = &args[1]; + + polars_ensure!( + n.dtype().is_integer(), + SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype() + ); + + let first_value = n.get(0)?; + let n = first_value.extract::().ok_or_else( + || polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value), + )?; + + Ok(c.new_from_index(0, n)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 7cc5b8c5c7ad..bc09cca94215 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -90,6 +90,7 @@ impl FunctionExpr { #[cfg(feature = "moment")] Kurtosis(..) => mapper.with_dtype(DataType::Float64), ArgUnique => mapper.with_dtype(IDX_DTYPE), + Repeat => mapper.with_same_dtype(), #[cfg(feature = "rank")] Rank { options, .. } => mapper.with_dtype(match options.method { RankMethod::Average => DataType::Float64, diff --git a/crates/polars-plan/src/dsl/functions/repeat.rs b/crates/polars-plan/src/dsl/functions/repeat.rs index 21d27a542e99..ea80e2598186 100644 --- a/crates/polars-plan/src/dsl/functions/repeat.rs +++ b/crates/polars-plan/src/dsl/functions/repeat.rs @@ -5,17 +5,20 @@ use super::*; /// Generally you won't need this function, as `lit(value)` already represents a column containing /// only `value` whose length is automatically set to the correct number of rows. pub fn repeat>(value: E, n: Expr) -> Expr { - let function = |s: Column, n: Column| { - polars_ensure!( - n.dtype().is_integer(), - SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype() - ); - let first_value = n.get(0)?; - let n = first_value.extract::().ok_or_else( - || polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value), - )?; - Ok(Some(s.new_from_index(0, n))) + let input = vec![value.into(), n]; + + let expr = Expr::Function { + input, + function: FunctionExpr::Repeat, + options: FunctionOptions { + flags: FunctionFlags::default() + | FunctionFlags::ALLOW_RENAME + | FunctionFlags::CHANGES_LENGTH, + ..Default::default() + }, }; - apply_binary(value.into(), n, function, GetOutput::same_type()) - .alias(PlSmallStr::from_static("repeat")) + + // @NOTE: This alias should probably not be here for consistency, but it is here for backwards + // compatibility until 2.0. + expr.alias(PlSmallStr::from_static("repeat")) } diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index 32c8d3d23b7d..19f0f06a2c8c 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (2, 2); + const VERSION: Version = (2, 3); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index e8832e9b5488..cb1e817e0feb 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1202,6 +1202,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { #[cfg(feature = "repeat_by")] FunctionExpr::RepeatBy => ("repeat_by",).to_object(py), FunctionExpr::ArgUnique => ("arg_unique",).to_object(py), + FunctionExpr::Repeat => ("repeat",).to_object(py), FunctionExpr::Rank { options: _, seed: _,