From 6f88454f52ade23983a3171080fe539d904b2c14 Mon Sep 17 00:00:00 2001 From: Maximilian Girlich Date: Tue, 4 Jul 2023 10:58:11 +0200 Subject: [PATCH] Translation for `str_detect()` and `str_starts()` (#1325) * Translation for `str_detect()` and `str_starts()` * Fix snowflake * Fix `fixed()` translation * Don't use stringr in test --- NEWS.md | 3 + R/backend-.R | 30 +++++++- R/backend-postgres.R | 35 +++++++-- R/backend-snowflake.R | 18 +++-- R/translate-sql-string.R | 82 ++++++++++++++++++++++ tests/testthat/test-backend-postgres.R | 13 ++++ tests/testthat/test-translate-sql-string.R | 34 +++++++++ 7 files changed, 204 insertions(+), 11 deletions(-) diff --git a/NEWS.md b/NEWS.md index 0f7116c6b..7d3acf053 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # dbplyr (development version) +* Added translation for `str_detect()`, `str_starts()` and `str_ends()` with + fixed patterns (@mgirlich, #1009). + * The `overwrite` argument of `db_copy_to()` now actually works. * `db_write_table()` and `db_save_query()` gain the `overwrite` argument. diff --git a/R/backend-.R b/R/backend-.R index 606a4712a..0572b5b5a 100644 --- a/R/backend-.R +++ b/R/backend-.R @@ -285,8 +285,28 @@ base_scalar <- sql_translator( str_conv = sql_not_supported("str_conv"), str_count = sql_not_supported("str_count"), - str_detect = sql_not_supported("str_detect"), + + fixed = function(pattern, ignore_case = FALSE) { + check_unsupported_arg(ignore_case, allowed = FALSE) + pattern + }, + str_detect = function(string, pattern, negate = FALSE) { + sql_str_pattern_switch( + string = string, + pattern = {{ pattern }}, + negate = negate, + f_fixed = sql_str_detect_fixed_instr("detect") + ) + }, str_dup = sql_not_supported("str_dup"), + str_ends = function(string, pattern, negate = FALSE) { + sql_str_pattern_switch( + string = string, + pattern = {{ pattern }}, + negate = negate, + f_fixed = sql_str_detect_fixed_instr("end") + ) + }, str_extract = sql_not_supported("str_extract"), str_extract_all = sql_not_supported("str_extract_all"), str_flatten = sql_not_supported("str_flatten"), @@ -308,6 +328,14 @@ base_scalar <- sql_translator( str_split = sql_not_supported("str_split"), str_split_fixed = sql_not_supported("str_split_fixed"), str_squish = sql_not_supported("str_squish"), + str_starts = function(string, pattern, negate = FALSE) { + sql_str_pattern_switch( + string = string, + pattern = {{ pattern }}, + negate = negate, + f_fixed = sql_str_detect_fixed_instr("start") + ) + }, str_subset = sql_not_supported("str_subset"), str_trunc = sql_not_supported("str_trunc"), str_view = sql_not_supported("str_view"), diff --git a/R/backend-postgres.R b/R/backend-postgres.R index cda139288..3cc32a92f 100644 --- a/R/backend-postgres.R +++ b/R/backend-postgres.R @@ -97,12 +97,29 @@ sql_translation.PqConnection <- function(con) { str_locate = function(string, pattern) { sql_expr(strpos(!!string, !!pattern)) }, + # https://www.postgresql.org/docs/9.1/functions-string.html str_detect = function(string, pattern, negate = FALSE) { - if (isTRUE(negate)) { - sql_expr(!(!!string ~ !!pattern)) - } else { - sql_expr(!!string ~ !!pattern) - } + sql_str_pattern_switch( + string = string, + pattern = {{ pattern }}, + negate = negate, + f_fixed = sql_str_detect_fixed_position("detect"), + f_regex = function(string, pattern, negate = FALSE) { + if (isTRUE(negate)) { + sql_expr(!(!!string ~ !!pattern)) + } else { + sql_expr(!!string ~ !!pattern) + } + } + ) + }, + str_ends = function(string, pattern, negate = FALSE) { + sql_str_pattern_switch( + string = string, + pattern = {{ pattern }}, + negate = negate, + f_fixed = sql_str_detect_fixed_position("end") + ) }, # https://www.postgresql.org/docs/current/functions-matching.html str_like = function(string, pattern, ignore_case = TRUE) { @@ -127,6 +144,14 @@ sql_translation.PqConnection <- function(con) { str_remove_all = function(string, pattern){ sql_expr(regexp_replace(!!string, !!pattern, '', 'g')) }, + str_starts = function(string, pattern, negate = FALSE) { + sql_str_pattern_switch( + string = string, + pattern = {{ pattern }}, + negate = negate, + f_fixed = sql_str_detect_fixed_position("start") + ) + }, # lubridate functions # https://www.postgresql.org/docs/9.1/functions-datetime.html diff --git a/R/backend-snowflake.R b/R/backend-snowflake.R index cef7b7057..4747ce82b 100644 --- a/R/backend-snowflake.R +++ b/R/backend-snowflake.R @@ -42,11 +42,19 @@ sql_translation.Snowflake <- function(con) { # str_detect does not. Left- and right-pad `pattern` with .* to get # str_detect-like behavior str_detect = function(string, pattern, negate = FALSE) { - if (isTRUE(negate)) { - sql_expr(!(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*"))) - } else { - sql_expr(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*")) - } + sql_str_pattern_switch( + string = string, + pattern = {{ pattern }}, + negate = negate, + f_fixed = sql_str_detect_fixed_instr("detect"), + f_regex = function(string, pattern, negate = FALSE) { + if (isTRUE(negate)) { + sql_expr(!(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*"))) + } else { + sql_expr(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*")) + } + } + ) }, # On Snowflake, REGEXP_REPLACE is used like this: # REGEXP_REPLACE( , [ , , diff --git a/R/translate-sql-string.R b/R/translate-sql-string.R index 9f068d633..2334bd666 100644 --- a/R/translate-sql-string.R +++ b/R/translate-sql-string.R @@ -72,4 +72,86 @@ sql_str_trim <- function(string, side = c("both", "left", "right")) { both = sql_expr(ltrim(rtrim(!!string))), ) } + + + +sql_str_pattern_switch <- function(string, + pattern, + negate = FALSE, + f_fixed = NULL, + f_regex = NULL, + error_call = caller_env()) { + pattern_quo <- enquo(pattern) + is_fixed <- quo_is_call(pattern_quo, "fixed") || inherits(pattern, "stringr_fixed") + + if (is_fixed) { + f_fixed(string, pattern, negate) + } else { + if (is_null(f_regex)) { + cli_abort("Only fixed patterns are supported on database backends.", call = error_call) + } else { + f_regex(string, pattern, negate) + } + } +} + +# INSTR +# * SQLite https://www.sqlitetutorial.net/sqlite-functions/sqlite-instr/ +# * MySQL https://dev.mysql.com/doc/refman/8.0/en/string-functions.html#function_instr +# * Oracle https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/INSTR.html#GUID-47E3A7C4-ED72-458D-A1FA-25A9AD3BE113 +# * Teradata https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/String-Operators-and-Functions/INSTR +# * Access https://support.microsoft.com/de-de/office/instr-funktion-85d3392c-3b1c-4232-bb18-77cd0cb8a55b +# * Hana https://help.sap.com/docs/SAP_HANA_PLATFORM/e8e6c8142e60469bb401de5fdb6f7c00/f5a9ca3718354a499a98ba61ae3da170.html +# * Hive https://www.revisitclass.com/hadoop/instr-function-in-hive-with-examples/ +# * Impala https://impala.apache.org/docs/build/html/topics/impala_string_functions.html#string_functions__instr +# POSITION +# * Snowflake https://docs.snowflake.com/en/sql-reference/functions/position +sql_str_detect_fixed_instr <- function(type = c("detect", "start", "end")) { + type <- arg_match(type) + + function(string, pattern, negate = FALSE) { + con <- sql_current_con() + pattern <- unclass(pattern) + index_sql <- glue_sql2(con, "INSTR({.val string}, {.val pattern})") + + if (negate) { + switch(type, + detect = translate_sql(!!index_sql == 0L, con = con), + start = translate_sql(!!index_sql != 1L, con = con), + end = translate_sql(!!index_sql != nchar(!!string) - nchar(!!pattern) + 1L, con = con) + ) + } else { + switch(type, + detect = translate_sql(!!index_sql > 0L, con = con), + start = translate_sql(!!index_sql == 1L, con = con), + end = translate_sql(!!index_sql == nchar(!!string) - nchar(!!pattern) + 1L, con = con) + ) + } + } +} + +sql_str_detect_fixed_position <- function(type = c("detect", "start", "end")) { + type <- arg_match(type) + + function(string, pattern, negate = FALSE) { + con <- sql_current_con() + pattern <- unclass(pattern) + index_sql <- glue_sql2(con, "POSITION({.val pattern} in {.val string})") + + if (negate) { + switch(type, + detect = translate_sql(!!index_sql == 0L, con = con), + start = translate_sql(!!index_sql != 1L, con = con), + end = translate_sql(!!index_sql != nchar(!!string) - nchar(!!pattern) + 1L, con = con) + ) + } else { + switch(type, + detect = translate_sql(!!index_sql > 0L, con = con), + start = translate_sql(!!index_sql == 1L, con = con), + end = translate_sql(!!index_sql == nchar(!!string) - nchar(!!pattern) + 1L, con = con) + ) + } + } +} + utils::globalVariables(c("ltrim", "rtrim")) diff --git a/tests/testthat/test-backend-postgres.R b/tests/testthat/test-backend-postgres.R index 5afb7087b..340ed06ce 100644 --- a/tests/testthat/test-backend-postgres.R +++ b/tests/testthat/test-backend-postgres.R @@ -24,6 +24,19 @@ test_that("custom stringr functions translated correctly", { expect_equal(test_translate_sql(str_squish(x)), sql("LTRIM(RTRIM(REGEXP_REPLACE(`x`, '\\s+', ' ', 'g')))")) expect_equal(test_translate_sql(str_remove(x, y)), sql("REGEXP_REPLACE(`x`, `y`, '')")) expect_equal(test_translate_sql(str_remove_all(x, y)), sql("REGEXP_REPLACE(`x`, `y`, '', 'g')")) + + expect_equal( + test_translate_sql(str_detect(x, fixed("%0"))), + sql("POSITION('%0' in `x`) > 0") + ) + expect_equal( + test_translate_sql(str_starts(x, fixed("%0"))), + sql("POSITION('%0' in `x`) = 1") + ) + expect_equal( + test_translate_sql(str_ends(x, fixed("%0"))), + sql("POSITION('%0' in `x`) = ((LENGTH(`x`) - LENGTH('%0')) + 1)") + ) }) test_that("two variable aggregates are translated correctly", { diff --git a/tests/testthat/test-translate-sql-string.R b/tests/testthat/test-translate-sql-string.R index 95f6ed5d2..b62c2d844 100644 --- a/tests/testthat/test-translate-sql-string.R +++ b/tests/testthat/test-translate-sql-string.R @@ -67,3 +67,37 @@ test_that("str_sub() returns consistent results", { expect_equal(mf %>% transmute(str_sub(t, 0, 1)) %>% pull(1), "a") expect_equal(mf %>% transmute(str_sub(t, 1, 3)) %>% pull(1), "abc") }) + +test_that("str_detect(), str_starts(), str_ends() support fixed patterns", { + mf <- memdb_frame(x = c("%0 start", "end %0", "detect %0 detect", "no", NA)) + + # detects fixed pattern + expect_equal( + mf %>% transmute(str_starts(x, fixed("%0"))) %>% pull(1), + c(1, 0, 0, 0, NA) + ) + # hack to avoid check complaining about not declared imports + pattern <- rlang::parse_expr("stringr::fixed('%0')") + expect_equal( + mf %>% transmute(str_starts(x, !!pattern)) %>% pull(1), + c(1, 0, 0, 0, NA) + ) + + # also works with ends and detect + expect_equal( + mf %>% transmute(str_ends(x, fixed("%0"))) %>% pull(1), + c(0, 1, 0, 0, NA) + ) + expect_equal( + mf %>% transmute(str_detect(x, fixed("%0"))) %>% pull(1), + c(1, 1, 1, 0, NA) + ) + + # negate works + expect_equal( + mf %>% transmute(str_detect(x, fixed("%0"), negate = TRUE)) %>% pull(1), + c(0, 0, 0, 1, NA) + ) + + expect_error(translate_sql(str_detect(x, "a"), con = simulate_dbi())) +})