diff --git a/R/reconciliation.R b/R/reconciliation.R index cc6ac024..31cdddd5 100644 --- a/R/reconciliation.R +++ b/R/reconciliation.R @@ -37,7 +37,8 @@ reconcile.mdl_df <- function(.data, ...){ #' @param models A column of models in a mable. #' @param method The reconciliation method to use. #' @param sparse If TRUE, the reconciliation will be computed using sparse -#' matrix algebra? By default, sparse matrices will be used if the MatrixM +#' @param immu Vector of logical indicating whether series should be immutable after reconciliation. +#' matrix algebra? By default, sparse matrices will be used if the Matrix #' package is installed. #' #' @seealso @@ -48,12 +49,12 @@ reconcile.mdl_df <- function(.data, ...){ #' #' @export min_trace <- function(models, method = c("wls_var", "ols", "wls_struct", "mint_cov", "mint_shrink"), - sparse = NULL){ + sparse = NULL, immu = NULL){ if(is.null(sparse)){ sparse <- requireNamespace("Matrix", quietly = TRUE) } structure(models, class = c("lst_mint_mdl", "lst_mdl", "list"), - method = match.arg(method), sparse = sparse) + method = match.arg(method), sparse = sparse, immu = immu) } #' @export @@ -62,6 +63,7 @@ forecast.lst_mint_mdl <- function(object, key_data, point_forecast = list(.mean = mean), ...){ method <- object%@%"method" sparse <- object%@%"sparse" + immu <- object%@%"immu" if(sparse){ require_package("Matrix") as.matrix <- Matrix::as.matrix @@ -145,9 +147,20 @@ forecast.lst_mint_mdl <- function(object, key_data, -S[row_agg,,drop = FALSE] ) U <- U[, order(c(row_agg, row_btm)), drop = FALSE] + if (!is.null(immu)) { + stopifnot("The specified immutable series can not be all immutable." = Matrix::rankMatrix(S[immu,,drop = FALSE]) == sum(immu)) + ind_vec <- Matrix::sparseMatrix(i=1:sum(immu), j=which(immu), x=1, dims = c(sum(immu), length(immu))) + U <- rbind(ind_vec, U) + } Ut <- t(U) WUt <- W %*% Ut - P <- J - J %*% WUt %*% solve(U %*% WUt, U) + if (!is.null(immu)) { + Ua <- U + Ua[seq_len(sum(immu)),] <- 0 + P <- J - J %*% WUt %*% solve(U %*% WUt, Ua) + } else { + P <- J - J %*% WUt %*% solve(U %*% WUt, U) + } # P <- J - J%*%W%*%t(U)%*%solve(U%*%W%*%t(U))%*%U } else {