diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 864293e9..8c832576 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -732,11 +732,13 @@ ObliqueForest <- R6::R6Class( pred_horizon = NULL, pred_type = NULL, importance_type = NULL, + class = NULL, verbose_progress = FALSE){ # check incoming values if they were specified. private$check_n_variables(n_variables) private$check_verbose_progress(verbose_progress) + private$check_class(class) if(!is.null(pred_horizon)){ private$check_pred_horizon(pred_horizon, boundary_checks = TRUE) @@ -838,6 +840,13 @@ ObliqueForest <- R6::R6Class( if(self$tree_type == 'classification'){ new_order <- insert_vals(new_order, 2, 'class') + if(!is.null(class)){ + .class <- class # prevents mix-up with class in dt + pd_output <- pd_output[class == .class] + } else { + # put the highest level class on top + pd_output <- pd_output[order(-class)] + } } setcolorder(pd_output, new_order) @@ -2153,6 +2162,27 @@ ObliqueForest <- R6::R6Class( }, + check_class = function(class = NULL){ + + if(!is.null(class)){ + + check_arg_is(arg_value = class, + arg_name = "class", + expected_class = "character") + + check_arg_length(arg_value = class, + arg_name = "class", + expected_length = 1L) + + check_arg_is_valid(arg_value = class, + arg_name = "class", + valid_options = self$class_levels) + + } + + + }, + # runs checks and sets defaults where needed. # data is NULL when we are creating a new forest, # but may be non-NULL if we update an existing one diff --git a/R/orsf_summary.R b/R/orsf_summary.R index a2024ffb..5680d94b 100644 --- a/R/orsf_summary.R +++ b/R/orsf_summary.R @@ -53,12 +53,20 @@ #' #' orsf_summarize_uni(object, n_variables = 2, importance = 'negate') #' +#' # for multi-category fits, you can specify which class +#' # you want to summarize: +#' +#' fit = orsf(species ~ ., data = penguins_orsf, n_tree = 25) +#' orsf_summarize_uni(fit, class = "Adelie", n_variables = 1) +#' orsf_summarize_uni(fit, class = "Gentoo", n_variables = 1) +#' #' orsf_summarize_uni <- function(object, n_variables = NULL, pred_horizon = NULL, pred_type = NULL, importance = NULL, + class = NULL, verbose_progress = FALSE, ...){ @@ -72,6 +80,7 @@ orsf_summarize_uni <- function(object, pred_horizon = pred_horizon, pred_type = pred_type, importance_type = importance, + class = class, verbose_progress = verbose_progress) }