Skip to content

Commit

Permalink
use standard number of spaces for tab (#168)
Browse files Browse the repository at this point in the history
* set `NumSpacesForTab` to usual value

* run `styler::style_pkg()`
  • Loading branch information
simonpcouch authored Oct 15, 2024
1 parent 6c2bb00 commit 7ae74bb
Show file tree
Hide file tree
Showing 37 changed files with 2,175 additions and 2,149 deletions.
17 changes: 8 additions & 9 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,14 @@ rank_plot <- function(object, rank_metric = NULL, metric = NULL,
has_std_error <- !all(is.na(res$std_err))

p <-
switch(
type,
class =
ggplot(res, aes(x = rank, y = mean, col = model)) +
geom_point(aes(shape = preprocessor)),
wflow_id =
ggplot(res, aes(x = rank, y = mean, col = wflow_id)) +
geom_point()
)
switch(type,
class =
ggplot(res, aes(x = rank, y = mean, col = model)) +
geom_point(aes(shape = preprocessor)),
wflow_id =
ggplot(res, aes(x = rank, y = mean, col = wflow_id)) +
geom_point()
)

if (num_metrics > 1) {
res$.metric <- factor(as.character(res$.metric), levels = metrics$metric)
Expand Down
20 changes: 11 additions & 9 deletions R/checks.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
check_wf_set <- function(x, arg = caller_arg(x), call = caller_env()) {
if (!inherits(x, "workflow_set")) {
cli::cli_abort(
"{arg} must be a workflow set, not {obj_type_friendly(x)}.",
call = call
"{arg} must be a workflow set, not {obj_type_friendly(x)}.",
call = call
)
}

Expand Down Expand Up @@ -78,8 +78,10 @@ check_options <- function(model, id, global, action = "fail") {
}

check_tune_args <- function(x) {
arg_names <- c("resamples", "param_info", "grid", "metrics", "control",
"iter", "objective", "initial", "eval_time")
arg_names <- c(
"resamples", "param_info", "grid", "metrics", "control",
"iter", "objective", "initial", "eval_time"
)
bad_args <- setdiff(x, arg_names)
if (length(bad_args) > 0) {
cli::cli_abort(
Expand Down Expand Up @@ -129,7 +131,7 @@ check_names <- function(x) {
xtab <- table(nms)
if (any(xtab > 1)) {
cli::cli_abort(
"The workflow names should be unique: {.val {names(xtab)[xtab > 1]}}."
"The workflow names should be unique: {.val {names(xtab)[xtab > 1]}}."
)
}
invisible(NULL)
Expand All @@ -140,10 +142,10 @@ check_for_workflow <- function(x) {
if (any(no_wflow)) {
bad <- names(no_wflow)[no_wflow]
cli::cli_abort(
c(
"The objects {.val {bad}} do not have workflows.",
"i" = "Use the control option {.code save_workflow} and re-run."
)
c(
"The objects {.val {bad}} do not have workflows.",
"i" = "Use the control option {.code save_workflow} and re-run."
)
)
}
invisible(NULL)
Expand Down
2 changes: 1 addition & 1 deletion R/collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ collect_notes.workflow_set <- function(x, ...) {
res
}

#'
#'
#' @export
#' @rdname collect_metrics.workflow_set
collect_extracts.workflow_set <- function(x, ...) {
Expand Down
4 changes: 2 additions & 2 deletions R/comments.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ comment_reset <- function(x, id) {
#' @export
#' @rdname comment_add
comment_print <- function(x, id = NULL, ...) {
check_wf_set(x)
if (is.null(id)) {
check_wf_set(x)
if (is.null(id)) {
id <- x$wflow_id
}

Expand Down
1 change: 0 additions & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,3 @@ NULL
#'
#' chi_features_set
NULL

18 changes: 9 additions & 9 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ check_empty_dots <- function(...) {
cli::cli_abort("{.arg estimated} should be a named argument.")
}
if (length(opts) > 0) {
cli::cli_abort("{.arg ...} are not used in this function.")
cli::cli_abort("{.arg ...} are not used in this function.")
}
invisible(NULL)
}
Expand Down Expand Up @@ -142,22 +142,22 @@ extract_preprocessor.workflow_set <- function(x, id, ...) {
#' @export
#' @rdname extract_workflow_set_result
extract_parameter_set_dials.workflow_set <- function(x, id, ...) {
y <- filter_id(x, id)
y <- filter_id(x, id)

if ("param_info" %in% names(y$option[[1]])) {
return(y$option[[1]][["param_info"]])
}
if ("param_info" %in% names(y$option[[1]])) {
return(y$option[[1]][["param_info"]])
}

extract_parameter_set_dials(y$info[[1]]$workflow[[1]])
extract_parameter_set_dials(y$info[[1]]$workflow[[1]])
}

#' @export
#' @rdname extract_workflow_set_result
extract_parameter_dials.workflow_set <- function(x, id, parameter, ...) {
res <- extract_parameter_set_dials(x, id)
res <- extract_parameter_dials(res, parameter)
res <- extract_parameter_set_dials(x, id)
res <- extract_parameter_dials(res, parameter)

res
res
}

# ------------------------------------------------------------------------------
Expand Down
32 changes: 17 additions & 15 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@
#' @method fit workflow_set
#' @export
fit.workflow_set <- function(object, ...) {
msg <- "`fit()` is not well-defined for workflow sets."
msg <- "`fit()` is not well-defined for workflow sets."

# supply a different message depending on whether the
# workflow set has been (attempted to have been) fitted or not
if (!all(purrr::map_lgl(object$result, ~ identical(.x, list())))) {
# if fitted:
msg <-
c(msg,
"i" = "Please see {.help [{.fun fit_best}](workflowsets::fit_best.workflow_set)}.")
} else {
# if not fitted:
msg <-
c(msg,
"i" = "Please see {.help [{.fun workflow_map}](workflowsets::workflow_map)}.")
}
# supply a different message depending on whether the
# workflow set has been (attempted to have been) fitted or not
if (!all(purrr::map_lgl(object$result, ~ identical(.x, list())))) {
# if fitted:
msg <-
c(msg,
"i" = "Please see {.help [{.fun fit_best}](workflowsets::fit_best.workflow_set)}."
)
} else {
# if not fitted:
msg <-
c(msg,
"i" = "Please see {.help [{.fun workflow_map}](workflowsets::workflow_map)}."
)
}

cli::cli_abort(msg)
cli::cli_abort(msg)
}
68 changes: 34 additions & 34 deletions R/fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,25 @@ tune::fit_best
#' library(rsample)
#'
#' data(Chicago)
#' Chicago <- Chicago[1:1195,]
#' Chicago <- Chicago[1:1195, ]
#'
#' time_val_split <-
#' sliding_period(
#' Chicago,
#' date,
#' "month",
#' lookback = 38,
#' assess_stop = 1
#' )
#' sliding_period(
#' Chicago,
#' date,
#' "month",
#' lookback = 38,
#' assess_stop = 1
#' )
#'
#' chi_features_set
#'
#' chi_features_res_new <-
#' chi_features_set %>%
#' # note: must set `save_workflow = TRUE` to use `fit_best()`
#' option_add(control = control_grid(save_workflow = TRUE)) %>%
#' # evaluate with resamples
#' workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE)
#' chi_features_set %>%
#' # note: must set `save_workflow = TRUE` to use `fit_best()`
#' option_add(control = control_grid(save_workflow = TRUE)) %>%
#' # evaluate with resamples
#' workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE)
#'
#' chi_features_res_new
#'
Expand All @@ -73,33 +73,33 @@ tune::fit_best
#' @name fit_best.workflow_set
#' @export
fit_best.workflow_set <- function(x, metric = NULL, eval_time = NULL, ...) {
check_string(metric, allow_null = TRUE)
result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]])
met_set <- tune::.get_tune_metrics(result_1)
check_string(metric, allow_null = TRUE)
result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]])
met_set <- tune::.get_tune_metrics(result_1)

if (is.null(metric)) {
metric <- .get_tune_metric_names(result_1)[1]
} else {
tune::check_metric_in_tune_results(tibble::as_tibble(met_set), metric)
}
if (is.null(metric)) {
metric <- .get_tune_metric_names(result_1)[1]
} else {
tune::check_metric_in_tune_results(tibble::as_tibble(met_set), metric)
}

if (is.null(eval_time) & is_dyn(met_set, metric)) {
eval_time <- tune::.get_tune_eval_times(result_1)[1]
}
if (is.null(eval_time) & is_dyn(met_set, metric)) {
eval_time <- tune::.get_tune_eval_times(result_1)[1]
}

rankings <-
rank_results(
x,
rank_metric = metric,
select_best = TRUE,
eval_time = eval_time
)
rankings <-
rank_results(
x,
rank_metric = metric,
select_best = TRUE,
eval_time = eval_time
)

tune_res <- extract_workflow_set_result(x, id = rankings$wflow_id[1])
tune_res <- extract_workflow_set_result(x, id = rankings$wflow_id[1])

best_params <- select_best(tune_res, metric = metric, eval_time = eval_time)
best_params <- select_best(tune_res, metric = metric, eval_time = eval_time)

fit_best(tune_res, parameters = best_params, ...)
fit_best(tune_res, parameters = best_params, ...)
}

# from unexported
Expand Down
1 change: 0 additions & 1 deletion R/misc.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

make_workflow <- function(x, y) {
exp_classes <- c("formula", "recipe", "workflow_variables")
w <-
Expand Down
28 changes: 14 additions & 14 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
#' @method predict workflow_set
#' @export
predict.workflow_set <- function(object, ...) {
cli::cli_abort(c(
"`predict()` is not well-defined for workflow sets.",
"i" = "To predict with the optimal model configuration from a workflow \\
set, ensure that the workflow set was fitted with the \\
{.help [control option](workflowsets::option_add)} \\
{.help [{.code save_workflow = TRUE}](tune::control_grid)}, run \\
{.help [{.fun fit_best}](tune::fit_best)}, and then predict using \\
{.help [{.fun predict}](workflows::predict.workflow)} on its output.",
"i" = "To collect predictions from a workflow set, ensure that \\
the workflow set was fitted with the \\
{.help [control option](workflowsets::option_add)} \\
{.help [{.code save_pred = TRUE}](tune::control_grid)} and run \\
{.help [{.fun collect_predictions}](tune::collect_predictions)}."
))
cli::cli_abort(c(
"`predict()` is not well-defined for workflow sets.",
"i" = "To predict with the optimal model configuration from a workflow \\
set, ensure that the workflow set was fitted with the \\
{.help [control option](workflowsets::option_add)} \\
{.help [{.code save_workflow = TRUE}](tune::control_grid)}, run \\
{.help [{.fun fit_best}](tune::fit_best)}, and then predict using \\
{.help [{.fun predict}](workflows::predict.workflow)} on its output.",
"i" = "To collect predictions from a workflow set, ensure that \\
the workflow set was fitted with the \\
{.help [control option](workflowsets::option_add)} \\
{.help [{.code save_pred = TRUE}](tune::control_grid)} and run \\
{.help [{.fun collect_predictions}](tune::collect_predictions)}."
))
}
6 changes: 4 additions & 2 deletions R/rank_results.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ rank_results <- function(x, rank_metric = NULL, eval_time = NULL, select_best =
eval_time <- tune::choose_eval_time(result_1, metric, eval_time = eval_time)

results <- collect_metrics(x) %>%
dplyr::select(wflow_id, .config, .metric, mean, std_err, n,
dplyr::any_of(".eval_time")) %>%
dplyr::select(
wflow_id, .config, .metric, mean, std_err, n,
dplyr::any_of(".eval_time")
) %>%
dplyr::full_join(wflow_info, by = "wflow_id") %>%
dplyr::select(-comment, -workflow)

Expand Down
46 changes: 23 additions & 23 deletions R/update.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,35 @@
#' extract_workflow(new_set, id = "none_cart")
#' @export
update_workflow_model <- function(x, id, spec, formula = NULL) {
check_wf_set(x)
check_string(id)
check_formula(formula, allow_null = TRUE)
check_wf_set(x)
check_string(id)
check_formula(formula, allow_null = TRUE)

wflow <- extract_workflow(x, id = id)
wflow <- workflows::update_model(wflow, spec = spec, formula = formula)
id_ind <- which(x$wflow_id == id)
x$info[[id_ind]]$workflow[[1]] <- wflow
# Remove any existing results since they are now inconsistent
if (!identical(x$result[[id_ind]], list())) {
x$result[[id_ind]] <- list()
}
x
wflow <- extract_workflow(x, id = id)
wflow <- workflows::update_model(wflow, spec = spec, formula = formula)
id_ind <- which(x$wflow_id == id)
x$info[[id_ind]]$workflow[[1]] <- wflow
# Remove any existing results since they are now inconsistent
if (!identical(x$result[[id_ind]], list())) {
x$result[[id_ind]] <- list()
}
x
}


#' @rdname update_workflow_model
#' @export
update_workflow_recipe <- function(x, id, recipe, blueprint = NULL) {
check_wf_set(x)
check_string(id)
check_wf_set(x)
check_string(id)

wflow <- extract_workflow(x, id = id)
wflow <- workflows::update_recipe(wflow, recipe = recipe, blueprint = blueprint)
id_ind <- which(x$wflow_id == id)
x$info[[id_ind]]$workflow[[1]] <- wflow
# Remove any existing results since they are now inconsistent
if (!identical(x$result[[id_ind]], list())) {
x$result[[id_ind]] <- list()
}
x
wflow <- extract_workflow(x, id = id)
wflow <- workflows::update_recipe(wflow, recipe = recipe, blueprint = blueprint)
id_ind <- which(x$wflow_id == id)
x$info[[id_ind]]$workflow[[1]] <- wflow
# Remove any existing results since they are now inconsistent
if (!identical(x$result[[id_ind]], list())) {
x$result[[id_ind]] <- list()
}
x
}
Loading

0 comments on commit 7ae74bb

Please sign in to comment.