Skip to content

Commit

Permalink
Translation for str_detect() and str_starts() (#1325)
Browse files Browse the repository at this point in the history
* Translation for `str_detect()` and `str_starts()`

* Fix snowflake

* Fix `fixed()` translation

* Don't use stringr in test
  • Loading branch information
mgirlich authored Jul 4, 2023
1 parent 1b10877 commit 6f88454
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 11 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# dbplyr (development version)

* Added translation for `str_detect()`, `str_starts()` and `str_ends()` with
fixed patterns (@mgirlich, #1009).

* The `overwrite` argument of `db_copy_to()` now actually works.

* `db_write_table()` and `db_save_query()` gain the `overwrite` argument.
Expand Down
30 changes: 29 additions & 1 deletion R/backend-.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,28 @@ base_scalar <- sql_translator(

str_conv = sql_not_supported("str_conv"),
str_count = sql_not_supported("str_count"),
str_detect = sql_not_supported("str_detect"),

fixed = function(pattern, ignore_case = FALSE) {
check_unsupported_arg(ignore_case, allowed = FALSE)
pattern
},
str_detect = function(string, pattern, negate = FALSE) {
sql_str_pattern_switch(
string = string,
pattern = {{ pattern }},
negate = negate,
f_fixed = sql_str_detect_fixed_instr("detect")
)
},
str_dup = sql_not_supported("str_dup"),
str_ends = function(string, pattern, negate = FALSE) {
sql_str_pattern_switch(
string = string,
pattern = {{ pattern }},
negate = negate,
f_fixed = sql_str_detect_fixed_instr("end")
)
},
str_extract = sql_not_supported("str_extract"),
str_extract_all = sql_not_supported("str_extract_all"),
str_flatten = sql_not_supported("str_flatten"),
Expand All @@ -308,6 +328,14 @@ base_scalar <- sql_translator(
str_split = sql_not_supported("str_split"),
str_split_fixed = sql_not_supported("str_split_fixed"),
str_squish = sql_not_supported("str_squish"),
str_starts = function(string, pattern, negate = FALSE) {
sql_str_pattern_switch(
string = string,
pattern = {{ pattern }},
negate = negate,
f_fixed = sql_str_detect_fixed_instr("start")
)
},
str_subset = sql_not_supported("str_subset"),
str_trunc = sql_not_supported("str_trunc"),
str_view = sql_not_supported("str_view"),
Expand Down
35 changes: 30 additions & 5 deletions R/backend-postgres.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,29 @@ sql_translation.PqConnection <- function(con) {
str_locate = function(string, pattern) {
sql_expr(strpos(!!string, !!pattern))
},
# https://www.postgresql.org/docs/9.1/functions-string.html
str_detect = function(string, pattern, negate = FALSE) {
if (isTRUE(negate)) {
sql_expr(!(!!string ~ !!pattern))
} else {
sql_expr(!!string ~ !!pattern)
}
sql_str_pattern_switch(
string = string,
pattern = {{ pattern }},
negate = negate,
f_fixed = sql_str_detect_fixed_position("detect"),
f_regex = function(string, pattern, negate = FALSE) {
if (isTRUE(negate)) {
sql_expr(!(!!string ~ !!pattern))
} else {
sql_expr(!!string ~ !!pattern)
}
}
)
},
str_ends = function(string, pattern, negate = FALSE) {
sql_str_pattern_switch(
string = string,
pattern = {{ pattern }},
negate = negate,
f_fixed = sql_str_detect_fixed_position("end")
)
},
# https://www.postgresql.org/docs/current/functions-matching.html
str_like = function(string, pattern, ignore_case = TRUE) {
Expand All @@ -127,6 +144,14 @@ sql_translation.PqConnection <- function(con) {
str_remove_all = function(string, pattern){
sql_expr(regexp_replace(!!string, !!pattern, '', 'g'))
},
str_starts = function(string, pattern, negate = FALSE) {
sql_str_pattern_switch(
string = string,
pattern = {{ pattern }},
negate = negate,
f_fixed = sql_str_detect_fixed_position("start")
)
},

# lubridate functions
# https://www.postgresql.org/docs/9.1/functions-datetime.html
Expand Down
18 changes: 13 additions & 5 deletions R/backend-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@ sql_translation.Snowflake <- function(con) {
# str_detect does not. Left- and right-pad `pattern` with .* to get
# str_detect-like behavior
str_detect = function(string, pattern, negate = FALSE) {
if (isTRUE(negate)) {
sql_expr(!(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*")))
} else {
sql_expr(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*"))
}
sql_str_pattern_switch(
string = string,
pattern = {{ pattern }},
negate = negate,
f_fixed = sql_str_detect_fixed_instr("detect"),
f_regex = function(string, pattern, negate = FALSE) {
if (isTRUE(negate)) {
sql_expr(!(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*")))
} else {
sql_expr(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*"))
}
}
)
},
# On Snowflake, REGEXP_REPLACE is used like this:
# REGEXP_REPLACE( <subject> , <pattern> [ , <replacement> ,
Expand Down
82 changes: 82 additions & 0 deletions R/translate-sql-string.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,86 @@ sql_str_trim <- function(string, side = c("both", "left", "right")) {
both = sql_expr(ltrim(rtrim(!!string))),
)
}



sql_str_pattern_switch <- function(string,
pattern,
negate = FALSE,
f_fixed = NULL,
f_regex = NULL,
error_call = caller_env()) {
pattern_quo <- enquo(pattern)
is_fixed <- quo_is_call(pattern_quo, "fixed") || inherits(pattern, "stringr_fixed")

if (is_fixed) {
f_fixed(string, pattern, negate)
} else {
if (is_null(f_regex)) {
cli_abort("Only fixed patterns are supported on database backends.", call = error_call)
} else {
f_regex(string, pattern, negate)
}
}
}

# INSTR
# * SQLite https://www.sqlitetutorial.net/sqlite-functions/sqlite-instr/
# * MySQL https://dev.mysql.com/doc/refman/8.0/en/string-functions.html#function_instr
# * Oracle https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/INSTR.html#GUID-47E3A7C4-ED72-458D-A1FA-25A9AD3BE113
# * Teradata https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/String-Operators-and-Functions/INSTR
# * Access https://support.microsoft.com/de-de/office/instr-funktion-85d3392c-3b1c-4232-bb18-77cd0cb8a55b
# * Hana https://help.sap.com/docs/SAP_HANA_PLATFORM/e8e6c8142e60469bb401de5fdb6f7c00/f5a9ca3718354a499a98ba61ae3da170.html
# * Hive https://www.revisitclass.com/hadoop/instr-function-in-hive-with-examples/
# * Impala https://impala.apache.org/docs/build/html/topics/impala_string_functions.html#string_functions__instr
# POSITION
# * Snowflake https://docs.snowflake.com/en/sql-reference/functions/position
sql_str_detect_fixed_instr <- function(type = c("detect", "start", "end")) {
type <- arg_match(type)

function(string, pattern, negate = FALSE) {
con <- sql_current_con()
pattern <- unclass(pattern)
index_sql <- glue_sql2(con, "INSTR({.val string}, {.val pattern})")

if (negate) {
switch(type,
detect = translate_sql(!!index_sql == 0L, con = con),
start = translate_sql(!!index_sql != 1L, con = con),
end = translate_sql(!!index_sql != nchar(!!string) - nchar(!!pattern) + 1L, con = con)
)
} else {
switch(type,
detect = translate_sql(!!index_sql > 0L, con = con),
start = translate_sql(!!index_sql == 1L, con = con),
end = translate_sql(!!index_sql == nchar(!!string) - nchar(!!pattern) + 1L, con = con)
)
}
}
}

sql_str_detect_fixed_position <- function(type = c("detect", "start", "end")) {
type <- arg_match(type)

function(string, pattern, negate = FALSE) {
con <- sql_current_con()
pattern <- unclass(pattern)
index_sql <- glue_sql2(con, "POSITION({.val pattern} in {.val string})")

if (negate) {
switch(type,
detect = translate_sql(!!index_sql == 0L, con = con),
start = translate_sql(!!index_sql != 1L, con = con),
end = translate_sql(!!index_sql != nchar(!!string) - nchar(!!pattern) + 1L, con = con)
)
} else {
switch(type,
detect = translate_sql(!!index_sql > 0L, con = con),
start = translate_sql(!!index_sql == 1L, con = con),
end = translate_sql(!!index_sql == nchar(!!string) - nchar(!!pattern) + 1L, con = con)
)
}
}
}

utils::globalVariables(c("ltrim", "rtrim"))
13 changes: 13 additions & 0 deletions tests/testthat/test-backend-postgres.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ test_that("custom stringr functions translated correctly", {
expect_equal(test_translate_sql(str_squish(x)), sql("LTRIM(RTRIM(REGEXP_REPLACE(`x`, '\\s+', ' ', 'g')))"))
expect_equal(test_translate_sql(str_remove(x, y)), sql("REGEXP_REPLACE(`x`, `y`, '')"))
expect_equal(test_translate_sql(str_remove_all(x, y)), sql("REGEXP_REPLACE(`x`, `y`, '', 'g')"))

expect_equal(
test_translate_sql(str_detect(x, fixed("%0"))),
sql("POSITION('%0' in `x`) > 0")
)
expect_equal(
test_translate_sql(str_starts(x, fixed("%0"))),
sql("POSITION('%0' in `x`) = 1")
)
expect_equal(
test_translate_sql(str_ends(x, fixed("%0"))),
sql("POSITION('%0' in `x`) = ((LENGTH(`x`) - LENGTH('%0')) + 1)")
)
})

test_that("two variable aggregates are translated correctly", {
Expand Down
34 changes: 34 additions & 0 deletions tests/testthat/test-translate-sql-string.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,37 @@ test_that("str_sub() returns consistent results", {
expect_equal(mf %>% transmute(str_sub(t, 0, 1)) %>% pull(1), "a")
expect_equal(mf %>% transmute(str_sub(t, 1, 3)) %>% pull(1), "abc")
})

test_that("str_detect(), str_starts(), str_ends() support fixed patterns", {
mf <- memdb_frame(x = c("%0 start", "end %0", "detect %0 detect", "no", NA))

# detects fixed pattern
expect_equal(
mf %>% transmute(str_starts(x, fixed("%0"))) %>% pull(1),
c(1, 0, 0, 0, NA)
)
# hack to avoid check complaining about not declared imports
pattern <- rlang::parse_expr("stringr::fixed('%0')")
expect_equal(
mf %>% transmute(str_starts(x, !!pattern)) %>% pull(1),
c(1, 0, 0, 0, NA)
)

# also works with ends and detect
expect_equal(
mf %>% transmute(str_ends(x, fixed("%0"))) %>% pull(1),
c(0, 1, 0, 0, NA)
)
expect_equal(
mf %>% transmute(str_detect(x, fixed("%0"))) %>% pull(1),
c(1, 1, 1, 0, NA)
)

# negate works
expect_equal(
mf %>% transmute(str_detect(x, fixed("%0"), negate = TRUE)) %>% pull(1),
c(0, 0, 0, 1, NA)
)

expect_error(translate_sql(str_detect(x, "a"), con = simulate_dbi()))
})

0 comments on commit 6f88454

Please sign in to comment.