Skip to content

Commit

Permalink
Fix some things for pandas 3 (#1110)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Jul 24, 2024
1 parent 9cfed0f commit 184a22c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
9 changes: 8 additions & 1 deletion dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
)
from dask_expr._str_accessor import StringAccessor
from dask_expr._util import (
PANDAS_GE_300,
_BackendData,
_convert_to_list,
_get_shuffle_preferring_order,
Expand Down Expand Up @@ -400,7 +401,7 @@ def __getitem__(self, other):
if (
self.ndim == 2
and is_integer_slice
and not is_float_dtype(self.index.dtype)
and (not is_float_dtype(self.index.dtype) or PANDAS_GE_300)
):
return self.iloc[other]
else:
Expand Down Expand Up @@ -1577,6 +1578,10 @@ def std(
if needs_time_conversion:
numeric_dd = _convert_to_numeric(self, skipna)

units = None
if needs_time_conversion and time_cols is not None:
units = [getattr(self._meta[c].array, "unit", None) for c in time_cols]

if axis == 1:
return numeric_dd.map_partitions(
M.std if not needs_time_conversion else _sqrt_and_convert_to_timedelta,
Expand All @@ -1598,6 +1603,8 @@ def std(
"time_cols": time_cols,
"axis": axis,
"dtype": getattr(meta, "dtype", None),
"unit": getattr(meta, "unit", None),
"units": units,
}
sqrt_func = _sqrt_and_convert_to_timedelta
else:
Expand Down
2 changes: 2 additions & 0 deletions dask_expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,8 @@ def npartitions(self):


def groupby_get_group(df, *by_key, get_key=None, columns=None):
if PANDAS_GE_300 and is_scalar(get_key):
get_key = (get_key,)
return _groupby_get_group(df, list(by_key), get_key, columns)


Expand Down
8 changes: 5 additions & 3 deletions dask_expr/_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def _loc(self, iindexer, cindexer):
return self._loc(iindexer(self.obj), cindexer)

if self.obj.known_divisions:
iindexer = self._maybe_partial_time_string(iindexer)
idx = self.obj.index._meta
unit = idx.unit if hasattr(idx, "unit") else None
iindexer = self._maybe_partial_time_string(iindexer, unit=unit)

if isinstance(iindexer, slice):
return self._loc_slice(iindexer, cindexer)
Expand Down Expand Up @@ -133,13 +135,13 @@ def _loc_array(self, iindexer, cindexer):
)
return self._loc_series(iindexer_series, cindexer, check_alignment=False)

def _maybe_partial_time_string(self, iindexer):
def _maybe_partial_time_string(self, iindexer, unit):
"""
Convert index-indexer for partial time string slicing
if obj.index is DatetimeIndex / PeriodIndex
"""
idx = meta_nonempty(self.obj._meta.index)
iindexer = _maybe_partial_time_string(idx, iindexer)
iindexer = _maybe_partial_time_string(idx, iindexer, unit)
return iindexer

def _loc_slice(self, iindexer, cindexer):
Expand Down

0 comments on commit 184a22c

Please sign in to comment.