Skip to content

Commit

Permalink
add_regexp_replace (#314)
Browse files Browse the repository at this point in the history
* add regexp replace

* add test

* make case sensitive by default, minor cleanup

---------

Co-authored-by: Ivan Shcheklein <[email protected]>
  • Loading branch information
EdwardLi-coder and shcheklein authored Aug 19, 2024
1 parent 61aeed4 commit c74b95c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/datachain/sql/functions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,17 @@ class split(GenericFunction): # noqa: N801
inherit_cache = True


class regexp_replace(GenericFunction): # noqa: N801
"""
Replaces substring that match a regular expression.
"""

type = String()
package = "string"
name = "regexp_replace"
inherit_cache = True


compiler_not_implemented(length)
compiler_not_implemented(split)
compiler_not_implemented(regexp_replace)
12 changes: 12 additions & 0 deletions src/datachain/sql/sqlite/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
import sqlite3
from collections.abc import Iterable
from datetime import MAXYEAR, MINYEAR, datetime, timezone
Expand Down Expand Up @@ -77,6 +78,7 @@ def setup():
compiles(array.length, "sqlite")(compile_array_length)
compiles(string.length, "sqlite")(compile_string_length)
compiles(string.split, "sqlite")(compile_string_split)
compiles(string.regexp_replace, "sqlite")(compile_regexp_replace)
compiles(conditional.greatest, "sqlite")(compile_greatest)
compiles(conditional.least, "sqlite")(compile_least)
compiles(Values, "sqlite")(compile_values)
Expand Down Expand Up @@ -178,9 +180,15 @@ def create_vector_functions(conn):

_registered_function_creators["vector_functions"] = create_vector_functions

def sqlite_regexp_replace(string: str, pattern: str, replacement: str) -> str:
return re.sub(pattern, replacement, string)

def create_string_functions(conn):
conn.create_function("split", 2, sqlite_string_split, deterministic=True)
conn.create_function("split", 3, sqlite_string_split, deterministic=True)
conn.create_function(
"regexp_replace", 3, sqlite_regexp_replace, deterministic=True
)

_registered_function_creators["string_functions"] = create_string_functions

Expand Down Expand Up @@ -265,6 +273,10 @@ def path_file_ext(path):
return func.substr(path, func.length(path) - path_file_ext_length(path) + 1)


def compile_regexp_replace(element, compiler, **kwargs):
return f"regexp_replace({compiler.process(element.clauses, **kwargs)})"


def compile_path_parent(element, compiler, **kwargs):
return compiler.process(path_parent(*element.clauses.clauses), **kwargs)

Expand Down
25 changes: 25 additions & 0 deletions tests/unit/sql/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,28 @@ def test_split(warehouse, args, expected):
query = select(string.split(*args))
result = tuple(warehouse.dataset_rows_select(query))
assert result == ((expected,),)


@pytest.mark.parametrize(
"input_string,pattern,replacement,expected",
[
("hello world", "world", "universe", "hello universe"),
("abc123def456", r"\d+", "X", "abcXdefX"),
("cat.1001.jpg", r"\.(\w+)\.", r"_\1_", "cat_1001_jpg"),
(
"dog_photo.jpg",
r"(\w+)\.(jpg|jpeg|png|gif)$",
r"\1_thumb.\2",
"dog_photo_thumb.jpg",
),
("file.with...dots.txt", r"\.+", ".", "file.with.dots.txt"),
],
)
def test_regexp_replace(warehouse, input_string, pattern, replacement, expected):
query = select(
string.regexp_replace(
literal(input_string), literal(pattern), literal(replacement)
)
)
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)

0 comments on commit c74b95c

Please sign in to comment.