Skip to content

Commit

Permalink
Fall back when casting a timestamp to numeric in cudf-polars (#16232)
Browse files Browse the repository at this point in the history
This PR adds logic that falls back to CPU when a cudf-polars query would cast a timestamp column to a numeric type, an unsupported operation in libcudf, which should fix a few polars tests. It could be cleaned up a bit with some of the utilities that will be added in #16150.

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #16232
  • Loading branch information
brandon-b-miller authored Jul 23, 2024
1 parent c7b28ce commit e6d412c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,10 @@ class Cast(Expr):
def __init__(self, dtype: plc.DataType, value: Expr) -> None:
super().__init__(dtype)
self.children = (value,)
if not plc.unary.is_supported_cast(self.dtype, value.dtype):
raise NotImplementedError(
f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}"
)

def do_evaluate(
self,
Expand Down
52 changes: 52 additions & 0 deletions python/cudf_polars/tests/expressions/test_casting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import (
assert_gpu_result_equal,
assert_ir_translation_raises,
)

_supported_dtypes = [(pl.Int8(), pl.Int64())]

_unsupported_dtypes = [
(pl.String(), pl.Int64()),
]


@pytest.fixture
def dtypes(request):
return request.param


@pytest.fixture
def tests(dtypes):
fromtype, totype = dtypes
if fromtype == pl.String():
data = ["a", "b", "c"]
else:
data = [1, 2, 3]
return pl.DataFrame(
{
"a": pl.Series(data, dtype=fromtype),
}
).lazy(), totype


@pytest.mark.parametrize("dtypes", _supported_dtypes, indirect=True)
def test_cast_supported(tests):
df, totype = tests
q = df.select(pl.col("a").cast(totype))
assert_gpu_result_equal(q)


@pytest.mark.parametrize("dtypes", _unsupported_dtypes, indirect=True)
def test_cast_unsupported(tests):
df, totype = tests
assert_ir_translation_raises(
df.select(pl.col("a").cast(totype)), NotImplementedError
)

0 comments on commit e6d412c

Please sign in to comment.