From 706b7c8ac6ff19ca9929956c3651488c0023acbd Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov <65722512+aidoskanapyanov@users.noreply.github.com> Date: Mon, 26 Aug 2024 00:24:30 +0500 Subject: [PATCH] feat: dask expr `cast` (#821) * feat: dask expr cast * replace temporary hack with `.cast` * test cast raises for unknown dtype * remove unused filterwarnings * use walrus for simplicity Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove duplicated test * add missing function parameter types * skip coverage for some paths * skip coevrage for unknown dtype exception Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> * remove redundant modin xfail * empty commit to trigger CI --------- Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: FBruzzesi <francesco.bruzzesi.93@gmail.com> --- narwhals/_dask/expr.py | 17 ++++++++++ narwhals/_dask/utils.py | 49 ++++++++++++++++++++++++++++ tests/expr_and_series/binary_test.py | 6 +--- tests/expr_and_series/cast_test.py | 35 ++++++++++++++++---- tests/selectors_test.py | 14 -------- tests/test_group_by.py | 3 -- 6 files changed, 95 insertions(+), 29 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 31bc4852f8..62aaa85e67 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -9,6 +9,7 @@ from narwhals._dask.utils import add_row_index from narwhals._dask.utils import maybe_evaluate +from narwhals._dask.utils import reverse_translate_dtype from narwhals.dependencies import get_dask from narwhals.utils import generate_unique_token @@ -17,6 +18,7 @@ from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace + from narwhals.dtypes import DType class DaskExpr: @@ -654,6 +656,21 @@ def dt(self: Self) -> DaskExprDateTimeNamespace: def name(self: Self) -> DaskExprNameNamespace: return DaskExprNameNamespace(self) + def cast( + self: Self, + dtype: DType | type[DType], + ) -> Self: + def func(_input: Any, dtype: DType | type[DType]) -> Any: + dtype = reverse_translate_dtype(dtype) + return _input.astype(dtype) + + return self._from_call( + func, + "cast", + dtype, + returns_scalar=False, + ) + class DaskExprStringNamespace: def __init__(self, expr: DaskExpr) -> None: diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 27f0d7643e..d3f02a07c4 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -4,9 +4,14 @@ from typing import Any from narwhals.dependencies import get_dask_expr +from narwhals.dependencies import get_pandas +from narwhals.dependencies import get_pyarrow +from narwhals.utils import isinstance_or_issubclass +from narwhals.utils import parse_version if TYPE_CHECKING: from narwhals._dask.dataframe import DaskLazyFrame + from narwhals.dtypes import DType def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: @@ -73,3 +78,47 @@ def parse_exprs_and_named_exprs( def add_row_index(frame: Any, name: str) -> Any: frame = frame.assign(**{name: 1}) return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1}) + + +def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: + from narwhals import dtypes + + if isinstance_or_issubclass(dtype, dtypes.Float64): + return "float64" + if isinstance_or_issubclass(dtype, dtypes.Float32): + return "float32" + if isinstance_or_issubclass(dtype, dtypes.Int64): + return "int64" + if isinstance_or_issubclass(dtype, dtypes.Int32): + return "int32" + if isinstance_or_issubclass(dtype, dtypes.Int16): + return "int16" + if isinstance_or_issubclass(dtype, dtypes.Int8): + return "int8" + if isinstance_or_issubclass(dtype, dtypes.UInt64): + return "uint64" + if isinstance_or_issubclass(dtype, dtypes.UInt32): + return "uint32" + if isinstance_or_issubclass(dtype, dtypes.UInt16): + return "uint16" + if isinstance_or_issubclass(dtype, dtypes.UInt8): + return "uint8" + if isinstance_or_issubclass(dtype, dtypes.String): + if (pd := get_pandas()) is not None and parse_version( + pd.__version__ + ) >= parse_version("2.0.0"): + if get_pyarrow() is not None: + return "string[pyarrow]" + return "string[python]" # pragma: no cover + return "object" # pragma: no cover + if isinstance_or_issubclass(dtype, dtypes.Boolean): + return "bool" + if isinstance_or_issubclass(dtype, dtypes.Categorical): + return "category" + if isinstance_or_issubclass(dtype, dtypes.Datetime): + return "datetime64[us]" + if isinstance_or_issubclass(dtype, dtypes.Duration): + return "timedelta64[ns]" + + msg = f"Unknown dtype: {dtype}" # pragma: no cover + raise AssertionError(msg) diff --git a/tests/expr_and_series/binary_test.py b/tests/expr_and_series/binary_test.py index 9d4e6cf6c0..2d55af2285 100644 --- a/tests/expr_and_series/binary_test.py +++ b/tests/expr_and_series/binary_test.py @@ -1,15 +1,11 @@ from typing import Any -import pytest - import narwhals.stable.v1 as nw from tests.utils import compare_dicts -def test_expr_binary(constructor: Any, request: Any) -> None: +def test_expr_binary(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) df_raw = constructor(data) result = nw.from_native(df_raw).with_columns( a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")), diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 3999b11097..0b496d7ae2 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -47,8 +47,6 @@ @pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning") def test_cast(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover @@ -98,17 +96,21 @@ def test_cast(constructor: Any, request: Any) -> None: assert dict(result.collect_schema()) == expected -def test_cast_series(constructor_eager: Any, request: Any) -> None: - if "pyarrow_table_constructor" in str(constructor_eager) and parse_version( +def test_cast_series(constructor: Any, request: Any) -> None: + if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover request.applymarker(pytest.mark.xfail) - if "modin" in str(constructor_eager): + if "modin" in str(constructor): # TODO(unassigned): in modin, we end up with `'<U0'` dtype request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor_eager(data), eager_only=True).select( - nw.col(key).cast(value) for key, value in schema.items() + df = ( + nw.from_native(constructor(data)) + .select(nw.col(key).cast(value) for key, value in schema.items()) + .lazy() + .collect() ) + expected = { "a": nw.Int32, "b": nw.Int16, @@ -158,3 +160,22 @@ def test_cast_string() -> None: s = s.cast(nw.String) result = nw.to_native(s) assert str(result.dtype) in ("string", "object", "dtype('O')") + + +def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None: + if "pyarrow_table_constructor" in str(constructor) and parse_version( + pa.__version__ + ) <= (15,): # pragma: no cover + request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)).select( + nw.col(key).cast(value) for key, value in schema.items() + ) + + class Banana: + pass + + with pytest.raises(AssertionError, match=r"Unknown dtype"): + df.select(nw.col("a").cast(Banana)) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 128cb73c03..ababee4a7a 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -7,7 +7,6 @@ import pytest import narwhals.stable.v1 as nw -from narwhals.dependencies import get_dask_dataframe from narwhals.selectors import all from narwhals.selectors import boolean from narwhals.selectors import by_dtype @@ -57,8 +56,6 @@ def test_string(constructor: Any, request: Any) -> None: def test_categorical(request: Any, constructor: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover @@ -70,17 +67,6 @@ def test_categorical(request: Any, constructor: Any) -> None: compare_dicts(result, expected) -@pytest.mark.skipif((get_dask_dataframe() is None), reason="too old for dask") -def test_dask_categorical() -> None: - import dask.dataframe as dd - - expected = {"b": ["a", "b", "c"]} - df_raw = dd.from_dict(expected, npartitions=1).astype({"b": "category"}) - df = nw.from_native(df_raw) - result = df.select(categorical()) - compare_dicts(result, expected) - - @pytest.mark.parametrize( ("selector", "expected"), [ diff --git a/tests/test_group_by.py b/tests/test_group_by.py index 40b598de2a..2bb8d435b4 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -173,9 +173,6 @@ def test_group_by_multiple_keys(constructor: Any) -> None: def test_key_with_nulls(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) - if "modin" in str(constructor): # TODO(unassigned): Modin flaky here? request.applymarker(pytest.mark.skip)