Skip to content

Commit

Permalink
perf: only rechunk once per aggregate (pola-rs#16469)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored May 24, 2024
1 parent d4c3aba commit 6de7422
Showing 1 changed file with 87 additions and 41 deletions.
128 changes: 87 additions & 41 deletions crates/polars-core/src/frame/group_by/aggregations/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,34 @@ impl Series {

#[doc(hidden)]
pub fn agg_valid_count(&self, groups: &GroupsProxy) -> Series {
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 && self.null_count() > 0 {
self.rechunk()
} else {
self.clone()
};

match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
debug_assert!(idx.len() <= self.len());
debug_assert!(idx.len() <= s.len());
if idx.is_empty() {
None
} else if !self.has_validity() {
} else if s.null_count() == 0 {
Some(idx.len() as IdxSize)
} else {
let take = unsafe { self.take_slice_unchecked(idx) };
let take = unsafe { s.take_slice_unchecked(idx) };
Some((take.len() - take.null_count()) as IdxSize)
}
}),
GroupsProxy::Slice { groups, .. } => {
_agg_helper_slice::<IdxType, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
debug_assert!(len <= s.len() as IdxSize);
if len == 0 {
None
} else if !self.has_validity() {
} else if s.null_count() == 0 {
Some(len)
} else {
let take = self.slice_from_offsets(first, len);
let take = s.slice_from_offsets(first, len);
Some((take.len() - take.null_count()) as IdxSize)
}
})
Expand All @@ -46,6 +53,13 @@ impl Series {

#[doc(hidden)]
pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Series {
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

let mut out = match groups {
GroupsProxy::Idx(groups) => {
let indices = groups
Expand All @@ -61,44 +75,51 @@ impl Series {
)
.collect_ca("");
// SAFETY: groups are always in bounds.
self.take_unchecked(&indices)
s.take_unchecked(&indices)
},
GroupsProxy::Slice { groups, .. } => {
let indices = groups
.iter()
.map(|&[first, len]| if len == 0 { None } else { Some(first) })
.collect_ca("");
// SAFETY: groups are always in bounds.
self.take_unchecked(&indices)
s.take_unchecked(&indices)
},
};
if groups.is_sorted_flag() {
out.set_sorted_flag(self.is_sorted_flag())
out.set_sorted_flag(s.is_sorted_flag())
}
self.restore_logical(out)
s.restore_logical(out)
}

#[doc(hidden)]
pub unsafe fn agg_n_unique(&self, groups: &GroupsProxy) -> Series {
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

match groups {
GroupsProxy::Idx(groups) => {
agg_helper_idx_on_all_no_null::<IdxType, _>(groups, |idx| {
debug_assert!(idx.len() <= self.len());
debug_assert!(idx.len() <= s.len());
if idx.is_empty() {
0
} else {
let take = self.take_slice_unchecked(idx);
let take = s.take_slice_unchecked(idx);
take.n_unique().unwrap() as IdxSize
}
})
},
GroupsProxy::Slice { groups, .. } => {
_agg_helper_slice_no_null::<IdxType, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
debug_assert!(len <= s.len() as IdxSize);
if len == 0 {
0
} else {
let take = self.slice_from_offsets(first, len);
let take = s.slice_from_offsets(first, len);
take.n_unique().unwrap() as IdxSize
}
})
Expand All @@ -108,30 +129,36 @@ impl Series {

#[doc(hidden)]
pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series {
use DataType::*;
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

match self.dtype() {
Boolean => self.cast(&Float64).unwrap().agg_median(groups),
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_median(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_median(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_median, groups),
use DataType::*;
match s.dtype() {
Boolean => s.cast(&Float64).unwrap().agg_median(groups),
Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_median(groups),
Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(s, agg_median, groups),
#[cfg(feature = "dtype-datetime")]
dt @ (Datetime(_, _) | Duration(_) | Time) => self
dt @ (Datetime(_, _) | Duration(_) | Time) => s
.to_physical_repr()
.agg_median(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
dt @ Date => {
let ca = self.to_physical_repr();
let ca = s.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_median, groups);
// back to physical and then
// back to logical type
s.cast(physical_type).unwrap().cast(dt).unwrap()
},
_ => Series::full_null("", groups.len(), self.dtype()),
_ => Series::full_null("", groups.len(), s.dtype()),
}
}

Expand All @@ -142,13 +169,19 @@ impl Series {
quantile: f64,
interpol: QuantileInterpolOptions,
) -> Series {
use DataType::*;
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

match self.dtype() {
Float32 => self.f32().unwrap().agg_quantile(groups, quantile, interpol),
Float64 => self.f64().unwrap().agg_quantile(groups, quantile, interpol),
use DataType::*;
match s.dtype() {
Float32 => s.f32().unwrap().agg_quantile(groups, quantile, interpol),
Float64 => s.f64().unwrap().agg_quantile(groups, quantile, interpol),
dt if dt.is_numeric() || dt.is_temporal() => {
let ca = self.to_physical_repr();
let ca = s.to_physical_repr();
let physical_type = ca.dtype();
let s =
apply_method_physical_integer!(ca, agg_quantile, groups, quantile, interpol);
Expand All @@ -160,41 +193,54 @@ impl Series {
s
}
},
_ => Series::full_null("", groups.len(), self.dtype()),
_ => Series::full_null("", groups.len(), s.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series {
use DataType::*;
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

match self.dtype() {
Boolean => self.cast(&Float64).unwrap().agg_mean(groups),
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_mean, groups),
use DataType::*;
match s.dtype() {
Boolean => s.cast(&Float64).unwrap().agg_mean(groups),
Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(s, agg_mean, groups),
#[cfg(feature = "dtype-datetime")]
dt @ (Datetime(_, _) | Duration(_) | Time) => self
dt @ (Datetime(_, _) | Duration(_) | Time) => s
.to_physical_repr()
.agg_mean(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
dt @ Date => {
let ca = self.to_physical_repr();
let ca = s.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_mean, groups);
// back to physical and then
// back to logical type
s.cast(physical_type).unwrap().cast(dt).unwrap()
},
_ => Series::full_null("", groups.len(), self.dtype()),
_ => Series::full_null("", groups.len(), s.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Series {
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

let out = match groups {
GroupsProxy::Idx(groups) => {
let indices = groups
Expand All @@ -208,7 +254,7 @@ impl Series {
}
})
.collect_ca("");
self.take_unchecked(&indices)
s.take_unchecked(&indices)
},
GroupsProxy::Slice { groups, .. } => {
let indices = groups
Expand All @@ -221,9 +267,9 @@ impl Series {
}
})
.collect_ca("");
self.take_unchecked(&indices)
s.take_unchecked(&indices)
},
};
self.restore_logical(out)
s.restore_logical(out)
}
}

0 comments on commit 6de7422

Please sign in to comment.