From 5e8855ddd57c88d1c7a4922926502f85d93d98ae Mon Sep 17 00:00:00 2001 From: Tyler White <50381805+IndexSeek@users.noreply.github.com> Date: Sat, 28 Dec 2024 06:25:29 -0500 Subject: [PATCH] feat(polars): add `StringFind` operation (#10624) ## Description of changes Adds the [`StringFind`](https://ibis-project.org/reference/operations#ibis.expr.operations.strings.StringFind) operation for the Polars backend using [`pl.Expr.str.find`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.find.html#). --- ibis/backends/polars/compiler.py | 9 +++++++++ ibis/backends/tests/test_string.py | 5 ----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 04c8a8cc928d..7ba22b9a1b80 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1555,3 +1555,12 @@ def visit_ArrayIntersect(op, **kw): left = translate(op.left, **kw) right = translate(op.right, **kw) return left.list.set_intersection(right) + + +@translate.register(ops.StringFind) +def visit_StringFind(op, **kw): + arg = translate(op.arg, **kw) + start = translate(op.start, **kw) if op.start is not None else 0 + end = translate(op.end, **kw) if op.end is not None else None + expr = arg.str.slice(start, end).str.find(_literal_value(op.substr), literal=True) + return pl.when(expr.is_null()).then(-1).otherwise(expr + start) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index b849e365f82a..6e45c9979457 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -406,15 +406,11 @@ def uses_java_re(t): lambda t: t.string_col.find("a"), lambda t: t.string_col.str.find("a"), id="find", - marks=pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError), ), param( lambda t: t.date_string_col.find("13", 3), lambda t: t.date_string_col.str.find("13", 3), id="find_start", - marks=[ - pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError), - ], ), param( lambda t: t.string_col.lpad(10, "a"), @@ -1084,7 +1080,6 @@ def string_temp_table(backend, con): lambda t: t.string_col.find("123"), lambda t: t.str.find("123"), id="find", - marks=pytest.mark.notimpl("polars", raises=com.OperationNotDefinedError), ), param( lambda t: t.string_col.rpad(4, "-"),