Skip to content

Commit

Permalink
Better support for string matching in Snowflake (#1406)
Browse files Browse the repository at this point in the history
* Added support for `str_starts()` and `str_ends()` by using Snowflake's
`REGEXP_INSTR()` function
* Refactored `str_detect()` to use Snowflake's `REGEXP_INSTR()` function.
* Ensure escape characters are escaped
  • Loading branch information
nathanhaigh authored Dec 20, 2023
1 parent 68767be commit e89aa96
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 36 deletions.
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# dbplyr (development version)

* Snowflake (@nathanhaigh, #1406)
* Added support for `str_starts()` and `str_ends()` via `REGEXP_INSTR()`
* Refactored `str_detect()` to use `REGEXP_INSTR()` so now supports
regular expressions.
* Refactored `grepl()` to use `REGEXP_INSTR()` so now supports
case-insensitive matching through `grepl(..., ignore.case = TRUE)`

* Functions qualified with the base namespace are now also translated, e.g.
`base::paste0(x, "_1")` is now translated (@mgirlich, #1022).

Expand Down
63 changes: 41 additions & 22 deletions R/backend-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,38 @@ sql_translation.Snowflake <- function(con) {
str_locate = function(string, pattern) {
sql_expr(POSITION(!!pattern, !!string))
},
# REGEXP on Snowflaake "implicitly anchors a pattern at both ends", which
# str_detect does not. Left- and right-pad `pattern` with .* to get
# str_detect-like behavior
str_detect = function(string, pattern, negate = FALSE) {
sql_str_pattern_switch(
string = string,
pattern = {{ pattern }},
negate = negate,
f_fixed = sql_str_detect_fixed_instr("detect"),
f_regex = function(string, pattern, negate = FALSE) {
if (isTRUE(negate)) {
sql_expr(!(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*")))
} else {
sql_expr(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*"))
}
}
)
con <- sql_current_con()

# Snowflake needs backslashes escaped, so we must increase the level of escaping
pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE)
if (negate) {
translate_sql(REGEXP_INSTR(!!string, !!pattern) == 0L, con = con)
} else {
translate_sql(REGEXP_INSTR(!!string, !!pattern) != 0L, con = con)
}
},
str_starts = function(string, pattern, negate = FALSE) {
con <- sql_current_con()

# Snowflake needs backslashes escaped, so we must increase the level of escaping
pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE)
if (negate) {
translate_sql(REGEXP_INSTR(!!string, !!pattern) != 1L, con = con)
} else {
translate_sql(REGEXP_INSTR(!!string, !!pattern) == 1L, con = con)
}
},
str_ends = function(string, pattern, negate = FALSE) {
con <- sql_current_con()

# Snowflake needs backslashes escaped, so we must increase the level of escaping
pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE)
if (negate) {
translate_sql(REGEXP_INSTR(!!string, !!pattern, 1L, 1L, 1L) != LENGTH(!!string) + 1L, con = con)
} else {
translate_sql(REGEXP_INSTR(!!string, !!pattern, 1L, 1L, 1L) == LENGTH(!!string) + 1L, con = con)
}
},
# On Snowflake, REGEXP_REPLACE is used like this:
# REGEXP_REPLACE( <subject> , <pattern> [ , <replacement> ,
Expand Down Expand Up @@ -261,15 +276,19 @@ snowflake_grepl <- function(pattern,
perl = FALSE,
fixed = FALSE,
useBytes = FALSE) {
# https://docs.snowflake.com/en/sql-reference/functions/regexp.html
check_unsupported_arg(ignore.case, FALSE, backend = "Snowflake")
con <- sql_current_con()

check_unsupported_arg(perl, FALSE, backend = "Snowflake")
check_unsupported_arg(fixed, FALSE, backend = "Snowflake")
check_unsupported_arg(useBytes, FALSE, backend = "Snowflake")
# REGEXP on Snowflaake "implicitly anchors a pattern at both ends", which
# grepl does not. Left- and right-pad `pattern` with .* to get grepl-like
# behavior
sql_expr(((!!x)) %REGEXP% (".*" || !!paste0("(", pattern, ")") || ".*"))

# https://docs.snowflake.com/en/sql-reference/functions/regexp_instr.html
# REGEXP_INSTR optional parameters: position, occurrance, option, regex_parameters
regexp_parameters <- "c"
if(ignore.case) { regexp_parameters <- "i" }
# Snowflake needs backslashes escaped, so we must increase the level of escaping
pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE)
translate_sql(REGEXP_INSTR(!!x, !!pattern, 1L, 1L, 0L, !!regexp_parameters) != 0L, con = con)
}

snowflake_round <- function(x, digits = 0L) {
Expand Down
10 changes: 0 additions & 10 deletions tests/testthat/_snaps/backend-snowflake.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
# custom scalar translated correctly

Code
(expect_error(test_translate_sql(grepl("exp", x, ignore.case = TRUE))))
Output
<error/rlang_error>
Error in `grepl()`:
! `ignore.case = TRUE` isn't supported in Snowflake translation.
i It must be FALSE instead.

# pmin() and pmax() respect na.rm

Code
Expand Down
12 changes: 8 additions & 4 deletions tests/testthat/test-backend-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ test_that("custom scalar translated correctly", {
local_con(simulate_snowflake())
expect_equal(test_translate_sql(log10(x)), sql("LOG(10.0, `x`)"))
expect_equal(test_translate_sql(round(x, digits = 1.1)), sql("ROUND((`x`) :: FLOAT, 1)"))
expect_equal(test_translate_sql(grepl("exp", x)), sql("(`x`) REGEXP ('.*' || '(exp)' || '.*')"))
expect_snapshot((expect_error(test_translate_sql(grepl("exp", x, ignore.case = TRUE)))))
expect_equal(test_translate_sql(grepl("exp", x)), sql("REGEXP_INSTR(`x`, 'exp', 1, 1, 0, 'c') != 0"))
expect_equal(test_translate_sql(grepl("exp", x, ignore.case = TRUE)), sql("REGEXP_INSTR(`x`, 'exp', 1, 1, 0, 'i') != 0"))
})

test_that("pasting translated correctly", {
Expand All @@ -25,15 +25,19 @@ test_that("custom stringr functions translated correctly", {
local_con(simulate_snowflake())

expect_equal(test_translate_sql(str_locate(x, y)), sql("POSITION(`y`, `x`)"))
expect_equal(test_translate_sql(str_detect(x, y)), sql("(`x`) REGEXP ('.*' || `y` || '.*')"))
expect_equal(test_translate_sql(str_detect(x, y, negate = TRUE)), sql("!((`x`) REGEXP ('.*' || `y` || '.*'))"))
expect_equal(test_translate_sql(str_detect(x, y)), sql("REGEXP_INSTR(`x`, `y`) != 0"))
expect_equal(test_translate_sql(str_detect(x, y, negate = TRUE)), sql("REGEXP_INSTR(`x`, `y`) = 0"))
expect_equal(test_translate_sql(str_replace(x, y, z)), sql("REGEXP_REPLACE(`x`, `y`, `z`, 1.0, 1.0)"))
expect_equal(test_translate_sql(str_replace(x, "\\d", z)), sql("REGEXP_REPLACE(`x`, '\\\\d', `z`, 1.0, 1.0)"))
expect_equal(test_translate_sql(str_replace_all(x, y, z)), sql("REGEXP_REPLACE(`x`, `y`, `z`)"))
expect_equal(test_translate_sql(str_squish(x)), sql("REGEXP_REPLACE(TRIM(`x`), '\\\\s+', ' ')"))
expect_equal(test_translate_sql(str_remove(x, y)), sql("REGEXP_REPLACE(`x`, `y`, '', 1.0, 1.0)"))
expect_equal(test_translate_sql(str_remove_all(x, y)), sql("REGEXP_REPLACE(`x`, `y`)"))
expect_equal(test_translate_sql(str_trim(x)), sql("TRIM(`x`)"))
expect_equal(test_translate_sql(str_starts(x, y)), sql("REGEXP_INSTR(`x`, `y`) = 1"))
expect_equal(test_translate_sql(str_starts(x, y, negate = TRUE)), sql("REGEXP_INSTR(`x`, `y`) != 1"))
expect_equal(test_translate_sql(str_ends(x, y)), sql("REGEXP_INSTR(`x`, `y`, 1, 1, 1) = (LENGTH(`x`) + 1)"))
expect_equal(test_translate_sql(str_ends(x, y, negate = TRUE)), sql("REGEXP_INSTR(`x`, `y`, 1, 1, 1) != (LENGTH(`x`) + 1)"))
})

test_that("aggregates are translated correctly", {
Expand Down

0 comments on commit e89aa96

Please sign in to comment.