Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add DataFrame|LazyFrame.cast method #1045

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
members:
- __arrow_c_stream__
- __getitem__
- cast
- clone
- collect_schema
- columns
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/lazyframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
handler: python
options:
members:
- cast
- clone
- collect
- collect_schema
Expand Down
16 changes: 16 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,19 @@ def sample(
mask = rng.choice(idx, size=n, replace=with_replacement)

return self._from_native_frame(pc.take(frame, mask))

def cast(
self: Self,
dtypes: dict[str, type[DType] | DType] | type[DType] | DType,
*,
strict: bool,
) -> Self:
plx = self.__narwhals_namespace__()
if isinstance(dtypes, dict):
return self.with_columns(
**{c: plx.col(c).cast(v, strict=strict) for c, v in dtypes.items()}
)
else:
return self.with_columns(
**{c: plx.col(c).cast(dtypes, strict=strict) for c in self.columns}
)
4 changes: 2 additions & 2 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def n_unique(self) -> Self:
def std(self, ddof: int = 1) -> Self:
return reuse_series_implementation(self, "std", ddof=ddof, returns_scalar=True)

def cast(self, dtype: DType) -> Self:
return reuse_series_implementation(self, "cast", dtype)
def cast(self, dtype: DType | type[DType], *, strict: bool) -> Self:
return reuse_series_implementation(self, "cast", dtype=dtype, strict=strict)

def abs(self) -> Self:
return reuse_series_implementation(self, "abs")
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
backend_version=self._backend_version,
)
if dtype:
return arrow_series.cast(dtype)
return arrow_series.cast(dtype, strict=True)
return arrow_series

return ArrowExpr(
Expand Down Expand Up @@ -184,7 +184,7 @@ def mean_horizontal(self, *exprs: IntoArrowExpr) -> IntoArrowExpr:
total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in arrow_exprs))
n_non_zero = reduce(
lambda x, y: x + y,
((1 - e.is_null().cast(self.Int64())) for e in arrow_exprs),
((1 - e.is_null().cast(self.Int64(), strict=True)) for e in arrow_exprs),
)
return total / n_non_zero

Expand Down
4 changes: 2 additions & 2 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,12 @@ def is_null(self) -> Self:
ser = self._native_series
return self._from_native_series(ser.is_null())

def cast(self, dtype: DType) -> Self:
def cast(self, dtype: DType | type[DType], *, strict: bool) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
dtype = narwhals_to_native_dtype(dtype)
return self._from_native_series(pc.cast(ser, dtype))
return self._from_native_series(pc.cast(ser, dtype, safe=strict))

def null_count(self: Self) -> int:
return self._native_series.null_count # type: ignore[no-any-return]
Expand Down
17 changes: 17 additions & 0 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Sequence

from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals._dask.utils import parse_exprs_and_named_exprs
from narwhals._pandas_like.utils import translate_dtype
from narwhals.dependencies import get_dask_dataframe
Expand Down Expand Up @@ -355,3 +356,19 @@ def gather_every(self: Self, n: int, offset: int) -> Self:
)
.drop([row_index_token], strict=False)
)

def cast(
self: Self,
dtypes: dict[str, type[DType] | DType] | type[DType] | DType,
*,
strict: bool,
) -> Self:
"""`strict` exists for compatibility as dask `astype` does not support
`errors` argument as pandas does.
"""
native_frame = self._native_frame
if isinstance(dtypes, dict):
dtypes = {k: narwhals_to_native_dtype(v) for k, v in dtypes.items()}
else:
dtypes = narwhals_to_native_dtype(dtypes)
return self._from_native_frame(native_frame.astype(dtypes))
13 changes: 7 additions & 6 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate
from narwhals._dask.utils import reverse_translate_dtype
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals.dependencies import get_dask
from narwhals.utils import generate_unique_token

Expand Down Expand Up @@ -675,12 +675,13 @@ def dt(self: Self) -> DaskExprDateTimeNamespace:
def name(self: Self) -> DaskExprNameNamespace:
return DaskExprNameNamespace(self)

def cast(
self: Self,
dtype: DType | type[DType],
) -> Self:
def cast(self: Self, dtype: DType | type[DType], *, strict: bool) -> Self:
"""`strict` exists for compatibility as dask `astype` does not support
`errors` argument as pandas does.
"""

def func(_input: Any, dtype: DType | type[DType]) -> Any:
dtype = reverse_translate_dtype(dtype)
dtype = narwhals_to_native_dtype(dtype)
return _input.astype(dtype)

return self._from_call(
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import reverse_translate_dtype
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals._dask.utils import validate_comparand
from narwhals._expression_parsing import parse_into_exprs

Expand Down Expand Up @@ -75,7 +75,7 @@ def lit(self, value: Any, dtype: dtypes.DType | None) -> DaskExpr:
def convert_if_dtype(
series: dask_expr.Series, dtype: DType | type[DType]
) -> dask_expr.Series:
return series.astype(reverse_translate_dtype(dtype)) if dtype else series
return series.astype(narwhals_to_native_dtype(dtype)) if dtype else series

return DaskExpr(
lambda df: [
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:
raise RuntimeError(msg)


def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
def narwhals_to_native_dtype(dtype: DType | type[DType]) -> Any:
from narwhals import dtypes

if isinstance_or_issubclass(dtype, dtypes.Float64):
Expand Down
30 changes: 30 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from narwhals._pandas_like.utils import convert_str_slice_to_int_slice
from narwhals._pandas_like.utils import create_native_series
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import narwhals_to_native_dtype
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_dataframe_comparand
from narwhals.dependencies import get_cudf
Expand Down Expand Up @@ -728,3 +729,32 @@ def sample(
n=n, frac=fraction, replace=with_replacement, random_state=seed
)
)

def cast(
self: Self,
dtypes: dict[str, type[DType] | DType] | type[DType] | DType,
*,
strict: bool,
) -> Self:
native_frame = self._native_frame
implementation = self._implementation
if isinstance(dtypes, dict):
dtypes = {
col: narwhals_to_native_dtype(
dtype,
starting_dtype=native_frame[col].dtype,
implementation=implementation,
)
for col, dtype in dtypes.items()
}
else:
dtypes = {
col: narwhals_to_native_dtype(
dtypes,
starting_dtype=native_frame[col].dtype,
implementation=implementation,
)
for col in self.columns
}
errors = "raise" if strict else "ignore"
return self._from_native_frame(native_frame.astype(dtypes, errors=errors))
8 changes: 3 additions & 5 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Implementation


Expand Down Expand Up @@ -81,11 +82,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
backend_version=backend_version,
)

def cast(
self,
dtype: Any,
) -> Self:
return reuse_series_implementation(self, "cast", dtype=dtype)
def cast(self: Self, dtype: DType | type[DType], *, strict: bool) -> Self:
return reuse_series_implementation(self, "cast", dtype=dtype, strict=strict)

def __eq__(self, other: PandasLikeExpr | Any) -> Self: # type: ignore[override]
return reuse_series_implementation(self, "__eq__", other=other)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries:
backend_version=self._backend_version,
)
if dtype:
return pandas_series.cast(dtype)
return pandas_series.cast(dtype, strict=True)
return pandas_series

return PandasLikeExpr(
Expand Down
8 changes: 3 additions & 5 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,11 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
s.name = self.name
return self._from_native_series(s)

def cast(
self,
dtype: Any,
) -> Self:
def cast(self, dtype: DType | type[DType], *, strict: bool) -> Self:
ser = self._native_series
dtype = narwhals_to_native_dtype(dtype, ser.dtype, self._implementation)
return self._from_native_series(ser.astype(dtype))
errors = "raise" if strict else "ignore"
return self._from_native_series(ser.astype(dtype, errors=errors))

def item(self: Self, index: int | None = None) -> Any:
# cuDF doesn't have Series.item().
Expand Down
29 changes: 29 additions & 0 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import convert_str_slice_to_int_slice
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals._polars.utils import translate_dtype
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation
Expand All @@ -16,6 +17,8 @@
import numpy as np
from typing_extensions import Self

from narwhals.dtypes import DType


class PolarsDataFrame:
def __init__(self, df: Any, *, backend_version: tuple[int, ...]) -> None:
Expand Down Expand Up @@ -186,6 +189,19 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
return self._from_native_frame(self._native_frame.drop(to_drop))
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))

def cast(
self: Self,
dtypes: dict[str, type[DType] | DType] | type[DType] | DType,
*,
strict: bool,
) -> Self:
native_frame = self._native_frame
if isinstance(dtypes, dict):
dtypes = {k: narwhals_to_native_dtype(v) for k, v in dtypes.items()}
else:
dtypes = narwhals_to_native_dtype(dtypes)
return self._from_native_frame(native_frame.cast(dtypes, strict=strict))


class PolarsLazyFrame:
def __init__(self, df: Any, *, backend_version: tuple[int, ...]) -> None:
Expand Down Expand Up @@ -251,3 +267,16 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
if self._backend_version < (1, 0, 0): # pragma: no cover
return self._from_native_frame(self._native_frame.drop(columns))
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))

def cast(
self: Self,
dtypes: dict[str, type[DType] | DType] | type[DType] | DType,
*,
strict: bool,
) -> Self:
native_frame = self._native_frame
if isinstance(dtypes, dict):
dtypes = {k: narwhals_to_native_dtype(v) for k, v in dtypes.items()}
else:
dtypes = narwhals_to_native_dtype(dtypes)
return self._from_native_frame(native_frame.cast(dtypes, strict=strict))
4 changes: 2 additions & 2 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def func(*args: Any, **kwargs: Any) -> Any:

return func

def cast(self, dtype: DType) -> Self:
def cast(self: Self, dtype: DType | type[DType], *, strict: bool) -> Self:
expr = self._native_expr
dtype = narwhals_to_native_dtype(dtype)
return self._from_native_expr(expr.cast(dtype))
return self._from_native_expr(expr.cast(dtype, strict=strict))

def __eq__(self, other: object) -> Self: # type: ignore[override]
return self._from_native_expr(self._native_expr.__eq__(extract_native(other)))
Expand Down
8 changes: 4 additions & 4 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals._polars.utils import translate_dtype
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation

Expand All @@ -17,8 +19,6 @@
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals.dtypes import DType

from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals._polars.utils import translate_dtype

PL = get_polars()

Expand Down Expand Up @@ -88,10 +88,10 @@ def __getitem__(self, item: slice | Sequence[int]) -> Self: ...
def __getitem__(self, item: int | slice | Sequence[int]) -> Any | Self:
return self._from_native_object(self._native_series.__getitem__(item))

def cast(self, dtype: DType) -> Self:
def cast(self, dtype: DType | type[DType], *, strict: bool) -> Self:
ser = self._native_series
dtype = narwhals_to_native_dtype(dtype)
return self._from_native_series(ser.cast(dtype))
return self._from_native_series(ser.cast(dtype, strict=strict))

def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
if self._backend_version < (0, 20, 29): # pragma: no cover
Expand Down
Loading
Loading