Skip to content

Commit

Permalink
allow selection of class to summarize
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed May 1, 2024
1 parent 22c1360 commit 6ec32d8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
30 changes: 30 additions & 0 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -732,11 +732,13 @@ ObliqueForest <- R6::R6Class(
pred_horizon = NULL,
pred_type = NULL,
importance_type = NULL,
class = NULL,
verbose_progress = FALSE){

# check incoming values if they were specified.
private$check_n_variables(n_variables)
private$check_verbose_progress(verbose_progress)
private$check_class(class)

if(!is.null(pred_horizon)){
private$check_pred_horizon(pred_horizon, boundary_checks = TRUE)
Expand Down Expand Up @@ -838,6 +840,13 @@ ObliqueForest <- R6::R6Class(

if(self$tree_type == 'classification'){
new_order <- insert_vals(new_order, 2, 'class')
if(!is.null(class)){
.class <- class # prevents mix-up with class in dt
pd_output <- pd_output[class == .class]
} else {
# put the highest level class on top
pd_output <- pd_output[order(-class)]
}
}

setcolorder(pd_output, new_order)
Expand Down Expand Up @@ -2153,6 +2162,27 @@ ObliqueForest <- R6::R6Class(

},

check_class = function(class = NULL){

if(!is.null(class)){

check_arg_is(arg_value = class,
arg_name = "class",
expected_class = "character")

check_arg_length(arg_value = class,
arg_name = "class",
expected_length = 1L)

check_arg_is_valid(arg_value = class,
arg_name = "class",
valid_options = self$class_levels)

}


},

# runs checks and sets defaults where needed.
# data is NULL when we are creating a new forest,
# but may be non-NULL if we update an existing one
Expand Down
9 changes: 9 additions & 0 deletions R/orsf_summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,20 @@
#'
#' orsf_summarize_uni(object, n_variables = 2, importance = 'negate')
#'
#' # for multi-category fits, you can specify which class
#' # you want to summarize:
#'
#' fit = orsf(species ~ ., data = penguins_orsf, n_tree = 25)
#' orsf_summarize_uni(fit, class = "Adelie", n_variables = 1)
#' orsf_summarize_uni(fit, class = "Gentoo", n_variables = 1)
#'
#'
orsf_summarize_uni <- function(object,
n_variables = NULL,
pred_horizon = NULL,
pred_type = NULL,
importance = NULL,
class = NULL,
verbose_progress = FALSE,
...){

Expand All @@ -72,6 +80,7 @@ orsf_summarize_uni <- function(object,
pred_horizon = pred_horizon,
pred_type = pred_type,
importance_type = importance,
class = class,
verbose_progress = verbose_progress)

}
Expand Down

0 comments on commit 6ec32d8

Please sign in to comment.