From 7b95f2002aaf5f7b8f68568d82d210a52d34edec Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 13 Nov 2024 14:16:42 -0500 Subject: [PATCH 1/3] Changes for #956 --- R/linear_reg.R | 21 +++++++++++++++++ tests/testthat/_snaps/linear_reg.md | 36 +++++++++++++++++++++++++++++ tests/testthat/test-linear_reg.R | 30 ++++++++++++++++++++++++ 3 files changed, 87 insertions(+) diff --git a/R/linear_reg.R b/R/linear_reg.R index 0b7b636b4..c80f7c14b 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -73,6 +73,27 @@ translate.linear_reg <- function(x, engine = x$engine, ...) { # evaluated value for the parameter. x$args$penalty <- rlang::eval_tidy(x$args$penalty) } + + # ------------------------------------------------------------------------------ + # We want to avoid folks passing in a poisson family instead of using + # poisson_reg(). It's hard to detect this. + + is_fam <- names(x$eng_args) == "family" + if (any(is_fam)) { + eng_args <- rlang::eval_tidy(x$eng_args[[which(is_fam)]]) + if (is.function(eng_args)) { + eng_args <- try(eng_args(), silent = TRUE) + } + if (inherits(eng_args, "family")) { + eng_args <- eng_args$family + } + if (eng_args == "poisson") { + cli::cli_abort( + "A Poisson family was requested for {.fn linear_reg}. Please use + {.fn poisson_reg} and the engines in the {.pkg poissonreg} package.", + call = rlang::call2("linear_reg")) + } + } x } diff --git a/tests/testthat/_snaps/linear_reg.md b/tests/testthat/_snaps/linear_reg.md index f497ce3da..0ed2c9274 100644 --- a/tests/testthat/_snaps/linear_reg.md +++ b/tests/testthat/_snaps/linear_reg.md @@ -139,3 +139,39 @@ Error in `fit()`: ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. +# Poisson family (#956) + + Code + linear_reg(penalty = 1) %>% set_engine("glmnet", family = poisson) %>% + translate() + Condition + Error in `linear_reg()`: + ! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package. + +--- + + Code + linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson) %>% + translate() + Condition + Error in `linear_reg()`: + ! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package. + +--- + + Code + linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson()) %>% + translate() + Condition + Error in `linear_reg()`: + ! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package. + +--- + + Code + linear_reg(penalty = 1) %>% set_engine("glmnet", family = "poisson") %>% + translate() + Condition + Error in `linear_reg()`: + ! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package. + diff --git a/tests/testthat/test-linear_reg.R b/tests/testthat/test-linear_reg.R index 62567fc44..36964a66e 100644 --- a/tests/testthat/test-linear_reg.R +++ b/tests/testthat/test-linear_reg.R @@ -358,3 +358,33 @@ test_that("check_args() works", { } ) }) + + +test_that('Poisson family (#956)', { + expect_snapshot( + linear_reg(penalty = 1) %>% + set_engine("glmnet", family = poisson) %>% + translate(), + error = TRUE + ) + expect_snapshot( + linear_reg(penalty = 1) %>% + set_engine("glmnet", family = stats::poisson) %>% + translate(), + error = TRUE + ) + expect_snapshot( + linear_reg(penalty = 1) %>% + set_engine("glmnet", family = stats::poisson()) %>% + translate(), + error = TRUE + ) + expect_snapshot( + linear_reg(penalty = 1) %>% + set_engine("glmnet", family = "poisson") %>% + translate(), + error = TRUE + ) + + +}) From b9f45eb52bd5733b366f7a220b168d676257cfb6 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 13 Nov 2024 14:23:12 -0500 Subject: [PATCH 2/3] update news --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index e35af6bcd..692f5f89a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -23,6 +23,8 @@ * Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083). +* If linear regression is requested with a Poisson family, an error will occur and refer the user to `poisson_reg()` (#956) + ## Bug Fixes * Make sure that parsnip does not convert ordered factor predictions to be unordered. From 72e01d1b030faf16781f4782ab1e7a6b1693aafe Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Mon, 9 Dec 2024 16:53:31 +0000 Subject: [PATCH 3/3] point to parsnip PR instead of tune issue --- NEWS.md | 2 +- tests/testthat/_snaps/linear_reg.md | 2 +- tests/testthat/test-linear_reg.R | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/NEWS.md b/NEWS.md index 692f5f89a..23a269c07 100644 --- a/NEWS.md +++ b/NEWS.md @@ -23,7 +23,7 @@ * Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083). -* If linear regression is requested with a Poisson family, an error will occur and refer the user to `poisson_reg()` (#956) +* If linear regression is requested with a Poisson family, an error will occur and refer the user to `poisson_reg()` (#1219). ## Bug Fixes diff --git a/tests/testthat/_snaps/linear_reg.md b/tests/testthat/_snaps/linear_reg.md index 0ed2c9274..1fc287ac4 100644 --- a/tests/testthat/_snaps/linear_reg.md +++ b/tests/testthat/_snaps/linear_reg.md @@ -139,7 +139,7 @@ Error in `fit()`: ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. -# Poisson family (#956) +# prevent using a Poisson family Code linear_reg(penalty = 1) %>% set_engine("glmnet", family = poisson) %>% diff --git a/tests/testthat/test-linear_reg.R b/tests/testthat/test-linear_reg.R index 36964a66e..969cb994c 100644 --- a/tests/testthat/test-linear_reg.R +++ b/tests/testthat/test-linear_reg.R @@ -360,7 +360,7 @@ test_that("check_args() works", { }) -test_that('Poisson family (#956)', { +test_that("prevent using a Poisson family", { expect_snapshot( linear_reg(penalty = 1) %>% set_engine("glmnet", family = poisson) %>% @@ -385,6 +385,4 @@ test_that('Poisson family (#956)', { translate(), error = TRUE ) - - })