Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Oct 9, 2024
1 parent ded4dd2 commit 054b271
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 177 deletions.
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from cudf_polars._version import __git_commit__, __version__
from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir
from cudf_polars.dsl.translate import Translator

# Check we have a supported polars version
from cudf_polars.utils.versions import _ensure_polars_version
Expand All @@ -22,7 +22,7 @@

__all__: list[str] = [
"execute_with_cudf",
"translate_ir",
"Translator",
"__git_commit__",
"__version__",
]
12 changes: 9 additions & 3 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import rmm
from rmm._cuda import gpu

from cudf_polars.dsl.translate import translate_ir
from cudf_polars.dsl.translate import Translator

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -174,16 +174,22 @@ def execute_with_cudf(
device = config.device
memory_resource = config.memory_resource
raise_on_fail = config.config.get("raise_on_fail", False)
if unsupported := (config.config.keys() - {"raise_on_fail"}):
debug_mode = config.config.get("debug_mode", False)
if unsupported := (config.config.keys() - {"raise_on_fail", "debug_mode"}):
raise ValueError(
f"Engine configuration contains unsupported settings {unsupported}"
)
try:
with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
translator = Translator(nt, debug_mode=debug_mode)
ir = translator.translate_ir()
if debug_mode and len(translator.errors):
print(set(translator.errors))
raise NotImplementedError("Query contained unsupported operations")
nt.set_udf(
partial(
_callback,
translate_ir(nt),
ir,
device=device,
memory_resource=memory_resource,
)
Expand Down
7 changes: 7 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

__all__ = [
"Expr",
"ErrorExpr",
"NamedExpr",
"Literal",
"Col",
Expand Down Expand Up @@ -275,6 +276,12 @@ def collect_agg(self, *, depth: int) -> AggInfo:
) # pragma: no cover; check_agg trips first


class ErrorExpr(Expr):
def __init__(self, dtype: plc.DataType, error: str) -> None:
super().__init__(dtype)
self.error = error


class NamedExpr:
# NamedExpr does not inherit from Expr since it does not appear
# when evaluating expressions themselves, only when constructing
Expand Down
9 changes: 9 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

__all__ = [
"IR",
"ErrorNode",
"PythonScan",
"Scan",
"Cache",
Expand Down Expand Up @@ -159,6 +160,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
) # pragma: no cover


@dataclasses.dataclass
class ErrorNode(IR):
"""Represents an error translating the IR."""

error: str
"""The error."""


@dataclasses.dataclass
class PythonScan(IR):
"""Representation of input from a python function."""
Expand Down
Loading

0 comments on commit 054b271

Please sign in to comment.