diff --git a/NAMESPACE b/NAMESPACE index 5f21cf0a8..d7a2d697a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -50,9 +50,12 @@ S3method(db_copy_to,DBIConnection) S3method(db_create_index,DBIConnection) S3method(db_desc,DBIConnection) S3method(db_explain,DBIConnection) +S3method(db_explain,OraConnection) +S3method(db_explain,Oracle) S3method(db_query_fields,DBIConnection) S3method(db_query_fields,PostgreSQLConnection) S3method(db_save_query,DBIConnection) +S3method(db_sql_render,"Microsoft SQL Server") S3method(db_sql_render,DBIConnection) S3method(db_supports_table_alias_with_as,DBIConnection) S3method(db_supports_table_alias_with_as,OraConnection) diff --git a/NEWS.md b/NEWS.md index 45c2e46f0..c5a3c5fc5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,44 @@ # dbplyr (development version) +* Namespaced dplyr calls now error if the function doesn't exist, or + a translation is not available (#1426). + +* `db_sql_render()` correctly passes on `...` when re-calling with + `sql_options` set (#1394). + +* `-1 + x` is now translated correctly (#1420). + +* SQL server: clear error if you attempt to use `n_distinct()` in `mutate()` + or `filter()` (#1366). + +* Add translations for clock functions `add_years()`, `add_days()`, + `date_build()`, `get_year()`, `get_month()`, `get_day()`, + and `base::difftime()` on SQL server, Redshift, Snowflake, and Postgres. + +* SQL server: `filter()` does a better job of converting logical vectors + from bit to boolean (@ejneer, #1288). + +* Oracle: Added support for `str_replace()` and `str_replace_all()` via + `REGEXP_REPLACE()` (@thomashulst, #1402). + +* Allow additional arguments to be passed from `compute()` all the way to + `sql_query_save()`-method (@rsund). + +* The class of remote sources now includes all S4 class names, not just + the first (#918). + +* `db_explain()` now works for Oracle (@thomashulst, #1353). + +* Database errors now show the generated SQL, which hopefully will make it + faster to track down problems (#1401). + +* Snowflake (@nathanhaigh, #1406) + * Added support for `str_starts()` and `str_ends()` via `REGEXP_INSTR()` + * Refactored `str_detect()` to use `REGEXP_INSTR()` so now supports + regular expressions. + * Refactored `grepl()` to use `REGEXP_INSTR()` so now supports + case-insensitive matching through `grepl(..., ignore.case = TRUE)` + * Functions qualified with the base namespace are now also translated, e.g. `base::paste0(x, "_1")` is now translated (@mgirlich, #1022). diff --git a/R/backend-mssql.R b/R/backend-mssql.R index a50fe7df7..e4e9c4671 100644 --- a/R/backend-mssql.R +++ b/R/backend-mssql.R @@ -5,13 +5,13 @@ #' details of overall translation technology. Key differences for this backend #' are: #' -#' * `SELECT` uses `TOP` not `LIMIT` -#' * Automatically prefixes `#` to create temporary tables. Add the prefix +#' - `SELECT` uses `TOP` not `LIMIT` +#' - Automatically prefixes `#` to create temporary tables. Add the prefix #' yourself to avoid the message. -#' * String basics: `paste()`, `substr()`, `nchar()` -#' * Custom types for `as.*` functions -#' * Lubridate extraction functions, `year()`, `month()`, `day()` etc -#' * Semi-automated bit <-> boolean translation (see below) +#' - String basics: `paste()`, `substr()`, `nchar()` +#' - Custom types for `as.*` functions +#' - Lubridate extraction functions, `year()`, `month()`, `day()` etc +#' - Semi-automated bit <-> boolean translation (see below) #' #' Use `simulate_mssql()` with `lazy_frame()` to see simulated SQL without #' converting to live access database. @@ -350,6 +350,41 @@ simulate_mssql <- function(version = "15.0") { sql_expr(DATEPART(QUARTER, !!x)) } }, + + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(DAY, !!n, !!x)) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(YEAR, !!n, !!x)) + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + sql_expr(DATEFROMPARTS(!!year, !!month, !!day)) + }, + get_year = function(x) { + sql_expr(DATEPART('year', !!x)) + }, + get_month = function(x) { + sql_expr(DATEPART('month', !!x)) + }, + get_day = function(x) { + sql_expr(DATEPART('day', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(DATEDIFF(day, !!time1, !!time2)) + } ) if (mssql_version(con) >= "11.0") { # MSSQL 2012 @@ -434,7 +469,13 @@ simulate_mssql <- function(version = "15.0") { }, all = mssql_bit_int_bit(win_aggregate("MIN")), any = mssql_bit_int_bit(win_aggregate("MAX")), - row_number = win_rank("ROW_NUMBER", empty_order = TRUE) + row_number = win_rank("ROW_NUMBER", empty_order = TRUE), + + n_distinct = function(x) { + cli_abort( + "No translation available in `mutate()`/`filter()` for SQL server." + ) + } ) )} @@ -582,4 +623,40 @@ mssql_bit_int_bit <- function(f) { dplyr::if_else(x, "1", "0", "NULL") } +#' @export +`db_sql_render.Microsoft SQL Server` <- function(con, sql, ..., cte = FALSE, use_star = TRUE) { + # Post-process WHERE to cast logicals from BIT to BOOLEAN + sql$lazy_query <- purrr::modify_tree( + sql$lazy_query, + is_node = function(x) inherits(x, "lazy_query"), + post = mssql_update_where_clause + ) + + NextMethod() +} + +mssql_update_where_clause <- function(qry) { + if (!has_name(qry, "where")) { + return(qry) + } + + qry$where <- lapply( + qry$where, + function(x) set_expr(x, bit_to_boolean(get_expr(x))) + ) + qry +} + +bit_to_boolean <- function(x_expr) { + if (is_atomic(x_expr) || is_symbol(x_expr)) { + expr(cast(!!x_expr %AS% BIT) == 1L) + } else if (is_call(x_expr, c("|", "&", "||", "&&", "!", "("))) { + idx <- seq2(2, length(x_expr)) + x_expr[idx] <- lapply(x_expr[idx], bit_to_boolean) + x_expr + } else { + x_expr + } +} + utils::globalVariables(c("BIT", "CAST", "%AS%", "%is%", "convert", "DATE", "DATENAME", "DATEPART", "IIF", "NOT", "SUBSTRING", "LTRIM", "RTRIM", "CHARINDEX", "SYSDATETIME", "SECOND", "MINUTE", "HOUR", "DAY", "DAYOFWEEK", "DAYOFYEAR", "MONTH", "QUARTER", "YEAR", "BIGINT", "INT", "%AND%", "%BETWEEN%")) diff --git a/R/backend-oracle.R b/R/backend-oracle.R index dcae5b1d2..fca852660 100644 --- a/R/backend-oracle.R +++ b/R/backend-oracle.R @@ -133,9 +133,43 @@ sql_translation.Oracle <- function(con) { paste0 = sql_paste_infix("", "||", function(x) sql_expr(cast(!!x %as% text))), str_c = sql_paste_infix("", "||", function(x) sql_expr(cast(!!x %as% text))), + # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/REGEXP_REPLACE.html + # 4th argument is starting position (default: 1 => first char of string) + # 5th argument is occurrence (default: 0 => match all occurrences) + str_replace = function(string, pattern, replacement){ + sql_expr(regexp_replace(!!string, !!pattern, !!replacement, 1L, 1L)) + }, + str_replace_all = function(string, pattern, replacement){ + sql_expr(regexp_replace(!!string, !!pattern, !!replacement)) + }, + # lubridate -------------------------------------------------------------- today = function() sql_expr(TRUNC(CURRENT_TIMESTAMP)), - now = function() sql_expr(CURRENT_TIMESTAMP) + now = function() sql_expr(CURRENT_TIMESTAMP), + + # clock ------------------------------------------------------------------ + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr((!!x + NUMTODSINTERVAL(!!n, 'day'))) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr((!!x + NUMTODSINTERVAL(!!n * 365.25, 'day'))) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(CEIL(CAST(!!time2 %AS% DATE) - CAST(!!time1 %AS% DATE))) + } + ), base_odbc_agg, base_odbc_win @@ -144,10 +178,11 @@ sql_translation.Oracle <- function(con) { #' @export sql_query_explain.Oracle <- function(con, sql, ...) { - glue_sql2( - con, - "EXPLAIN PLAN FOR {sql};\n", - "SELECT PLAN_TABLE_OUTPUT FROM TABLE(DBMS_XPLAN.DISPLAY()));", + + # https://docs.oracle.com/en/database/oracle/oracle-database/19/tgsql/generating-and-displaying-execution-plans.html + c( + glue_sql2(con, "EXPLAIN PLAN FOR {sql}"), + glue_sql2(con, "SELECT PLAN_TABLE_OUTPUT FROM TABLE(DBMS_XPLAN.DISPLAY())") ) } @@ -182,6 +217,18 @@ sql_expr_matches.Oracle <- function(con, x, y, ...) { glue_sql2(con, "decode({x}, {y}, 0, 1) = 0") } +#' @export +db_explain.Oracle <- function(con, sql, ...) { + sql <- sql_query_explain(con, sql, ...) + + msg <- "Can't explain query." + db_execute(con, sql[[1]], msg) # EXPLAIN PLAN + expl <- db_get_query(con, sql[[2]], msg) # DBMS_XPLAN.DISPLAY + + out <- utils::capture.output(print(expl)) + paste(out, collapse = "\n") +} + #' @export db_supports_table_alias_with_as.Oracle <- function(con) { FALSE @@ -219,6 +266,9 @@ setdiff.OraConnection <- setdiff.tbl_Oracle #' @export sql_expr_matches.OraConnection <- sql_expr_matches.Oracle +#' @export +db_explain.OraConnection <- db_explain.Oracle + #' @export db_supports_table_alias_with_as.OraConnection <- db_supports_table_alias_with_as.Oracle diff --git a/R/backend-postgres.R b/R/backend-postgres.R index 929b1114b..b2fe756ce 100644 --- a/R/backend-postgres.R +++ b/R/backend-postgres.R @@ -235,6 +235,41 @@ sql_translation.PqConnection <- function(con) { ) sql_expr(DATE_TRUNC(!!unit, !!x)) }, + + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + glue_sql2(sql_current_con(), "({.col x} + {.val n}*INTERVAL'1 day')") + }, + add_years = function(x, n, ...) { + check_dots_empty() + glue_sql2(sql_current_con(), "({.col x} + {.val n}*INTERVAL'1 year')") + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + sql_expr(make_date(!!year, !!month, !!day)) + }, + get_year = function(x) { + sql_expr(date_part('year', !!x)) + }, + get_month = function(x) { + sql_expr(date_part('month', !!x)) + }, + get_day = function(x) { + sql_expr(date_part('day', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr((CAST(!!time2 %AS% DATE) - CAST(!!time1 %AS% DATE))) + }, ), sql_translator(.parent = base_agg, cor = sql_aggregate_2("CORR"), diff --git a/R/backend-redshift.R b/R/backend-redshift.R index 735085ebb..f40186f3e 100644 --- a/R/backend-redshift.R +++ b/R/backend-redshift.R @@ -60,6 +60,41 @@ sql_translation.RedshiftConnection <- function(con) { str_replace = sql_not_supported("str_replace"), str_replace_all = function(string, pattern, replacement) { sql_expr(REGEXP_REPLACE(!!string, !!pattern, !!replacement)) + }, + + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(DAY, !!n, !!x)) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(YEAR, !!n, !!x)) + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + glue_sql2(sql_current_con(), "TO_DATE(CAST({.val year} AS TEXT) || '-' CAST({.val month} AS TEXT) || '-' || CAST({.val day} AS TEXT)), 'YYYY-MM-DD')") + }, + get_year = function(x) { + sql_expr(DATE_PART('year', !!x)) + }, + get_month = function(x) { + sql_expr(DATE_PART('month', !!x)) + }, + get_day = function(x) { + sql_expr(DATE_PART('day', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(DATEDIFF(day, !!time1, !!time2)) } ), sql_translator(.parent = postgres$aggregate, diff --git a/R/backend-snowflake.R b/R/backend-snowflake.R index 5774be3cf..a72561524 100644 --- a/R/backend-snowflake.R +++ b/R/backend-snowflake.R @@ -38,23 +38,38 @@ sql_translation.Snowflake <- function(con) { str_locate = function(string, pattern) { sql_expr(POSITION(!!pattern, !!string)) }, - # REGEXP on Snowflaake "implicitly anchors a pattern at both ends", which - # str_detect does not. Left- and right-pad `pattern` with .* to get - # str_detect-like behavior str_detect = function(string, pattern, negate = FALSE) { - sql_str_pattern_switch( - string = string, - pattern = {{ pattern }}, - negate = negate, - f_fixed = sql_str_detect_fixed_instr("detect"), - f_regex = function(string, pattern, negate = FALSE) { - if (isTRUE(negate)) { - sql_expr(!(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*"))) - } else { - sql_expr(((!!string)) %REGEXP% (".*" || (!!pattern) || ".*")) - } - } - ) + con <- sql_current_con() + + # Snowflake needs backslashes escaped, so we must increase the level of escaping + pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE) + if (negate) { + translate_sql(REGEXP_INSTR(!!string, !!pattern) == 0L, con = con) + } else { + translate_sql(REGEXP_INSTR(!!string, !!pattern) != 0L, con = con) + } + }, + str_starts = function(string, pattern, negate = FALSE) { + con <- sql_current_con() + + # Snowflake needs backslashes escaped, so we must increase the level of escaping + pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE) + if (negate) { + translate_sql(REGEXP_INSTR(!!string, !!pattern) != 1L, con = con) + } else { + translate_sql(REGEXP_INSTR(!!string, !!pattern) == 1L, con = con) + } + }, + str_ends = function(string, pattern, negate = FALSE) { + con <- sql_current_con() + + # Snowflake needs backslashes escaped, so we must increase the level of escaping + pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE) + if (negate) { + translate_sql(REGEXP_INSTR(!!string, !!pattern, 1L, 1L, 1L) != LENGTH(!!string) + 1L, con = con) + } else { + translate_sql(REGEXP_INSTR(!!string, !!pattern, 1L, 1L, 1L) == LENGTH(!!string) + 1L, con = con) + } }, # On Snowflake, REGEXP_REPLACE is used like this: # REGEXP_REPLACE( , [ , , @@ -195,6 +210,41 @@ sql_translation.Snowflake <- function(con) { ) sql_expr(DATE_TRUNC(!!unit, !!x)) }, + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(DAY, !!n, !!x)) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(YEAR, !!n, !!x)) + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + # https://docs.snowflake.com/en/sql-reference/functions/date_from_parts + sql_expr(DATE_FROM_PARTS(!!year, !!month, !!day)) + }, + get_year = function(x) { + sql_expr(DATE_PART('year', !!x)) + }, + get_month = function(x) { + sql_expr(DATE_PART('month', !!x)) + }, + get_day = function(x) { + sql_expr(DATE_PART('day', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(DATEDIFF(day, !!time1, !!time2)) + }, # LEAST / GREATEST on Snowflake will not respect na.rm = TRUE by default (similar to Oracle/Access) # https://docs.snowflake.com/en/sql-reference/functions/least # https://docs.snowflake.com/en/sql-reference/functions/greatest @@ -261,15 +311,19 @@ snowflake_grepl <- function(pattern, perl = FALSE, fixed = FALSE, useBytes = FALSE) { - # https://docs.snowflake.com/en/sql-reference/functions/regexp.html - check_unsupported_arg(ignore.case, FALSE, backend = "Snowflake") + con <- sql_current_con() + check_unsupported_arg(perl, FALSE, backend = "Snowflake") check_unsupported_arg(fixed, FALSE, backend = "Snowflake") check_unsupported_arg(useBytes, FALSE, backend = "Snowflake") - # REGEXP on Snowflaake "implicitly anchors a pattern at both ends", which - # grepl does not. Left- and right-pad `pattern` with .* to get grepl-like - # behavior - sql_expr(((!!x)) %REGEXP% (".*" || !!paste0("(", pattern, ")") || ".*")) + + # https://docs.snowflake.com/en/sql-reference/functions/regexp_instr.html + # REGEXP_INSTR optional parameters: position, occurrance, option, regex_parameters + regexp_parameters <- "c" + if(ignore.case) { regexp_parameters <- "i" } + # Snowflake needs backslashes escaped, so we must increase the level of escaping + pattern <- gsub("\\", "\\\\", pattern, fixed = TRUE) + translate_sql(REGEXP_INSTR(!!x, !!pattern, 1L, 1L, 0L, !!regexp_parameters) != 0L, con = con) } snowflake_round <- function(x, digits = 0L) { @@ -301,4 +355,4 @@ snowflake_pmin_pmax_builder <- function(dot_1, dot_2, comparison){ glue_sql2(sql_current_con(), glue("COALESCE(IFF({dot_2} {comparison} {dot_1}, {dot_2}, {dot_1}), {dot_2}, {dot_1})")) } -utils::globalVariables(c("%REGEXP%", "DAYNAME", "DECODE", "FLOAT", "MONTHNAME", "POSITION", "trim")) +utils::globalVariables(c("%REGEXP%", "DAYNAME", "DECODE", "FLOAT", "MONTHNAME", "POSITION", "trim", "LENGTH")) diff --git a/R/backend-spark-sql.R b/R/backend-spark-sql.R index 2ac282427..70b6d4d4b 100644 --- a/R/backend-spark-sql.R +++ b/R/backend-spark-sql.R @@ -36,7 +36,42 @@ simulate_spark_sql <- function() simulate_dbi("Spark SQL") #' @export `sql_translation.Spark SQL` <- function(con) { sql_variant( - base_odbc_scalar, + sql_translator(.parent = base_odbc_scalar, + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr(date_add(!!x, !!n)) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr(add_months(!!!x, !!n*12)) + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + sql_expr(make_date(!!year, !!month, !!day)) + }, + get_year = function(x) { + sql_expr(date_part('YEAR', !!x)) + }, + get_month = function(x) { + sql_expr(date_part('MONTH', !!x)) + }, + get_day = function(x) { + sql_expr(date_part('DAY', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(datediff(!!time2, !!time1)) + } + ), sql_translator(.parent = base_odbc_agg, var = sql_aggregate("VARIANCE", "var"), quantile = sql_quantile("PERCENTILE"), diff --git a/R/db-io.R b/R/db-io.R index 084520620..fffcdb991 100644 --- a/R/db-io.R +++ b/R/db-io.R @@ -71,25 +71,23 @@ db_copy_to.DBIConnection <- function(con, temporary <- new$temporary call <- current_env() - with_transaction(con, in_transaction, { - withCallingHandlers( - { - table <- dplyr::db_write_table(con, table, - types = types, - values = values, - temporary = temporary, - overwrite = overwrite, - ... - ) - create_indexes(con, table, unique_indexes, unique = TRUE) - create_indexes(con, table, indexes) - if (analyze) dbplyr_analyze(con, table) - }, - error = function(cnd) { - cli_abort("Can't copy to table {.field {format(table, con = con)}}.", parent = cnd, call = call) - } - ) - }) + with_transaction( + con, + in_transaction, + "Can't copy data to table {.field {format(table, con = con)}}.", + { + table <- dplyr::db_write_table(con, table, + types = types, + values = values, + temporary = temporary, + overwrite = overwrite, + ... + ) + create_indexes(con, table, unique_indexes, unique = TRUE) + create_indexes(con, table, indexes) + if (analyze) dbplyr_analyze(con, table) + } + ) table } @@ -131,18 +129,24 @@ db_compute.DBIConnection <- function(con, table <- new$table temporary <- new$temporary - with_transaction(con, in_transaction, { - table <- dbplyr_save_query( - con, - sql, - table, - temporary = temporary, - overwrite = overwrite - ) - create_indexes(con, table, unique_indexes, unique = TRUE) - create_indexes(con, table, indexes) - if (analyze) dbplyr_analyze(con, table) - }) + with_transaction( + con, + in_transaction, + "Can't copy query to table {.field {format(table, con = con)}}.", + { + table <- dbplyr_save_query( + con, + sql, + table, + ..., + temporary = temporary, + overwrite = overwrite + ) + create_indexes(con, table, unique_indexes, unique = TRUE) + create_indexes(con, table, indexes) + if (analyze) dbplyr_analyze(con, table) + } + ) table } @@ -156,14 +160,12 @@ db_collect <- function(con, sql, n = -1, warn_incomplete = TRUE, ...) { #' @export db_collect.DBIConnection <- function(con, sql, n = -1, warn_incomplete = TRUE, ...) { res <- dbSendQuery(con, sql) - tryCatch({ - out <- dbFetch(res, n = n) - if (warn_incomplete) { - res_warn_incomplete(res, "n = Inf") - } - }, finally = { - dbClearResult(res) - }) + on.exit(dbClearResult(res), add = TRUE) + + out <- dbFetch(res, n = n) + if (warn_incomplete) { + res_warn_incomplete(res, "n = Inf") + } out } @@ -217,14 +219,23 @@ create_indexes <- function(con, table, indexes = NULL, unique = FALSE, ...) { } } -# Don't use `tryCatch()` because it messes with the callstack -with_transaction <- function(con, in_transaction, code) { +with_transaction <- function(con, + in_transaction, + msg, + code, + call = caller_env(), + env = caller_env()) { if (in_transaction) { dbBegin(con) on.exit(dbRollback(con)) } - code + withCallingHandlers( + code, + error = function(cnd) { + cli_abort(msg, parent = cnd, call = call, .envir = env) + } + ) if (in_transaction) { on.exit() diff --git a/R/db-sql.R b/R/db-sql.R index 536adcc1d..22e66f510 100644 --- a/R/db-sql.R +++ b/R/db-sql.R @@ -1067,13 +1067,8 @@ db_analyze.DBIConnection <- function(con, table, ...) { if (is.null(sql)) { return() # nocov } - withCallingHandlers( - DBI::dbExecute(con, sql), - error = function(cnd) { - msg <- "Can't analyze table {.field {format(table, con = con)}}." - cli_abort(msg, parent = cnd) - } - ) + + db_execute(con, sql, "Can't analyze table {.field {format(table, con = con)}}.") } dbplyr_create_index <- function(con, ...) { @@ -1088,13 +1083,7 @@ db_create_index.DBIConnection <- function(con, unique = FALSE, ...) { sql <- sql_table_index(con, table, columns, name = name, unique = unique, ...) - withCallingHandlers( - DBI::dbExecute(con, sql), - error = function(cnd) { - msg <- "Can't create index on table {.field {format(table, con = con)}}." - cli_abort(msg, parent = cnd) - } - ) + db_execute(con, sql, "Can't create index on table {.field {format(table, con = con)}}.") } dbplyr_explain <- function(con, ...) { @@ -1104,13 +1093,7 @@ dbplyr_explain <- function(con, ...) { #' @importFrom dplyr db_explain db_explain.DBIConnection <- function(con, sql, ...) { sql <- sql_query_explain(con, sql, ...) - call <- current_call() - expl <- withCallingHandlers( - DBI::dbGetQuery(con, sql), - error = function(cnd) { - cli_abort("Can't explain query.", parent = cnd) - } - ) + expl <- db_get_query(con, sql, "Can't explain query.") out <- utils::capture.output(print(expl)) paste(out, collapse = "\n") @@ -1123,12 +1106,7 @@ dbplyr_query_fields <- function(con, ...) { #' @importFrom dplyr db_query_fields db_query_fields.DBIConnection <- function(con, sql, ...) { sql <- sql_query_fields(con, sql, ...) - df <- withCallingHandlers( - DBI::dbGetQuery(con, sql), - error = function(cnd) { - cli_abort("Can't query fields.", parent = cnd) - } - ) + df <- db_get_query(con, sql, "Can't query fields.") names(df) } @@ -1144,24 +1122,17 @@ db_save_query.DBIConnection <- function(con, ..., overwrite = FALSE) { name <- as_table_name(name, con) - sql <- sql_query_save(con, sql(sql), name, temporary = temporary, ...) - withCallingHandlers( - { - if (overwrite) { - found <- DBI::dbExistsTable(con, SQL(name)) - if (found) { - DBI::dbRemoveTable(con, SQL(name)) - } - } - DBI::dbExecute(con, sql, immediate = TRUE) - }, - error = function(cnd) { - cli_abort( - "Can't save query to table {.table {format(name, con = con)}}.", - parent = cnd - ) + sql <- sql_query_save(con, sql, name, temporary = temporary, ...) + + if (overwrite) { + found <- DBI::dbExistsTable(con, SQL(name)) + if (found) { + DBI::dbRemoveTable(con, SQL(name)) } - ) + } + + db_execute(con, sql, "Can't save query to table {.table {format(name, con = con)}}.") + name } @@ -1177,3 +1148,30 @@ sql_subquery.DBIConnection <- function(con, lvl = 0) { sql_query_wrap(con, from = from, name = name, ..., lvl = lvl) } + +# Helpers ------------------------------------------------------------------- + +db_execute <- function(con, sql, msg, call = caller_env(), env = caller_env()) { + dbi_wrap( + dbExecute(con, sql, immediate = TRUE), + sql = sql, + msg = msg, + call = call, + env = env + ) + invisible() +} + +db_get_query <- function(con, sql, msg, call = caller_env(), env = caller_env()) { + dbi_wrap(dbGetQuery(con, sql), sql, msg, call = call, env = env) +} + +dbi_wrap <- function(code, sql, msg, call = caller_env(), env = caller_env()) { + withCallingHandlers( + code, + error = function(cnd) { + msg <- c(msg, i = paste0("Using SQL: ", sql)) + cli_abort(msg, parent = cnd, call = call, .envir = env) + } + ) +} diff --git a/R/db.R b/R/db.R index 8576b5384..942f7f647 100644 --- a/R/db.R +++ b/R/db.R @@ -60,20 +60,20 @@ sql_join_suffix.DBIConnection <- function(con, suffix, ...) { db_sql_render <- function(con, sql, ..., cte = FALSE, sql_options = NULL) { check_bool(cte) if (cte) { - lifecycle::deprecate_soft( + lifecycle::deprecate_soft( when = "2.4.0", what = "db_sql_render(cte)", with = I("db_sql_render(sql_options = sql_options(cte = TRUE))") ) sql_options <- sql_options %||% sql_options(cte = TRUE) - out <- db_sql_render(con, sql, sql_options = sql_options) + out <- db_sql_render(con, sql, ..., sql_options = sql_options) return(out) } if (is.null(sql_options)) { sql_options <- sql_options() - out <- db_sql_render(con, sql, sql_options = sql_options) + out <- db_sql_render(con, sql, ..., sql_options = sql_options) return(out) } diff --git a/R/rows.R b/R/rows.R index 9655e470b..7aa18cb71 100644 --- a/R/rows.R +++ b/R/rows.R @@ -759,25 +759,15 @@ rows_auto_copy <- function(x, y, copy, call = caller_env()) { } rows_get_or_execute <- function(x, sql, returning_cols, call = caller_env()) { + error <- "Can't modify database table {.val {remote_name(x)}}." con <- remote_con(x) - withCallingHandlers( - { - if (is_empty(returning_cols)) { - DBI::dbExecute(con, sql, immediate = TRUE) - } else { - returned_rows <- DBI::dbGetQuery(con, sql, immediate = TRUE) - x <- set_returned_rows(x, returned_rows) - } - }, - error = function(cnd) { - cli_abort( - "Can't modify database table {.val {remote_name(x)}}.", - parent = cnd, - call = call - ) - } - ) + if (is_empty(returning_cols)) { + db_execute(con, sql, error, call = call) + } else { + returned_rows <- db_get_query(con, sql, error, call = call) + x <- set_returned_rows(x, returned_rows) + } invisible(x) } diff --git a/R/src_dbi.R b/R/src_dbi.R index 57e6c193f..978be8439 100644 --- a/R/src_dbi.R +++ b/R/src_dbi.R @@ -123,17 +123,20 @@ src_dbi <- function(con, auto_disconnect = FALSE) { disco <- db_disconnector(con, quiet = is_true(auto_disconnect)) # nocov } - subclass <- paste0("src_", class(con)[[1]]) - structure( list( con = con, disco = disco ), - class = c(subclass, "src_dbi", "src_sql", "src") + class = connection_s3_class(con) ) } +connection_s3_class <- function(con) { + subclass <- setdiff(methods::is(con), methods::extends("DBIConnection")) + c(paste0("src_", subclass), "src_dbi", "src_sql", "src") +} + methods::setOldClass(c("src_dbi", "src_sql", "src")) # nocov start diff --git a/R/tbl-sql.R b/R/tbl-sql.R index a6a320553..fd93ccc9e 100644 --- a/R/tbl-sql.R +++ b/R/tbl-sql.R @@ -14,7 +14,7 @@ #' multiple `tbl` objects. #' @param check_from Check if `from` is likely misspecified SQL or a table in a schema. tbl_sql <- function(subclass, src, from, ..., vars = NULL, check_from = TRUE) { - check_dots_used() + # Can't use check_dots_used(), #1429 check_character(vars, allow_null = TRUE) from <- as_table_source(from, con = src$con) diff --git a/R/tidyeval-across.R b/R/tidyeval-across.R index d0b4ec8c2..8b4327ecb 100644 --- a/R/tidyeval-across.R +++ b/R/tidyeval-across.R @@ -213,13 +213,14 @@ across_setup <- function(data, dots <- call$... for (i in seq_along(call$...)) { dot <- call$...[[i]] - try_fetch({ - dots[[i]] <- partial_eval(dot, data = data, env = env, error_call = error_call) - }, error = function(cnd) { - label <- expr_as_label(dot, names2(call$...)[[i]]) - msg <- "Problem while evaluating {.code {label}}." - cli_abort(msg, call = call(fn), parent = cnd) - }) + withCallingHandlers( + dots[[i]] <- partial_eval(dot, data = data, env = env, error_call = error_call), + error = function(cnd) { + label <- expr_as_label(dot, names2(call$...)[[i]]) + msg <- "Problem while evaluating {.code {label}}." + cli_abort(msg, call = call(fn), parent = cnd) + } + ) } names_spec <- eval(call$.names, env) diff --git a/R/tidyeval.R b/R/tidyeval.R index 9270051f9..68a1f1724 100644 --- a/R/tidyeval.R +++ b/R/tidyeval.R @@ -124,7 +124,7 @@ partial_eval_dots <- function(.data, partial_eval_quo <- function(x, data, error_call, dot_name, was_named) { # no direct equivalent in `dtplyr`, mostly handled in `dt_squash()` - try_fetch( + withCallingHandlers( expr <- partial_eval(get_expr(x), data, get_env(x), error_call = error_call), error = function(cnd) { label <- expr_as_label(x, dot_name) @@ -162,11 +162,6 @@ partial_eval_sym <- function(sym, data, env) { } } -is_namespaced_dplyr_call <- function(call) { - packages <- c("base", "dplyr", "stringr", "lubridate") - is_symbol(call[[1]], "::") && is_symbol(call[[2]], packages) -} - is_mask_pronoun <- function(call) { is_call(call, c("$", "[["), n = 2) && is_symbol(call[[2]], c(".data", ".env")) } @@ -190,11 +185,9 @@ partial_eval_call <- function(call, data, env) { call[[1]] <- fun <- sym(fun_name) } - # So are compound calls, EXCEPT dplyr::foo() - if (is.call(fun)) { - if (is_namespaced_dplyr_call(fun)) { - call[[1]] <- fun[[3]] - } else if (is_mask_pronoun(fun)) { + # Compound calls, apart from `::` aren't translatable + if (is_call(fun) && !is_call(fun, "::")) { + if (is_mask_pronoun(fun)) { stop("Use local() or remote() to force evaluation of functions", call. = FALSE) } else { return(eval_bare(call, env)) @@ -216,10 +209,9 @@ partial_eval_call <- function(call, data, env) { } else { # Process call arguments recursively, unless user has manually called # remote/local - name <- as_string(call[[1]]) - if (name == "local") { + if (is_call(call, "local")) { eval_bare(call[[2]], env) - } else if (name == "remote") { + } else if (is_call(call, "remote")) { call[[2]] } else { call[-1] <- lapply(call[-1], partial_eval, data = data, env = env) diff --git a/R/translate-sql-helpers.R b/R/translate-sql-helpers.R index 5fafce1d9..baed5ad68 100644 --- a/R/translate-sql-helpers.R +++ b/R/translate-sql-helpers.R @@ -175,10 +175,11 @@ sql_infix <- function(f, pad = TRUE) { escape_infix_expr <- function(xq, x, escape_unary_minus = FALSE) { infix_calls <- c("+", "-", "*", "/", "%%", "^") is_infix <- is_call(xq, infix_calls, n = 2) - is_unary_minus <- escape_unary_minus && is_call(xq, "-", n = 1) + is_unary_minus <- escape_unary_minus && + is_call(xq, "-", n = 1) && !is_atomic(x, n = 1) if (is_infix || is_unary_minus) { - enpared <- glue_sql2(sql_current_con(), "({x})") + enpared <- glue_sql2(sql_current_con(), "({.val x})") return(enpared) } diff --git a/R/translate-sql.R b/R/translate-sql.R index 044b16374..fe281017c 100644 --- a/R/translate-sql.R +++ b/R/translate-sql.R @@ -185,7 +185,7 @@ sql_data_mask <- function(expr, if (env_has(special_calls2, name) || env_has(special_calls, name)) { env_get(special_calls2, name, inherit = TRUE) } else { - cli_abort("No known translation for {.fun {pkg}::{name}}") + cli_abort("No known translation", call = call2(call2("::", sym(pkg), sym(name)))) } } diff --git a/R/utils-check.R b/R/utils-check.R index e7e986007..edf407936 100644 --- a/R/utils-check.R +++ b/R/utils-check.R @@ -182,7 +182,7 @@ with_indexed_errors <- function(expr, ..., .error_call = caller_env(), .frame = caller_env()) { - try_fetch( + withCallingHandlers( expr, purrr_error_indexed = function(cnd) { message <- message(cnd) diff --git a/R/utils.R b/R/utils.R index 945039294..931ebde15 100644 --- a/R/utils.R +++ b/R/utils.R @@ -20,12 +20,12 @@ named_commas <- function(x) { commas <- function(...) paste0(..., collapse = ", ") -unique_table_name <- function() { - # Needs to use option to unique names across reloads while testing - i <- getOption("dbplyr_table_name", 0) + 1 - options(dbplyr_table_name = i) - sprintf("dbplyr_%03i", i) +unique_table_name <- function(prefix = "") { + vals <- c(letters, LETTERS, 0:9) + name <- paste0(sample(vals, 10, replace = TRUE), collapse = "") + paste0(prefix, "dbplyr_", name) } + unique_subquery_name <- function() { # Needs to use option so can reset at the start of each query i <- getOption("dbplyr_subquery_name", 0) + 1 diff --git a/R/verb-compute.R b/R/verb-compute.R index 3b87592e3..5b9e40898 100644 --- a/R/verb-compute.R +++ b/R/verb-compute.R @@ -25,7 +25,7 @@ collapse.tbl_sql <- function(x, ...) { #' @rdname collapse.tbl_sql #' @param name Table name in remote database. -#' @param temporary Should the table be temporary (`TRUE`, the default`) or +#' @param temporary Should the table be temporary (`TRUE`, the default) or #' persistent (`FALSE`)? #' @inheritParams copy_to.src_sql #' @inheritParams collect.tbl_sql @@ -128,8 +128,8 @@ collect.tbl_sql <- function(x, ..., n = Inf, warn_incomplete = TRUE, cte = FALSE } sql <- db_sql_render(x$src$con, x, cte = cte) - out <- withCallingHandlers( - db_collect(x$src$con, sql, n = n, warn_incomplete = warn_incomplete, ...), + withCallingHandlers( + out <- db_collect(x$src$con, sql, n = n, warn_incomplete = warn_incomplete, ...), error = function(cnd) { cli_abort("Failed to collect lazy table.", parent = cnd) } diff --git a/R/verb-pivot-longer.R b/R/verb-pivot-longer.R index 8d4aed53d..028934a7f 100644 --- a/R/verb-pivot-longer.R +++ b/R/verb-pivot-longer.R @@ -3,7 +3,7 @@ #' @description #' `pivot_longer()` "lengthens" data, increasing the number of rows and #' decreasing the number of columns. The inverse transformation is -#' `tidyr::pivot_wider()] +#' [tidyr::pivot_wider()]. #' #' Learn more in `vignette("pivot", "tidyr")`. #' diff --git a/R/verb-pivot-wider.R b/R/verb-pivot-wider.R index 6d915193a..476577fe5 100644 --- a/R/verb-pivot-wider.R +++ b/R/verb-pivot-wider.R @@ -359,7 +359,7 @@ select_wider_id_cols <- function(data, return(names(sim_data)) } - try_fetch( + withCallingHandlers( id_cols <- tidyselect::eval_select( enquo(id_cols), sim_data, diff --git a/man/collapse.tbl_sql.Rd b/man/collapse.tbl_sql.Rd index b322e4b2e..c93a83aa2 100644 --- a/man/collapse.tbl_sql.Rd +++ b/man/collapse.tbl_sql.Rd @@ -28,7 +28,8 @@ \item{name}{Table name in remote database.} -\item{temporary}{Should the table be temporary (\code{TRUE}, the default\verb{) or persistent (}FALSE`)?} +\item{temporary}{Should the table be temporary (\code{TRUE}, the default) or +persistent (\code{FALSE})?} \item{unique_indexes}{a list of character vectors. Each element of the list will create a new unique index over the specified column(s). Duplicate rows diff --git a/man/pivot_longer.tbl_lazy.Rd b/man/pivot_longer.tbl_lazy.Rd index 22952c699..0bd531529 100644 --- a/man/pivot_longer.tbl_lazy.Rd +++ b/man/pivot_longer.tbl_lazy.Rd @@ -60,7 +60,7 @@ in the \code{value_to} column.} \description{ \code{pivot_longer()} "lengthens" data, increasing the number of rows and decreasing the number of columns. The inverse transformation is -`tidyr::pivot_wider()] +\code{\link[tidyr:pivot_wider]{tidyr::pivot_wider()}}. Learn more in \code{vignette("pivot", "tidyr")}. diff --git a/tests/testthat/_snaps/backend-mssql.md b/tests/testthat/_snaps/backend-mssql.md index 0abf90281..98ab650dc 100644 --- a/tests/testthat/_snaps/backend-mssql.md +++ b/tests/testthat/_snaps/backend-mssql.md @@ -15,6 +15,14 @@ i Use a combination of `distinct()` and `mutate()` for the same result: `mutate( = median(x, na.rm = TRUE)) %>% distinct()` +# custom window functions translated correctly + + Code + test_translate_sql(n_distinct(x), vars_group = "x") + Condition + Error in `n_distinct()`: + ! No translation available in `mutate()`/`filter()` for SQL server. + # custom lubridate functions translated correctly Code @@ -370,6 +378,51 @@ OUTPUT `INSERTED`.`a`, `INSERTED`.`b` AS `b2` ; +# atoms and symbols are cast to bit in `filter` + + Code + mf %>% filter(x) + Output + + SELECT `df`.* + FROM `df` + WHERE (cast(`x` AS `BIT`) = 1) + +--- + + Code + mf %>% filter(TRUE) + Output + + SELECT `df`.* + FROM `df` + WHERE (cast(1 AS `BIT`) = 1) + +--- + + Code + mf %>% filter((!x) | FALSE) + Output + + SELECT `df`.* + FROM `df` + WHERE ((NOT(cast(`x` AS `BIT`) = 1)) OR cast(0 AS `BIT`) = 1) + +--- + + Code + mf %>% filter(x) %>% inner_join(mf, by = "x") + Output + + SELECT `LHS`.`x` AS `x` + FROM ( + SELECT `df`.* + FROM `df` + WHERE (cast(`x` AS `BIT`) = 1) + ) AS `LHS` + INNER JOIN `df` + ON (`LHS`.`x` = `df`.`x`) + # row_number() with and without group_by() and arrange(): unordered defaults to Ordering by NULL (per empty_order) Code @@ -400,6 +453,20 @@ FROM `df` ORDER BY `y` +# can copy_to() and compute() with temporary tables (#438) + + Code + db <- copy_to(con, df, name = unique_table_name(), temporary = TRUE) + Message + Created a temporary table named #dbplyr_{tmp} + +--- + + Code + db2 <- db %>% mutate(y = x + 1) %>% compute() + Message + Created a temporary table named #dbplyr_{tmp} + # add prefix to temporary table Code diff --git a/tests/testthat/_snaps/backend-oracle.md b/tests/testthat/_snaps/backend-oracle.md index 72e3f07bb..e50731905 100644 --- a/tests/testthat/_snaps/backend-oracle.md +++ b/tests/testthat/_snaps/backend-oracle.md @@ -1,3 +1,14 @@ +# string functions translate correctly + + Code + test_translate_sql(str_replace(col, "pattern", "replacement")) + Output + REGEXP_REPLACE(`col`, 'pattern', 'replacement', 1, 1) + Code + test_translate_sql(str_replace_all(col, "pattern", "replacement")) + Output + REGEXP_REPLACE(`col`, 'pattern', 'replacement') + # queries translate correctly Code @@ -41,8 +52,8 @@ Code sql_query_explain(con, sql("SELECT * FROM foo")) Output - EXPLAIN PLAN FOR SELECT * FROM foo; - SELECT PLAN_TABLE_OUTPUT FROM TABLE(DBMS_XPLAN.DISPLAY())); + EXPLAIN PLAN FOR SELECT * FROM foo + SELECT PLAN_TABLE_OUTPUT FROM TABLE(DBMS_XPLAN.DISPLAY()) --- diff --git a/tests/testthat/_snaps/backend-postgres.md b/tests/testthat/_snaps/backend-postgres.md index 50bad613a..b6c98be73 100644 --- a/tests/testthat/_snaps/backend-postgres.md +++ b/tests/testthat/_snaps/backend-postgres.md @@ -126,6 +126,7 @@ Condition Error in `rows_insert()`: ! Can't modify database table "df_x". + i Using SQL: INSERT INTO "df_x" ("a", "b", "c", "d") SELECT * FROM ( SELECT "a", "b", "c" + 1.0 AS "c", "d" FROM "df_y" ) AS "...y" ON CONFLICT ("a", "b") DO NOTHING RETURNING "df_x"."a", "df_x"."b", "df_x"."c", "df_x"."d" Caused by error: ! dummy DBI error @@ -137,6 +138,7 @@ Condition Error in `rows_upsert()`: ! Can't modify database table "df_x". + i Using SQL: INSERT INTO "df_x" ("a", "b", "c", "d") SELECT "a", "b", "c", "d" FROM ( SELECT "a", "b", "c" + 1.0 AS "c", "d" FROM "df_y" ) AS "...y" WHERE true ON CONFLICT ("a", "b") DO UPDATE SET "c" = "excluded"."c", "d" = "excluded"."d" RETURNING "df_x"."a", "df_x"."b", "df_x"."c", "df_x"."d" Caused by error: ! dummy DBI error diff --git a/tests/testthat/_snaps/backend-snowflake.md b/tests/testthat/_snaps/backend-snowflake.md index e43d3042e..2eef02c36 100644 --- a/tests/testthat/_snaps/backend-snowflake.md +++ b/tests/testthat/_snaps/backend-snowflake.md @@ -1,13 +1,3 @@ -# custom scalar translated correctly - - Code - (expect_error(test_translate_sql(grepl("exp", x, ignore.case = TRUE)))) - Output - - Error in `grepl()`: - ! `ignore.case = TRUE` isn't supported in Snowflake translation. - i It must be FALSE instead. - # pmin() and pmax() respect na.rm Code diff --git a/tests/testthat/_snaps/db-io.md b/tests/testthat/_snaps/db-io.md index cfa33ee80..537255689 100644 --- a/tests/testthat/_snaps/db-io.md +++ b/tests/testthat/_snaps/db-io.md @@ -6,9 +6,10 @@ Output Error in `db_copy_to()`: - ! Can't copy to table `tmp2`. + ! Can't copy data to table `tmp2`. Caused by error in `db_create_index.DBIConnection()`: ! Can't create index on table `tmp2`. + i Using SQL: CREATE UNIQUE INDEX `tmp2_x` ON `tmp2` (`x`) Caused by error: ! dummy DBI error @@ -20,7 +21,7 @@ Output Error in `db_copy_to()`: - ! Can't copy to table `tmp`. + ! Can't copy data to table `tmp`. Caused by error in `dplyr::db_write_table()`: ! Can't write table table `tmp`. Caused by error: @@ -35,6 +36,7 @@ Error in `db_save_query()`: ! Can't save query to table `tmp`. + i Using SQL: CREATE TEMPORARY TABLE `tmp` AS `SELECT 2 FROM tmp` Caused by error: ! dummy DBI error diff --git a/tests/testthat/_snaps/db-sql.md b/tests/testthat/_snaps/db-sql.md index 1b3fcb28f..faec92a58 100644 --- a/tests/testthat/_snaps/db-sql.md +++ b/tests/testthat/_snaps/db-sql.md @@ -16,6 +16,7 @@ Error in `db_analyze()`: ! Can't analyze table tbl. + i Using SQL: ANALYZE `tbl` Caused by error: ! dummy DBI error Code @@ -24,6 +25,7 @@ Error in `db_create_index()`: ! Can't create index on table tbl. + i Using SQL: CREATE INDEX `tbl_col` ON `tbl` (`col`) Caused by error: ! dummy DBI error Code @@ -32,6 +34,7 @@ Error in `db_explain()`: ! Can't explain query. + i Using SQL: EXPLAIN QUERY PLAN invalid sql Caused by error: ! dummy DBI error Code @@ -40,6 +43,7 @@ Error in `db_query_fields()`: ! Can't query fields. + i Using SQL: SELECT * FROM `does not exist` AS `q01` WHERE (0 = 1) Caused by error: ! dummy DBI error Code @@ -48,6 +52,7 @@ Error in `db_save_query()`: ! Can't save query to table `tbl`. + i Using SQL: CREATE TEMPORARY TABLE `tbl` AS `invalid sql` Caused by error: ! dummy DBI error diff --git a/tests/testthat/_snaps/rows.md b/tests/testthat/_snaps/rows.md index 02b447cb3..0f9229fc9 100644 --- a/tests/testthat/_snaps/rows.md +++ b/tests/testthat/_snaps/rows.md @@ -97,6 +97,7 @@ Error in `rows_append()`: ! Can't modify database table "mtcars". + i Using SQL: INSERT INTO `mtcars` (`x`) SELECT * FROM ( SELECT * FROM `dbplyr_{tmp}` ) AS `...y` Caused by error: ! dummy DBI error Code @@ -106,6 +107,7 @@ Error in `rows_append()`: ! Can't modify database table "mtcars". + i Using SQL: INSERT INTO `mtcars` (`x`) SELECT * FROM ( SELECT * FROM `dbplyr_{tmp}` ) AS `...y` RETURNING `mtcars`.`x` Caused by error: ! dummy DBI error diff --git a/tests/testthat/_snaps/translate-sql.md b/tests/testthat/_snaps/translate-sql.md index 60fefbb90..db49b687c 100644 --- a/tests/testthat/_snaps/translate-sql.md +++ b/tests/testthat/_snaps/translate-sql.md @@ -18,20 +18,32 @@ Condition Error: ! There is no package called NOSUCHPACKAGE - ---- - Code test_translate_sql(dbplyr::NOSUCHFUNCTION()) Condition Error: ! "NOSUCHFUNCTION" is not an exported object from dbplyr + Code + test_translate_sql(base::abbreviate(x)) + Condition + Error in `base::abbreviate()`: + ! No known translation --- Code - test_translate_sql(base::abbreviate(x)) + lz %>% mutate(x = NOSUCHPACKAGE::foo()) Condition Error: - ! No known translation for `base::abbreviate()` + ! There is no package called NOSUCHPACKAGE + Code + lz %>% mutate(x = dbplyr::NOSUCHFUNCTION()) + Condition + Error: + ! "NOSUCHFUNCTION" is not an exported object from dbplyr + Code + lz %>% mutate(x = base::abbreviate(x)) + Condition + Error in `base::abbreviate()`: + ! No known translation diff --git a/tests/testthat/_snaps/verb-compute.md b/tests/testthat/_snaps/verb-compute.md index b79714bae..133926a3d 100644 --- a/tests/testthat/_snaps/verb-compute.md +++ b/tests/testthat/_snaps/verb-compute.md @@ -20,10 +20,13 @@ Code df %>% compute(name = in_schema("main", "db1"), temporary = FALSE) Condition - Error in `db_save_query.DBIConnection()`: + Error in `db_compute()`: + ! Can't copy query to table `main`.`db1`. + Caused by error in `db_save_query.DBIConnection()`: ! Can't save query to table `main`.`db1`. + i Using SQL: CREATE TABLE `main`.`db1` AS SELECT * FROM `dbplyr_{tmp}` Caused by error: - ! table `db1` already exists + ! dummy DBI error # collect() handles DBI error diff --git a/tests/testthat/helper-src.R b/tests/testthat/helper-src.R index ac7e3c5a9..1957b3f9d 100644 --- a/tests/testthat/helper-src.R +++ b/tests/testthat/helper-src.R @@ -3,14 +3,14 @@ on_cran <- function() !identical(Sys.getenv("NOT_CRAN"), "true") if (test_srcs$length() == 0) { - # test_register_src("df", dplyr::src_df(env = new.env(parent = emptyenv()))) test_register_con("sqlite", RSQLite::SQLite(), ":memory:") if (identical(Sys.getenv("GITHUB_POSTGRES"), "true")) { test_register_con("postgres", RPostgres::Postgres(), dbname = "test", user = "postgres", - password = "password" + password = "password", + host = "127.0.0.1" ) } else if (identical(Sys.getenv("GITHUB_MSSQL"), "true")) { test_register_con("mssql", odbc::odbc(), @@ -23,7 +23,7 @@ if (test_srcs$length() == 0) { ) } else if (on_gha() || on_cran()) { # Only test with sqlite - } else { + } else { test_register_con("MariaDB", RMariaDB::MariaDB(), dbname = "test", host = "localhost", @@ -50,8 +50,16 @@ local_sqlite_con_with_aux <- function(envir = parent.frame()) { } snap_transform_dbi <- function(x) { + x <- gsub("dbplyr_[a-zA-Z0-9]+", "dbplyr_{tmp}", x) + # use the last line matching this in case of multiple chained errors - dbi_line_id <- max(which(x == "Caused by error:")) + caused_by <- which(x == "Caused by error:") + if (length(caused_by) == 0) { + return(x) + } + + dbi_line_id <- max(caused_by) + n <- length(x) x <- x[-seq2(dbi_line_id + 1, n)] c(x, "! dummy DBI error") diff --git a/tests/testthat/test-backend-.R b/tests/testthat/test-backend-.R index db1808f28..b96e39886 100644 --- a/tests/testthat/test-backend-.R +++ b/tests/testthat/test-backend-.R @@ -39,6 +39,7 @@ test_that("unary plus works for non-numeric expressions", { test_that("unary minus flips sign of number", { local_con(simulate_dbi()) expect_equal(test_translate_sql(-10L), sql("-10")) + expect_equal(test_translate_sql(-10L + x), sql("-10 + `x`")) expect_equal(test_translate_sql(x == -10), sql('`x` = -10.0')) expect_equal(test_translate_sql(x %in% c(-1L, 0L)), sql('`x` IN (-1, 0)')) }) diff --git a/tests/testthat/test-backend-mssql.R b/tests/testthat/test-backend-mssql.R index 57de1ecf5..dea21af51 100644 --- a/tests/testthat/test-backend-mssql.R +++ b/tests/testthat/test-backend-mssql.R @@ -101,6 +101,11 @@ test_that("custom window functions translated correctly", { test_translate_sql(any(x, na.rm = TRUE)), sql("CAST(MAX(CAST(`x` AS INT)) OVER () AS BIT)") ) + + expect_snapshot( + test_translate_sql(n_distinct(x), vars_group = "x"), + error = TRUE + ) }) test_that("custom lubridate functions translated correctly", { @@ -124,6 +129,27 @@ test_that("custom lubridate functions translated correctly", { expect_error(test_translate_sql(quarter(x, fiscal_start = 5))) }) +test_that("custom clock functions translated correctly", { + local_con(simulate_mssql()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("DATEADD(YEAR, 1.0, `x`)")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("DATEADD(DAY, 1.0, `x`)")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("DATEFROMPARTS(2020.0, 1.0, 1.0)")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("DATEFROMPARTS(`year_column`, 1, 1)")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATEPART('year', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATEPART('month', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATEPART('day', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_mssql()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("DATEDIFF(day, `start_date`, `end_date`)")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("DATEDIFF(day, `start_date`, `end_date`)")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) + test_that("last_value_sql() translated correctly", { con <- simulate_mssql() expect_equal( @@ -322,6 +348,20 @@ test_that("`sql_query_upsert()` is correct", { ) }) +test_that("atoms and symbols are cast to bit in `filter`", { + mf <- lazy_frame(x = TRUE, con = simulate_mssql()) + + # as simple symbol and atom + expect_snapshot(mf %>% filter(x)) + expect_snapshot(mf %>% filter(TRUE)) + + # when involved in a (perhaps nested) logical expression + expect_snapshot(mf %>% filter((!x) | FALSE)) + + # in a subquery + expect_snapshot(mf %>% filter(x) %>% inner_join(mf, by = "x")) +}) + test_that("row_number() with and without group_by() and arrange(): unordered defaults to Ordering by NULL (per empty_order)", { mf <- lazy_frame(x = c(1:5), y = c(rep("A", 5)), con = simulate_mssql()) expect_snapshot(mf %>% mutate(rown = row_number())) @@ -331,18 +371,20 @@ test_that("row_number() with and without group_by() and arrange(): unordered def # Live database ----------------------------------------------------------- -test_that("can copy_to() and compute() with temporary tables (#272)", { +test_that("can copy_to() and compute() with temporary tables (#438)", { con <- src_test("mssql") df <- tibble(x = 1:3) - expect_message( - db <- copy_to(con, df, name = "temp", temporary = TRUE), - "Created a temporary table", + + # converts to name automatically with message + expect_snapshot( + db <- copy_to(con, df, name = unique_table_name(), temporary = TRUE), + transform = snap_transform_dbi ) expect_equal(db %>% pull(), 1:3) - expect_message( + expect_snapshot( db2 <- db %>% mutate(y = x + 1) %>% compute(), - "Created a temporary table" + transform = snap_transform_dbi ) expect_equal(db2 %>% pull(), 2:4) }) @@ -362,12 +404,12 @@ test_that("add prefix to temporary table", { test_that("bit conversion works for important cases", { df <- tibble(x = 1:3, y = 3:1) - db <- copy_to(src_test("mssql"), df, name = unique_table_name()) + db <- copy_to(src_test("mssql"), df, name = unique_table_name("#")) expect_equal(db %>% mutate(z = x == y) %>% pull(), c(FALSE, TRUE, FALSE)) expect_equal(db %>% filter(x == y) %>% pull(), 2) df <- tibble(x = c(TRUE, FALSE, FALSE), y = c(TRUE, FALSE, TRUE)) - db <- copy_to(src_test("mssql"), df, name = unique_table_name()) + db <- copy_to(src_test("mssql"), df, name = unique_table_name("#")) expect_equal(db %>% filter(x == 1) %>% pull(), TRUE) expect_equal(db %>% mutate(z = TRUE) %>% pull(), c(1, 1, 1)) @@ -381,7 +423,7 @@ test_that("bit conversion works for important cases", { test_that("as.integer and as.integer64 translations if parsing failures", { df <- data.frame(x = c("1.3", "2x")) - db <- copy_to(src_test("mssql"), df, name = unique_table_name()) + db <- copy_to(src_test("mssql"), df, name = unique_table_name("#")) out <- db %>% mutate( diff --git a/tests/testthat/test-backend-oracle.R b/tests/testthat/test-backend-oracle.R index 5c299600f..1ab5462c6 100644 --- a/tests/testthat/test-backend-oracle.R +++ b/tests/testthat/test-backend-oracle.R @@ -16,6 +16,16 @@ test_that("paste and paste0 translate correctly", { expect_equal(test_translate_sql(str_c(x, y)), sql("`x` || `y`")) }) + +test_that("string functions translate correctly", { + local_con(simulate_oracle()) + + expect_snapshot({ + test_translate_sql(str_replace(col, "pattern", "replacement")) + test_translate_sql(str_replace_all(col, "pattern", "replacement")) + }) +}) + test_that("queries translate correctly", { mf <- lazy_frame(x = 1, con = simulate_oracle()) expect_snapshot(mf %>% head()) @@ -72,3 +82,19 @@ test_that("copy_inline uses UNION ALL", { copy_inline(con, y, types = types) %>% remote_query() }) }) + +test_that("custom clock functions translated correctly", { + local_con(simulate_oracle()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("(`x` + NUMTODSINTERVAL(1.0 * 365.25, 'day'))")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("(`x` + NUMTODSINTERVAL(1.0, 'day'))")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_oracle()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("CEIL(CAST(`end_date` AS DATE) - CAST(`start_date` AS DATE))")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("CEIL(CAST(`end_date` AS DATE) - CAST(`start_date` AS DATE))")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) diff --git a/tests/testthat/test-backend-postgres-old.R b/tests/testthat/test-backend-postgres-old.R index 7b1b2cd50..cbbe284cf 100644 --- a/tests/testthat/test-backend-postgres-old.R +++ b/tests/testthat/test-backend-postgres-old.R @@ -11,7 +11,9 @@ test_that("RPostgreSQL backend", { ) ) - copy_to(src, mtcars, "mtcars", overwrite = TRUE, temporary = FALSE) + suppressWarnings( + copy_to(src, mtcars, "mtcars", overwrite = TRUE, temporary = FALSE) + ) withr::defer(DBI::dbRemoveTable(src, "mtcars")) expect_identical(colnames(tbl(src, "mtcars")), colnames(mtcars)) diff --git a/tests/testthat/test-backend-postgres.R b/tests/testthat/test-backend-postgres.R index d62d737b4..0517f195e 100644 --- a/tests/testthat/test-backend-postgres.R +++ b/tests/testthat/test-backend-postgres.R @@ -88,6 +88,27 @@ test_that("custom lubridate functions translated correctly", { expect_equal(test_translate_sql(floor_date(x, 'week')), sql("DATE_TRUNC('week', `x`)")) }) +test_that("custom clock functions translated correctly", { + local_con(simulate_postgres()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("(`x` + 1.0*INTERVAL'1 year')")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("(`x` + 1.0*INTERVAL'1 day')")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("MAKE_DATE(2020.0, 1.0, 1.0)")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("MAKE_DATE(`year_column`, 1, 1)")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATE_PART('year', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATE_PART('month', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATE_PART('day', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_postgres()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("(CAST(`end_date` AS DATE) - CAST(`start_date` AS DATE))")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("(CAST(`end_date` AS DATE) - CAST(`start_date` AS DATE))")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) + test_that("custom window functions translated correctly", { local_con(simulate_postgres()) diff --git a/tests/testthat/test-backend-redshift.R b/tests/testthat/test-backend-redshift.R index 9e90aa192..55e66b20f 100644 --- a/tests/testthat/test-backend-redshift.R +++ b/tests/testthat/test-backend-redshift.R @@ -57,3 +57,24 @@ test_that("copy_inline uses UNION ALL", { copy_inline(con, y, types = types) %>% remote_query() }) }) + +test_that("custom clock functions translated correctly", { + local_con(simulate_redshift()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("DATEADD(YEAR, 1.0, `x`)")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("DATEADD(DAY, 1.0, `x`)")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("TO_DATE(CAST(2020.0 AS TEXT) || '-' CAST(1.0 AS TEXT) || '-' || CAST(1.0 AS TEXT)), 'YYYY-MM-DD')")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("TO_DATE(CAST(`year_column` AS TEXT) || '-' CAST(1 AS TEXT) || '-' || CAST(1 AS TEXT)), 'YYYY-MM-DD')")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATE_PART('year', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATE_PART('month', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATE_PART('day', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_redshift()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("DATEDIFF(day, `start_date`, `end_date`)")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("DATEDIFF(day, `start_date`, `end_date`)")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) diff --git a/tests/testthat/test-backend-snowflake.R b/tests/testthat/test-backend-snowflake.R index b1f37b14a..c628308d5 100644 --- a/tests/testthat/test-backend-snowflake.R +++ b/tests/testthat/test-backend-snowflake.R @@ -2,8 +2,8 @@ test_that("custom scalar translated correctly", { local_con(simulate_snowflake()) expect_equal(test_translate_sql(log10(x)), sql("LOG(10.0, `x`)")) expect_equal(test_translate_sql(round(x, digits = 1.1)), sql("ROUND((`x`) :: FLOAT, 1)")) - expect_equal(test_translate_sql(grepl("exp", x)), sql("(`x`) REGEXP ('.*' || '(exp)' || '.*')")) - expect_snapshot((expect_error(test_translate_sql(grepl("exp", x, ignore.case = TRUE))))) + expect_equal(test_translate_sql(grepl("exp", x)), sql("REGEXP_INSTR(`x`, 'exp', 1, 1, 0, 'c') != 0")) + expect_equal(test_translate_sql(grepl("exp", x, ignore.case = TRUE)), sql("REGEXP_INSTR(`x`, 'exp', 1, 1, 0, 'i') != 0")) }) test_that("pasting translated correctly", { @@ -25,8 +25,8 @@ test_that("custom stringr functions translated correctly", { local_con(simulate_snowflake()) expect_equal(test_translate_sql(str_locate(x, y)), sql("POSITION(`y`, `x`)")) - expect_equal(test_translate_sql(str_detect(x, y)), sql("(`x`) REGEXP ('.*' || `y` || '.*')")) - expect_equal(test_translate_sql(str_detect(x, y, negate = TRUE)), sql("!((`x`) REGEXP ('.*' || `y` || '.*'))")) + expect_equal(test_translate_sql(str_detect(x, y)), sql("REGEXP_INSTR(`x`, `y`) != 0")) + expect_equal(test_translate_sql(str_detect(x, y, negate = TRUE)), sql("REGEXP_INSTR(`x`, `y`) = 0")) expect_equal(test_translate_sql(str_replace(x, y, z)), sql("REGEXP_REPLACE(`x`, `y`, `z`, 1.0, 1.0)")) expect_equal(test_translate_sql(str_replace(x, "\\d", z)), sql("REGEXP_REPLACE(`x`, '\\\\d', `z`, 1.0, 1.0)")) expect_equal(test_translate_sql(str_replace_all(x, y, z)), sql("REGEXP_REPLACE(`x`, `y`, `z`)")) @@ -34,6 +34,10 @@ test_that("custom stringr functions translated correctly", { expect_equal(test_translate_sql(str_remove(x, y)), sql("REGEXP_REPLACE(`x`, `y`, '', 1.0, 1.0)")) expect_equal(test_translate_sql(str_remove_all(x, y)), sql("REGEXP_REPLACE(`x`, `y`)")) expect_equal(test_translate_sql(str_trim(x)), sql("TRIM(`x`)")) + expect_equal(test_translate_sql(str_starts(x, y)), sql("REGEXP_INSTR(`x`, `y`) = 1")) + expect_equal(test_translate_sql(str_starts(x, y, negate = TRUE)), sql("REGEXP_INSTR(`x`, `y`) != 1")) + expect_equal(test_translate_sql(str_ends(x, y)), sql("REGEXP_INSTR(`x`, `y`, 1, 1, 1) = (LENGTH(`x`) + 1)")) + expect_equal(test_translate_sql(str_ends(x, y, negate = TRUE)), sql("REGEXP_INSTR(`x`, `y`, 1, 1, 1) != (LENGTH(`x`) + 1)")) }) test_that("aggregates are translated correctly", { @@ -98,6 +102,27 @@ test_that("custom lubridate functions translated correctly", { expect_equal(test_translate_sql(floor_date(x, "week")), sql("DATE_TRUNC('week', `x`)")) }) +test_that("custom clock functions translated correctly", { + local_con(simulate_snowflake()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("DATEADD(YEAR, 1.0, `x`)")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("DATEADD(DAY, 1.0, `x`)")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("DATE_FROM_PARTS(2020.0, 1.0, 1.0)")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("DATE_FROM_PARTS(`year_column`, 1, 1)")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATE_PART('year', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATE_PART('month', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATE_PART('day', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_snowflake()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("DATEDIFF(day, `start_date`, `end_date`)")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("DATEDIFF(day, `start_date`, `end_date`)")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) + test_that("min() and max()", { local_con(simulate_snowflake()) diff --git a/tests/testthat/test-backend-spark-sql.R b/tests/testthat/test-backend-spark-sql.R new file mode 100644 index 000000000..e1276c7a0 --- /dev/null +++ b/tests/testthat/test-backend-spark-sql.R @@ -0,0 +1,20 @@ +test_that("custom clock functions translated correctly", { + local_con(simulate_spark_sql()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("ADD_MONTHS('`x`', 1.0 * 12.0)")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("DATE_ADD(`x`, 1.0)")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("MAKE_DATE(2020.0, 1.0, 1.0)")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("MAKE_DATE(`year_column`, 1, 1)")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATE_PART('YEAR', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATE_PART('MONTH', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATE_PART('DAY', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_spark_sql()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("DATEDIFF(`end_date`, `start_date`)")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("DATEDIFF(`end_date`, `start_date`)")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) diff --git a/tests/testthat/test-db-sql.R b/tests/testthat/test-db-sql.R index 398f77456..c6239082b 100644 --- a/tests/testthat/test-db-sql.R +++ b/tests/testthat/test-db-sql.R @@ -1,4 +1,5 @@ test_that("2nd edition uses sql methods", { + reset_warning_verbosity("Test-edition") local_methods( db_analyze.Test = function(con, ...) abort("db_method") ) @@ -21,6 +22,7 @@ test_that("sql_query_rows() works", { }) test_that("handles DBI error", { + unique_subquery_name_reset() con <- local_sqlite_connection() expect_snapshot({ diff --git a/tests/testthat/test-src_dbi.R b/tests/testthat/test-src_dbi.R index e7b8c23a3..337808779 100644 --- a/tests/testthat/test-src_dbi.R +++ b/tests/testthat/test-src_dbi.R @@ -4,3 +4,21 @@ test_that("tbl and src classes include connection class", { expect_true(inherits(mf, "tbl_SQLiteConnection")) expect_true(inherits(mf$src, "src_SQLiteConnection")) }) + +test_that("generates S3 class based on S4 class name", { + con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:") + expect_equal( + connection_s3_class(con), + c("src_SQLiteConnection", "src_dbi", "src_sql", "src") + ) + + on.exit(removeClass("Foo2")) + on.exit(removeClass("Foo1")) + + Foo1 <- setClass("Foo1", contains = "DBIConnection") + Foo2 <- setClass("Foo2", contains = "Foo1") + expect_equal( + connection_s3_class(Foo2()), + c("src_Foo2", "src_Foo1", "src_dbi", "src_sql", "src") + ) +}) diff --git a/tests/testthat/test-tidyeval.R b/tests/testthat/test-tidyeval.R index 652e90207..ac1fefee3 100644 --- a/tests/testthat/test-tidyeval.R +++ b/tests/testthat/test-tidyeval.R @@ -30,16 +30,6 @@ test_that("using environment of inlined quosures", { expect_equal(capture_dot(lf, f(!!quo)), quote(f(x + 20))) }) -test_that("namespaced calls to dplyr functions are stripped", { - lf <- lazy_frame(x = 1, y = 2) - - expect_equal(partial_eval(quote(dplyr::n()), lf), expr(n())) - expect_equal(partial_eval(quote(base::paste(x, "a")), lf), expr(paste(x, "a"))) - # hack to avoid check complaining about not declared imports - expect_equal(partial_eval(rlang::parse_expr("stringr::str_to_lower(x)"), lf), expr(str_to_lower(x))) - expect_equal(partial_eval(rlang::parse_expr("lubridate::today()"), lf), expr(today())) -}) - test_that("use quosure environment for unevaluted formulas", { lf <- lazy_frame(x = 1, y = 2) diff --git a/tests/testthat/test-translate-sql-helpers.R b/tests/testthat/test-translate-sql-helpers.R index 2657d5f34..34673286c 100644 --- a/tests/testthat/test-translate-sql-helpers.R +++ b/tests/testthat/test-translate-sql-helpers.R @@ -63,7 +63,7 @@ test_that("can translate infix expression without parentheses", { test_that("unary minus works with expressions", { local_con(simulate_dbi()) expect_equal(test_translate_sql(-!!expr(x+2)), sql("-(`x` + 2.0)")) - expect_equal(test_translate_sql(--x), sql("-(-`x`)")) + expect_equal(test_translate_sql(--x), sql("--`x`")) }) test_that("pad = FALSE works", { diff --git a/tests/testthat/test-translate-sql.R b/tests/testthat/test-translate-sql.R index 3a1d8873b..fb7ace800 100644 --- a/tests/testthat/test-translate-sql.R +++ b/tests/testthat/test-translate-sql.R @@ -14,9 +14,19 @@ test_that("namespace calls are translated", { expect_equal(test_translate_sql(dplyr::n(), window = FALSE), sql("COUNT(*)")) expect_equal(test_translate_sql(base::ceiling(x)), sql("CEIL(`x`)")) - expect_snapshot(error = TRUE, test_translate_sql(NOSUCHPACKAGE::foo())) - expect_snapshot(error = TRUE, test_translate_sql(dbplyr::NOSUCHFUNCTION())) - expect_snapshot(error = TRUE, test_translate_sql(base::abbreviate(x))) + expect_snapshot(error = TRUE, { + test_translate_sql(NOSUCHPACKAGE::foo()) + test_translate_sql(dbplyr::NOSUCHFUNCTION()) + test_translate_sql(base::abbreviate(x)) + }) + + lz <- lazy_frame(x = 1) + # Also test full pipeline to ensure that they make it through partial_eval + expect_snapshot(error = TRUE, { + lz %>% mutate(x = NOSUCHPACKAGE::foo()) + lz %>% mutate(x = dbplyr::NOSUCHFUNCTION()) + lz %>% mutate(x = base::abbreviate(x)) + }) }) test_that("Wrong number of arguments raises error", { diff --git a/tests/testthat/test-verb-compute.R b/tests/testthat/test-verb-compute.R index 4dc1e2282..21ac3ac49 100644 --- a/tests/testthat/test-verb-compute.R +++ b/tests/testthat/test-verb-compute.R @@ -104,10 +104,11 @@ test_that("compute can handle schema", { ) # errors because name already exists - expect_snapshot(error = TRUE, { - df %>% - compute(name = in_schema("main", "db1"), temporary = FALSE) - }) + expect_snapshot( + df %>% compute(name = in_schema("main", "db1"), temporary = FALSE), + transform = snap_transform_dbi, + error = TRUE + ) }) test_that("collect() handles DBI error", { diff --git a/tests/testthat/test-verb-set-ops.R b/tests/testthat/test-verb-set-ops.R index 17909a199..a8db07ca6 100644 --- a/tests/testthat/test-verb-set-ops.R +++ b/tests/testthat/test-verb-set-ops.R @@ -110,14 +110,17 @@ test_that("SQLite warns if set op attempted when tbl has LIMIT", { test_that("other backends can combine with a limit", { df <- tibble(x = 1:2) - # sqlite only allows limit at top level - tbls_full <- test_load(df, ignore = "sqlite") - tbls_head <- lapply(test_load(df, ignore = "sqlite"), head, n = 1) + ignore <- c( + "sqlite", # only allows limit at top level + "mssql" # unusual execution order gives unintuitive result + ) + tbls_full <- test_load(df, ignore = ignore) + tbls_head <- lapply(test_load(df, ignore = ignore), head, n = 1) tbls_full %>% purrr::map2(tbls_head, union) %>% - expect_equal_tbls() + expect_equal_tbls(head(df, 1)) tbls_full %>% purrr::map2(tbls_head, union_all) %>% - expect_equal_tbls() + expect_equal_tbls(head(df, 1)) })