diff --git a/python/cudf_polars/cudf_polars/__init__.py b/python/cudf_polars/cudf_polars/__init__.py index 66c15f694ee..ba4858c5619 100644 --- a/python/cudf_polars/cudf_polars/__init__.py +++ b/python/cudf_polars/cudf_polars/__init__.py @@ -12,7 +12,7 @@ from cudf_polars._version import __git_commit__, __version__ from cudf_polars.callback import execute_with_cudf -from cudf_polars.dsl.translate import translate_ir +from cudf_polars.dsl.translate import Translator # Check we have a supported polars version from cudf_polars.utils.versions import _ensure_polars_version @@ -22,7 +22,7 @@ __all__: list[str] = [ "execute_with_cudf", - "translate_ir", + "Translator", "__git_commit__", "__version__", ] diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 76816ee0a61..ff4933c7564 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -18,7 +18,7 @@ import rmm from rmm._cuda import gpu -from cudf_polars.dsl.translate import translate_ir +from cudf_polars.dsl.translate import Translator if TYPE_CHECKING: from collections.abc import Generator @@ -180,14 +180,30 @@ def execute_with_cudf( ) try: with nvtx.annotate(message="ConvertIR", domain="cudf_polars"): - nt.set_udf( - partial( - _callback, - translate_ir(nt), - device=device, - memory_resource=memory_resource, + translator = Translator(nt) + ir = translator.translate_ir() + ir_translation_errors = translator.errors + if len(ir_translation_errors): + # TODO: Display these errors in user-friendly way. + # tracked in https://github.com/rapidsai/cudf/issues/17051 + unique_errors = sorted(set(ir_translation_errors), key=str) + error_message = "Query contained unsupported operations" + verbose_error_message = ( + f"{error_message}\nThe errors were:\n{unique_errors}" + ) + unsupported_ops_exception = NotImplementedError( + error_message, unique_errors + ) + if bool(int(os.environ.get("POLARS_VERBOSE", 0))): + warnings.warn(verbose_error_message, UserWarning, stacklevel=2) + if raise_on_fail: + raise unsupported_ops_exception + else: + nt.set_udf( + partial( + _callback, ir, device=device, memory_resource=memory_resource + ) ) - ) except exception as e: if bool(int(os.environ.get("POLARS_VERBOSE", 0))): warnings.warn( diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index e748ec16f14..9d7c13bfde7 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -19,6 +19,7 @@ from cudf_polars.dsl.expressions.base import ( AggInfo, Col, + ErrorExpr, Expr, NamedExpr, ) @@ -35,6 +36,7 @@ __all__ = [ "Expr", + "ErrorExpr", "NamedExpr", "Literal", "LiteralColumn", diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/base.py b/python/cudf_polars/cudf_polars/dsl/expressions/base.py index effe8cb2378..eafffd52f25 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/base.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/base.py @@ -155,6 +155,17 @@ def collect_agg(self, *, depth: int) -> AggInfo: ) # pragma: no cover; check_agg trips first +class ErrorExpr(Expr): + __slots__ = ("error",) + _non_child = ("dtype", "error") + error: str + + def __init__(self, dtype: plc.DataType, error: str) -> None: + self.dtype = dtype + self.error = error + self.children = () + + class NamedExpr: # NamedExpr does not inherit from Expr since it does not appear # when evaluating expressions themselves, only when constructing diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index a242ff9300f..ee09488dea6 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -41,6 +41,7 @@ __all__ = [ "IR", + "ErrorNode", "PythonScan", "Scan", "Cache", @@ -210,6 +211,22 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: ) +class ErrorNode(IR): + """Represents an error translating the IR.""" + + __slots__ = ("error",) + _non_child = ( + "schema", + "error", + ) + error: str + """The error.""" + + def __init__(self, schema: Schema, error: str): + self.schema = schema + self.error = error + + class PythonScan(IR): """Representation of input from a python function.""" diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 5181214819e..93f9d560e57 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -28,7 +28,123 @@ if TYPE_CHECKING: from cudf_polars.typing import ExprTransformer -__all__ = ["translate_ir", "translate_named_expr"] +if TYPE_CHECKING: + from cudf_polars.typing import NodeTraverser + +__all__ = ["Translator", "translate_named_expr"] + + +class Translator: + """ + Translates polars-internal IR nodes and expressions to our representation. + + Parameters + ---------- + visitor + Polars NodeTraverser object + """ + + def __init__(self, visitor: NodeTraverser): + self.visitor = visitor + self.errors: list[Exception] = [] + + def translate_ir(self, *, n: int | None = None) -> ir.IR: + """ + Translate a polars-internal IR node to our representation. + + Parameters + ---------- + visitor + Polars NodeTraverser object + n + Optional node to start traversing from, if not provided uses + current polars-internal node. + + Returns + ------- + Translated IR object + + Raises + ------ + NotImplementedError + If the version of Polars IR is unsupported. + + Notes + ----- + Any expression nodes that cannot be translated are replaced by + :class:`expr.ErrorNode` nodes and collected in the the `errors` attribute. + After translation is complete, this list of errors should be inspected + to determine if the query is supported. + """ + ctx: AbstractContextManager[None] = ( + set_node(self.visitor, n) if n is not None else noop_context + ) + # IR is versioned with major.minor, minor is bumped for backwards + # compatible changes (e.g. adding new nodes), major is bumped for + # incompatible changes (e.g. renaming nodes). + if (version := self.visitor.version()) >= (4, 0): + e = NotImplementedError( + f"No support for polars IR {version=}" + ) # pragma: no cover; no such version for now. + self.errors.append(e) # pragma: no cover + raise e # pragma: no cover + + with ctx: + polars_schema = self.visitor.get_schema() + try: + node = self.visitor.view_current_node() + except Exception as e: + self.errors.append(e) + return ir.ErrorNode({}, str(e)) + try: + schema = {k: dtypes.from_polars(v) for k, v in polars_schema.items()} + except Exception as e: + self.errors.append(e) + return ir.ErrorNode({}, str(e)) + try: + result = _translate_ir(node, self, schema) + except Exception as e: + self.errors.append(e) + return ir.ErrorNode(schema, str(e)) + if any( + isinstance(dtype, pl.Null) + for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values()) + ): + error = NotImplementedError( + f"No GPU support for {result} with Null column dtype." + ) + self.errors.append(error) + return ir.ErrorNode(schema, str(error)) + + return result + + def translate_expr(self, *, n: int) -> expr.Expr: + """ + Translate a polars-internal expression IR into our representation. + + Parameters + ---------- + n + Node to translate, an integer referencing a polars internal node. + + Returns + ------- + Translated IR object. + + Notes + ----- + Any expression nodes that cannot be translated are replaced by + :class:`expr.ErrorExpr` nodes and collected in the the `errors` attribute. + After translation is complete, this list of errors should be inspected + to determine if the query is supported. + """ + node = self.visitor.view_expression(n) + dtype = dtypes.from_polars(self.visitor.get_dtype(n)) + try: + return _translate_expr(node, self, dtype) + except Exception as e: + self.errors.append(e) + return expr.ErrorExpr(dtype, str(e)) class set_node(AbstractContextManager[None]): @@ -70,7 +186,7 @@ def __exit__(self, *args: Any) -> None: @singledispatch def _translate_ir( - node: Any, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: Any, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: raise NotImplementedError( f"Translation for {type(node).__name__}" @@ -79,19 +195,19 @@ def _translate_ir( @_translate_ir.register def _( - node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.PythonScan, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: scan_fn, with_columns, source_type, predicate, nrows = node.options options = (scan_fn, with_columns, source_type, nrows) predicate = ( - translate_named_expr(visitor, n=predicate) if predicate is not None else None + translate_named_expr(translator, n=predicate) if predicate is not None else None ) return ir.PythonScan(schema, options, predicate) @_translate_ir.register def _( - node: pl_ir.Scan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Scan, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: typ, *options = node.scan_type if typ == "ndjson": @@ -120,7 +236,7 @@ def _( skip_rows, n_rows, row_index, - translate_named_expr(visitor, n=node.predicate) + translate_named_expr(translator, n=node.predicate) if node.predicate is not None else None, ) @@ -128,20 +244,20 @@ def _( @_translate_ir.register def _( - node: pl_ir.Cache, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Cache, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.Cache(schema, node.id_, translate_ir(visitor, n=node.input)) + return ir.Cache(schema, node.id_, translator.translate_ir(n=node.input)) @_translate_ir.register def _( - node: pl_ir.DataFrameScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.DataFrameScan, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: return ir.DataFrameScan( schema, node.df, node.projection, - translate_named_expr(visitor, n=node.selection) + translate_named_expr(translator, n=node.selection) if node.selection is not None else None, ) @@ -149,22 +265,22 @@ def _( @_translate_ir.register def _( - node: pl_ir.Select, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Select, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - exprs = [translate_named_expr(visitor, n=e) for e in node.expr] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + exprs = [translate_named_expr(translator, n=e) for e in node.expr] return ir.Select(schema, exprs, node.should_broadcast, inp) @_translate_ir.register def _( - node: pl_ir.GroupBy, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.GroupBy, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - aggs = [translate_named_expr(visitor, n=e) for e in node.aggs] - keys = [translate_named_expr(visitor, n=e) for e in node.keys] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + aggs = [translate_named_expr(translator, n=e) for e in node.aggs] + keys = [translate_named_expr(translator, n=e) for e in node.keys] return ir.GroupBy( schema, keys, @@ -177,17 +293,17 @@ def _( @_translate_ir.register def _( - node: pl_ir.Join, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Join, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: # Join key dtypes are dependent on the schema of the left and # right inputs, so these must be translated with the relevant # input active. - with set_node(visitor, node.input_left): - inp_left = translate_ir(visitor, n=None) - left_on = [translate_named_expr(visitor, n=e) for e in node.left_on] - with set_node(visitor, node.input_right): - inp_right = translate_ir(visitor, n=None) - right_on = [translate_named_expr(visitor, n=e) for e in node.right_on] + with set_node(translator.visitor, node.input_left): + inp_left = translator.translate_ir(n=None) + left_on = [translate_named_expr(translator, n=e) for e in node.left_on] + with set_node(translator.visitor, node.input_right): + inp_right = translator.translate_ir(n=None) + right_on = [translate_named_expr(translator, n=e) for e in node.right_on] if (how := node.options[0]) in { "inner", "left", @@ -257,27 +373,27 @@ def _rename(e: expr.Expr, rec: ExprTransformer) -> expr.Expr: @_translate_ir.register def _( - node: pl_ir.HStack, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.HStack, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - exprs = [translate_named_expr(visitor, n=e) for e in node.exprs] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + exprs = [translate_named_expr(translator, n=e) for e in node.exprs] return ir.HStack(schema, exprs, node.should_broadcast, inp) @_translate_ir.register def _( - node: pl_ir.Reduce, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Reduce, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - exprs = [translate_named_expr(visitor, n=e) for e in node.expr] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + exprs = [translate_named_expr(translator, n=e) for e in node.expr] return ir.Reduce(schema, exprs, inp) @_translate_ir.register def _( - node: pl_ir.Distinct, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Distinct, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: (keep, subset, maintain_order, zlice) = node.options keep = ir.Distinct._KEEP_MAP[keep] @@ -288,17 +404,17 @@ def _( subset, zlice, maintain_order, - translate_ir(visitor, n=node.input), + translator.translate_ir(n=node.input), ) @_translate_ir.register def _( - node: pl_ir.Sort, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Sort, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - by = [translate_named_expr(visitor, n=e) for e in node.by_column] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + by = [translate_named_expr(translator, n=e) for e in node.by_column] stable, nulls_last, descending = node.sort_options order, null_order = sorting.sort_order( descending, nulls_last=nulls_last, num_keys=len(by) @@ -308,33 +424,35 @@ def _( @_translate_ir.register def _( - node: pl_ir.Slice, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Slice, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.Slice(schema, node.offset, node.len, translate_ir(visitor, n=node.input)) + return ir.Slice( + schema, node.offset, node.len, translator.translate_ir(n=node.input) + ) @_translate_ir.register def _( - node: pl_ir.Filter, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Filter, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - mask = translate_named_expr(visitor, n=node.predicate) + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + mask = translate_named_expr(translator, n=node.predicate) return ir.Filter(schema, mask, inp) @_translate_ir.register def _( node: pl_ir.SimpleProjection, - visitor: NodeTraverser, + translator: Translator, schema: dict[str, plc.DataType], ) -> ir.IR: - return ir.Projection(schema, translate_ir(visitor, n=node.input)) + return ir.Projection(schema, translator.translate_ir(n=node.input)) @_translate_ir.register def _( - node: pl_ir.MapFunction, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.MapFunction, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: name, *options = node.function return ir.MapFunction( @@ -342,83 +460,36 @@ def _( name, options, # TODO: merge_sorted breaks this pattern - translate_ir(visitor, n=node.input), + translator.translate_ir(n=node.input), ) @_translate_ir.register def _( - node: pl_ir.Union, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Union, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: return ir.Union( - schema, node.options, *(translate_ir(visitor, n=n) for n in node.inputs) + schema, node.options, *(translator.translate_ir(n=n) for n in node.inputs) ) @_translate_ir.register def _( - node: pl_ir.HConcat, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.HConcat, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.HConcat(schema, *(translate_ir(visitor, n=n) for n in node.inputs)) - - -def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: - """ - Translate a polars-internal IR node to our representation. - - Parameters - ---------- - visitor - Polars NodeTraverser object - n - Optional node to start traversing from, if not provided uses - current polars-internal node. - - Returns - ------- - Translated IR object - - Raises - ------ - NotImplementedError - If we can't translate the nodes due to unsupported functionality. - """ - ctx: AbstractContextManager[None] = ( - set_node(visitor, n) if n is not None else noop_context - ) - # IR is versioned with major.minor, minor is bumped for backwards - # compatible changes (e.g. adding new nodes), major is bumped for - # incompatible changes (e.g. renaming nodes). - if (version := visitor.version()) >= (4, 0): - raise NotImplementedError( - f"No support for polars IR {version=}" - ) # pragma: no cover; no such version for now. - - with ctx: - polars_schema = visitor.get_schema() - node = visitor.view_current_node() - schema = {k: dtypes.from_polars(v) for k, v in polars_schema.items()} - result = _translate_ir(node, visitor, schema) - if any( - isinstance(dtype, pl.Null) - for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values()) - ): - raise NotImplementedError( - f"No GPU support for {result} with Null column dtype." - ) - return result + return ir.HConcat(schema, *(translator.translate_ir(n=n) for n in node.inputs)) def translate_named_expr( - visitor: NodeTraverser, *, n: pl_expr.PyExprIR + translator: Translator, *, n: pl_expr.PyExprIR ) -> expr.NamedExpr: """ Translate a polars-internal named expression IR object into our representation. Parameters ---------- - visitor - Polars NodeTraverser object + translator + Translator object n Node to translate, a named expression node. @@ -438,12 +509,12 @@ def translate_named_expr( NotImplementedError If any translation fails due to unsupported functionality. """ - return expr.NamedExpr(n.output_name, translate_expr(visitor, n=n.node)) + return expr.NamedExpr(n.output_name, translator.translate_expr(n=n.node)) @singledispatch def _translate_expr( - node: Any, visitor: NodeTraverser, dtype: plc.DataType + node: Any, translator: Translator, dtype: plc.DataType ) -> expr.Expr: raise NotImplementedError( f"Translation for {type(node).__name__}" @@ -451,7 +522,7 @@ def _translate_expr( @_translate_expr.register -def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> expr.Expr: name, *options = node.function_data options = tuple(options) if isinstance(name, pl_expr.StringFunction): @@ -460,7 +531,7 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex pl_expr.StringFunction.StripCharsStart, pl_expr.StringFunction.StripCharsEnd, }: - column, chars = (translate_expr(visitor, n=n) for n in node.input) + column, chars = (translator.translate_expr(n=n) for n in node.input) if isinstance(chars, expr.Literal): if chars.value == pa.scalar(""): # No-op in polars, but libcudf uses empty string @@ -477,11 +548,11 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex dtype, name, options, - *(translate_expr(visitor, n=n) for n in node.input), + *(translator.translate_expr(n=n) for n in node.input), ) elif isinstance(name, pl_expr.BooleanFunction): if name == pl_expr.BooleanFunction.IsBetween: - column, lo, hi = (translate_expr(visitor, n=n) for n in node.input) + column, lo, hi = (translator.translate_expr(n=n) for n in node.input) (closed,) = options lop, rop = expr.BooleanFunction._BETWEEN_OPS[closed] return expr.BinOp( @@ -494,7 +565,7 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex dtype, name, options, - *(translate_expr(visitor, n=n) for n in node.input), + *(translator.translate_expr(n=n) for n in node.input), ) elif isinstance(name, pl_expr.TemporalFunction): # functions for which evaluation of the expression may not return @@ -514,14 +585,14 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex dtype, name, options, - *(translate_expr(visitor, n=n) for n in node.input), + *(translator.translate_expr(n=n) for n in node.input), ) if name in needs_cast: return expr.Cast(dtype, result_expr) return result_expr elif isinstance(name, str): - children = (translate_expr(visitor, n=n) for n in node.input) + children = (translator.translate_expr(n=n) for n in node.input) if name == "log": (base,) = options (child,) = children @@ -540,26 +611,26 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex @_translate_expr.register -def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Window, translator: Translator, dtype: plc.DataType) -> expr.Expr: # TODO: raise in groupby? if isinstance(node.options, pl_expr.RollingGroupOptions): # pl.col("a").rolling(...) return expr.RollingWindow( - dtype, node.options, translate_expr(visitor, n=node.function) + dtype, node.options, translator.translate_expr(n=node.function) ) elif isinstance(node.options, pl_expr.WindowMapping): # pl.col("a").over(...) return expr.GroupedRollingWindow( dtype, node.options, - translate_expr(visitor, n=node.function), - *(translate_expr(visitor, n=n) for n in node.partition_by), + translator.translate_expr(n=node.function), + *(translator.translate_expr(n=n) for n in node.partition_by), ) assert_never(node.options) @_translate_expr.register -def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Literal, translator: Translator, dtype: plc.DataType) -> expr.Expr: if isinstance(node.value, plrs.PySeries): return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value)) value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype)) @@ -567,42 +638,42 @@ def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> exp @_translate_expr.register -def _(node: pl_expr.Sort, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Sort, translator: Translator, dtype: plc.DataType) -> expr.Expr: # TODO: raise in groupby - return expr.Sort(dtype, node.options, translate_expr(visitor, n=node.expr)) + return expr.Sort(dtype, node.options, translator.translate_expr(n=node.expr)) @_translate_expr.register -def _(node: pl_expr.SortBy, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.SortBy( dtype, node.sort_options, - translate_expr(visitor, n=node.expr), - *(translate_expr(visitor, n=n) for n in node.by), + translator.translate_expr(n=node.expr), + *(translator.translate_expr(n=n) for n in node.by), ) @_translate_expr.register -def _(node: pl_expr.Gather, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Gather, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Gather( dtype, - translate_expr(visitor, n=node.expr), - translate_expr(visitor, n=node.idx), + translator.translate_expr(n=node.expr), + translator.translate_expr(n=node.idx), ) @_translate_expr.register -def _(node: pl_expr.Filter, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Filter, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Filter( dtype, - translate_expr(visitor, n=node.input), - translate_expr(visitor, n=node.by), + translator.translate_expr(n=node.input), + translator.translate_expr(n=node.by), ) @_translate_expr.register -def _(node: pl_expr.Cast, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: - inner = translate_expr(visitor, n=node.expr) +def _(node: pl_expr.Cast, translator: Translator, dtype: plc.DataType) -> expr.Expr: + inner = translator.translate_expr(n=node.expr) # Push casts into literals so we can handle Cast(Literal(Null)) if isinstance(inner, expr.Literal): return expr.Literal(dtype, inner.value.cast(plc.interop.to_arrow(dtype))) @@ -614,17 +685,17 @@ def _(node: pl_expr.Cast, visitor: NodeTraverser, dtype: plc.DataType) -> expr.E @_translate_expr.register -def _(node: pl_expr.Column, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Column, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Col(dtype, node.name) @_translate_expr.register -def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Agg, translator: Translator, dtype: plc.DataType) -> expr.Expr: value = expr.Agg( dtype, node.name, node.options, - *(translate_expr(visitor, n=n) for n in node.arguments), + *(translator.translate_expr(n=n) for n in node.arguments), ) if value.name == "count" and value.dtype.id() != plc.TypeId.INT32: return expr.Cast(value.dtype, value) @@ -632,55 +703,30 @@ def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Ex @_translate_expr.register -def _(node: pl_expr.Ternary, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Ternary, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Ternary( dtype, - translate_expr(visitor, n=node.predicate), - translate_expr(visitor, n=node.truthy), - translate_expr(visitor, n=node.falsy), + translator.translate_expr(n=node.predicate), + translator.translate_expr(n=node.truthy), + translator.translate_expr(n=node.falsy), ) @_translate_expr.register def _( - node: pl_expr.BinaryExpr, visitor: NodeTraverser, dtype: plc.DataType + node: pl_expr.BinaryExpr, translator: Translator, dtype: plc.DataType ) -> expr.Expr: return expr.BinOp( dtype, expr.BinOp._MAPPING[node.op], - translate_expr(visitor, n=node.left), - translate_expr(visitor, n=node.right), + translator.translate_expr(n=node.left), + translator.translate_expr(n=node.right), ) @_translate_expr.register -def _(node: pl_expr.Len, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Len, translator: Translator, dtype: plc.DataType) -> expr.Expr: value = expr.Len(dtype) if dtype.id() != plc.TypeId.INT32: return expr.Cast(dtype, value) return value # pragma: no cover; never reached since polars len has uint32 dtype - - -def translate_expr(visitor: NodeTraverser, *, n: int) -> expr.Expr: - """ - Translate a polars-internal expression IR into our representation. - - Parameters - ---------- - visitor - Polars NodeTraverser object - n - Node to translate, an integer referencing a polars internal node. - - Returns - ------- - Translated IR object. - - Raises - ------ - NotImplementedError - If any translation fails due to unsupported functionality. - """ - node = visitor.view_expression(n) - dtype = dtypes.from_polars(visitor.get_dtype(n)) - return _translate_expr(node, visitor, dtype) diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index 7b45c1eaa06..fae98c648f0 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -10,7 +10,7 @@ from polars import GPUEngine from polars.testing.asserts import assert_frame_equal -from cudf_polars.dsl.translate import translate_ir +from cudf_polars.dsl.translate import Translator if TYPE_CHECKING: import polars as pl @@ -117,12 +117,12 @@ def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception]) AssertionError If the specified exceptions were not raised. """ - try: - _ = translate_ir(q._ldf.visit()) - except exceptions: + translator = Translator(q._ldf.visit()) + translator.translate_ir() + if errors := translator.errors: + for err in errors: + assert any(isinstance(err, err_type) for err_type in exceptions) return - except Exception as e: - raise AssertionError(f"Translation DID NOT RAISE {exceptions}") from e else: raise AssertionError(f"Translation DID NOT RAISE {exceptions}") diff --git a/python/cudf_polars/cudf_polars/testing/plugin.py b/python/cudf_polars/cudf_polars/testing/plugin.py index e01ccd05527..45accba0859 100644 --- a/python/cudf_polars/cudf_polars/testing/plugin.py +++ b/python/cudf_polars/cudf_polars/testing/plugin.py @@ -45,6 +45,7 @@ def pytest_configure(config: pytest.Config) -> None: EXPECTED_FAILURES: Mapping[str, str] = { + "tests/unit/dataframe/test_df.py::test_extension": "AssertionError", "tests/unit/io/test_csv.py::test_compressed_csv": "Need to determine if file is compressed", "tests/unit/io/test_csv.py::test_read_csv_only_loads_selected_columns": "Memory usage won't be correct due to GPU", "tests/unit/io/test_lazy_count_star.py::test_count_compressed_csv_18057": "Need to determine if file is compressed", diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 1d0479802ca..5932ba05596 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -61,10 +61,15 @@ def can_cast(from_: plc.DataType, to: plc.DataType) -> bool: ------- True if casting is supported, False otherwise """ + has_empty = from_.id() == plc.TypeId.EMPTY or to.id() == plc.TypeId.EMPTY return ( - plc.traits.is_fixed_width(to) - and plc.traits.is_fixed_width(from_) - and plc.unary.is_supported_cast(from_, to) + from_ == to + or not has_empty + and ( + plc.traits.is_fixed_width(to) + and plc.traits.is_fixed_width(from_) + and plc.unary.is_supported_cast(from_, to) + ) ) diff --git a/python/cudf_polars/docs/overview.md b/python/cudf_polars/docs/overview.md index 17a94c633f8..2f2361223d2 100644 --- a/python/cudf_polars/docs/overview.md +++ b/python/cudf_polars/docs/overview.md @@ -458,12 +458,12 @@ translate it to our intermediate representation (IR), and then execute and convert back to polars: ```python -from cudf_polars.dsl.translate import translate_ir +from cudf_polars.dsl.translate import Translator q = ... # Convert to our IR -ir = translate_ir(q._ldf.visit()) +ir = Translator(q._ldf.visit()).translate_ir() # DataFrame living on the device result = ir.evaluate(cache={}) diff --git a/python/cudf_polars/tests/dsl/test_to_ast.py b/python/cudf_polars/tests/dsl/test_to_ast.py index 57d794d4890..2ee6c7ed020 100644 --- a/python/cudf_polars/tests/dsl/test_to_ast.py +++ b/python/cudf_polars/tests/dsl/test_to_ast.py @@ -11,7 +11,7 @@ import pylibcudf as plc import cudf_polars.dsl.ir as ir_nodes -from cudf_polars import translate_ir +from cudf_polars import Translator from cudf_polars.containers.dataframe import DataFrame, NamedColumn from cudf_polars.dsl.to_ast import to_ast @@ -58,7 +58,7 @@ def df(): ) def test_compute_column(expr, df): q = df.select(expr) - ir = translate_ir(q._ldf.visit()) + ir = Translator(q._ldf.visit()).translate_ir() assert isinstance(ir, ir_nodes.Select) table = ir.children[0].evaluate(cache={}) diff --git a/python/cudf_polars/tests/dsl/test_traversal.py b/python/cudf_polars/tests/dsl/test_traversal.py index 15c644d7978..8958c2a0f84 100644 --- a/python/cudf_polars/tests/dsl/test_traversal.py +++ b/python/cudf_polars/tests/dsl/test_traversal.py @@ -10,7 +10,7 @@ import pylibcudf as plc -from cudf_polars import translate_ir +from cudf_polars import Translator from cudf_polars.dsl import expr, ir from cudf_polars.dsl.traversal import ( CachingVisitor, @@ -109,7 +109,7 @@ def test_rewrite_ir_node(): df = pl.LazyFrame({"a": [1, 2, 1], "b": [1, 3, 4]}) q = df.group_by("a").agg(pl.col("b").sum()).sort("b") - orig = translate_ir(q._ldf.visit()) + orig = Translator(q._ldf.visit()).translate_ir() new_df = pl.DataFrame({"a": [1, 1, 2], "b": [-1, -2, -4]}) @@ -150,7 +150,7 @@ def replace_scan(node, rec): mapper = CachingVisitor(replace_scan) - orig = translate_ir(q._ldf.visit()) + orig = Translator(q._ldf.visit()).translate_ir() new = mapper(orig) result = new.evaluate(cache={}).to_polars() @@ -174,7 +174,7 @@ def test_rewrite_names_and_ops(): .collect() ) - qir = translate_ir(q._ldf.visit()) + qir = Translator(q._ldf.visit()).translate_ir() @singledispatch def _transform(e: expr.Expr, fn: ExprTransformer) -> expr.Expr: diff --git a/python/cudf_polars/tests/expressions/test_sort.py b/python/cudf_polars/tests/expressions/test_sort.py index 62df8ce1498..6170281ad54 100644 --- a/python/cudf_polars/tests/expressions/test_sort.py +++ b/python/cudf_polars/tests/expressions/test_sort.py @@ -10,7 +10,7 @@ import pylibcudf as plc -from cudf_polars import translate_ir +from cudf_polars import Translator from cudf_polars.testing.asserts import assert_gpu_result_equal @@ -68,7 +68,7 @@ def test_setsorted(descending, nulls_last, with_nulls): assert_gpu_result_equal(q) - df = translate_ir(q._ldf.visit()).evaluate(cache={}) + df = Translator(q._ldf.visit()).translate_ir().evaluate(cache={}) a = df.column_map["a"]