Skip to content

Commit

Permalink
Improve handling of transformations for combination models
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchelloharawild committed Oct 17, 2023
1 parent 782848d commit dfc4011
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 34 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ S3method(report,null_mdl)
S3method(residuals,"NULL")
S3method(residuals,mdl_df)
S3method(residuals,mdl_ts)
S3method(residuals,model_combination)
S3method(residuals,null_mdl)
S3method(response,mdl_df)
S3method(response,mdl_ts)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# fabletools (development version)

## Improvements

* Improved handling of `combination_model()` when used with transformed
component models.

# fabletools 0.3.4

## New features
Expand Down
5 changes: 4 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ model_lhs <- function(model){
if(is_quosure(f)){
f <- get_expr(f)
}

if(is.call(f)) {
if(call_name(f) == "~")
return(f[[2]])
}
if(is.formula(f)){
f_lhs(f)
}
Expand Down
80 changes: 47 additions & 33 deletions R/model_combination.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,15 @@ combination_model <- function(..., cmbn_fn = combination_ensemble,
if(!any(map_lgl(mdls, inherits, "mdl_defn"))){
abort("`combination_model()` must contain at least one valid model definition.")
}
# Guess the response variable without transformations
resp <- Reduce(intersect, lapply(mdls, function(x) all.vars(model_lhs(x))))
if(length(resp) == 0) abort("`combination_model()` must use component models with the same response variable.")

cmbn_model <- new_model_class("cmbn_mdl", train = train_combination,
specials = new_specials(xreg = function(...) NULL))
new_model_definition(cmbn_model, !!quo(!!model_lhs(mdls[[1]])), ...,
fml <- mdls[[1]]$formula
fml <- quo_set_expr(fml, sym(resp[[1]]))
new_model_definition(cmbn_model, !!fml, ...,
cmbn_fn = cmbn_fn, cmbn_args = cmbn_args)
}

Expand Down Expand Up @@ -276,51 +281,51 @@ forecast.model_combination <- function(object, new_data, specials, ...){
# Compute residual covariance to adjust the forecast variance
# Assumes correlation across h is identical

if(all(mdls)){
fc_cov <- var(
cbind(
residuals(object[[1]], type = "response")[[".resid"]],
residuals(object[[2]], type = "response")[[".resid"]]
),
na.rm = TRUE
)
}
else{
fc_cov <- 0
}
object[mdls] <- map(object[mdls], forecast, new_data = new_data, ...)
object[mdls] <- map(object[mdls], function(x) x[[distribution_var(x)]])

if(all(mdls)){
fc_sd <- object %>%
map(function(x) sqrt(distributional::variance(x))) %>%
transpose_dbl()
fc_cov <- suppressWarnings(stats::cov2cor(fc_cov))
fc_cov[!is.finite(fc_cov)] <- 0 # In case of perfect forecasts
fc_cov <- map_dbl(fc_sd, function(sigma) (diag(sigma)%*%fc_cov%*%t(diag(sigma)))[1,2])
}
fbl <- object
fbl[mdls] <- map(fbl[mdls], forecast, new_data = new_data, ...)
fbl[mdls] <- map(fbl[mdls], function(x) x[[distribution_var(x)]])

is_normal <- map_lgl(object[mdls], function(x) all(dist_types(x) == "dist_normal"))
if(all(is_normal)){ # Improve check to ensure all distributions are normal
.dist <- eval_tidy(expr, object)
is_normal <- map_lgl(fbl[mdls], function(x) all(dist_types(x) == "dist_normal"))
if(all(is_normal)){
.dist <- eval_tidy(expr, fbl)

# Adjust for covariance
# var(x) + var(y) + 2*cov(x,y)
.dist <- distributional::dist_normal(mean(.dist), sqrt(distributional::variance(.dist) + 2*fc_cov))
if(all(mdls)) {
fc_cov <- var(
cbind(
residuals(object[[1]], type = "response")[[".resid"]],
residuals(object[[2]], type = "response")[[".resid"]]
),
na.rm = TRUE
)
fc_sd <- fbl %>%
map(function(x) sqrt(distributional::variance(x))) %>%
transpose_dbl()
fc_cov <- suppressWarnings(stats::cov2cor(fc_cov))
fc_cov[!is.finite(fc_cov)] <- 0 # In case of perfect forecasts
fc_cov <- map_dbl(fc_sd, function(sigma) (diag(sigma)%*%fc_cov%*%t(diag(sigma)))[1,2])
.dist <- distributional::dist_normal(mean(.dist), sqrt(distributional::variance(.dist) + 2*fc_cov))
}
} else {
.dist <- distributional::dist_degenerate(eval_tidy(expr, map(object, mean)))
.dist <- distributional::dist_degenerate(eval_tidy(expr, map(fbl, mean)))
}

.dist
}

#' @export
generate.model_combination <- function(x, new_data, specials, ...){
if(".innov" %in% new_data){
abort("Providing innovations for simulating combination models is not supported.")
generate.model_combination <- function(x, new_data, specials, bootstrap = FALSE, ...){
if(".innov" %in% names(new_data)){
# Assume bootstrapped paths are requested (this needs future work)
bootstrap <- TRUE
new_data[[".innov"]] <- NULL
# abort("Providing innovations for simulating combination models is not supported.")
}

mdls <- map_lgl(x, is_model)
expr <- attr(x, "combination")
x[mdls] <- map(x[mdls], generate, new_data, ...)
x[mdls] <- map(x[mdls], generate, new_data, bootstrap = bootstrap, ...)
out <- x[[which(mdls)[1]]]
sims <- map(x, function(x) if(is_tsibble(x)) x[[".sim"]] else x)
out[[".sim"]] <- eval_tidy(expr, sims)
Expand All @@ -335,3 +340,12 @@ fitted.model_combination <- function(object, ...){
fits <- map(object, function(x) if(is_tsibble(x)) x[[".fitted"]] else x)
eval_tidy(expr, fits)
}

#' @export
residuals.model_combination <- function(object, ...) {
mdls <- map_lgl(object, is_model)
expr <- attr(object, "combination")
object[mdls] <- map(object[mdls], residuals, type = "response", ...)
res <- map(object, function(x) if(is_tsibble(x)) x[[".resid"]] else x)
eval_tidy(expr, res)
}

0 comments on commit dfc4011

Please sign in to comment.