Skip to content

Commit cf9a823

Browse files
authored
Fixed # 1212 enabled augment() for quartile regression (#1292)
1 parent 765da5a commit cf9a823

File tree

1 file changed

+45
-6
lines changed

1 file changed

+45
-6
lines changed

R/augment.R

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
#' page in the references below). This enables the user to compute performance
3232
#' metrics in the \pkg{yardstick} package.
3333
#'
34+
#' ## Quantile Regression
35+
#'
36+
#' For quantile regression models, a `.pred_quantile` column is added that
37+
#' contains the quantile predictions for each row. This column has a special
38+
#' class `"quantile_pred"` and can be unnested using [tidyr::unnest()]
39+
#'
3440
#' @param new_data A data frame or matrix.
3541
#' @param ... Not currently used.
3642
#' @rdname augment
@@ -78,14 +84,31 @@
7884
#' augment(cls_xy, cls_tst)
7985
#' augment(cls_xy, cls_tst[, -3])
8086
#'
87+
#' # ------------------------------------------------------------------------------
88+
#'
89+
#' # Quantile regression example
90+
#' qr_form <-
91+
#' linear_reg() |>
92+
#' set_engine("quantreg") |>
93+
#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |>
94+
#' fit(mpg ~ ., data = car_trn)
95+
#'
96+
#' augment(qr_form, car_tst)
97+
#' augment(qr_form, car_tst[, -1])
98+
#'
8199
augment.model_fit <- function(x, new_data, eval_time = NULL, ...) {
82100
new_data <- tibble::new_tibble(new_data)
83101
res <-
84102
switch(
85103
x$spec$mode,
86-
"regression" = augment_regression(x, new_data),
87-
"classification" = augment_classification(x, new_data),
88-
"censored regression" = augment_censored(x, new_data, eval_time = eval_time),
104+
"regression" = augment_regression(x, new_data),
105+
"classification" = augment_classification(x, new_data),
106+
"censored regression" = augment_censored(
107+
x,
108+
new_data,
109+
eval_time = eval_time
110+
),
111+
"quantile regression" = augment_quantile_regression(x, new_data),
89112
cli::cli_abort(
90113
c(
91114
"Unknown mode {.val {x$spec$mode}}.",
@@ -106,7 +129,11 @@ augment_regression <- function(x, new_data) {
106129
ret <- dplyr::mutate(ret, .resid = !!rlang::sym(y_nm) - .pred)
107130
}
108131
}
109-
dplyr::relocate(ret, dplyr::starts_with(".pred"), dplyr::starts_with(".resid"))
132+
dplyr::relocate(
133+
ret,
134+
dplyr::starts_with(".pred"),
135+
dplyr::starts_with(".resid")
136+
)
110137
}
111138

112139
augment_classification <- function(x, new_data) {
@@ -117,11 +144,15 @@ augment_classification <- function(x, new_data) {
117144
}
118145

119146
if (spec_has_pred_type(x, "class")) {
120-
ret <- dplyr::bind_cols(predict(x, new_data = new_data, type = "class"), ret)
147+
ret <- dplyr::bind_cols(
148+
predict(x, new_data = new_data, type = "class"),
149+
ret
150+
)
121151
}
122152
ret
123153
}
124154

155+
125156
# nocov start
126157
# tested in tidymodels/extratests#
127158
augment_censored <- function(x, new_data, eval_time = NULL) {
@@ -145,7 +176,8 @@ augment_censored <- function(x, new_data, eval_time = NULL) {
145176
.filter_eval_time(eval_time)
146177
ret <- dplyr::bind_cols(
147178
predict(x, new_data = new_data, type = "survival", eval_time = eval_time),
148-
ret)
179+
ret
180+
)
149181
# Add inverse probability weights when the outcome is present in new_data
150182
y_col <- .find_surv_col(new_data, fail = FALSE)
151183
if (length(y_col) != 0) {
@@ -155,3 +187,10 @@ augment_censored <- function(x, new_data, eval_time = NULL) {
155187
ret
156188
}
157189
# nocov end
190+
191+
augment_quantile_regression <- function(x, new_data) {
192+
ret <- new_data
193+
check_spec_pred_type(x, "quantile")
194+
ret <- dplyr::bind_cols(predict(x, new_data = new_data), ret)
195+
dplyr::relocate(ret, dplyr::starts_with(".pred"))
196+
}

0 commit comments

Comments
 (0)