Skip to content

Commit

Permalink
feat: dask expr cast (#821)
Browse files Browse the repository at this point in the history
* feat: dask expr cast

* replace temporary hack with `.cast`

* test cast raises for unknown dtype

* remove unused filterwarnings

* use walrus for simplicity

Co-authored-by: Francesco Bruzzesi <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove duplicated test

* add missing function parameter types

* skip coverage for some paths

* skip coevrage for unknown dtype exception

Co-authored-by: Francesco Bruzzesi <[email protected]>

* remove redundant modin xfail

* empty commit to trigger CI

---------

Co-authored-by: Francesco Bruzzesi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: FBruzzesi <[email protected]>
  • Loading branch information
4 people authored Aug 25, 2024
1 parent 1b7bd73 commit 706b7c8
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 29 deletions.
17 changes: 17 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate
from narwhals._dask.utils import reverse_translate_dtype
from narwhals.dependencies import get_dask
from narwhals.utils import generate_unique_token

Expand All @@ -17,6 +18,7 @@

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals.dtypes import DType


class DaskExpr:
Expand Down Expand Up @@ -654,6 +656,21 @@ def dt(self: Self) -> DaskExprDateTimeNamespace:
def name(self: Self) -> DaskExprNameNamespace:
return DaskExprNameNamespace(self)

def cast(
self: Self,
dtype: DType | type[DType],
) -> Self:
def func(_input: Any, dtype: DType | type[DType]) -> Any:
dtype = reverse_translate_dtype(dtype)
return _input.astype(dtype)

return self._from_call(
func,
"cast",
dtype,
returns_scalar=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
49 changes: 49 additions & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
from typing import Any

from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version

if TYPE_CHECKING:
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals.dtypes import DType


def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
Expand Down Expand Up @@ -73,3 +78,47 @@ def parse_exprs_and_named_exprs(
def add_row_index(frame: Any, name: str) -> Any:
frame = frame.assign(**{name: 1})
return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1})


def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
from narwhals import dtypes

if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
return "float32"
if isinstance_or_issubclass(dtype, dtypes.Int64):
return "int64"
if isinstance_or_issubclass(dtype, dtypes.Int32):
return "int32"
if isinstance_or_issubclass(dtype, dtypes.Int16):
return "int16"
if isinstance_or_issubclass(dtype, dtypes.Int8):
return "int8"
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return "uint64"
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return "uint32"
if isinstance_or_issubclass(dtype, dtypes.UInt16):
return "uint16"
if isinstance_or_issubclass(dtype, dtypes.UInt8):
return "uint8"
if isinstance_or_issubclass(dtype, dtypes.String):
if (pd := get_pandas()) is not None and parse_version(
pd.__version__
) >= parse_version("2.0.0"):
if get_pyarrow() is not None:
return "string[pyarrow]"
return "string[python]" # pragma: no cover
return "object" # pragma: no cover
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return "bool"
if isinstance_or_issubclass(dtype, dtypes.Categorical):
return "category"
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return "datetime64[us]"
if isinstance_or_issubclass(dtype, dtypes.Duration):
return "timedelta64[ns]"

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
6 changes: 1 addition & 5 deletions tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts


def test_expr_binary(constructor: Any, request: Any) -> None:
def test_expr_binary(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df_raw = constructor(data)
result = nw.from_native(df_raw).with_columns(
a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")),
Expand Down
35 changes: 28 additions & 7 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@

@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
def test_cast(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
Expand Down Expand Up @@ -98,17 +96,21 @@ def test_cast(constructor: Any, request: Any) -> None:
assert dict(result.collect_schema()) == expected


def test_cast_series(constructor_eager: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor_eager) and parse_version(
def test_cast_series(constructor: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
request.applymarker(pytest.mark.xfail)
if "modin" in str(constructor_eager):
if "modin" in str(constructor):
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor_eager(data), eager_only=True).select(
nw.col(key).cast(value) for key, value in schema.items()
df = (
nw.from_native(constructor(data))
.select(nw.col(key).cast(value) for key, value in schema.items())
.lazy()
.collect()
)

expected = {
"a": nw.Int32,
"b": nw.Int16,
Expand Down Expand Up @@ -158,3 +160,22 @@ def test_cast_string() -> None:
s = s.cast(nw.String)
result = nw.to_native(s)
assert str(result.dtype) in ("string", "object", "dtype('O')")


def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data)).select(
nw.col(key).cast(value) for key, value in schema.items()
)

class Banana:
pass

with pytest.raises(AssertionError, match=r"Unknown dtype"):
df.select(nw.col("a").cast(Banana))
14 changes: 0 additions & 14 deletions tests/selectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest

import narwhals.stable.v1 as nw
from narwhals.dependencies import get_dask_dataframe
from narwhals.selectors import all
from narwhals.selectors import boolean
from narwhals.selectors import by_dtype
Expand Down Expand Up @@ -57,8 +56,6 @@ def test_string(constructor: Any, request: Any) -> None:


def test_categorical(request: Any, constructor: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
Expand All @@ -70,17 +67,6 @@ def test_categorical(request: Any, constructor: Any) -> None:
compare_dicts(result, expected)


@pytest.mark.skipif((get_dask_dataframe() is None), reason="too old for dask")
def test_dask_categorical() -> None:
import dask.dataframe as dd

expected = {"b": ["a", "b", "c"]}
df_raw = dd.from_dict(expected, npartitions=1).astype({"b": "category"})
df = nw.from_native(df_raw)
result = df.select(categorical())
compare_dicts(result, expected)


@pytest.mark.parametrize(
("selector", "expected"),
[
Expand Down
3 changes: 0 additions & 3 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def test_group_by_multiple_keys(constructor: Any) -> None:


def test_key_with_nulls(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "modin" in str(constructor):
# TODO(unassigned): Modin flaky here?
request.applymarker(pytest.mark.skip)
Expand Down

0 comments on commit 706b7c8

Please sign in to comment.