Skip to content

Commit

Permalink
assert equal fits when weights are used
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed May 13, 2024
1 parent ca0e86b commit c47976f
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions tests/testthat/test-aorsf_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ test_that("regression model object", {

skip_if_not_installed("aorsf", "0.1.3")

set.seed(321)
wts <- sample(0:5, size = nrow(mtcars_orsf), replace = TRUE)

set.seed(1234)
aorsf_regr_fit <- aorsf::orsf(
# everyone's favorite
Expand All @@ -20,14 +23,27 @@ test_that("regression model object", {
n_thread = 1
)

set.seed(1234)
aorsf_regr_fit_wtd <- aorsf::orsf_update(aorsf_regr_fit, weights = wts)

# formula method
regr_spec <- rand_forest(trees = 10) %>%
set_engine("aorsf") %>%
set_mode("regression")

set.seed(1234)
expect_no_condition(
bonsai_regr_fit <- fit(regr_spec, data = mtcars_orsf, mpg ~ .)
bonsai_regr_fit <- fit(regr_spec,
data = mtcars_orsf,
formula = mpg ~ .)
)

set.seed(1234)
expect_no_condition(
bonsai_regr_fit_wtd <- fit(regr_spec,
data = mtcars_orsf,
formula = mpg ~ .,
case_weights = importance_weights(wts))
)

expect_equal(
Expand All @@ -36,27 +52,50 @@ test_that("regression model object", {
ignore_formula_env = TRUE
)

expect_equal(
bonsai_regr_fit_wtd$fit,
aorsf_regr_fit_wtd,
ignore_formula_env = TRUE
)

})

test_that("classification model object", {

skip_if_not_installed("aorsf", "0.1.3")

set.seed(321)
wts <- sample(0:5, size = nrow(mtcars_orsf), replace = TRUE)

set.seed(1234)
aorsf_clsf_fit <- aorsf::orsf(
data = mtcars_orsf, formula = vs ~ .,
n_tree = 10,
n_thread = 1
)

set.seed(1234)
aorsf_clsf_fit_wtd <- aorsf::orsf_update(aorsf_clsf_fit, weights = wts)


# formula method
clsf_spec <- rand_forest(trees = 10) %>%
set_engine("aorsf") %>%
set_mode("classification")

set.seed(1234)
expect_no_condition(
bonsai_clsf_fit <- fit(clsf_spec, data = mtcars_orsf, vs ~ .)
bonsai_clsf_fit <- fit(clsf_spec,
data = mtcars_orsf,
formula = vs ~ .)
)

set.seed(1234)
expect_no_condition(
bonsai_clsf_fit_wtd <- fit(clsf_spec,
data = mtcars_orsf,
formula = vs ~ .,
case_weights = importance_weights(wts))
)

expect_equal(
Expand All @@ -65,6 +104,12 @@ test_that("classification model object", {
ignore_formula_env = TRUE
)

expect_equal(
bonsai_clsf_fit_wtd$fit,
aorsf_clsf_fit_wtd,
ignore_formula_env = TRUE
)

})

test_that("regression predictions", {
Expand Down

0 comments on commit c47976f

Please sign in to comment.