From e6d412cba7c23df7ee500c28257ed9281cea49b9 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Tue, 23 Jul 2024 06:03:28 -0500 Subject: [PATCH] Fall back when casting a timestamp to numeric in cudf-polars (#16232) This PR adds logic that falls back to CPU when a cudf-polars query would cast a timestamp column to a numeric type, an unsupported operation in libcudf, which should fix a few polars tests. It could be cleaned up a bit with some of the utilities that will be added in https://github.com/rapidsai/cudf/pull/16150. Authors: - https://github.com/brandon-b-miller Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/16232 --- python/cudf_polars/cudf_polars/dsl/expr.py | 4 ++ .../tests/expressions/test_casting.py | 52 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 python/cudf_polars/tests/expressions/test_casting.py diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 8322d6bd6fb..9835e6f8461 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1188,6 +1188,10 @@ class Cast(Expr): def __init__(self, dtype: plc.DataType, value: Expr) -> None: super().__init__(dtype) self.children = (value,) + if not plc.unary.is_supported_cast(self.dtype, value.dtype): + raise NotImplementedError( + f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}" + ) def do_evaluate( self, diff --git a/python/cudf_polars/tests/expressions/test_casting.py b/python/cudf_polars/tests/expressions/test_casting.py new file mode 100644 index 00000000000..3e003054338 --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_casting.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) + +_supported_dtypes = [(pl.Int8(), pl.Int64())] + +_unsupported_dtypes = [ + (pl.String(), pl.Int64()), +] + + +@pytest.fixture +def dtypes(request): + return request.param + + +@pytest.fixture +def tests(dtypes): + fromtype, totype = dtypes + if fromtype == pl.String(): + data = ["a", "b", "c"] + else: + data = [1, 2, 3] + return pl.DataFrame( + { + "a": pl.Series(data, dtype=fromtype), + } + ).lazy(), totype + + +@pytest.mark.parametrize("dtypes", _supported_dtypes, indirect=True) +def test_cast_supported(tests): + df, totype = tests + q = df.select(pl.col("a").cast(totype)) + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize("dtypes", _unsupported_dtypes, indirect=True) +def test_cast_unsupported(tests): + df, totype = tests + assert_ir_translation_raises( + df.select(pl.col("a").cast(totype)), NotImplementedError + )