Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dask expr cast #821

Merged
merged 13 commits into from
Aug 25, 2024
16 changes: 16 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate
from narwhals._dask.utils import reverse_translate_dtype
from narwhals.dependencies import get_dask
from narwhals.utils import generate_unique_token

Expand Down Expand Up @@ -610,6 +611,21 @@ def dt(self: Self) -> DaskExprDateTimeNamespace:
def name(self: Self) -> DaskExprNameNamespace:
return DaskExprNameNamespace(self)

def cast(
self,
dtype: Any,
) -> Self:
def func(_input: Any, dtype: Any) -> Any:
aidoskanapyanov marked this conversation as resolved.
Show resolved Hide resolved
dtype = reverse_translate_dtype(dtype)
return _input.astype(dtype)

return self._from_call(
func,
"cast",
dtype,
returns_scalar=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
49 changes: 49 additions & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
from typing import Any

from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version

if TYPE_CHECKING:
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals.dtypes import DType


def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
Expand Down Expand Up @@ -73,3 +78,47 @@ def parse_exprs_and_named_exprs(
def add_row_index(frame: Any, name: str) -> Any:
frame = frame.assign(**{name: 1})
return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1})


def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
from narwhals import dtypes

if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
return "float32"
if isinstance_or_issubclass(dtype, dtypes.Int64):
return "int64"
if isinstance_or_issubclass(dtype, dtypes.Int32):
return "int32"
if isinstance_or_issubclass(dtype, dtypes.Int16):
return "int16"
if isinstance_or_issubclass(dtype, dtypes.Int8):
return "int8"
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return "uint64"
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return "uint32"
if isinstance_or_issubclass(dtype, dtypes.UInt16):
return "uint16"
if isinstance_or_issubclass(dtype, dtypes.UInt8):
return "uint8"
if isinstance_or_issubclass(dtype, dtypes.String):
pd = get_pandas()

if pd is not None and parse_version(pd.__version__) >= parse_version("2.0.0"):
aidoskanapyanov marked this conversation as resolved.
Show resolved Hide resolved
if get_pyarrow() is not None:
return "string[pyarrow]"
return "string[python]"
return "object"
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return "bool"
if isinstance_or_issubclass(dtype, dtypes.Categorical):
return "category"
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return "datetime64[us]"
if isinstance_or_issubclass(dtype, dtypes.Duration):
return "timedelta64[ns]"

msg = f"Unknown dtype: {dtype}"
aidoskanapyanov marked this conversation as resolved.
Show resolved Hide resolved
raise AssertionError(msg)
6 changes: 1 addition & 5 deletions tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts


def test_expr_binary(constructor: Any, request: Any) -> None:
def test_expr_binary(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df_raw = constructor(data)
result = nw.from_native(df_raw).with_columns(
a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")),
Expand Down
38 changes: 31 additions & 7 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@

@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
def test_cast(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
Expand Down Expand Up @@ -98,17 +96,21 @@ def test_cast(constructor: Any, request: Any) -> None:
assert dict(result.collect_schema()) == expected


def test_cast_series(constructor_eager: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor_eager) and parse_version(
def test_cast_series(constructor: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
request.applymarker(pytest.mark.xfail)
if "modin" in str(constructor_eager):
if "modin" in str(constructor):
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor_eager(data), eager_only=True).select(
nw.col(key).cast(value) for key, value in schema.items()
df = (
nw.from_native(constructor(data))
.select(nw.col(key).cast(value) for key, value in schema.items())
.lazy()
.collect()
)

expected = {
"a": nw.Int32,
"b": nw.Int16,
Expand Down Expand Up @@ -158,3 +160,25 @@ def test_cast_string() -> None:
s = s.cast(nw.String)
result = nw.to_native(s)
assert str(result.dtype) in ("string", "object", "dtype('O')")


def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
request.applymarker(pytest.mark.xfail)
if "modin" in str(constructor):
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data)).select(
nw.col(key).cast(value) for key, value in schema.items()
)

class Banana:
pass

with pytest.raises(AssertionError, match=r"Unknown dtype"):
df.select(nw.col("a").cast(Banana))
9 changes: 2 additions & 7 deletions tests/selectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def test_string(constructor: Any, request: Any) -> None:


def test_categorical(request: Any, constructor: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
Expand All @@ -71,12 +69,9 @@ def test_categorical(request: Any, constructor: Any) -> None:


@pytest.mark.skipif((get_dask_dataframe() is None), reason="too old for dask")
def test_dask_categorical() -> None:
import dask.dataframe as dd

def test_dask_categorical(constructor: Any) -> None:
aidoskanapyanov marked this conversation as resolved.
Show resolved Hide resolved
expected = {"b": ["a", "b", "c"]}
df_raw = dd.from_dict(expected, npartitions=1).astype({"b": "category"})
df = nw.from_native(df_raw)
df = nw.from_native(constructor(data)).with_columns(nw.col("b").cast(nw.Categorical))
result = df.select(categorical())
compare_dicts(result, expected)

Expand Down
3 changes: 0 additions & 3 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def test_group_by_multiple_keys(constructor: Any) -> None:


def test_key_with_nulls(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "modin" in str(constructor):
# TODO(unassigned): Modin flaky here?
request.applymarker(pytest.mark.skip)
Expand Down
Loading