diff --git a/R/backend-snowflake.R b/R/backend-snowflake.R index f6fea41a0..e5bfb609f 100644 --- a/R/backend-snowflake.R +++ b/R/backend-snowflake.R @@ -186,31 +186,33 @@ sql_translation.Snowflake <- function(con) { }, # https://docs.snowflake.com/en/sql-reference/functions/date_trunc.html floor_date = function(x, unit = "seconds") { - unit <- arg_match(unit, - c("second", "minute", "hour", "day", "week", "month", "quarter", "year", - "seconds", "minutes", "hours", "days", "weeks", "months", "quarters", "years") + unit <- arg_match( + unit, + c( + "second", "minute", "hour", "day", "week", "month", "quarter", "year", + "seconds", "minutes", "hours", "days", "weeks", "months", "quarters", "years" + ) ) sql_expr(DATE_TRUNC(!!unit, !!x)) }, # LEAST / GREATEST on Snowflake will not respect na.rm = TRUE by default (similar to Oracle/Access) # https://docs.snowflake.com/en/sql-reference/functions/least # https://docs.snowflake.com/en/sql-reference/functions/greatest - # Solution source: https://stackoverflow.com/a/74529349/22193215 + # Support two columns only with na.rm = TRUE (mirrors Access) pmin = function(..., na.rm = FALSE) { + dots <- snowflake_pmin_pmax_na_rm(..., na.rm = na.rm) if (identical(na.rm, TRUE)) { - dots <- snowflake_pmin_pmax_concat_dots(..., na.rm = na.rm, negate = TRUE) - glue_sql2(sql_current_con(), "-(GREATEST({dots})[0]::FLOAT)") + sql_expr(IFF(!!dots$x <= !!dots$y, !!dots$x, !!dots$y)) } else { - dots <- snowflake_pmin_pmax_concat_dots(..., na.rm = na.rm) - glue_sql2(sql_current_con(), "LEAST({dots})") + glue_sql2(sql_current_con(), "LEAST({.val dots*})") } }, pmax = function(..., na.rm = FALSE) { - dots <- snowflake_pmin_pmax_concat_dots(..., na.rm = na.rm) + dots <- snowflake_pmin_pmax_na_rm(..., na.rm = na.rm) if (identical(na.rm, TRUE)) { - glue_sql2(sql_current_con(), "GREATEST({dots})[0]::FLOAT") + sql_expr(IFF(!!dots$x <= !!dots$y, !!dots$y, !!dots$x)) } else { - glue_sql2(sql_current_con(), "GREATEST({dots})") + glue_sql2(sql_current_con(), "GREATEST({.val dots*})") } } ), @@ -267,7 +269,7 @@ snowflake_grepl <- function(pattern, # REGEXP on Snowflaake "implicitly anchors a pattern at both ends", which # grepl does not. Left- and right-pad `pattern` with .* to get grepl-like # behavior - sql_expr(((!!x)) %REGEXP% (".*" || !!paste0('(', pattern, ')') || ".*")) + sql_expr(((!!x)) %REGEXP% (".*" || !!paste0("(", pattern, ")") || ".*")) } snowflake_round <- function(x, digits = 0L) { @@ -287,19 +289,17 @@ snowflake_paste <- function(default_sep) { } } -snowflake_pmin_pmax_concat_dots <- function(..., na.rm = FALSE, negate = FALSE){ +snowflake_pmin_pmax_na_rm <- function(..., na.rm = FALSE) { dots <- list(...) - if(isTRUE(negate)) { - dots <- dots %>% - purrr::map(~glue('-{.x}')) - } - if (isTRUE(na.rm)){ - dots <- dots %>% - purrr::map(~glue('[{.x}]')) + if (identical(na.rm, TRUE)) { + if (length(dots) > 2) cli::cli_abort("pmin()/pmax() with na.rm = TRUE currently only supports two columns for Snowflake") + list( + x = dots[[1]], + y = dots[[2]] + ) + } else { + dots } - - dots %>% - paste(collapse = ", ") } utils::globalVariables(c("%REGEXP%", "DAYNAME", "DECODE", "FLOAT", "MONTHNAME", "POSITION", "trim")) diff --git a/tests/testthat/test-backend-snowflake.R b/tests/testthat/test-backend-snowflake.R index b01a851e4..75bc847c4 100644 --- a/tests/testthat/test-backend-snowflake.R +++ b/tests/testthat/test-backend-snowflake.R @@ -2,17 +2,17 @@ test_that("custom scalar translated correctly", { local_con(simulate_snowflake()) expect_equal(test_translate_sql(log10(x)), sql("LOG(10.0, `x`)")) expect_equal(test_translate_sql(round(x, digits = 1.1)), sql("ROUND((`x`) :: FLOAT, 1)")) - expect_equal(test_translate_sql(grepl("exp", x)), sql("(`x`) REGEXP ('.*' || '(exp)' || '.*')")) + expect_equal(test_translate_sql(grepl("exp", x)), sql("(`x`) REGEXP ('.*' || '(exp)' || '.*')")) expect_snapshot((expect_error(test_translate_sql(grepl("exp", x, ignore.case = TRUE))))) }) test_that("pasting translated correctly", { local_con(simulate_snowflake()) - expect_equal(test_translate_sql(paste(x, y)), sql("ARRAY_TO_STRING(ARRAY_CONSTRUCT_COMPACT(`x`, `y`), ' ')")) + expect_equal(test_translate_sql(paste(x, y)), sql("ARRAY_TO_STRING(ARRAY_CONSTRUCT_COMPACT(`x`, `y`), ' ')")) expect_equal(test_translate_sql(paste0(x, y)), sql("ARRAY_TO_STRING(ARRAY_CONSTRUCT_COMPACT(`x`, `y`), '')")) expect_equal(test_translate_sql(str_c(x, y)), sql("CONCAT_WS('', `x`, `y`)")) - expect_equal(test_translate_sql(str_c(x, y, sep = '|')), sql("CONCAT_WS('|', `x`, `y`)")) + expect_equal(test_translate_sql(str_c(x, y, sep = "|")), sql("CONCAT_WS('|', `x`, `y`)")) expect_error(test_translate_sql(paste0(x, collapse = "")), "`collapse` not supported") @@ -40,33 +40,33 @@ test_that("aggregates are translated correctly", { local_con(simulate_snowflake()) expect_equal(test_translate_sql(cor(x, y), window = FALSE), sql("CORR(`x`, `y`)")) - expect_equal(test_translate_sql(cor(x, y), window = TRUE), sql("CORR(`x`, `y`) OVER ()")) + expect_equal(test_translate_sql(cor(x, y), window = TRUE), sql("CORR(`x`, `y`) OVER ()")) expect_equal(test_translate_sql(cov(x, y), window = FALSE), sql("COVAR_SAMP(`x`, `y`)")) - expect_equal(test_translate_sql(cov(x, y), window = TRUE), sql("COVAR_SAMP(`x`, `y`) OVER ()")) + expect_equal(test_translate_sql(cov(x, y), window = TRUE), sql("COVAR_SAMP(`x`, `y`) OVER ()")) expect_equal(test_translate_sql(all(x, na.rm = TRUE), window = FALSE), sql("BOOLAND_AGG(`x`)")) - expect_equal(test_translate_sql(all(x, na.rm = TRUE), window = TRUE), sql("BOOLAND_AGG(`x`) OVER ()")) + expect_equal(test_translate_sql(all(x, na.rm = TRUE), window = TRUE), sql("BOOLAND_AGG(`x`) OVER ()")) expect_equal(test_translate_sql(any(x, na.rm = TRUE), window = FALSE), sql("BOOLOR_AGG(`x`)")) - expect_equal(test_translate_sql(any(x, na.rm = TRUE), window = TRUE), sql("BOOLOR_AGG(`x`) OVER ()")) + expect_equal(test_translate_sql(any(x, na.rm = TRUE), window = TRUE), sql("BOOLOR_AGG(`x`) OVER ()")) expect_equal(test_translate_sql(sd(x, na.rm = TRUE), window = FALSE), sql("STDDEV(`x`)")) - expect_equal(test_translate_sql(sd(x, na.rm = TRUE), window = TRUE), sql("STDDEV(`x`) OVER ()")) + expect_equal(test_translate_sql(sd(x, na.rm = TRUE), window = TRUE), sql("STDDEV(`x`) OVER ()")) }) test_that("snowflake mimics two argument log", { local_con(simulate_snowflake()) - expect_equal(test_translate_sql(log(x)), sql('LN(`x`)')) - expect_equal(test_translate_sql(log(x, 10)), sql('LOG(10.0, `x`)')) - expect_equal(test_translate_sql(log(x, 10L)), sql('LOG(10, `x`)')) + expect_equal(test_translate_sql(log(x)), sql("LN(`x`)")) + expect_equal(test_translate_sql(log(x, 10)), sql("LOG(10.0, `x`)")) + expect_equal(test_translate_sql(log(x, 10L)), sql("LOG(10, `x`)")) }) test_that("custom lubridate functions translated correctly", { local_con(simulate_snowflake()) - expect_equal(test_translate_sql(day(x)), sql("EXTRACT(DAY FROM `x`)")) + expect_equal(test_translate_sql(day(x)), sql("EXTRACT(DAY FROM `x`)")) expect_equal(test_translate_sql(mday(x)), sql("EXTRACT(DAY FROM `x`)")) expect_equal(test_translate_sql(yday(x)), sql("EXTRACT('dayofyear', `x`)")) expect_equal(test_translate_sql(wday(x)), sql("EXTRACT('dayofweek', DATE(`x`) + 0) + 1.0")) @@ -88,21 +88,21 @@ test_that("custom lubridate functions translated correctly", { expect_equal(test_translate_sql(seconds(x)), sql("INTERVAL '`x` second'")) expect_equal(test_translate_sql(minutes(x)), sql("INTERVAL '`x` minute'")) - expect_equal(test_translate_sql(hours(x)), sql("INTERVAL '`x` hour'")) - expect_equal(test_translate_sql(days(x)), sql("INTERVAL '`x` day'")) - expect_equal(test_translate_sql(weeks(x)), sql("INTERVAL '`x` week'")) - expect_equal(test_translate_sql(months(x)), sql("INTERVAL '`x` month'")) - expect_equal(test_translate_sql(years(x)), sql("INTERVAL '`x` year'")) - - expect_equal(test_translate_sql(floor_date(x, 'month')), sql("DATE_TRUNC('month', `x`)")) - expect_equal(test_translate_sql(floor_date(x, 'week')), sql("DATE_TRUNC('week', `x`)")) + expect_equal(test_translate_sql(hours(x)), sql("INTERVAL '`x` hour'")) + expect_equal(test_translate_sql(days(x)), sql("INTERVAL '`x` day'")) + expect_equal(test_translate_sql(weeks(x)), sql("INTERVAL '`x` week'")) + expect_equal(test_translate_sql(months(x)), sql("INTERVAL '`x` month'")) + expect_equal(test_translate_sql(years(x)), sql("INTERVAL '`x` year'")) + + expect_equal(test_translate_sql(floor_date(x, "month")), sql("DATE_TRUNC('month', `x`)")) + expect_equal(test_translate_sql(floor_date(x, "week")), sql("DATE_TRUNC('week', `x`)")) }) test_that("min() and max()", { local_con(simulate_snowflake()) - expect_equal(test_translate_sql(min(x, na.rm = TRUE)), sql('MIN(`x`) OVER ()')) - expect_equal(test_translate_sql(max(x, na.rm = TRUE)), sql('MAX(`x`) OVER ()')) + expect_equal(test_translate_sql(min(x, na.rm = TRUE)), sql("MIN(`x`) OVER ()")) + expect_equal(test_translate_sql(max(x, na.rm = TRUE)), sql("MAX(`x`) OVER ()")) # na.rm = FALSE is ignored # https://docs.snowflake.com/en/sql-reference/functions/min @@ -121,16 +121,19 @@ test_that("min() and max()", { test_that("pmin() and pmax() respect na.rm", { local_con(simulate_snowflake()) + test_translate_sql(pmin(x, y, na.rm = TRUE)) # Snowflake default for LEAST/GREATEST: If any of the argument values is NULL, the result is NULL. # https://docs.snowflake.com/en/sql-reference/functions/least # https://docs.snowflake.com/en/sql-reference/functions/greatest - # na.rm = TRUE: override default behavior for Snowflake - expect_equal(test_translate_sql(pmin(x, y, z, na.rm = TRUE)), sql('-(GREATEST([-`x`], [-`y`], [-`z`])[0]::FLOAT)')) - expect_equal(test_translate_sql(pmax(x, y, z, na.rm = TRUE)), sql('GREATEST([`x`], [`y`], [`z`])[0]::FLOAT')) + # na.rm = TRUE: override default behavior for Snowflake (only supports pairs) + expect_equal(test_translate_sql(pmin(x, y, na.rm = TRUE)), sql("IFF(`x` <= `y`, `x`, `y`)")) + expect_equal(test_translate_sql(pmax(x, y, na.rm = TRUE)), sql("IFF(`x` <= `y`, `y`, `x`)")) + + expect_error(test_translate_sql(pmax(x, y, z, na.rm = TRUE))) # na.rm = FALSE: leverage default behavior for Snowflake - expect_equal(test_translate_sql(pmin(x, y, z, na.rm = FALSE)), sql('LEAST(`x`, `y`, `z`)')) - expect_equal(test_translate_sql(pmax(x, y, z, na.rm = FALSE)), sql('GREATEST(`x`, `y`, `z`)')) + expect_equal(test_translate_sql(pmin(x, y, z, na.rm = FALSE)), sql("LEAST(`x`, `y`, `z`)")) + expect_equal(test_translate_sql(pmax(x, y, z, na.rm = FALSE)), sql("GREATEST(`x`, `y`, `z`)")) })