From dfc40110188094fe45c5d37e7fc9fdefba84519f Mon Sep 17 00:00:00 2001 From: mitchelloharawild Date: Wed, 18 Oct 2023 10:12:49 +1100 Subject: [PATCH] Improve handling of transformations for combination models --- NAMESPACE | 1 + NEWS.md | 5 +++ R/model.R | 5 ++- R/model_combination.R | 80 +++++++++++++++++++++++++------------------ 4 files changed, 57 insertions(+), 34 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index f8553dbc..6d4846b5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -139,6 +139,7 @@ S3method(report,null_mdl) S3method(residuals,"NULL") S3method(residuals,mdl_df) S3method(residuals,mdl_ts) +S3method(residuals,model_combination) S3method(residuals,null_mdl) S3method(response,mdl_df) S3method(response,mdl_ts) diff --git a/NEWS.md b/NEWS.md index ed32cc6c..4045c482 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # fabletools (development version) +## Improvements + +* Improved handling of `combination_model()` when used with transformed + component models. + # fabletools 0.3.4 ## New features diff --git a/R/model.R b/R/model.R index 12dd938f..d8ad403c 100644 --- a/R/model.R +++ b/R/model.R @@ -160,7 +160,10 @@ model_lhs <- function(model){ if(is_quosure(f)){ f <- get_expr(f) } - + if(is.call(f)) { + if(call_name(f) == "~") + return(f[[2]]) + } if(is.formula(f)){ f_lhs(f) } diff --git a/R/model_combination.R b/R/model_combination.R index 5da114b1..0d29f2a9 100644 --- a/R/model_combination.R +++ b/R/model_combination.R @@ -49,10 +49,15 @@ combination_model <- function(..., cmbn_fn = combination_ensemble, if(!any(map_lgl(mdls, inherits, "mdl_defn"))){ abort("`combination_model()` must contain at least one valid model definition.") } + # Guess the response variable without transformations + resp <- Reduce(intersect, lapply(mdls, function(x) all.vars(model_lhs(x)))) + if(length(resp) == 0) abort("`combination_model()` must use component models with the same response variable.") cmbn_model <- new_model_class("cmbn_mdl", train = train_combination, specials = new_specials(xreg = function(...) NULL)) - new_model_definition(cmbn_model, !!quo(!!model_lhs(mdls[[1]])), ..., + fml <- mdls[[1]]$formula + fml <- quo_set_expr(fml, sym(resp[[1]])) + new_model_definition(cmbn_model, !!fml, ..., cmbn_fn = cmbn_fn, cmbn_args = cmbn_args) } @@ -276,51 +281,51 @@ forecast.model_combination <- function(object, new_data, specials, ...){ # Compute residual covariance to adjust the forecast variance # Assumes correlation across h is identical - if(all(mdls)){ - fc_cov <- var( - cbind( - residuals(object[[1]], type = "response")[[".resid"]], - residuals(object[[2]], type = "response")[[".resid"]] - ), - na.rm = TRUE - ) - } - else{ - fc_cov <- 0 - } - object[mdls] <- map(object[mdls], forecast, new_data = new_data, ...) - object[mdls] <- map(object[mdls], function(x) x[[distribution_var(x)]]) - - if(all(mdls)){ - fc_sd <- object %>% - map(function(x) sqrt(distributional::variance(x))) %>% - transpose_dbl() - fc_cov <- suppressWarnings(stats::cov2cor(fc_cov)) - fc_cov[!is.finite(fc_cov)] <- 0 # In case of perfect forecasts - fc_cov <- map_dbl(fc_sd, function(sigma) (diag(sigma)%*%fc_cov%*%t(diag(sigma)))[1,2]) - } + fbl <- object + fbl[mdls] <- map(fbl[mdls], forecast, new_data = new_data, ...) + fbl[mdls] <- map(fbl[mdls], function(x) x[[distribution_var(x)]]) - is_normal <- map_lgl(object[mdls], function(x) all(dist_types(x) == "dist_normal")) - if(all(is_normal)){ # Improve check to ensure all distributions are normal - .dist <- eval_tidy(expr, object) + is_normal <- map_lgl(fbl[mdls], function(x) all(dist_types(x) == "dist_normal")) + if(all(is_normal)){ + .dist <- eval_tidy(expr, fbl) + + # Adjust for covariance # var(x) + var(y) + 2*cov(x,y) - .dist <- distributional::dist_normal(mean(.dist), sqrt(distributional::variance(.dist) + 2*fc_cov)) + if(all(mdls)) { + fc_cov <- var( + cbind( + residuals(object[[1]], type = "response")[[".resid"]], + residuals(object[[2]], type = "response")[[".resid"]] + ), + na.rm = TRUE + ) + fc_sd <- fbl %>% + map(function(x) sqrt(distributional::variance(x))) %>% + transpose_dbl() + fc_cov <- suppressWarnings(stats::cov2cor(fc_cov)) + fc_cov[!is.finite(fc_cov)] <- 0 # In case of perfect forecasts + fc_cov <- map_dbl(fc_sd, function(sigma) (diag(sigma)%*%fc_cov%*%t(diag(sigma)))[1,2]) + .dist <- distributional::dist_normal(mean(.dist), sqrt(distributional::variance(.dist) + 2*fc_cov)) + } } else { - .dist <- distributional::dist_degenerate(eval_tidy(expr, map(object, mean))) + .dist <- distributional::dist_degenerate(eval_tidy(expr, map(fbl, mean))) } .dist } #' @export -generate.model_combination <- function(x, new_data, specials, ...){ - if(".innov" %in% new_data){ - abort("Providing innovations for simulating combination models is not supported.") +generate.model_combination <- function(x, new_data, specials, bootstrap = FALSE, ...){ + if(".innov" %in% names(new_data)){ + # Assume bootstrapped paths are requested (this needs future work) + bootstrap <- TRUE + new_data[[".innov"]] <- NULL + # abort("Providing innovations for simulating combination models is not supported.") } mdls <- map_lgl(x, is_model) expr <- attr(x, "combination") - x[mdls] <- map(x[mdls], generate, new_data, ...) + x[mdls] <- map(x[mdls], generate, new_data, bootstrap = bootstrap, ...) out <- x[[which(mdls)[1]]] sims <- map(x, function(x) if(is_tsibble(x)) x[[".sim"]] else x) out[[".sim"]] <- eval_tidy(expr, sims) @@ -335,3 +340,12 @@ fitted.model_combination <- function(object, ...){ fits <- map(object, function(x) if(is_tsibble(x)) x[[".fitted"]] else x) eval_tidy(expr, fits) } + +#' @export +residuals.model_combination <- function(object, ...) { + mdls <- map_lgl(object, is_model) + expr <- attr(object, "combination") + object[mdls] <- map(object[mdls], residuals, type = "response", ...) + res <- map(object, function(x) if(is_tsibble(x)) x[[".resid"]] else x) + eval_tidy(expr, res) +} \ No newline at end of file