Skip to content

Commit

Permalink
Set up horizon_group for when users want to set groups manually in ex…
Browse files Browse the repository at this point in the history
…plain_forecast (#433)
  • Loading branch information
jonlachmann authored Jan 22, 2025
1 parent dfb572c commit b21462f
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 14 deletions.
4 changes: 2 additions & 2 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100) {
for (i in seq_along(X_list)) {
X <- X_list[[i]]
if (is_groupwise) {
n_shapley_values <- length(internal$data$shap_names)
shap_names <- internal$data$shap_names
n_shapley_values <- internal$parameters$n_shapley_values
shap_names <- internal$parameters$shap_names
} else {
n_shapley_values <- length(internal$parameters$horizon_features[[i]])
shap_names <- internal$parameters$horizon_features[[i]]
Expand Down
31 changes: 31 additions & 0 deletions R/explain_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,34 @@ reg_forecast_setup <- function(x, horizon, group) {
colnames(fcast) <- names
return(list(fcast = fcast, group = group, horizon_group = horizon_group))
}

#' Set up user provided groups for explanation in a forecast model.
#'
#' @param group The list of groups to be explained.
#' @param horizon_features A list of features per horizon, to split appropriate groups over.
#'
#' @return A list containing
#' - group The list group with entries that differ per horizon split accordingly.
#' - horizon_group A list of which groups are applicable per horizon.
#' @keywords internal
group_forecast_setup <- function(group, horizon_features) {
horizon_group <- vector("list", length(horizon_features))
new_group <- list()

for (i in seq_along(group)) {
if (!all(group[[i]] %in% horizon_features[[1]])) {
for (h in seq_along(horizon_group)) {
new_name <- paste0(names(group)[i], ".", h)
new_group[[new_name]] <- group[[i]][group[[i]] %in% horizon_features[[h]]]
horizon_group[[h]] <- c(horizon_group[[h]], new_name)
}
} else {
name <- names(group)[i]
new_group[[name]] <- group[[i]]
for (h in seq_along(horizon_group)) {
horizon_group[[h]] <- c(horizon_group[[h]], name)
}
}
}
return(list(group = new_group, horizon_group = horizon_group))
}
23 changes: 11 additions & 12 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ get_parameters <- function(approach,
stop("`n_MC_samples` must be a single positive integer.")
}


# type
if (!(type %in% c("regular", "forecast"))) {
stop("`type` must be either `regular` or `forecast`.\n")
Expand Down Expand Up @@ -471,13 +470,20 @@ compare_feature_specs <- function(spec1, spec2, name1 = "model", name2 = "x_trai
#' @keywords internal
get_extra_parameters <- function(internal, type) {
if (type == "forecast") {
if (internal$parameters$group_lags) {
internal$parameters$group <- internal$data$group
}
internal$parameters$horizon_features <- lapply(
internal$data$horizon_group,
function(x) as.character(unlist(internal$data$group[x]))
)
if (internal$parameters$group_lags) {
internal$parameters$shap_names <- internal$data$shap_names
internal$parameters$group <- internal$data$group
internal$parameters$horizon_group <- internal$data$horizon_group
} else if (!is.null(internal$parameters$group)) {
internal$parameters$shap_names <- names(internal$parameters$group)
group_setup <- group_forecast_setup(internal$parameters$group, internal$parameters$horizon_features)
internal$parameters$group <- group_setup$group
internal$parameters$horizon_group <- group_setup$horizon_group
}
}

# get number of features and observations to explain
Expand Down Expand Up @@ -515,14 +521,7 @@ get_extra_parameters <- function(internal, type) {
internal$parameters$group_names <- names(group)
internal$parameters$group <- group

if (type == "forecast") {
if (internal$parameters$group_lags) {
internal$parameters$horizon_group <- internal$data$horizon_group
internal$parameters$shap_names <- internal$data$shap_names
} else {
internal$parameters$shap_names <- internal$parameters$group_names
}
} else {
if (type != "forecast") {
# For regular explain
internal$parameters$shap_names <- internal$parameters$group_names
}
Expand Down
24 changes: 24 additions & 0 deletions man/group_forecast_setup.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions tests/testthat/_snaps/forecast-output.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,61 @@
5: 149 3 77.88 -3.133 -3.133 -2.46
6: 150 3 77.88 -1.383 -1.383 -1.91

# forecast_output_forecast_ARIMA_manual_group_numeric

Code
(out <- code)
Message
Note: Feature names extracted from the model contains NA.
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is NULL or larger than or 2^n_groups = 4,
and is therefore set to 2^n_groups = 4.
* Model class: <forecast_ARIMA/ARIMA/Arima>
* Approach: empirical
* Iterative estimation: FALSE
* Number of group-wise Shapley values: 2
* Number of observations to explain: 2
-- Main computation started --
i Using 4 of 4 coalitions.
Output
explain_idx horizon none Temp Wind
<int> <int> <num> <num> <num>
1: 149 1 77.88 -5.3063 -5.201
2: 150 1 77.88 -1.4435 -4.192
3: 149 2 77.88 -3.6824 -7.202
4: 150 2 77.88 -0.2568 -3.220

# forecast_output_forecast_ARIMA_manual_group_numeric2

Code
(out <- code)
Message
Note: Feature names extracted from the model contains NA.
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is NULL or larger than or 2^n_groups = 4,
and is therefore set to 2^n_groups = 4.
* Model class: <forecast_ARIMA/ARIMA/Arima>
* Approach: empirical
* Iterative estimation: FALSE
* Number of group-wise Shapley values: 2
* Number of observations to explain: 2
-- Main computation started --
i Using 4 of 4 coalitions.
Output
explain_idx horizon none Group1 Group2
<int> <int> <num> <num> <num>
1: 149 1 77.88 -2.5593 -7.948
2: 150 1 77.88 -0.5681 -5.067
3: 149 2 77.88 -2.1223 -8.762
4: 150 2 77.88 0.7271 -4.203

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
50 changes: 50 additions & 0 deletions tests/testthat/test-forecast-output.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,56 @@ test_that("forecast_output_arima_numeric_no_lags", {
)
})

test_that("forecast_output_forecast_ARIMA_manual_group_numeric", {
expect_snapshot_rds(
explain_forecast(
testing = TRUE,
model = model_forecast_ARIMA_temp,
y = data_arima[1:150, "Temp"],
xreg = data_arima[, "Wind"],
train_idx = 2:148,
explain_idx = 149:150,
explain_y_lags = 2,
explain_xreg_lags = 2,
horizon = 2,
approach = "empirical",
phi0 = p0_ar[1:2],
group_lags = FALSE,
group = list(
Temp = c("Temp.1", "Temp.2"),
Wind = c("Wind.1", "Wind.2", "Wind.F1", "Wind.F2")
),
n_batches = 1
),
"forecast_output_forecast_ARIMA_manual_group_numeric"
)
})

test_that("forecast_output_forecast_ARIMA_manual_group_numeric2", {
expect_snapshot_rds(
explain_forecast(
testing = TRUE,
model = model_forecast_ARIMA_temp,
y = data_arima[1:150, "Temp"],
xreg = data_arima[, "Wind"],
train_idx = 2:148,
explain_idx = 149:150,
explain_y_lags = 2,
explain_xreg_lags = 2,
horizon = 2,
approach = "empirical",
phi0 = p0_ar[1:2],
group_lags = FALSE,
group = list(
Group1 = c("Wind.1", "Temp.1", "Wind.F2"),
Group2 = c("Wind.2", "Temp.2", "Wind.F1")
),
n_batches = 1
),
"forecast_output_forecast_ARIMA_manual_group_numeric2"
)
})

test_that("ARIMA gives the same output with different horizons", {
h3 <- explain_forecast(
testing = TRUE,
Expand Down

0 comments on commit b21462f

Please sign in to comment.