diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index 1628780d7b0e..4a6d717e3ef5 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -53,16 +53,94 @@ fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { } } +/// Arithmetic operations that can be applied to a Series +#[derive(Clone, Copy)] +enum Op { + Add, + Subtract, + Multiply, + Divide, + Remainder, +} + +impl Op { + /// Apply the operation to a pair of Series. + fn apply_with_series(&self, lhs: &Series, rhs: &Series) -> PolarsResult { + use Op::*; + + match self { + Add => lhs + rhs, + Subtract => lhs - rhs, + Multiply => lhs * rhs, + Divide => lhs / rhs, + Remainder => lhs % rhs, + } + } + + /// Apply the operation to a Series and scalar. + fn apply_with_scalar(&self, lhs: &Series, rhs: T) -> Series { + use Op::*; + + match self { + Add => lhs + rhs, + Subtract => lhs - rhs, + Multiply => lhs * rhs, + Divide => lhs / rhs, + Remainder => lhs % rhs, + } + } +} + impl ListChunked { + /// Helper function for NumOpsDispatchInner implementation for ListChunked. + /// + /// Run the given `op` on `self` and `rhs`, for cases where `rhs` has a + /// primitive numeric dtype. + fn arithm_helper_numeric(&self, rhs: &Series, op: Op) -> PolarsResult { + let mut result = AnonymousListBuilder::new( + self.name().clone(), + self.len(), + Some(self.inner_dtype().clone()), + ); + macro_rules! combine { + ($ca:expr) => {{ + self.amortized_iter() + .zip($ca.iter()) + .map(|(a, b)| { + let (Some(a_owner), Some(b)) = (a, b) else { + // Operations with nulls always result in nulls: + return Ok(None); + }; + let a = a_owner.as_ref().rechunk(); + let leaf_result = op.apply_with_scalar(&a.get_leaf_array(), b); + let result = + reshape_list_based_on(&leaf_result.chunks()[0], &a.chunks()[0]); + Ok(Some(result)) + }) + .collect::>>>>()? + }}; + } + let combined = downcast_as_macro_arg_physical!(rhs, combine); + for arr in combined.iter() { + if let Some(arr) = arr { + result.append_array(arr.as_ref()); + } else { + result.append_null(); + } + } + Ok(result.finish().into()) + } + /// Helper function for NumOpsDispatchInner implementation for ListChunked. /// /// Run the given `op` on `self` and `rhs`. - fn arithm_helper( - &self, - rhs: &Series, - op: &dyn Fn(&Series, &Series) -> PolarsResult, - has_nulls: Option, - ) -> PolarsResult { + fn arithm_helper(&self, rhs: &Series, op: Op, has_nulls: Option) -> PolarsResult { + polars_ensure!( + self.dtype().leaf_dtype().is_numeric() && rhs.dtype().leaf_dtype().is_numeric(), + InvalidOperation: "List Series can only do arithmetic operations if they and other Series are numeric, left and right dtypes are {:?} and {:?}", + self.dtype(), + rhs.dtype() + ); polars_ensure!( self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", @@ -70,6 +148,17 @@ impl ListChunked { rhs.len() ); + if rhs.dtype().is_numeric() { + return self.arithm_helper_numeric(rhs, op); + } + + polars_ensure!( + self.dtype() == rhs.dtype(), + InvalidOperation: "List Series doing arithmetic operations to each other should have same dtype; got {:?} and {:?}", + self.dtype(), + rhs.dtype() + ); + let mut has_nulls = has_nulls.unwrap_or(false); if !has_nulls { for chunk in self.chunks().iter() { @@ -118,7 +207,7 @@ impl ListChunked { // along. a_listchunked.arithm_helper(b, op, Some(true)) } else { - op(a, b) + op.apply_with_series(a, b) }; chunk_result.map(Some) }).collect::>>>()?; @@ -139,8 +228,7 @@ impl ListChunked { InvalidOperation: "can only do arithmetic operations on lists of the same size" ); - let result = op(&l_leaf_array, &r_leaf_array)?; - + let result = op.apply_with_series(&l_leaf_array, &r_leaf_array)?; // We now need to wrap the Arrow arrays with the metadata that turns // them into lists: // TODO is there a way to do this without cloning the underlying data? @@ -160,18 +248,18 @@ impl ListChunked { impl NumOpsDispatchInner for ListType { fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None) + lhs.arithm_helper(rhs, Op::Add, None) } fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None) + lhs.arithm_helper(rhs, Op::Subtract, None) } fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None) + lhs.arithm_helper(rhs, Op::Multiply, None) } fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.divide(r), None) + lhs.arithm_helper(rhs, Op::Divide, None) } fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None) + lhs.arithm_helper(rhs, Op::Remainder, None) } }