Skip to content

Commit

Permalink
Bug in explain_forecast() (#425)
Browse files Browse the repository at this point in the history
  • Loading branch information
LHBO authored Dec 17, 2024
1 parent a7efa3e commit db81ed7
Show file tree
Hide file tree
Showing 101 changed files with 90 additions and 72 deletions.
14 changes: 10 additions & 4 deletions R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@ check_convergence <- function(internal) {
paired_shap_sampling <- internal$parameters$paired_shap_sampling
n_shapley_values <- internal$parameters$n_shapley_values

n_sampled_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions
exact <- internal$iter_list[[iter]]$exact

shap_names <- internal$parameters$shap_names
shap_names_with_none <- c("none", shap_names)

dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd

n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions
if (!all.equal(names(dt_shapley_est), names(dt_shapley_sd))) {
stop("The column names of the dt_shapley_est and dt_shapley_df are not equal.")
}

max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = shap_names_with_none, by = .I]$V1 # Max per prediction
max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate

dt_shapley_est0 <- copy(dt_shapley_est)
Expand All @@ -33,8 +39,8 @@ check_convergence <- function(internal) {
} else {
converged_exact <- FALSE
if (!is.null(convergence_tol)) {
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = shap_names, by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = shap_names, by = .I]
dt_shapley_est0[, max_sd0 := max_sd0]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tol))^2]
dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))]
Expand Down
2 changes: 1 addition & 1 deletion R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
dt_vS_this <- dt_vS[, dt_cols, with = FALSE]
result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed)
}
result <- rbindlist(result, fill = TRUE)
result <- cbind(internal$parameters$output_labels, rbindlist(result, fill = TRUE))
} else {
X <- internal$iter_list[[iter]]$X
n_shapley_values <- internal$parameters$n_shapley_values
Expand Down
3 changes: 2 additions & 1 deletion R/prepare_next_iteration.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ prepare_next_iteration <- function(internal) {

est_remaining_coalitions <- internal$iter_list[[iter]]$est_remaining_coalitions
n_coal_next_iter_factor <- internal$iter_list[[iter]]$n_coal_next_iter_factor
current_n_coalitions <- internal$iter_list[[iter]]$n_coalitions
current_n_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions + 2 # Used instead of n_coalitions to
# deal with forecast special case
current_coal_samples <- internal$iter_list[[iter]]$coal_samples

if (is.null(fixed_n_coalitions_per_iter)) {
Expand Down
9 changes: 4 additions & 5 deletions R/print_iter.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ print_iter <- function(internal) {
}

if ("shapley" %in% verbose) {
n_explain <- internal$parameters$n_explain

dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est[, -1]
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd[, -1]
shap_names_with_none <- c("none", internal$parameters$shap_names)
dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est[, shap_names_with_none, with = FALSE]
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd[, shap_names_with_none, with = FALSE]

# Printing the current Shapley values
matrix1 <- format(round(dt_shapley_est, 3), nsmall = 2, justify = "right")
Expand All @@ -99,7 +98,7 @@ print_iter <- function(internal) {
print_dt <- as.data.table(matrix1)
} else {
msg <- paste0(msg, "estimated Shapley values (sd)")
print_dt <- as.data.table(matrix(paste(matrix1, " (", matrix2, ") ", sep = ""), nrow = n_explain))
print_dt <- as.data.table(matrix(paste(matrix1, " (", matrix2, ") ", sep = ""), nrow = nrow(matrix1)))
}

cli::cli_h3(msg)
Expand Down
42 changes: 20 additions & 22 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ get_extra_parameters <- function(internal, type) {
internal$parameters$n_groups <- length(group)
internal$parameters$group_names <- names(group)
internal$parameters$group <- group
internal$parameters$n_shapley_values <- internal$parameters$n_groups

if (type == "forecast") {
if (internal$parameters$group_lags) {
Expand All @@ -543,8 +542,9 @@ get_extra_parameters <- function(internal, type) {
internal$parameters$n_groups <- NULL
internal$parameters$group_names <- NULL
internal$parameters$shap_names <- internal$parameters$feature_names
internal$parameters$n_shapley_values <- internal$parameters$n_features
}
internal$parameters$n_shapley_values <- length(internal$parameters$shap_names)


# Get the number of unique approaches
internal$parameters$n_approaches <- length(internal$parameters$approach)
Expand Down Expand Up @@ -898,36 +898,36 @@ adjust_max_n_coalitions <- function(internal) {
}
} else { # group wise
# Set max_n_coalitions to upper bound
if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_groups) {
max_n_coalitions <- 2^n_groups
if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_shapley_values) {
max_n_coalitions <- 2^n_shapley_values
message(
paste0(
"Success with message:\n",
"max_n_coalitions is NULL or larger than or 2^n_groups = ", 2^n_groups, ", \n",
"and is therefore set to 2^n_groups = ", 2^n_groups, ".\n"
"max_n_coalitions is NULL or larger than or 2^n_groups = ", 2^n_shapley_values, ", \n",
"and is therefore set to 2^n_groups = ", 2^n_shapley_values, ".\n"
)
)
}
# Set max_n_coalitions to lower bound
if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_groups + 1)) {
if (n_groups <= 3) {
max_n_coalitions <- 2^n_groups
if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_shapley_values + 1)) {
if (n_shapley_values <= 3) {
max_n_coalitions <- 2^n_shapley_values
message(
paste0(
"Success with message:\n",
"n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (", 2^n_groups, ") ",
"that we should use all to get reliable results.\n",
"max_n_coalitions is therefore set to 2^n_groups = ", 2^n_groups, ".\n"
"n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (",
2^n_shapley_values, ") that we should use all to get reliable results.\n",
"max_n_coalitions is therefore set to 2^n_groups = ", 2^n_shapley_values, ".\n"
)
)
} else {
max_n_coalitions <- min(10, n_groups + 1)
max_n_coalitions <- min(10, n_shapley_values + 1)
message(
paste0(
"Success with message:\n",
"max_n_coalitions is smaller than max(10, n_groups + 1 = ", n_groups + 1, "),",
"max_n_coalitions is smaller than max(10, n_groups + 1 = ", n_shapley_values + 1, "),",
"which will result in unreliable results.\n",
"It is therefore set to ", max(10, n_groups + 1), ".\n"
"It is therefore set to ", max(10, n_shapley_values + 1), ".\n"
)
)
}
Expand All @@ -943,6 +943,7 @@ check_max_n_coalitions_fc <- function(internal) {
max_n_coalitions <- internal$parameters$max_n_coalitions
n_features <- internal$parameters$n_features
n_groups <- internal$parameters$n_groups
n_shapley_values <- internal$parameters$n_shapley_values

type <- internal$parameters$type

Expand All @@ -953,7 +954,7 @@ check_max_n_coalitions_fc <- function(internal) {
xreg <- internal$data$xreg

if (!is_groupwise) {
if (max_n_coalitions <= n_features) {
if (max_n_coalitions <= n_shapley_values) {
stop(paste0(
"`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ",
"components to decompose the forecast onto:\n",
Expand All @@ -962,7 +963,7 @@ check_max_n_coalitions_fc <- function(internal) {
))
}
} else {
if (max_n_coalitions <= n_groups) {
if (max_n_coalitions <= n_shapley_values) {
stop(paste0(
"`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ",
"components to decompose the forecast onto:\n",
Expand Down Expand Up @@ -1168,18 +1169,15 @@ check_and_set_iterative <- function(internal) {

set_exact <- function(internal) {
max_n_coalitions <- internal$parameters$max_n_coalitions
n_features <- internal$parameters$n_features
n_groups <- internal$parameters$n_groups
is_groupwise <- internal$parameters$is_groupwise
n_shapley_values <- internal$parameters$n_shapley_values
iterative <- internal$parameters$iterative
asymmetric <- internal$parameters$asymmetric
max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal

if (isFALSE(iterative) &&
(
(isTRUE(asymmetric) && max_n_coalitions == max_n_coalitions_causal) ||
(isFALSE(is_groupwise) && max_n_coalitions == 2^n_features) ||
(isTRUE(is_groupwise) && max_n_coalitions == 2^n_groups)
(max_n_coalitions == 2^n_shapley_values)
)
) {
exact <- TRUE
Expand Down
9 changes: 8 additions & 1 deletion R/shapley_setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ shapley_setup <- function(internal) {
internal$parameters$exact <- TRUE # Since this means that all coalitions have been sampled
}

# Updating n_coalitions in the end based on what is actually used. I don't think this is necessary now. TODO: Check.
# Updating n_coalitions in the end based on what is actually used.
internal$iter_list[[iter]]$n_coalitions <- nrow(S)
# The number of sampled coalitions to be used for convergence detection only (exclude the zero and full prediction)
internal$iter_list[[iter]]$n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2


# This will be obsolete later
internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed
Expand Down Expand Up @@ -758,6 +761,10 @@ shapley_setup_forecast <- function(internal) {

internal$iter_list[[iter]]$n_coalitions <- nrow(S) # Updating this parameter in the end based on what is used.

# The number of sampled coalitions *per horizon* to be used for convergence detection only
# Exclude the zero and full prediction
internal$iter_list[[iter]]$n_sampled_coalitions <- length(unique(id_coalition_mapper_dt$horizon_id_coalition)) - 2

# This will be obsolete later
internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed
# instead of storing it
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
71 changes: 39 additions & 32 deletions tests/testthat/_snaps/forecast-output.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,27 +88,30 @@
i Using 10 of 512 coalitions, 10 new.
-- Iteration 2 -----------------------------------------------------------------
i Using 30 of 512 coalitions, 4 new.
i Using 60 of 512 coalitions, 50 new.
-- Iteration 3 -----------------------------------------------------------------
i Using 78 of 512 coalitions, 6 new.
i Using 106 of 512 coalitions, 46 new.
-- Iteration 4 -----------------------------------------------------------------
i Using 150 of 512 coalitions, 44 new.
Output
explain_idx horizon none Temp.1 Temp.2 Temp.3 Wind.1 Wind.2 Wind.3
<int> <int> <num> <num> <num> <num> <num> <num> <num>
1: 149 1 77.88 -2.795 -4.5597 -1.114 1.564 -1.8995 0.2087
2: 150 1 77.88 4.024 -0.5774 -4.589 -2.234 0.1985 -2.2827
3: 149 2 77.88 -3.701 -4.2427 -1.326 1.465 -1.9227 0.7060
4: 150 2 77.88 3.460 -0.9158 -5.264 -2.452 0.7709 -1.7864
5: 149 3 77.88 -4.721 -3.4208 -1.503 1.172 -0.4564 -0.6058
6: 150 3 77.88 2.811 0.4206 -5.361 -1.388 0.0752 -0.2130
explain_idx horizon none Temp.1 Temp.2 Temp.3 Wind.1 Wind.2 Wind.3
<int> <int> <num> <num> <num> <num> <num> <num> <num>
1: 149 1 77.88 -3.335 -4.2630 -1.527 1.7674 -1.6361 0.3304
2: 150 1 77.88 3.767 -0.4812 -4.734 -2.0593 0.8002 -2.4860
3: 149 2 77.88 -2.925 -4.0802 -1.061 0.7282 -2.1425 1.3892
4: 150 2 77.88 3.304 -0.8942 -5.255 -2.3629 1.1470 -2.1038
5: 149 3 77.88 -4.167 -4.7628 -1.615 1.2049 -0.8727 1.4791
6: 150 3 77.88 2.777 -0.7697 -5.938 -0.9178 0.5417 -0.9851
Wind.F1 Wind.F2 Wind.F3
<num> <num> <num>
1: -1.9118 NA NA
2: -0.1747 NA NA
3: -1.1883 -0.6744 NA
4: 0.7128 1.9982 NA
5: -1.5436 -0.5418 2.8952
6: -0.6202 -0.8545 0.4549
1: -1.8441 NA NA
2: -0.4417 NA NA
3: -2.1499 -0.6431 NA
4: 1.0132 1.6761 NA
5: -0.7669 -0.2837 1.05906
6: 0.3650 0.2094 0.04183

# forecast_output_arima_numeric_iterative_groups

Expand All @@ -118,31 +121,35 @@
Note: Feature names extracted from the model contains NA.
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is NULL or larger than or 2^n_groups = 16,
and is therefore set to 2^n_groups = 16.
* Model class: <Arima>
* Approach: empirical
* Iterative estimation: TRUE
* Number of group-wise Shapley values: 10
* Number of group-wise Shapley values: 4
* Number of observations to explain: 2
-- iterative computation started --
-- Iteration 1 -----------------------------------------------------------------
i Using 10 of 1024 coalitions, 10 new.
i Using 10 of 16 coalitions, 10 new.
-- Iteration 2 -----------------------------------------------------------------
i Using 28 of 1024 coalitions, 2 new.
i Using 12 of 16 coalitions, 2 new.
-- Iteration 3 -----------------------------------------------------------------
i Using 56 of 1024 coalitions, 12 new.
i Using 14 of 16 coalitions, 2 new.
Output
explain_idx horizon none Temp Wind Solar.R Ozone
<int> <int> <num> <num> <num> <num> <num>
1: 149 1 77.88 -4.680 -3.6712 0.3230 -1.253
2: 150 1 77.88 -2.487 -3.6317 1.8415 -0.891
3: 149 2 77.88 -6.032 -4.1973 2.5973 -2.402
4: 150 2 77.88 -3.124 0.1986 0.8258 -2.245
5: 149 3 77.88 -7.777 1.1382 0.6962 -3.267
6: 150 3 77.88 -3.142 -1.6674 2.9047 -2.024
explain_idx horizon none Temp Wind Solar.R Ozone
<int> <int> <num> <num> <num> <num> <num>
1: 149 1 77.88 -3.896 -4.2285 -0.3807 -0.7759
2: 150 1 77.88 -2.011 -3.9476 1.4200 -0.6295
3: 149 2 77.88 -6.503 -4.5272 2.9701 -1.9733
4: 150 2 77.88 -3.574 -0.2358 1.2984 -1.8324
5: 149 3 77.88 -7.544 0.9077 0.9121 -3.4847
6: 150 3 77.88 -2.887 -1.9034 3.1385 -2.2767

# forecast_output_arima_numeric_no_xreg

Expand Down Expand Up @@ -184,18 +191,18 @@
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is NULL or larger than or 2^n_groups = 16,
and is therefore set to 2^n_groups = 16.
max_n_coalitions is NULL or larger than or 2^n_groups = 4,
and is therefore set to 2^n_groups = 4.
* Model class: <forecast_ARIMA/ARIMA/Arima>
* Approach: empirical
* Iterative estimation: FALSE
* Number of group-wise Shapley values: 4
* Number of group-wise Shapley values: 2
* Number of observations to explain: 2
-- Main computation started --
i Using 16 of 16 coalitions.
i Using 4 of 4 coalitions.
Output
explain_idx horizon none Temp Wind
<int> <int> <num> <num> <num>
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
12 changes: 6 additions & 6 deletions tests/testthat/_snaps/forecast-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is NULL or larger than or 2^n_groups = 16,
and is therefore set to 2^n_groups = 16.
max_n_coalitions is NULL or larger than or 2^n_groups = 4,
and is therefore set to 2^n_groups = 4.
Condition
Error in `get_predict_model()`:
Expand Down Expand Up @@ -124,18 +124,18 @@
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is smaller than max(10, n_groups + 1 = 5),which will result in unreliable results.
It is therefore set to 10.
n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (4) that we should use all to get reliable results.
max_n_coalitions is therefore set to 2^n_groups = 4.
* Model class: <Arima>
* Approach: independence
* Iterative estimation: FALSE
* Number of group-wise Shapley values: 4
* Number of group-wise Shapley values: 2
* Number of observations to explain: 2
-- Main computation started --
i Using 5 of 16 coalitions.
i Using 4 of 4 coalitions.
Output
explain_idx horizon none Temp Wind
<int> <int> <num> <num> <num>
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/testthat/_snaps/iterative-output/output_verbose_1.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/iterative-output/output_verbose_1_3.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/iterative-output/output_verbose_1_3_4.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/iterative-output/output_verbose_1_3_4_5.rds
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_mixed_comb.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_mixed_ctree.rds
Binary file not shown.
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_mixed_vaeac.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_numeric_comb1.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_numeric_comb2.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_numeric_comb3.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_numeric_copula.rds
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_numeric_ctree.rds
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/testthat/_snaps/regular-output/output_lm_numeric_vaeac.rds
Binary file not shown.
Binary file not shown.

0 comments on commit db81ed7

Please sign in to comment.