From 91ba6d1036de5d6618e754a1e4099244e51eae50 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Sep 2024 12:47:29 -0700 Subject: [PATCH 1/6] test that sparse matrices doesn't work with fit() --- R/fit.R | 7 +++++++ R/sparsevctrs.R | 6 +++++- tests/testthat/_snaps/sparsevctrs.md | 9 +++++++++ tests/testthat/test-sparsevctrs.R | 15 +++++++++++++++ 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/R/fit.R b/R/fit.R index 7be77c3ca..67006aebc 100644 --- a/R/fit.R +++ b/R/fit.R @@ -174,6 +174,13 @@ fit.model_spec <- eval_env$formula <- formula eval_env$weights <- wts + if (is_sparse_matrix(data)) { + cli::cli_abort(c( + x = "Sparse matrices cannot be used with {.fn fit}.", + i = "Please use {.fn fit_xy} interface instead." + )) + } + data <- materialize_sparse_tibble(data, object, "data") fit_interface <- diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index 5fe3633ae..4a01e0218 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -1,5 +1,5 @@ to_sparse_data_frame <- function(x, object) { - if (methods::is(x, "sparseMatrix")) { + if (is_sparse_matrix(x)) { if (allow_sparse(object)) { x <- sparsevctrs::coerce_to_sparse_data_frame(x) } else { @@ -21,6 +21,10 @@ is_sparse_tibble <- function(x) { any(vapply(x, sparsevctrs::is_sparse_vector, logical(1))) } +is_sparse_matrix <- function(x) { + methods::is(x, "sparseMatrix") +} + materialize_sparse_tibble <- function(x, object, input) { if (is_sparse_tibble(x) && (!allow_sparse(object))) { if (inherits(object, "model_fit")) { diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 7eb9d3a55..1888752e3 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -22,6 +22,15 @@ Error in `to_sparse_data_frame()`: ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. +# sparse matrices can not be passed to `fit() + + Code + hotel_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + Condition + Error in `fit()`: + x Sparse matrices cannot be used with `fit()`. + i Please use `fit_xy()` interface instead. + # sparse tibble can be passed to `predict() Code diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index aa452f2e3..a85f59273 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -67,6 +67,21 @@ test_that("sparse matrices can be passed to `fit_xy()", { ) }) +test_that("sparse matrices can not be passed to `fit()", { + skip_if_not_installed("xgboost") + + hotel_data <- sparse_hotel_rates() + + spec <- boost_tree() %>% + set_mode("regression") %>% + set_engine("xgboost") + + expect_snapshot( + error = TRUE, + hotel_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + ) +}) + test_that("sparse tibble can be passed to `predict()", { skip_if_not_installed("ranger") From 6140db3ef45417558d614a67dec7d9ea4d09b080 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Sep 2024 12:56:43 -0700 Subject: [PATCH 2/6] pass call through to_sparse_data_frame() --- R/sparsevctrs.R | 8 +++++--- tests/testthat/_snaps/sparsevctrs.md | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index 4a01e0218..2f70dca83 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -1,4 +1,4 @@ -to_sparse_data_frame <- function(x, object) { +to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) { if (is_sparse_matrix(x)) { if (allow_sparse(object)) { x <- sparsevctrs::coerce_to_sparse_data_frame(x) @@ -8,8 +8,10 @@ to_sparse_data_frame <- function(x, object) { } cli::cli_abort( - "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with - engine {.code {object$engine}} doesn't accept that.") + "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with + engine {.code {object$engine}} doesn't accept that.", + call = call + ) } } else if (is.data.frame(x)) { x <- materialize_sparse_tibble(x, object, "x") diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 1888752e3..85b4fac63 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -19,7 +19,7 @@ Code lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) Condition - Error in `to_sparse_data_frame()`: + Error in `fit_xy()`: ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. # sparse matrices can not be passed to `fit() @@ -44,7 +44,7 @@ Code predict(lm_fit, sparse_mtcars) Condition - Error in `to_sparse_data_frame()`: + Error in `predict()`: ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. # to_sparse_data_frame() is used correctly From e9ff9970d2e54640d8ea7835f3b632f4d94a323b Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Sep 2024 13:02:51 -0700 Subject: [PATCH 3/6] .code to .val for engine argument in sparse data functions --- R/sparsevctrs.R | 4 ++-- tests/testthat/_snaps/sparsevctrs.md | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index 2f70dca83..73b6b6443 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -9,7 +9,7 @@ to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) { cli::cli_abort( "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with - engine {.code {object$engine}} doesn't accept that.", + engine {.val {object$engine}} doesn't accept that.", call = call ) } @@ -35,7 +35,7 @@ materialize_sparse_tibble <- function(x, object, input) { cli::cli_warn( "{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with - engine {.code {object$engine}} doesn't accept that. Converting to + engine {.val {object$engine}} doesn't accept that. Converting to non-sparse." ) for (i in seq_along(ncol(x))) { diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 85b4fac63..b098a8c8c 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -4,7 +4,7 @@ lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) Condition Warning: - `data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + `data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. # sparse tibble can be passed to `fit_xy() @@ -12,7 +12,7 @@ lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) Condition Warning: - `x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + `x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. # sparse matrices can be passed to `fit_xy() @@ -20,7 +20,7 @@ lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) Condition Error in `fit_xy()`: - ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. + ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. # sparse matrices can not be passed to `fit() @@ -37,7 +37,7 @@ preds <- predict(lm_fit, sparse_mtcars) Condition Warning: - `x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + `x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. # sparse matrices can be passed to `predict() @@ -45,7 +45,7 @@ predict(lm_fit, sparse_mtcars) Condition Error in `predict()`: - ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. + ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. # to_sparse_data_frame() is used correctly From 31506c79cdf3c09b9e128a4f08183fadfe196e8a Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Sep 2024 16:39:30 -0700 Subject: [PATCH 4/6] switch to fit_xy() if sparse matrix is passed to fit() --- R/fit.R | 11 +++++++---- tests/testthat/_snaps/sparsevctrs.md | 17 ++++++++--------- tests/testthat/test-sparsevctrs.R | 28 ++++++++++++++++++---------- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/R/fit.R b/R/fit.R index 67006aebc..802addec1 100644 --- a/R/fit.R +++ b/R/fit.R @@ -175,10 +175,13 @@ fit.model_spec <- eval_env$weights <- wts if (is_sparse_matrix(data)) { - cli::cli_abort(c( - x = "Sparse matrices cannot be used with {.fn fit}.", - i = "Please use {.fn fit_xy} interface instead." - )) + outcome_names <- all.names(rlang::f_lhs(formula)) + outcome_ind <- match(outcome_names, colnames(data)) + + y <- data[, outcome_ind] + x <- data[, -outcome_ind, drop = TRUE] + + return(fit_xy(object, x, y, case_weights, control, ...)) } data <- materialize_sparse_tibble(data, object, "data") diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index b098a8c8c..728a9706e 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -6,6 +6,14 @@ Warning: `data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. +# sparse matrix can be passed to `fit() + + Code + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + Condition + Error in `fit_xy()`: + ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. + # sparse tibble can be passed to `fit_xy() Code @@ -22,15 +30,6 @@ Error in `fit_xy()`: ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. -# sparse matrices can not be passed to `fit() - - Code - hotel_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) - Condition - Error in `fit()`: - x Sparse matrices cannot be used with `fit()`. - i Please use `fit_xy()` interface instead. - # sparse tibble can be passed to `predict() Code diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index a85f59273..1ce77da3f 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -21,18 +21,17 @@ test_that("sparse tibble can be passed to `fit()", { ) }) -test_that("sparse tibble can be passed to `fit_xy()", { +test_that("sparse matrix can be passed to `fit()", { skip_if_not_installed("xgboost") hotel_data <- sparse_hotel_rates() - hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data) spec <- boost_tree() %>% set_mode("regression") %>% set_engine("xgboost") expect_no_error( - lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) ) spec <- linear_reg() %>% @@ -40,14 +39,16 @@ test_that("sparse tibble can be passed to `fit_xy()", { set_engine("lm") expect_snapshot( - lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) + error = TRUE, + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) ) }) -test_that("sparse matrices can be passed to `fit_xy()", { +test_that("sparse tibble can be passed to `fit_xy()", { skip_if_not_installed("xgboost") hotel_data <- sparse_hotel_rates() + hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data) spec <- boost_tree() %>% set_mode("regression") %>% @@ -62,12 +63,11 @@ test_that("sparse matrices can be passed to `fit_xy()", { set_engine("lm") expect_snapshot( - lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]), - error = TRUE + lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) ) }) -test_that("sparse matrices can not be passed to `fit()", { +test_that("sparse matrices can be passed to `fit_xy()", { skip_if_not_installed("xgboost") hotel_data <- sparse_hotel_rates() @@ -76,9 +76,17 @@ test_that("sparse matrices can not be passed to `fit()", { set_mode("regression") %>% set_engine("xgboost") + expect_no_error( + lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + ) + + spec <- linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") + expect_snapshot( - error = TRUE, - hotel_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]), + error = TRUE ) }) From 0e37996124e89126b88988b2106226c5b1c7311d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Sep 2024 17:30:00 -0700 Subject: [PATCH 5/6] earlier exit for sparse matrix in fit() --- R/fit.R | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/R/fit.R b/R/fit.R index 802addec1..fe3d4dcea 100644 --- a/R/fit.R +++ b/R/fit.R @@ -137,6 +137,16 @@ fit.model_spec <- cli::cli_abort(msg) } + if (is_sparse_matrix(data)) { + outcome_names <- all.names(rlang::f_lhs(formula)) + outcome_ind <- match(outcome_names, colnames(data)) + + y <- data[, outcome_ind] + x <- data[, -outcome_ind, drop = TRUE] + + return(fit_xy(object, x, y, case_weights, control, ...)) + } + dots <- quos(...) if (length(possible_engines(object)) == 0) { @@ -174,16 +184,6 @@ fit.model_spec <- eval_env$formula <- formula eval_env$weights <- wts - if (is_sparse_matrix(data)) { - outcome_names <- all.names(rlang::f_lhs(formula)) - outcome_ind <- match(outcome_names, colnames(data)) - - y <- data[, outcome_ind] - x <- data[, -outcome_ind, drop = TRUE] - - return(fit_xy(object, x, y, case_weights, control, ...)) - } - data <- materialize_sparse_tibble(data, object, "data") fit_interface <- From 236a39b2a394a1ac03b8f8c8c793f7524476c0e7 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 6 Sep 2024 10:28:10 -0700 Subject: [PATCH 6/6] happy path for sparse matrix passed to fit() --- R/fit.R | 8 +------- tests/testthat/_snaps/sparsevctrs.md | 4 ++-- tests/testthat/test-sparsevctrs.R | 1 - 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/R/fit.R b/R/fit.R index fe3d4dcea..ff6fb71ff 100644 --- a/R/fit.R +++ b/R/fit.R @@ -138,13 +138,7 @@ fit.model_spec <- } if (is_sparse_matrix(data)) { - outcome_names <- all.names(rlang::f_lhs(formula)) - outcome_ind <- match(outcome_names, colnames(data)) - - y <- data[, outcome_ind] - x <- data[, -outcome_ind, drop = TRUE] - - return(fit_xy(object, x, y, case_weights, control, ...)) + data <- sparsevctrs::coerce_to_sparse_tibble(data) } dots <- quos(...) diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 728a9706e..797dd2285 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -11,8 +11,8 @@ Code lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) Condition - Error in `fit_xy()`: - ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. + Warning: + `data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. # sparse tibble can be passed to `fit_xy() diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 1ce77da3f..067498bf1 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -39,7 +39,6 @@ test_that("sparse matrix can be passed to `fit()", { set_engine("lm") expect_snapshot( - error = TRUE, lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) ) })