Skip to content

Commit

Permalink
make sure all possible bernoulli models are detected (#816)
Browse files Browse the repository at this point in the history
* make sure all possible bernoulli models are detected

* tests

* fix test issues

* suppress warnings

* typo
  • Loading branch information
strengejacke authored Oct 5, 2023
1 parent ed0c2a8 commit 3d07438
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 58 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: insight
Title: Easy Access to Model Information for Various Model Objects
Version: 0.19.5.11
Version: 0.19.5.12
Authors@R:
c(person(given = "Daniel",
family = "Lüdecke",
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
* `clean_parameters()` did not return the `"clean_parameters"` class attributes
for some object. This caused issued in upstream packages.

* Fixed issue in `model_info()`, which did not correctly detect "Bernoulli"
property for some models classes (like `glmmTMB` or `glmerMod`).

# insight 0.19.5

## Bug fixes
Expand Down
13 changes: 12 additions & 1 deletion R/utils_model_info.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,18 @@

is_bernoulli <- FALSE

if (binom_fam && inherits(x, "glm") && !neg_bin_fam && !poisson_fam) {
# These models can be logistic regresion models with bernoulli outcome
potential_bernoulli <- inherits(
x,
c(
"glm", "gee", "glmmTMB", "glmerMod", "merMod", "stanreg", "MixMod",
"logistf", "bigglm", "brglm", "feglm", "geeglm", "glmm", "glmmadmb",
"glmmPQL", "glmrob", "glmRob", "logitmfx", "logitor", "logitr",
"mixed", "mixor", "svyglm"
)
)

if (binom_fam && potential_bernoulli && !neg_bin_fam && !poisson_fam) {
if (inherits(x, "gee")) {
resp <- .safe(get_response(x))
} else {
Expand Down
32 changes: 13 additions & 19 deletions tests/testthat/test-bigglm.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
skip_if_not_installed("glmmTMB")
suppressWarnings(skip_if_not_installed("glmmTMB"))
skip_if_not_installed("biglm")

data(Salamanders, package = "glmmTMB")
Expand Down Expand Up @@ -43,26 +43,20 @@ test_that("find_response", {
})

test_that("get_response", {
expect_equal(get_response(m1), Salamanders$count)
expect_identical(get_response(m1), Salamanders$count)
})

test_that("get_predictors", {
expect_equal(
colnames(get_predictors(m1)),
c("mined", "cover", "sample")
)
expect_named(get_predictors(m1), c("mined", "cover", "sample"))
})

test_that("link_inverse", {
expect_equal(link_inverse(m1)(0.2), exp(0.2), tolerance = 1e-5)
})

test_that("get_data", {
expect_equal(nrow(get_data(m1)), 644)
expect_equal(
colnames(get_data(m1)),
c("count", "mined", "cover", "sample")
)
expect_identical(nrow(get_data(m1)), 644L)
expect_named(get_data(m1), c("count", "mined", "cover", "sample"))
})

test_that("find_formula", {
Expand All @@ -75,36 +69,36 @@ test_that("find_formula", {
})

test_that("find_variables", {
expect_equal(
expect_identical(
find_variables(m1),
list(
response = "count",
conditional = c("mined", "cover", "sample")
)
)
expect_equal(
expect_identical(
find_variables(m1, flatten = TRUE),
c("count", "mined", "cover", "sample")
)
})

test_that("n_obs", {
expect_equal(n_obs(m1), 644)
expect_identical(n_obs(m1), 644)
})

test_that("linkfun", {
expect_false(is.null(link_function(m1)))
})

test_that("find_parameters", {
expect_equal(
expect_identical(
find_parameters(m1),
list(
conditional = c("(Intercept)", "minedno", "log(cover)", "sample")
)
)
expect_equal(nrow(get_parameters(m1)), 4)
expect_equal(
expect_identical(nrow(get_parameters(m1)), 4L)
expect_identical(
get_parameters(m1)$Parameter,
c("(Intercept)", "minedno", "log(cover)", "sample")
)
Expand All @@ -115,7 +109,7 @@ test_that("is_multivariate", {
})

test_that("find_terms", {
expect_equal(
expect_identical(
find_terms(m1),
list(
response = "count",
Expand All @@ -125,7 +119,7 @@ test_that("find_terms", {
})

test_that("find_algorithm", {
expect_equal(find_algorithm(m1), list(algorithm = "ML"))
expect_identical(find_algorithm(m1), list(algorithm = "ML"))
})

test_that("find_statistic", {
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-gee.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ junk <- capture.output({
test_that("model_info", {
expect_true(model_info(m1)$is_linear)
expect_true(model_info(dep_gee)$is_binomial)
expect_true(model_info(dep_gee)$is_bernoulli)
})

test_that("find_predictors", {
Expand Down
41 changes: 20 additions & 21 deletions tests/testthat/test-geeglm.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
skip_if_not_installed("geepack")

data(warpbreaks)
m1 <-
geepack::geeglm(
breaks ~ tension,
id = wool,
data = warpbreaks,
family = poisson,
corstr = "ar1"
)
m1 <- geepack::geeglm(
breaks ~ tension,
id = wool,
data = warpbreaks,
family = poisson,
corstr = "ar1"
)

test_that("model_info", {
expect_true(model_info(m1)$is_count)
Expand All @@ -34,32 +33,32 @@ test_that("find_response", {

test_that("get_varcov", {
out <- get_varcov(m1)
expect_equal(colnames(out), names(coef(m1)))
expect_identical(colnames(out), names(coef(m1)))
})

test_that("get_response", {
expect_equal(get_response(m1), warpbreaks$breaks)
expect_identical(get_response(m1), warpbreaks$breaks)
})

test_that("find_random", {
expect_equal(find_random(m1), list(random = "wool"))
expect_identical(find_random(m1), list(random = "wool"))
})

test_that("get_random", {
expect_equal(get_random(m1), warpbreaks[, "wool", drop = FALSE], ignore_attr = TRUE)
})

test_that("get_predictors", {
expect_equal(get_predictors(m1), warpbreaks[, "tension", drop = FALSE])
expect_equal(get_predictors(m1), warpbreaks[, "tension", drop = FALSE], ignore_attr = TRUE)
})

test_that("link_inverse", {
expect_equal(link_inverse(m1)(0.2), exp(0.2), tolerance = 1e-5)
})

test_that("get_data", {
expect_equal(nrow(get_data(m1)), 54)
expect_equal(colnames(get_data(m1)), c("breaks", "tension", "wool"))
expect_identical(nrow(get_data(m1)), 54L)
expect_named(get_data(m1), c("breaks", "tension", "wool"))
})

test_that("find_formula", {
Expand All @@ -75,37 +74,37 @@ test_that("find_formula", {
})

test_that("find_terms", {
expect_equal(
expect_identical(
find_terms(m1),
list(
response = "breaks",
conditional = "tension",
random = "wool"
)
)
expect_equal(
expect_identical(
find_terms(m1, flatten = TRUE),
c("breaks", "tension", "wool")
)
})

test_that("n_obs", {
expect_equal(n_obs(m1), 54)
expect_identical(n_obs(m1), 54L)
})

test_that("linkfun", {
expect_false(is.null(link_function(m1)))
})

test_that("find_parameters", {
expect_equal(
expect_identical(
find_parameters(m1),
list(conditional = c(
"(Intercept)", "tensionM", "tensionH"
))
)
expect_equal(nrow(get_parameters(m1)), 3)
expect_equal(
expect_identical(nrow(get_parameters(m1)), 3L)
expect_identical(
get_parameters(m1)$Parameter,
c("(Intercept)", "tensionM", "tensionH")
)
Expand All @@ -116,7 +115,7 @@ test_that("is_multivariate", {
})

test_that("find_algorithm", {
expect_equal(find_algorithm(m1), list(algorithm = "ML"))
expect_identical(find_algorithm(m1), list(algorithm = "ML"))
})

test_that("find_statistic", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-get_predicted.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ test_that("MASS::rlm", {
# =========================================================================

test_that("get_predicted - lmerMod", {
skip_if_not_installed("glmmTMB")
suppressWarnings(skip_if_not_installed("glmmTMB"))
skip_if_not_installed("lme4")
skip_if_not_installed("merTools")
skip_if_not_installed("rstanarm")
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-logistf.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ m1 <- logistf::logistf(case ~ age + oc + vic + vicl + vis + dia, data = sex2)

test_that("model_info", {
expect_true(model_info(m1)$is_binomial)
expect_true(model_info(m1)$is_bernoulli)
expect_true(model_info(m1)$is_logit)
expect_false(model_info(m1)$is_linear)
})
Expand Down
83 changes: 70 additions & 13 deletions tests/testthat/test-model_info.R
Original file line number Diff line number Diff line change
@@ -1,28 +1,85 @@
skip_if_not_installed("BayesFactor")
test_that("glm bernoulli", {
data(mtcars)
model <- glm(vs ~ disp, data = mtcars, family = binomial())
mi <- model_info(model)
expect_true(mi$is_binomial)
expect_true(mi$is_bernoulli)
})

test_that("geeglm bernoulli", {
skip_if_not_installed("geepack")
data(mtcars)
model <- geepack::geeglm(
vs ~ disp,
data = mtcars,
id = cyl,
family = binomial()
)
mi <- model_info(model)
expect_true(mi$is_binomial)
expect_true(mi$is_bernoulli)
})

test_that("bigglm bernoulli", {
skip_if_not_installed("bigglm")
data(mtcars)
model <- biglm::bigglm(
vs ~ disp,
family = binomial(),
data = mtcars
)
mi <- model_info(model)
expect_true(mi$is_binomial)
expect_true(mi$is_bernoulli)
})

test_that("glmmTMB bernoulli", {
skip_if_not_installed("glmmTMB")
data(mtcars)
model <- glmmTMB::glmmTMB(vs ~ disp, data = mtcars, family = binomial())
mi <- model_info(model)
expect_true(mi$is_binomial)
expect_true(mi$is_bernoulli)

model <- glmmTMB::glmmTMB(vs ~ disp + (1 | cyl), data = mtcars, family = binomial())
mi <- model_info(model)
expect_true(mi$is_binomial)
expect_true(mi$is_bernoulli)
})

test_that("glmer bernoulli", {
skip_if_not_installed("lme4")
data(mtcars)
model <- lme4::glmer(vs ~ disp + (1 | cyl), data = mtcars, family = binomial())
mi <- model_info(model)
expect_true(mi$is_binomial)
expect_true(mi$is_bernoulli)
})

model <- BayesFactor::proportionBF(15, 25, p = 0.5)
mi <- insight::model_info(model)
test_that("model_info-BF-proptest", {
skip_if_not_installed("BayesFactor")
model <- BayesFactor::proportionBF(15, 25, p = 0.5)
mi <- model_info(model)
expect_true(mi$is_binomial)
expect_false(mi$is_linear)
})

model <- prop.test(15, 25, p = 0.5)
mi <- insight::model_info(model)
test_that("model_info-BF-proptest", {
test_that("model_info-proptest", {
model <- prop.test(15, 25, p = 0.5)
mi <- model_info(model)
expect_true(mi$is_binomial)
expect_false(mi$is_linear)
expect_false(mi$is_correlation)
})

skip_if_not_installed("tweedie")

d <- data.frame(x = 1:20, y = rgamma(20, shape = 5))
# Fit a poisson generalized linear model with identity link
model <- glm(y ~ x, data = d, family = statmod::tweedie(var.power = 1, link.power = 1))
mi <- insight::model_info(model)
test_that("model_info-tweedie", {
skip_if_not_installed("tweedie")
skip_if_not_installed("statmod")
d <- data.frame(x = 1:20, y = rgamma(20, shape = 5))
# Fit a poisson generalized linear model with identity link
model <- glm(y ~ x, data = d, family = statmod::tweedie(var.power = 1, link.power = 1))
mi <- model_info(model)
expect_true(mi$is_tweedie)
expect_false(mi$is_poisson)
expect_equal(mi$family, "Tweedie")
expect_identical(mi$family, "Tweedie")
})
4 changes: 2 additions & 2 deletions tests/testthat/test-spatial.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
skip_if_offline()
skip_if_not_installed("glmmTMB")
skip_if_not_installed("geoR")
suppressWarnings(skip_if_not_installed("glmmTMB"))
suppressWarnings(skip_if_not_installed("geoR"))
skip_if_not_installed("TMB")

data(ca20, package = "geoR")
Expand Down

0 comments on commit 3d07438

Please sign in to comment.