diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index c19fe4a38ae9..6b0258f7ac77 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -16,27 +16,34 @@ impl Series { #[doc(hidden)] pub fn agg_valid_count(&self, groups: &GroupsProxy) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 && self.null_count() > 0 { + self.rechunk() + } else { + self.clone() + }; + match groups { GroupsProxy::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { - debug_assert!(idx.len() <= self.len()); + debug_assert!(idx.len() <= s.len()); if idx.is_empty() { None - } else if !self.has_validity() { + } else if s.null_count() == 0 { Some(idx.len() as IdxSize) } else { - let take = unsafe { self.take_slice_unchecked(idx) }; + let take = unsafe { s.take_slice_unchecked(idx) }; Some((take.len() - take.null_count()) as IdxSize) } }), GroupsProxy::Slice { groups, .. } => { _agg_helper_slice::(groups, |[first, len]| { - debug_assert!(len <= self.len() as IdxSize); + debug_assert!(len <= s.len() as IdxSize); if len == 0 { None - } else if !self.has_validity() { + } else if s.null_count() == 0 { Some(len) } else { - let take = self.slice_from_offsets(first, len); + let take = s.slice_from_offsets(first, len); Some((take.len() - take.null_count()) as IdxSize) } }) @@ -46,6 +53,13 @@ impl Series { #[doc(hidden)] pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + let mut out = match groups { GroupsProxy::Idx(groups) => { let indices = groups @@ -61,7 +75,7 @@ impl Series { ) .collect_ca(""); // SAFETY: groups are always in bounds. - self.take_unchecked(&indices) + s.take_unchecked(&indices) }, GroupsProxy::Slice { groups, .. } => { let indices = groups @@ -69,36 +83,43 @@ impl Series { .map(|&[first, len]| if len == 0 { None } else { Some(first) }) .collect_ca(""); // SAFETY: groups are always in bounds. - self.take_unchecked(&indices) + s.take_unchecked(&indices) }, }; if groups.is_sorted_flag() { - out.set_sorted_flag(self.is_sorted_flag()) + out.set_sorted_flag(s.is_sorted_flag()) } - self.restore_logical(out) + s.restore_logical(out) } #[doc(hidden)] pub unsafe fn agg_n_unique(&self, groups: &GroupsProxy) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + match groups { GroupsProxy::Idx(groups) => { agg_helper_idx_on_all_no_null::(groups, |idx| { - debug_assert!(idx.len() <= self.len()); + debug_assert!(idx.len() <= s.len()); if idx.is_empty() { 0 } else { - let take = self.take_slice_unchecked(idx); + let take = s.take_slice_unchecked(idx); take.n_unique().unwrap() as IdxSize } }) }, GroupsProxy::Slice { groups, .. } => { _agg_helper_slice_no_null::(groups, |[first, len]| { - debug_assert!(len <= self.len() as IdxSize); + debug_assert!(len <= s.len() as IdxSize); if len == 0 { 0 } else { - let take = self.slice_from_offsets(first, len); + let take = s.slice_from_offsets(first, len); take.n_unique().unwrap() as IdxSize } }) @@ -108,15 +129,21 @@ impl Series { #[doc(hidden)] pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { - use DataType::*; + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; - match self.dtype() { - Boolean => self.cast(&Float64).unwrap().agg_median(groups), - Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_median(groups), - Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_median(groups), - dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_median, groups), + use DataType::*; + match s.dtype() { + Boolean => s.cast(&Float64).unwrap().agg_median(groups), + Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_median(groups), + Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups), + dt if dt.is_numeric() => apply_method_physical_integer!(s, agg_median, groups), #[cfg(feature = "dtype-datetime")] - dt @ (Datetime(_, _) | Duration(_) | Time) => self + dt @ (Datetime(_, _) | Duration(_) | Time) => s .to_physical_repr() .agg_median(groups) .cast(&Int64) @@ -124,14 +151,14 @@ impl Series { .cast(dt) .unwrap(), dt @ Date => { - let ca = self.to_physical_repr(); + let ca = s.to_physical_repr(); let physical_type = ca.dtype(); let s = apply_method_physical_integer!(ca, agg_median, groups); // back to physical and then // back to logical type s.cast(physical_type).unwrap().cast(dt).unwrap() }, - _ => Series::full_null("", groups.len(), self.dtype()), + _ => Series::full_null("", groups.len(), s.dtype()), } } @@ -142,13 +169,19 @@ impl Series { quantile: f64, interpol: QuantileInterpolOptions, ) -> Series { - use DataType::*; + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; - match self.dtype() { - Float32 => self.f32().unwrap().agg_quantile(groups, quantile, interpol), - Float64 => self.f64().unwrap().agg_quantile(groups, quantile, interpol), + use DataType::*; + match s.dtype() { + Float32 => s.f32().unwrap().agg_quantile(groups, quantile, interpol), + Float64 => s.f64().unwrap().agg_quantile(groups, quantile, interpol), dt if dt.is_numeric() || dt.is_temporal() => { - let ca = self.to_physical_repr(); + let ca = s.to_physical_repr(); let physical_type = ca.dtype(); let s = apply_method_physical_integer!(ca, agg_quantile, groups, quantile, interpol); @@ -160,21 +193,27 @@ impl Series { s } }, - _ => Series::full_null("", groups.len(), self.dtype()), + _ => Series::full_null("", groups.len(), s.dtype()), } } #[doc(hidden)] pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series { - use DataType::*; + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; - match self.dtype() { - Boolean => self.cast(&Float64).unwrap().agg_mean(groups), - Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_mean(groups), - Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups), - dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_mean, groups), + use DataType::*; + match s.dtype() { + Boolean => s.cast(&Float64).unwrap().agg_mean(groups), + Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups), + Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups), + dt if dt.is_numeric() => apply_method_physical_integer!(s, agg_mean, groups), #[cfg(feature = "dtype-datetime")] - dt @ (Datetime(_, _) | Duration(_) | Time) => self + dt @ (Datetime(_, _) | Duration(_) | Time) => s .to_physical_repr() .agg_mean(groups) .cast(&Int64) @@ -182,19 +221,26 @@ impl Series { .cast(dt) .unwrap(), dt @ Date => { - let ca = self.to_physical_repr(); + let ca = s.to_physical_repr(); let physical_type = ca.dtype(); let s = apply_method_physical_integer!(ca, agg_mean, groups); // back to physical and then // back to logical type s.cast(physical_type).unwrap().cast(dt).unwrap() }, - _ => Series::full_null("", groups.len(), self.dtype()), + _ => Series::full_null("", groups.len(), s.dtype()), } } #[doc(hidden)] pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + let out = match groups { GroupsProxy::Idx(groups) => { let indices = groups @@ -208,7 +254,7 @@ impl Series { } }) .collect_ca(""); - self.take_unchecked(&indices) + s.take_unchecked(&indices) }, GroupsProxy::Slice { groups, .. } => { let indices = groups @@ -221,9 +267,9 @@ impl Series { } }) .collect_ca(""); - self.take_unchecked(&indices) + s.take_unchecked(&indices) }, }; - self.restore_logical(out) + s.restore_logical(out) } }