From 054b271459acb1511328d0bf5cb8801f3a142c53 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Wed, 9 Oct 2024 08:15:40 -0700 Subject: [PATCH] address review --- python/cudf_polars/cudf_polars/__init__.py | 4 +- python/cudf_polars/cudf_polars/callback.py | 12 +- python/cudf_polars/cudf_polars/dsl/expr.py | 7 + python/cudf_polars/cudf_polars/dsl/ir.py | 9 + .../cudf_polars/cudf_polars/dsl/translate.py | 366 ++++++++++-------- .../cudf_polars/testing/asserts.py | 4 +- 6 files changed, 225 insertions(+), 177 deletions(-) diff --git a/python/cudf_polars/cudf_polars/__init__.py b/python/cudf_polars/cudf_polars/__init__.py index 66c15f694ee..ba4858c5619 100644 --- a/python/cudf_polars/cudf_polars/__init__.py +++ b/python/cudf_polars/cudf_polars/__init__.py @@ -12,7 +12,7 @@ from cudf_polars._version import __git_commit__, __version__ from cudf_polars.callback import execute_with_cudf -from cudf_polars.dsl.translate import translate_ir +from cudf_polars.dsl.translate import Translator # Check we have a supported polars version from cudf_polars.utils.versions import _ensure_polars_version @@ -22,7 +22,7 @@ __all__: list[str] = [ "execute_with_cudf", - "translate_ir", + "Translator", "__git_commit__", "__version__", ] diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 76816ee0a61..8b17bfc7f47 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -18,7 +18,7 @@ import rmm from rmm._cuda import gpu -from cudf_polars.dsl.translate import translate_ir +from cudf_polars.dsl.translate import Translator if TYPE_CHECKING: from collections.abc import Generator @@ -174,16 +174,22 @@ def execute_with_cudf( device = config.device memory_resource = config.memory_resource raise_on_fail = config.config.get("raise_on_fail", False) - if unsupported := (config.config.keys() - {"raise_on_fail"}): + debug_mode = config.config.get("debug_mode", False) + if unsupported := (config.config.keys() - {"raise_on_fail", "debug_mode"}): raise ValueError( f"Engine configuration contains unsupported settings {unsupported}" ) try: with nvtx.annotate(message="ConvertIR", domain="cudf_polars"): + translator = Translator(nt, debug_mode=debug_mode) + ir = translator.translate_ir() + if debug_mode and len(translator.errors): + print(set(translator.errors)) + raise NotImplementedError("Query contained unsupported operations") nt.set_udf( partial( _callback, - translate_ir(nt), + ir, device=device, memory_resource=memory_resource, ) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index f7775ceb238..666db9f56df 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -40,6 +40,7 @@ __all__ = [ "Expr", + "ErrorExpr", "NamedExpr", "Literal", "Col", @@ -275,6 +276,12 @@ def collect_agg(self, *, depth: int) -> AggInfo: ) # pragma: no cover; check_agg trips first +class ErrorExpr(Expr): + def __init__(self, dtype: plc.DataType, error: str) -> None: + super().__init__(dtype) + self.error = error + + class NamedExpr: # NamedExpr does not inherit from Expr since it does not appear # when evaluating expressions themselves, only when constructing diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index e319c363a23..a61629ef438 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -38,6 +38,7 @@ __all__ = [ "IR", + "ErrorNode", "PythonScan", "Scan", "Cache", @@ -159,6 +160,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: ) # pragma: no cover +@dataclasses.dataclass +class ErrorNode(IR): + """Represents an error translating the IR.""" + + error: str + """The error.""" + + @dataclasses.dataclass class PythonScan(IR): """Representation of input from a python function.""" diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index a0291037f01..1faec73891b 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -8,7 +8,7 @@ import json from contextlib import AbstractContextManager, nullcontext from functools import singledispatch -from typing import Any +from typing import TYPE_CHECKING, Any import pyarrow as pa import pylibcudf as plc @@ -19,10 +19,107 @@ from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir from cudf_polars.dsl import expr, ir -from cudf_polars.typing import NodeTraverser from cudf_polars.utils import dtypes -__all__ = ["translate_ir", "translate_named_expr"] +if TYPE_CHECKING: + from cudf_polars.typing import NodeTraverser +__all__ = ["Translator", "translate_named_expr"] + + +class Translator: + """ + Translates polars-internal IR nodes and expressions to our representation. + + Parameters + ---------- + visitor + Polars NodeTraverser object + debug_mode + Setting this mode to True allows Translator to collect + errors raised during translation + """ + + def __init__(self, visitor: NodeTraverser, *, debug_mode: bool = False): + self.visitor = visitor + self.debug_mode = debug_mode + self.errors: list[str] = [] + + def translate_ir(self, *, n: int | None = None) -> ir.IR: + """ + Translate a polars-internal IR node to our representation. + + Parameters + ---------- + n + Optional node to start traversing from, if not provided uses + current polars-internal node. + + Returns + ------- + Translated IR object + + Raises + ------ + NotImplementedError + If we can't translate the nodes due to unsupported functionality. + """ + ctx: AbstractContextManager[None] = ( + set_node(self.visitor, n) if n is not None else noop_context + ) + # IR is versioned with major.minor, minor is bumped for backwards + # compatible changes (e.g. adding new nodes), major is bumped for + # incompatible changes (e.g. renaming nodes). + # Polars 1.7 changes definition of the CSV reader options schema name. + if (version := self.visitor.version()) >= (3, 0): + raise NotImplementedError( + f"No support for polars IR {version=}" + ) # pragma: no cover; no such version for now. + + with ctx: + polars_schema = self.visitor.get_schema() + node = self.visitor.view_current_node() + schema = {k: dtypes.from_polars(v) for k, v in polars_schema.items()} + try: + result = _translate_ir(node, self, schema) + except Exception as e: + self.errors.append(str(e)) + return ir.ErrorNode(polars_schema, str(e)) + if any( + isinstance(dtype, pl.Null) + for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values()) + ): + error = f"No GPU support for {result} with Null column dtype." + if self.debug_mode: + self.errors.append(error) + return ir.ErrorNode(polars_schema, error) + raise NotImplementedError(error) + return result + + def translate_expr(self, *, n: int) -> expr.Expr: + """ + Translate a polars-internal expression IR into our representation. + + Parameters + ---------- + n + Node to translate, an integer referencing a polars internal node. + + Returns + ------- + Translated IR object. + + Raises + ------ + NotImplementedError + If any translation fails due to unsupported functionality. + """ + node = self.visitor.view_expression(n) + dtype = dtypes.from_polars(self.visitor.get_dtype(n)) + try: + return _translate_expr(node, self, dtype) + except Exception as e: + self.errors.append(str(e)) + return expr.ErrorExpr(dtype, str(e)) class set_node(AbstractContextManager[None]): @@ -64,7 +161,7 @@ def __exit__(self, *args: Any) -> None: @singledispatch def _translate_ir( - node: Any, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: Any, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: raise NotImplementedError( f"Translation for {type(node).__name__}" @@ -73,19 +170,19 @@ def _translate_ir( @_translate_ir.register def _( - node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.PythonScan, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: scan_fn, with_columns, source_type, predicate, nrows = node.options options = (scan_fn, with_columns, source_type, nrows) predicate = ( - translate_named_expr(visitor, n=predicate) if predicate is not None else None + translate_named_expr(translator, n=predicate) if predicate is not None else None ) return ir.PythonScan(schema, options, predicate) @_translate_ir.register def _( - node: pl_ir.Scan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Scan, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: typ, *options = node.scan_type if typ == "ndjson": @@ -114,7 +211,7 @@ def _( skip_rows, n_rows, row_index, - translate_named_expr(visitor, n=node.predicate) + translate_named_expr(translator, n=node.predicate) if node.predicate is not None else None, ) @@ -122,20 +219,20 @@ def _( @_translate_ir.register def _( - node: pl_ir.Cache, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Cache, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.Cache(schema, node.id_, translate_ir(visitor, n=node.input)) + return ir.Cache(schema, node.id_, translator.translate_ir(n=node.input)) @_translate_ir.register def _( - node: pl_ir.DataFrameScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.DataFrameScan, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: return ir.DataFrameScan( schema, node.df, node.projection, - translate_named_expr(visitor, n=node.selection) + translate_named_expr(translator, n=node.selection) if node.selection is not None else None, ) @@ -143,22 +240,22 @@ def _( @_translate_ir.register def _( - node: pl_ir.Select, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Select, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - exprs = [translate_named_expr(visitor, n=e) for e in node.expr] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + exprs = [translate_named_expr(translator, n=e) for e in node.expr] return ir.Select(schema, inp, exprs, node.should_broadcast) @_translate_ir.register def _( - node: pl_ir.GroupBy, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.GroupBy, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - aggs = [translate_named_expr(visitor, n=e) for e in node.aggs] - keys = [translate_named_expr(visitor, n=e) for e in node.keys] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + aggs = [translate_named_expr(translator, n=e) for e in node.aggs] + keys = [translate_named_expr(translator, n=e) for e in node.keys] return ir.GroupBy( schema, inp, @@ -171,96 +268,98 @@ def _( @_translate_ir.register def _( - node: pl_ir.Join, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Join, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: # Join key dtypes are dependent on the schema of the left and # right inputs, so these must be translated with the relevant # input active. - with set_node(visitor, node.input_left): - inp_left = translate_ir(visitor, n=None) - left_on = [translate_named_expr(visitor, n=e) for e in node.left_on] - with set_node(visitor, node.input_right): - inp_right = translate_ir(visitor, n=None) - right_on = [translate_named_expr(visitor, n=e) for e in node.right_on] + with set_node(translator.visitor, node.input_left): + inp_left = translator.translate_ir(n=None) + left_on = [translate_named_expr(translator, n=e) for e in node.left_on] + with set_node(translator.visitor, node.input_right): + inp_right = translator.translate_ir(n=None) + right_on = [translate_named_expr(translator, n=e) for e in node.right_on] return ir.Join(schema, inp_left, inp_right, left_on, right_on, node.options) @_translate_ir.register def _( - node: pl_ir.HStack, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.HStack, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - exprs = [translate_named_expr(visitor, n=e) for e in node.exprs] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + exprs = [translate_named_expr(translator, n=e) for e in node.exprs] return ir.HStack(schema, inp, exprs, node.should_broadcast) @_translate_ir.register def _( - node: pl_ir.Reduce, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Reduce, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - exprs = [translate_named_expr(visitor, n=e) for e in node.expr] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + exprs = [translate_named_expr(translator, n=e) for e in node.expr] return ir.Reduce(schema, inp, exprs) @_translate_ir.register def _( - node: pl_ir.Distinct, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Distinct, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: return ir.Distinct( schema, - translate_ir(visitor, n=node.input), + translator.translate_ir(n=node.input), node.options, ) @_translate_ir.register def _( - node: pl_ir.Sort, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Sort, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - by = [translate_named_expr(visitor, n=e) for e in node.by_column] + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + by = [translate_named_expr(translator, n=e) for e in node.by_column] return ir.Sort(schema, inp, by, node.sort_options, node.slice) @_translate_ir.register def _( - node: pl_ir.Slice, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Slice, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.Slice(schema, translate_ir(visitor, n=node.input), node.offset, node.len) + return ir.Slice( + schema, translator.translate_ir(n=node.input), node.offset, node.len + ) @_translate_ir.register def _( - node: pl_ir.Filter, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Filter, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - with set_node(visitor, node.input): - inp = translate_ir(visitor, n=None) - mask = translate_named_expr(visitor, n=node.predicate) + with set_node(translator.visitor, node.input): + inp = translator.translate_ir(n=None) + mask = translate_named_expr(translator, n=node.predicate) return ir.Filter(schema, inp, mask) @_translate_ir.register def _( node: pl_ir.SimpleProjection, - visitor: NodeTraverser, + translator: Translator, schema: dict[str, plc.DataType], ) -> ir.IR: - return ir.Projection(schema, translate_ir(visitor, n=node.input)) + return ir.Projection(schema, translator.translate_ir(n=node.input)) @_translate_ir.register def _( - node: pl_ir.MapFunction, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.MapFunction, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: name, *options = node.function return ir.MapFunction( schema, # TODO: merge_sorted breaks this pattern - translate_ir(visitor, n=node.input), + translator.translate_ir(n=node.input), name, options, ) @@ -268,78 +367,30 @@ def _( @_translate_ir.register def _( - node: pl_ir.Union, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Union, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: return ir.Union( - schema, [translate_ir(visitor, n=n) for n in node.inputs], node.options + schema, [translator.translate_ir(n=n) for n in node.inputs], node.options ) @_translate_ir.register def _( - node: pl_ir.HConcat, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.HConcat, translator: Translator, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.HConcat(schema, [translate_ir(visitor, n=n) for n in node.inputs]) - - -def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: - """ - Translate a polars-internal IR node to our representation. - - Parameters - ---------- - visitor - Polars NodeTraverser object - n - Optional node to start traversing from, if not provided uses - current polars-internal node. - - Returns - ------- - Translated IR object - - Raises - ------ - NotImplementedError - If we can't translate the nodes due to unsupported functionality. - """ - ctx: AbstractContextManager[None] = ( - set_node(visitor, n) if n is not None else noop_context - ) - # IR is versioned with major.minor, minor is bumped for backwards - # compatible changes (e.g. adding new nodes), major is bumped for - # incompatible changes (e.g. renaming nodes). - # Polars 1.7 changes definition of the CSV reader options schema name. - if (version := visitor.version()) >= (3, 0): - raise NotImplementedError( - f"No support for polars IR {version=}" - ) # pragma: no cover; no such version for now. - - with ctx: - polars_schema = visitor.get_schema() - node = visitor.view_current_node() - schema = {k: dtypes.from_polars(v) for k, v in polars_schema.items()} - result = _translate_ir(node, visitor, schema) - if any( - isinstance(dtype, pl.Null) - for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values()) - ): - raise NotImplementedError( - f"No GPU support for {result} with Null column dtype." - ) - return result + return ir.HConcat(schema, [translator.translate_ir(n=n) for n in node.inputs]) def translate_named_expr( - visitor: NodeTraverser, *, n: pl_expr.PyExprIR + translator: Translator, *, n: pl_expr.PyExprIR ) -> expr.NamedExpr: """ Translate a polars-internal named expression IR object into our representation. Parameters ---------- - visitor - Polars NodeTraverser object + translator + Translator object n Node to translate, a named expression node. @@ -359,12 +410,12 @@ def translate_named_expr( NotImplementedError If any translation fails due to unsupported functionality. """ - return expr.NamedExpr(n.output_name, translate_expr(visitor, n=n.node)) + return expr.NamedExpr(n.output_name, translator.translate_expr(n=n.node)) @singledispatch def _translate_expr( - node: Any, visitor: NodeTraverser, dtype: plc.DataType + node: Any, translator: Translator, dtype: plc.DataType ) -> expr.Expr: raise NotImplementedError( f"Translation for {type(node).__name__}" @@ -372,7 +423,7 @@ def _translate_expr( @_translate_expr.register -def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> expr.Expr: name, *options = node.function_data options = tuple(options) if isinstance(name, pl_expr.StringFunction): @@ -381,7 +432,7 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex pl_expr.StringFunction.StripCharsStart, pl_expr.StringFunction.StripCharsEnd, }: - column, chars = (translate_expr(visitor, n=n) for n in node.input) + column, chars = (translator.translate_expr(n=n) for n in node.input) if isinstance(chars, expr.Literal): if chars.value == pa.scalar(""): # No-op in polars, but libcudf uses empty string @@ -398,11 +449,11 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex dtype, name, options, - *(translate_expr(visitor, n=n) for n in node.input), + *(translator.translate_expr(n=n) for n in node.input), ) elif isinstance(name, pl_expr.BooleanFunction): if name == pl_expr.BooleanFunction.IsBetween: - column, lo, hi = (translate_expr(visitor, n=n) for n in node.input) + column, lo, hi = (translator.translate_expr(n=n) for n in node.input) (closed,) = options lop, rop = expr.BooleanFunction._BETWEEN_OPS[closed] return expr.BinOp( @@ -415,7 +466,7 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex dtype, name, options, - *(translate_expr(visitor, n=n) for n in node.input), + *(translator.translate_expr(n=n) for n in node.input), ) elif isinstance(name, pl_expr.TemporalFunction): # functions for which evaluation of the expression may not return @@ -435,14 +486,14 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex dtype, name, options, - *(translate_expr(visitor, n=n) for n in node.input), + *(translator.translate_expr(n=n) for n in node.input), ) if name in needs_cast: return expr.Cast(dtype, result_expr) return result_expr elif isinstance(name, str): - children = (translate_expr(visitor, n=n) for n in node.input) + children = (translator.translate_expr(n=n) for n in node.input) if name == "log": (base,) = options (child,) = children @@ -461,26 +512,26 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex @_translate_expr.register -def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Window, translator: Translator, dtype: plc.DataType) -> expr.Expr: # TODO: raise in groupby? if isinstance(node.options, pl_expr.RollingGroupOptions): # pl.col("a").rolling(...) return expr.RollingWindow( - dtype, node.options, translate_expr(visitor, n=node.function) + dtype, node.options, translator.translate_expr(n=node.function) ) elif isinstance(node.options, pl_expr.WindowMapping): # pl.col("a").over(...) return expr.GroupedRollingWindow( dtype, node.options, - translate_expr(visitor, n=node.function), - *(translate_expr(visitor, n=n) for n in node.partition_by), + translator.translate_expr(n=node.function), + *(translator.translate_expr(n=n) for n in node.partition_by), ) assert_never(node.options) @_translate_expr.register -def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Literal, translator: Translator, dtype: plc.DataType) -> expr.Expr: if isinstance(node.value, plrs.PySeries): return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value)) value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype)) @@ -488,42 +539,42 @@ def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> exp @_translate_expr.register -def _(node: pl_expr.Sort, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Sort, translator: Translator, dtype: plc.DataType) -> expr.Expr: # TODO: raise in groupby - return expr.Sort(dtype, node.options, translate_expr(visitor, n=node.expr)) + return expr.Sort(dtype, node.options, translator.translate_expr(n=node.expr)) @_translate_expr.register -def _(node: pl_expr.SortBy, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.SortBy( dtype, node.sort_options, - translate_expr(visitor, n=node.expr), - *(translate_expr(visitor, n=n) for n in node.by), + translator.translate_expr(n=node.expr), + *(translator.translate_expr(n=n) for n in node.by), ) @_translate_expr.register -def _(node: pl_expr.Gather, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Gather, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Gather( dtype, - translate_expr(visitor, n=node.expr), - translate_expr(visitor, n=node.idx), + translator.translate_expr(n=node.expr), + translator.translate_expr(n=node.idx), ) @_translate_expr.register -def _(node: pl_expr.Filter, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Filter, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Filter( dtype, - translate_expr(visitor, n=node.input), - translate_expr(visitor, n=node.by), + translator.translate_expr(n=node.input), + translator.translate_expr(n=node.by), ) @_translate_expr.register -def _(node: pl_expr.Cast, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: - inner = translate_expr(visitor, n=node.expr) +def _(node: pl_expr.Cast, translator: Translator, dtype: plc.DataType) -> expr.Expr: + inner = translator.translate_expr(n=node.expr) # Push casts into literals so we can handle Cast(Literal(Null)) if isinstance(inner, expr.Literal): return expr.Literal(dtype, inner.value.cast(plc.interop.to_arrow(dtype))) @@ -535,17 +586,17 @@ def _(node: pl_expr.Cast, visitor: NodeTraverser, dtype: plc.DataType) -> expr.E @_translate_expr.register -def _(node: pl_expr.Column, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Column, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Col(dtype, node.name) @_translate_expr.register -def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Agg, translator: Translator, dtype: plc.DataType) -> expr.Expr: value = expr.Agg( dtype, node.name, node.options, - *(translate_expr(visitor, n=n) for n in node.arguments), + *(translator.translate_expr(n=n) for n in node.arguments), ) if value.name == "count" and value.dtype.id() != plc.TypeId.INT32: return expr.Cast(value.dtype, value) @@ -553,55 +604,30 @@ def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Ex @_translate_expr.register -def _(node: pl_expr.Ternary, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Ternary, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Ternary( dtype, - translate_expr(visitor, n=node.predicate), - translate_expr(visitor, n=node.truthy), - translate_expr(visitor, n=node.falsy), + translator.translate_expr(n=node.predicate), + translator.translate_expr(n=node.truthy), + translator.translate_expr(n=node.falsy), ) @_translate_expr.register def _( - node: pl_expr.BinaryExpr, visitor: NodeTraverser, dtype: plc.DataType + node: pl_expr.BinaryExpr, translator: Translator, dtype: plc.DataType ) -> expr.Expr: return expr.BinOp( dtype, expr.BinOp._MAPPING[node.op], - translate_expr(visitor, n=node.left), - translate_expr(visitor, n=node.right), + translator.translate_expr(n=node.left), + translator.translate_expr(n=node.right), ) @_translate_expr.register -def _(node: pl_expr.Len, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Len, translator: Translator, dtype: plc.DataType) -> expr.Expr: value = expr.Len(dtype) if dtype.id() != plc.TypeId.INT32: return expr.Cast(dtype, value) return value # pragma: no cover; never reached since polars len has uint32 dtype - - -def translate_expr(visitor: NodeTraverser, *, n: int) -> expr.Expr: - """ - Translate a polars-internal expression IR into our representation. - - Parameters - ---------- - visitor - Polars NodeTraverser object - n - Node to translate, an integer referencing a polars internal node. - - Returns - ------- - Translated IR object. - - Raises - ------ - NotImplementedError - If any translation fails due to unsupported functionality. - """ - node = visitor.view_expression(n) - dtype = dtypes.from_polars(visitor.get_dtype(n)) - return _translate_expr(node, visitor, dtype) diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index 7b6f3848fc4..7073231497a 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -10,7 +10,7 @@ from polars import GPUEngine from polars.testing.asserts import assert_frame_equal -from cudf_polars.dsl.translate import translate_ir +from cudf_polars.dsl.translate import Translator if TYPE_CHECKING: import polars as pl @@ -118,7 +118,7 @@ def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception]) If the specified exceptions were not raised. """ try: - _ = translate_ir(q._ldf.visit()) + _ = Translator(q._ldf.visit()).translate_ir() except exceptions: return except Exception as e: