From 236a39b2a394a1ac03b8f8c8c793f7524476c0e7 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 6 Sep 2024 10:28:10 -0700 Subject: [PATCH] happy path for sparse matrix passed to fit() --- R/fit.R | 8 +------- tests/testthat/_snaps/sparsevctrs.md | 4 ++-- tests/testthat/test-sparsevctrs.R | 1 - 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/R/fit.R b/R/fit.R index fe3d4dcea..ff6fb71ff 100644 --- a/R/fit.R +++ b/R/fit.R @@ -138,13 +138,7 @@ fit.model_spec <- } if (is_sparse_matrix(data)) { - outcome_names <- all.names(rlang::f_lhs(formula)) - outcome_ind <- match(outcome_names, colnames(data)) - - y <- data[, outcome_ind] - x <- data[, -outcome_ind, drop = TRUE] - - return(fit_xy(object, x, y, case_weights, control, ...)) + data <- sparsevctrs::coerce_to_sparse_tibble(data) } dots <- quos(...) diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 728a9706e..797dd2285 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -11,8 +11,8 @@ Code lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) Condition - Error in `fit_xy()`: - ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. + Warning: + `data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. # sparse tibble can be passed to `fit_xy() diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 1ce77da3f..067498bf1 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -39,7 +39,6 @@ test_that("sparse matrix can be passed to `fit()", { set_engine("lm") expect_snapshot( - error = TRUE, lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) ) })