Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in explain_forecast() #425

Merged
merged 13 commits into from
Dec 17, 2024
Merged

Conversation

LHBO
Copy link
Collaborator

@LHBO LHBO commented Dec 11, 2024

In this PR, we fix 2 main bugs in explain_forecast():

A: Using verbose = "shapley", caused the function to crash.
B: Inconsistencies related to the number of coaltions used and shapley values computed with explain_forecast().

Bug A stemmed from :

  1. There was an inconsistency in that internal$iter_list[[iter]]$dt_shapley_est contained the index columns "explain_idx", "horizon" while internal$iter_list[[iter]]$dt_shapley_sd did not. We added the index columns to the latter.

  2. When creating a data.table to print with the iterative Shapley values estimates with standard deviations, the code did not take into consideration that the number of rows is not n_explain when doing forecast, but rather n_times x horizon. This is now fixed by using the number of rows in internal$iter_list[[iter]]$dt_shapley_est instead.

  3. To be consistent with regular Shapley value, we remove "explain_idx", "horizon" when printing the iterative Shapley value estimates and standard deviation. In regular Shapley we omit "explain_idx".
    Note. It might be a good thing to include "explain_idx" (and "horizon"), but then the print function needs to be rewritten to reflect that "explain_idx" (and "horizon") is not an estimate with a standard deviation.
    Note 2. There is an inconsistency between explain and explain_forecast, where the former use explain_id, while the latter use explain_idx.

Bug B was fixed by defining the number of coaltions to be the number of coalitions used per horizon. We also introduced a new parameter n_sampled_coalitions which for explain_forecast() is the number of coalitions per horizon -2 (zero and full prediction), and for regular explain() it is the number of coalitions. A number of additional minor fixed related to the number of shapley values computed with/without grouping was also added. As a side-effect this also fixed the incorrect verbose printout for explain_forecast().

PS.
print_dt <- as.data.table(matrix(paste(matrix1, " (", matrix2, ") ", sep = ""), nrow = nrow(matrix1)))
produce extra whitespace at the end. Should change to, e.g.,
print_dt <- as.data.table(matrix(paste0(matrix1, " (", matrix2, ")"), nrow = nrow(matrix1))).

…ich was present in shapley_est data.table. Added id and horizon column.

2. Updated the print to remove both explain_idx and horizon from the printouts for forecast. Maybe we want the horizons, but then we need to rewrite the print function as horizon does not have a standard deviation.

3. Fixed that created matrix did not take into consider that the number of rows is not equal to n_explain when forecasting, but rather n_explain times horizon.
@LHBO LHBO requested a review from martinju December 11, 2024 14:11
@LHBO
Copy link
Collaborator Author

LHBO commented Dec 11, 2024

To see that bug no longer occurs, run:

devtools::load_all()
explain_forecast(
 testing = TRUE,
 model = model_arima_temp,
 y = data_arima[1:150, "Temp"],
 xreg = data_arima[, "Wind"],
 train_idx = 3:148,
 explain_idx = 149:150,
 explain_y_lags = 3,
 explain_xreg_lags = 3,
 horizon = 3,
 approach = "empirical",
 phi0 = p0_ar,
 group_lags = FALSE,
 max_n_coalitions = 150,
 iterative = TRUE,
 iterative_args = list(initial_n_coalitions = 10),
 verbose = c("basic", "progress", "convergence", "shapley", "vS_details")
)

@LHBO
Copy link
Collaborator Author

LHBO commented Dec 13, 2024

I am confident that the test that fails now are actually correct, in the sense that the output of the old runs were incorrect.
As in the previous code in check_convergence.R, the forecast dt_shapley_sd did not contain the index columns and then max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 removed none column, while dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I] removed explain_idx and horizon but kept none, as dt_shapley_est had the index columns.

Meaning that @martinju can accept the changes to the test output and merge.

@martinju martinju changed the title Bug in explain_forecast() when verbose = "shapley" Bug in explain_forecast() Dec 16, 2024
@martinju martinju merged commit db81ed7 into master Dec 17, 2024
7 checks passed
@martinju martinju deleted the Lars/Fix_forecast_bug_Shapley_sd_dt branch December 17, 2024 07:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants