From 706b7c8ac6ff19ca9929956c3651488c0023acbd Mon Sep 17 00:00:00 2001
From: Aidos Kanapyanov <65722512+aidoskanapyanov@users.noreply.github.com>
Date: Mon, 26 Aug 2024 00:24:30 +0500
Subject: [PATCH] feat: dask expr `cast` (#821)

* feat: dask expr cast

* replace temporary hack with `.cast`

* test cast raises for unknown dtype

* remove unused filterwarnings

* use walrus for simplicity

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove duplicated test

* add missing function parameter types

* skip coverage for some paths

* skip coevrage for unknown dtype exception

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>

* remove redundant modin xfail

* empty commit to trigger CI

---------

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: FBruzzesi <francesco.bruzzesi.93@gmail.com>
---
 narwhals/_dask/expr.py               | 17 ++++++++++
 narwhals/_dask/utils.py              | 49 ++++++++++++++++++++++++++++
 tests/expr_and_series/binary_test.py |  6 +---
 tests/expr_and_series/cast_test.py   | 35 ++++++++++++++++----
 tests/selectors_test.py              | 14 --------
 tests/test_group_by.py               |  3 --
 6 files changed, 95 insertions(+), 29 deletions(-)

diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py
index 31bc4852f8..62aaa85e67 100644
--- a/narwhals/_dask/expr.py
+++ b/narwhals/_dask/expr.py
@@ -9,6 +9,7 @@
 
 from narwhals._dask.utils import add_row_index
 from narwhals._dask.utils import maybe_evaluate
+from narwhals._dask.utils import reverse_translate_dtype
 from narwhals.dependencies import get_dask
 from narwhals.utils import generate_unique_token
 
@@ -17,6 +18,7 @@
 
     from narwhals._dask.dataframe import DaskLazyFrame
     from narwhals._dask.namespace import DaskNamespace
+    from narwhals.dtypes import DType
 
 
 class DaskExpr:
@@ -654,6 +656,21 @@ def dt(self: Self) -> DaskExprDateTimeNamespace:
     def name(self: Self) -> DaskExprNameNamespace:
         return DaskExprNameNamespace(self)
 
+    def cast(
+        self: Self,
+        dtype: DType | type[DType],
+    ) -> Self:
+        def func(_input: Any, dtype: DType | type[DType]) -> Any:
+            dtype = reverse_translate_dtype(dtype)
+            return _input.astype(dtype)
+
+        return self._from_call(
+            func,
+            "cast",
+            dtype,
+            returns_scalar=False,
+        )
+
 
 class DaskExprStringNamespace:
     def __init__(self, expr: DaskExpr) -> None:
diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py
index 27f0d7643e..d3f02a07c4 100644
--- a/narwhals/_dask/utils.py
+++ b/narwhals/_dask/utils.py
@@ -4,9 +4,14 @@
 from typing import Any
 
 from narwhals.dependencies import get_dask_expr
+from narwhals.dependencies import get_pandas
+from narwhals.dependencies import get_pyarrow
+from narwhals.utils import isinstance_or_issubclass
+from narwhals.utils import parse_version
 
 if TYPE_CHECKING:
     from narwhals._dask.dataframe import DaskLazyFrame
+    from narwhals.dtypes import DType
 
 
 def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
@@ -73,3 +78,47 @@ def parse_exprs_and_named_exprs(
 def add_row_index(frame: Any, name: str) -> Any:
     frame = frame.assign(**{name: 1})
     return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1})
+
+
+def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
+    from narwhals import dtypes
+
+    if isinstance_or_issubclass(dtype, dtypes.Float64):
+        return "float64"
+    if isinstance_or_issubclass(dtype, dtypes.Float32):
+        return "float32"
+    if isinstance_or_issubclass(dtype, dtypes.Int64):
+        return "int64"
+    if isinstance_or_issubclass(dtype, dtypes.Int32):
+        return "int32"
+    if isinstance_or_issubclass(dtype, dtypes.Int16):
+        return "int16"
+    if isinstance_or_issubclass(dtype, dtypes.Int8):
+        return "int8"
+    if isinstance_or_issubclass(dtype, dtypes.UInt64):
+        return "uint64"
+    if isinstance_or_issubclass(dtype, dtypes.UInt32):
+        return "uint32"
+    if isinstance_or_issubclass(dtype, dtypes.UInt16):
+        return "uint16"
+    if isinstance_or_issubclass(dtype, dtypes.UInt8):
+        return "uint8"
+    if isinstance_or_issubclass(dtype, dtypes.String):
+        if (pd := get_pandas()) is not None and parse_version(
+            pd.__version__
+        ) >= parse_version("2.0.0"):
+            if get_pyarrow() is not None:
+                return "string[pyarrow]"
+            return "string[python]"  # pragma: no cover
+        return "object"  # pragma: no cover
+    if isinstance_or_issubclass(dtype, dtypes.Boolean):
+        return "bool"
+    if isinstance_or_issubclass(dtype, dtypes.Categorical):
+        return "category"
+    if isinstance_or_issubclass(dtype, dtypes.Datetime):
+        return "datetime64[us]"
+    if isinstance_or_issubclass(dtype, dtypes.Duration):
+        return "timedelta64[ns]"
+
+    msg = f"Unknown dtype: {dtype}"  # pragma: no cover
+    raise AssertionError(msg)
diff --git a/tests/expr_and_series/binary_test.py b/tests/expr_and_series/binary_test.py
index 9d4e6cf6c0..2d55af2285 100644
--- a/tests/expr_and_series/binary_test.py
+++ b/tests/expr_and_series/binary_test.py
@@ -1,15 +1,11 @@
 from typing import Any
 
-import pytest
-
 import narwhals.stable.v1 as nw
 from tests.utils import compare_dicts
 
 
-def test_expr_binary(constructor: Any, request: Any) -> None:
+def test_expr_binary(constructor: Any) -> None:
     data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
-    if "dask" in str(constructor):
-        request.applymarker(pytest.mark.xfail)
     df_raw = constructor(data)
     result = nw.from_native(df_raw).with_columns(
         a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")),
diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py
index 3999b11097..0b496d7ae2 100644
--- a/tests/expr_and_series/cast_test.py
+++ b/tests/expr_and_series/cast_test.py
@@ -47,8 +47,6 @@
 
 @pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
 def test_cast(constructor: Any, request: Any) -> None:
-    if "dask" in str(constructor):
-        request.applymarker(pytest.mark.xfail)
     if "pyarrow_table_constructor" in str(constructor) and parse_version(
         pa.__version__
     ) <= (15,):  # pragma: no cover
@@ -98,17 +96,21 @@ def test_cast(constructor: Any, request: Any) -> None:
     assert dict(result.collect_schema()) == expected
 
 
-def test_cast_series(constructor_eager: Any, request: Any) -> None:
-    if "pyarrow_table_constructor" in str(constructor_eager) and parse_version(
+def test_cast_series(constructor: Any, request: Any) -> None:
+    if "pyarrow_table_constructor" in str(constructor) and parse_version(
         pa.__version__
     ) <= (15,):  # pragma: no cover
         request.applymarker(pytest.mark.xfail)
-    if "modin" in str(constructor_eager):
+    if "modin" in str(constructor):
         # TODO(unassigned): in modin, we end up with `'<U0'` dtype
         request.applymarker(pytest.mark.xfail)
-    df = nw.from_native(constructor_eager(data), eager_only=True).select(
-        nw.col(key).cast(value) for key, value in schema.items()
+    df = (
+        nw.from_native(constructor(data))
+        .select(nw.col(key).cast(value) for key, value in schema.items())
+        .lazy()
+        .collect()
     )
+
     expected = {
         "a": nw.Int32,
         "b": nw.Int16,
@@ -158,3 +160,22 @@ def test_cast_string() -> None:
     s = s.cast(nw.String)
     result = nw.to_native(s)
     assert str(result.dtype) in ("string", "object", "dtype('O')")
+
+
+def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None:
+    if "pyarrow_table_constructor" in str(constructor) and parse_version(
+        pa.__version__
+    ) <= (15,):  # pragma: no cover
+        request.applymarker(pytest.mark.xfail)
+    if "polars" in str(constructor):
+        request.applymarker(pytest.mark.xfail)
+
+    df = nw.from_native(constructor(data)).select(
+        nw.col(key).cast(value) for key, value in schema.items()
+    )
+
+    class Banana:
+        pass
+
+    with pytest.raises(AssertionError, match=r"Unknown dtype"):
+        df.select(nw.col("a").cast(Banana))
diff --git a/tests/selectors_test.py b/tests/selectors_test.py
index 128cb73c03..ababee4a7a 100644
--- a/tests/selectors_test.py
+++ b/tests/selectors_test.py
@@ -7,7 +7,6 @@
 import pytest
 
 import narwhals.stable.v1 as nw
-from narwhals.dependencies import get_dask_dataframe
 from narwhals.selectors import all
 from narwhals.selectors import boolean
 from narwhals.selectors import by_dtype
@@ -57,8 +56,6 @@ def test_string(constructor: Any, request: Any) -> None:
 
 
 def test_categorical(request: Any, constructor: Any) -> None:
-    if "dask" in str(constructor):
-        request.applymarker(pytest.mark.xfail)
     if "pyarrow_table_constructor" in str(constructor) and parse_version(
         pa.__version__
     ) <= (15,):  # pragma: no cover
@@ -70,17 +67,6 @@ def test_categorical(request: Any, constructor: Any) -> None:
     compare_dicts(result, expected)
 
 
-@pytest.mark.skipif((get_dask_dataframe() is None), reason="too old for dask")
-def test_dask_categorical() -> None:
-    import dask.dataframe as dd
-
-    expected = {"b": ["a", "b", "c"]}
-    df_raw = dd.from_dict(expected, npartitions=1).astype({"b": "category"})
-    df = nw.from_native(df_raw)
-    result = df.select(categorical())
-    compare_dicts(result, expected)
-
-
 @pytest.mark.parametrize(
     ("selector", "expected"),
     [
diff --git a/tests/test_group_by.py b/tests/test_group_by.py
index 40b598de2a..2bb8d435b4 100644
--- a/tests/test_group_by.py
+++ b/tests/test_group_by.py
@@ -173,9 +173,6 @@ def test_group_by_multiple_keys(constructor: Any) -> None:
 
 
 def test_key_with_nulls(constructor: Any, request: Any) -> None:
-    if "dask" in str(constructor):
-        request.applymarker(pytest.mark.xfail)
-
     if "modin" in str(constructor):
         # TODO(unassigned): Modin flaky here?
         request.applymarker(pytest.mark.skip)