From 97a7c410ef045024ca469486e569bf109bb2856d Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov Date: Tue, 20 Aug 2024 16:36:27 +0500 Subject: [PATCH 01/12] feat: dask expr cast --- narwhals/_dask/expr.py | 16 +++++++++ narwhals/_dask/utils.py | 49 ++++++++++++++++++++++++++++ tests/expr_and_series/binary_test.py | 6 +--- tests/expr_and_series/cast_test.py | 16 +++++---- tests/selectors_test.py | 2 -- tests/test_group_by.py | 3 -- 6 files changed, 75 insertions(+), 17 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index faedb6095..0d6465c38 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -8,6 +8,7 @@ from narwhals._dask.utils import add_row_index from narwhals._dask.utils import maybe_evaluate +from narwhals._dask.utils import reverse_translate_dtype from narwhals.dependencies import get_dask from narwhals.utils import generate_unique_token @@ -610,6 +611,21 @@ def dt(self: Self) -> DaskExprDateTimeNamespace: def name(self: Self) -> DaskExprNameNamespace: return DaskExprNameNamespace(self) + def cast( + self, + dtype: Any, + ) -> Self: + def func(_input: Any, dtype: Any) -> Any: + dtype = reverse_translate_dtype(dtype) + return _input.astype(dtype) + + return self._from_call( + func, + "cast", + dtype, + returns_scalar=False, + ) + class DaskExprStringNamespace: def __init__(self, expr: DaskExpr) -> None: diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 27f0d7643..c271c1038 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -4,9 +4,14 @@ from typing import Any from narwhals.dependencies import get_dask_expr +from narwhals.dependencies import get_pandas +from narwhals.dependencies import get_pyarrow +from narwhals.utils import isinstance_or_issubclass +from narwhals.utils import parse_version if TYPE_CHECKING: from narwhals._dask.dataframe import DaskLazyFrame + from narwhals.dtypes import DType def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: @@ -73,3 +78,47 @@ def parse_exprs_and_named_exprs( def add_row_index(frame: Any, name: str) -> Any: frame = frame.assign(**{name: 1}) return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1}) + + +def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: + from narwhals import dtypes + + if isinstance_or_issubclass(dtype, dtypes.Float64): + return "float64" + if isinstance_or_issubclass(dtype, dtypes.Float32): + return "float32" + if isinstance_or_issubclass(dtype, dtypes.Int64): + return "int64" + if isinstance_or_issubclass(dtype, dtypes.Int32): + return "int32" + if isinstance_or_issubclass(dtype, dtypes.Int16): + return "int16" + if isinstance_or_issubclass(dtype, dtypes.Int8): + return "int8" + if isinstance_or_issubclass(dtype, dtypes.UInt64): + return "uint64" + if isinstance_or_issubclass(dtype, dtypes.UInt32): + return "uint32" + if isinstance_or_issubclass(dtype, dtypes.UInt16): + return "uint16" + if isinstance_or_issubclass(dtype, dtypes.UInt8): + return "uint8" + if isinstance_or_issubclass(dtype, dtypes.String): + pd = get_pandas() + + if pd is not None and parse_version(pd.__version__) >= parse_version("2.0.0"): + if get_pyarrow() is not None: + return "string[pyarrow]" + return "string[python]" + return "object" + if isinstance_or_issubclass(dtype, dtypes.Boolean): + return "bool" + if isinstance_or_issubclass(dtype, dtypes.Categorical): + return "category" + if isinstance_or_issubclass(dtype, dtypes.Datetime): + return "datetime64[us]" + if isinstance_or_issubclass(dtype, dtypes.Duration): + return "timedelta64[ns]" + + msg = f"Unknown dtype: {dtype}" + raise AssertionError(msg) diff --git a/tests/expr_and_series/binary_test.py b/tests/expr_and_series/binary_test.py index 9d4e6cf6c..2d55af228 100644 --- a/tests/expr_and_series/binary_test.py +++ b/tests/expr_and_series/binary_test.py @@ -1,15 +1,11 @@ from typing import Any -import pytest - import narwhals.stable.v1 as nw from tests.utils import compare_dicts -def test_expr_binary(constructor: Any, request: Any) -> None: +def test_expr_binary(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) df_raw = constructor(data) result = nw.from_native(df_raw).with_columns( a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")), diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 3999b1109..fde2de99a 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -47,8 +47,6 @@ @pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning") def test_cast(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover @@ -98,17 +96,21 @@ def test_cast(constructor: Any, request: Any) -> None: assert dict(result.collect_schema()) == expected -def test_cast_series(constructor_eager: Any, request: Any) -> None: - if "pyarrow_table_constructor" in str(constructor_eager) and parse_version( +def test_cast_series(constructor: Any, request: Any) -> None: + if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover request.applymarker(pytest.mark.xfail) - if "modin" in str(constructor_eager): + if "modin" in str(constructor): # TODO(unassigned): in modin, we end up with `' None: def test_categorical(request: Any, constructor: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover diff --git a/tests/test_group_by.py b/tests/test_group_by.py index 40b598de2..2bb8d435b 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -173,9 +173,6 @@ def test_group_by_multiple_keys(constructor: Any) -> None: def test_key_with_nulls(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) - if "modin" in str(constructor): # TODO(unassigned): Modin flaky here? request.applymarker(pytest.mark.skip) From 2054694f09d70bbffc59a3390f4feaabb465f76b Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov Date: Tue, 20 Aug 2024 16:45:28 +0500 Subject: [PATCH 02/12] replace temporary hack with `.cast` --- tests/selectors_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 025b931d1..221318cff 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -69,12 +69,9 @@ def test_categorical(request: Any, constructor: Any) -> None: @pytest.mark.skipif((get_dask_dataframe() is None), reason="too old for dask") -def test_dask_categorical() -> None: - import dask.dataframe as dd - +def test_dask_categorical(constructor: Any) -> None: expected = {"b": ["a", "b", "c"]} - df_raw = dd.from_dict(expected, npartitions=1).astype({"b": "category"}) - df = nw.from_native(df_raw) + df = nw.from_native(constructor(data)).with_columns(nw.col("b").cast(nw.Categorical)) result = df.select(categorical()) compare_dicts(result, expected) From 91ee57f60c59495696e3af04d822cd8eaa2a1c86 Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov Date: Tue, 20 Aug 2024 17:06:59 +0500 Subject: [PATCH 03/12] test cast raises for unknown dtype --- tests/expr_and_series/cast_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index fde2de99a..b34f6b071 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -160,3 +160,26 @@ def test_cast_string() -> None: s = s.cast(nw.String) result = nw.to_native(s) assert str(result.dtype) in ("string", "object", "dtype('O')") + + +@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning") +def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None: + if "pyarrow_table_constructor" in str(constructor) and parse_version( + pa.__version__ + ) <= (15,): # pragma: no cover + request.applymarker(pytest.mark.xfail) + if "modin" in str(constructor): + # TODO(unassigned): in modin, we end up with `' Date: Tue, 20 Aug 2024 17:08:23 +0500 Subject: [PATCH 04/12] remove unused filterwarnings --- tests/expr_and_series/cast_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index b34f6b071..d870a371c 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -162,7 +162,6 @@ def test_cast_string() -> None: assert str(result.dtype) in ("string", "object", "dtype('O')") -@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning") def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None: if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ From 90d43ce8a16bc644d294f998d9eb7c31b0082860 Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov <65722512+aidoskanapyanov@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:51:21 +0500 Subject: [PATCH 05/12] use walrus for simplicity Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> --- narwhals/_dask/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index c271c1038..499d5a194 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -104,9 +104,7 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: if isinstance_or_issubclass(dtype, dtypes.UInt8): return "uint8" if isinstance_or_issubclass(dtype, dtypes.String): - pd = get_pandas() - - if pd is not None and parse_version(pd.__version__) >= parse_version("2.0.0"): + if (pd := get_pandas()) is not None and parse_version(pd.__version__) >= parse_version("2.0.0"): if get_pyarrow() is not None: return "string[pyarrow]" return "string[python]" From ededb95ae130e4daad2ae9e6ce4558ba87f6c7c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 17:51:44 +0000 Subject: [PATCH 06/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_dask/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 499d5a194..281c4bcd8 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -104,7 +104,9 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: if isinstance_or_issubclass(dtype, dtypes.UInt8): return "uint8" if isinstance_or_issubclass(dtype, dtypes.String): - if (pd := get_pandas()) is not None and parse_version(pd.__version__) >= parse_version("2.0.0"): + if (pd := get_pandas()) is not None and parse_version( + pd.__version__ + ) >= parse_version("2.0.0"): if get_pyarrow() is not None: return "string[pyarrow]" return "string[python]" From d7e053d799c8c0a9be0e01d933b944981a9776cb Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov Date: Wed, 21 Aug 2024 22:53:13 +0500 Subject: [PATCH 07/12] remove duplicated test --- tests/selectors_test.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 221318cff..ababee4a7 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -7,7 +7,6 @@ import pytest import narwhals.stable.v1 as nw -from narwhals.dependencies import get_dask_dataframe from narwhals.selectors import all from narwhals.selectors import boolean from narwhals.selectors import by_dtype @@ -68,14 +67,6 @@ def test_categorical(request: Any, constructor: Any) -> None: compare_dicts(result, expected) -@pytest.mark.skipif((get_dask_dataframe() is None), reason="too old for dask") -def test_dask_categorical(constructor: Any) -> None: - expected = {"b": ["a", "b", "c"]} - df = nw.from_native(constructor(data)).with_columns(nw.col("b").cast(nw.Categorical)) - result = df.select(categorical()) - compare_dicts(result, expected) - - @pytest.mark.parametrize( ("selector", "expected"), [ From 994e25268cf77399aee1994d97c374131e553bb0 Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov Date: Wed, 21 Aug 2024 22:58:07 +0500 Subject: [PATCH 08/12] add missing function parameter types --- narwhals/_dask/expr.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 0d6465c38..9f6129bab 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -17,6 +17,7 @@ from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace + from narwhals.dtypes import DType class DaskExpr: @@ -612,10 +613,10 @@ def name(self: Self) -> DaskExprNameNamespace: return DaskExprNameNamespace(self) def cast( - self, - dtype: Any, + self: Self, + dtype: DType | type[DType], ) -> Self: - def func(_input: Any, dtype: Any) -> Any: + def func(_input: Any, dtype: DType | type[DType]) -> Any: dtype = reverse_translate_dtype(dtype) return _input.astype(dtype) From c9912dde7f55bd68b518aae9e1313a460228e9c5 Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov Date: Wed, 21 Aug 2024 23:14:22 +0500 Subject: [PATCH 09/12] skip coverage for some paths --- narwhals/_dask/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 281c4bcd8..0aa3030e2 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -109,8 +109,8 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: ) >= parse_version("2.0.0"): if get_pyarrow() is not None: return "string[pyarrow]" - return "string[python]" - return "object" + return "string[python]" # pragma: no cover + return "object" # pragma: no cover if isinstance_or_issubclass(dtype, dtypes.Boolean): return "bool" if isinstance_or_issubclass(dtype, dtypes.Categorical): From 0556b0d897687db8b15b1a495b90c99f68f45802 Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov <65722512+aidoskanapyanov@users.noreply.github.com> Date: Thu, 22 Aug 2024 12:58:48 +0500 Subject: [PATCH 10/12] skip coevrage for unknown dtype exception Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> --- narwhals/_dask/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 0aa3030e2..d3f02a07c 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -120,5 +120,5 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: if isinstance_or_issubclass(dtype, dtypes.Duration): return "timedelta64[ns]" - msg = f"Unknown dtype: {dtype}" + msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) From 0af7233466a24b5d8685d0d0f514480ff5d847e0 Mon Sep 17 00:00:00 2001 From: Aidos Kanapyanov Date: Thu, 22 Aug 2024 23:17:25 +0500 Subject: [PATCH 11/12] remove redundant modin xfail --- tests/expr_and_series/cast_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index d870a371c..0b496d7ae 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -167,9 +167,6 @@ def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None: pa.__version__ ) <= (15,): # pragma: no cover request.applymarker(pytest.mark.xfail) - if "modin" in str(constructor): - # TODO(unassigned): in modin, we end up with `' Date: Sun, 25 Aug 2024 19:33:44 +0200 Subject: [PATCH 12/12] empty commit to trigger CI