Skip to content

Commit

Permalink
feat: add replace and replace_strict (#1327)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Edoardo Abati <[email protected]>
  • Loading branch information
MarcoGorelli and EdAbati authored Nov 8, 2024
1 parent 8efc161 commit 50c0a0f
Show file tree
Hide file tree
Showing 17 changed files with 377 additions and 40 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
- over
- pipe
- quantile
- replace_strict
- round
- sample
- shift
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
- pipe
- quantile
- rename
- replace_strict
- round
- sample
- scatter
Expand Down
8 changes: 8 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Callable
from typing import Literal
from typing import Sequence

from narwhals._expression_parsing import reuse_series_implementation
from narwhals._expression_parsing import reuse_series_namespace_implementation
Expand Down Expand Up @@ -320,6 +321,13 @@ def is_last_distinct(self: Self) -> Self:
def unique(self: Self, *, maintain_order: bool = False) -> Self:
return reuse_series_implementation(self, "unique", maintain_order=maintain_order)

def replace_strict(
self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
) -> Self:
return reuse_series_implementation(
self, "replace_strict", old, new, return_dtype=return_dtype
)

def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self:
return reuse_series_implementation(
self, "sort", descending=descending, nulls_last=nulls_last
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _create_expr_from_series(self, series: ArrowSeries) -> ArrowExpr:
def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSeries:
from narwhals._arrow.series import ArrowSeries

if self._backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
if self._backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
return ArrowSeries._from_iterable(
[value],
Expand Down
20 changes: 20 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,26 @@ def unique(self: Self, *, maintain_order: bool = False) -> ArrowSeries:

return self._from_native_series(pc.unique(self._native_series))

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
) -> ArrowSeries:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

# https://stackoverflow.com/a/79111029/4451315
idxs = pc.index_in(self._native_series, pa.array(old))
result_native = pc.take(pa.array(new), idxs).cast(
narwhals_to_native_dtype(return_dtype, self._dtypes)
)
result = self._from_native_series(result_native)
if result.is_null().sum() != self.is_null().sum():
msg = (
"replace_strict did not replace all non-null values.\n\n"
f"The following did not get replaced: {self.filter(~self.is_null() & result.is_null()).unique().to_list()}"
)
raise ValueError(msg)
return result

def sort(
self: Self, *, descending: bool = False, nulls_last: bool = False
) -> ArrowSeries:
Expand Down
7 changes: 7 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable
from typing import Literal
from typing import NoReturn
from typing import Sequence

from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate
Expand Down Expand Up @@ -477,6 +478,12 @@ def head(self) -> NoReturn:
msg = "`Expr.head` is not supported for the Dask backend. Please use `LazyFrame.head` instead."
raise NotImplementedError(msg)

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
) -> Self:
msg = "`replace_strict` is not yet supported for Dask expressions"
raise NotImplementedError(msg)

def sort(self, *, descending: bool = False, nulls_last: bool = False) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.sort` is not supported for the Dask backend. Please use `LazyFrame.sort` instead."
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def with_columns(
)
else:
# This is the logic in pandas' DataFrame.assign
if self._backend_version < (2,): # pragma: no cover
if self._backend_version < (2,):
df = self._native_frame.copy(deep=True)
else:
df = self._native_frame.copy(deep=False)
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Callable
from typing import Literal
from typing import Sequence

from narwhals._expression_parsing import reuse_series_implementation
from narwhals._expression_parsing import reuse_series_namespace_implementation
Expand All @@ -14,6 +15,7 @@

from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals.dtypes import DType
from narwhals.typing import DTypes
from narwhals.utils import Implementation

Expand Down Expand Up @@ -271,6 +273,13 @@ def filter(self, *predicates: Any) -> Self:
def drop_nulls(self) -> Self:
return reuse_series_implementation(self, "drop_nulls")

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
) -> Self:
return reuse_series_implementation(
self, "replace_strict", old, new, return_dtype=return_dtype
)

def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:
return reuse_series_implementation(
self, "sort", descending=descending, nulls_last=nulls_last
Expand Down
30 changes: 30 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,36 @@ def diff(self) -> PandasLikeSeries:
def shift(self, n: int) -> PandasLikeSeries:
return self._from_native_series(self._native_series.shift(n))

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
) -> PandasLikeSeries:
tmp_name = f"{self.name}_tmp"
dtype = narwhals_to_native_dtype(
return_dtype,
self._native_series.dtype,
self._implementation,
self._backend_version,
self._dtypes,
)
other = self.__native_namespace__().DataFrame(
{
self.name: old,
tmp_name: self.__native_namespace__().Series(new, dtype=dtype),
}
)
result = self._from_native_series(
self._native_series.to_frame()
.merge(other, on=self.name, how="left")[tmp_name]
.rename(self.name)
)
if result.is_null().sum() != self.is_null().sum():
msg = (
"replace_strict did not replace all non-null values.\n\n"
f"The following did not get replaced: {self.filter(~self.is_null() & result.is_null()).unique().to_list()}"
)
raise ValueError(msg)
return result

def sort(
self, *, descending: bool = False, nulls_last: bool = False
) -> PandasLikeSeries:
Expand Down
20 changes: 10 additions & 10 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def func(*args: Any, **kwargs: Any) -> Any:
return func

def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.ndarray:
if self._backend_version < (0, 20, 28) and copy is not None: # pragma: no cover
if self._backend_version < (0, 20, 28) and copy is not None:
msg = "`copy` in `__array__` is only supported for Polars>=0.20.28"
raise NotImplementedError(msg)
if self._backend_version < (0, 20, 28): # pragma: no cover
if self._backend_version < (0, 20, 28):
return self._native_frame.__array__(dtype)
return self._native_frame.__array__(dtype)

def collect_schema(self) -> dict[str, Any]:
if self._backend_version < (1,): # pragma: no cover
if self._backend_version < (1,):
schema = self._native_frame.schema
else:
schema = dict(self._native_frame.collect_schema())
Expand Down Expand Up @@ -209,12 +209,12 @@ def group_by(self, *by: str, drop_null_keys: bool) -> Any:
return PolarsGroupBy(self, list(by), drop_null_keys=drop_null_keys)

def with_row_index(self, name: str) -> Any:
if self._backend_version < (0, 20, 4): # pragma: no cover
if self._backend_version < (0, 20, 4):
return self._from_native_frame(self._native_frame.with_row_count(name))
return self._from_native_frame(self._native_frame.with_row_index(name))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
if self._backend_version < (1, 0, 0): # pragma: no cover
if self._backend_version < (1, 0, 0):
to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
Expand All @@ -228,7 +228,7 @@ def unpivot(
variable_name: str | None,
value_name: str | None,
) -> Self:
if self._backend_version < (1, 0, 0): # pragma: no cover
if self._backend_version < (1, 0, 0):
return self._from_native_frame(
self._native_frame.melt(
id_vars=index,
Expand Down Expand Up @@ -296,7 +296,7 @@ def schema(self) -> dict[str, Any]:
}

def collect_schema(self) -> dict[str, Any]:
if self._backend_version < (1,): # pragma: no cover
if self._backend_version < (1,):
schema = self._native_frame.schema
else:
schema = dict(self._native_frame.collect_schema())
Expand All @@ -318,12 +318,12 @@ def group_by(self, *by: str, drop_null_keys: bool) -> Any:
return PolarsLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys)

def with_row_index(self, name: str) -> Any:
if self._backend_version < (0, 20, 4): # pragma: no cover
if self._backend_version < (0, 20, 4):
return self._from_native_frame(self._native_frame.with_row_count(name))
return self._from_native_frame(self._native_frame.with_row_index(name))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
if self._backend_version < (1, 0, 0): # pragma: no cover
if self._backend_version < (1, 0, 0):
return self._from_native_frame(self._native_frame.drop(columns))
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))

Expand All @@ -334,7 +334,7 @@ def unpivot(
variable_name: str | None,
value_name: str | None,
) -> Self:
if self._backend_version < (1, 0, 0): # pragma: no cover
if self._backend_version < (1, 0, 0):
return self._from_native_frame(
self._native_frame.melt(
id_vars=index,
Expand Down
22 changes: 20 additions & 2 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import extract_native
Expand All @@ -16,16 +17,21 @@


class PolarsExpr:
def __init__(self, expr: Any, dtypes: DTypes) -> None:
def __init__(
self, expr: Any, dtypes: DTypes, backend_version: tuple[int, ...]
) -> None:
self._native_expr = expr
self._implementation = Implementation.POLARS
self._dtypes = dtypes
self._backend_version = backend_version

def __repr__(self) -> str: # pragma: no cover
return "PolarsExpr"

def _from_native_expr(self, expr: Any) -> Self:
return self.__class__(expr, dtypes=self._dtypes)
return self.__class__(
expr, dtypes=self._dtypes, backend_version=self._backend_version
)

def __getattr__(self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
Expand All @@ -41,6 +47,18 @@ def cast(self, dtype: DType) -> Self:
dtype = narwhals_to_native_dtype(dtype, self._dtypes)
return self._from_native_expr(expr.cast(dtype))

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
) -> Self:
expr = self._native_expr
return_dtype = narwhals_to_native_dtype(return_dtype, self._dtypes)
if self._backend_version < (1,):
msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}"
raise NotImplementedError(msg)
return self._from_native_expr(
expr.replace_strict(old, new, return_dtype=return_dtype)
)

def __eq__(self, other: object) -> Self: # type: ignore[override]
return self._from_native_expr(self._native_expr.__eq__(extract_native(other)))

Expand Down
Loading

0 comments on commit 50c0a0f

Please sign in to comment.