31
31
# ' page in the references below). This enables the user to compute performance
32
32
# ' metrics in the \pkg{yardstick} package.
33
33
# '
34
+ # ' ## Quantile Regression
35
+ # '
36
+ # ' For quantile regression models, a `.pred_quantile` column is added that
37
+ # ' contains the quantile predictions for each row. This column has a special
38
+ # ' class `"quantile_pred"` and can be unnested using [tidyr::unnest()]
39
+ # '
34
40
# ' @param new_data A data frame or matrix.
35
41
# ' @param ... Not currently used.
36
42
# ' @rdname augment
78
84
# ' augment(cls_xy, cls_tst)
79
85
# ' augment(cls_xy, cls_tst[, -3])
80
86
# '
87
+ # ' # ------------------------------------------------------------------------------
88
+ # '
89
+ # ' # Quantile regression example
90
+ # ' qr_form <-
91
+ # ' linear_reg() |>
92
+ # ' set_engine("quantreg") |>
93
+ # ' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |>
94
+ # ' fit(mpg ~ ., data = car_trn)
95
+ # '
96
+ # ' augment(qr_form, car_tst)
97
+ # ' augment(qr_form, car_tst[, -1])
98
+ # '
81
99
augment.model_fit <- function (x , new_data , eval_time = NULL , ... ) {
82
100
new_data <- tibble :: new_tibble(new_data )
83
101
res <-
84
102
switch (
85
103
x $ spec $ mode ,
86
- " regression" = augment_regression(x , new_data ),
87
- " classification" = augment_classification(x , new_data ),
88
- " censored regression" = augment_censored(x , new_data , eval_time = eval_time ),
104
+ " regression" = augment_regression(x , new_data ),
105
+ " classification" = augment_classification(x , new_data ),
106
+ " censored regression" = augment_censored(
107
+ x ,
108
+ new_data ,
109
+ eval_time = eval_time
110
+ ),
111
+ " quantile regression" = augment_quantile_regression(x , new_data ),
89
112
cli :: cli_abort(
90
113
c(
91
114
" Unknown mode {.val {x$spec$mode}}." ,
@@ -106,7 +129,11 @@ augment_regression <- function(x, new_data) {
106
129
ret <- dplyr :: mutate(ret , .resid = !! rlang :: sym(y_nm ) - .pred )
107
130
}
108
131
}
109
- dplyr :: relocate(ret , dplyr :: starts_with(" .pred" ), dplyr :: starts_with(" .resid" ))
132
+ dplyr :: relocate(
133
+ ret ,
134
+ dplyr :: starts_with(" .pred" ),
135
+ dplyr :: starts_with(" .resid" )
136
+ )
110
137
}
111
138
112
139
augment_classification <- function (x , new_data ) {
@@ -117,11 +144,15 @@ augment_classification <- function(x, new_data) {
117
144
}
118
145
119
146
if (spec_has_pred_type(x , " class" )) {
120
- ret <- dplyr :: bind_cols(predict(x , new_data = new_data , type = " class" ), ret )
147
+ ret <- dplyr :: bind_cols(
148
+ predict(x , new_data = new_data , type = " class" ),
149
+ ret
150
+ )
121
151
}
122
152
ret
123
153
}
124
154
155
+
125
156
# nocov start
126
157
# tested in tidymodels/extratests#
127
158
augment_censored <- function (x , new_data , eval_time = NULL ) {
@@ -145,7 +176,8 @@ augment_censored <- function(x, new_data, eval_time = NULL) {
145
176
.filter_eval_time(eval_time )
146
177
ret <- dplyr :: bind_cols(
147
178
predict(x , new_data = new_data , type = " survival" , eval_time = eval_time ),
148
- ret )
179
+ ret
180
+ )
149
181
# Add inverse probability weights when the outcome is present in new_data
150
182
y_col <- .find_surv_col(new_data , fail = FALSE )
151
183
if (length(y_col ) != 0 ) {
@@ -155,3 +187,10 @@ augment_censored <- function(x, new_data, eval_time = NULL) {
155
187
ret
156
188
}
157
189
# nocov end
190
+
191
+ augment_quantile_regression <- function (x , new_data ) {
192
+ ret <- new_data
193
+ check_spec_pred_type(x , " quantile" )
194
+ ret <- dplyr :: bind_cols(predict(x , new_data = new_data ), ret )
195
+ dplyr :: relocate(ret , dplyr :: starts_with(" .pred" ))
196
+ }
0 commit comments