Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix keep_samp_for_vS with iterative approach #417

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions R/compute_vS.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ compute_MCint <- function(dt, pred_cols = "p_hat") {
#' @keywords internal
append_vS_list <- function(vS_list, internal) {
iter <- length(internal$iter_list)
keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS

# Adds v_S output above to any vS_list already computed
if (iter > 1) {
Expand All @@ -249,17 +250,53 @@ append_vS_list <- function(vS_list, internal) {
prev_vS_list_new <- list()

# Applies the mapper to update the prev_vS_list ot the new id_coalition numbering
for (k in seq_along(prev_vS_list)) {
prev_vS_list_new[[k]] <- merge(prev_vS_list[[k]],
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)
prev_vS_list_new[[k]][, id_coalition := id_coalition_new]
prev_vS_list_new[[k]][, id_coalition_new := NULL]
if (isFALSE(keep_samp_for_vS)) {
for (k in seq_along(prev_vS_list)) {
this_vS <- prev_vS_list[[k]]

this_vS_new <- merge(this_vS,
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)

this_vS_new[, id_coalition := id_coalition_new]
this_vS_new[, id_coalition_new := NULL]


prev_vS_list_new[[k]] <- this_vS_new
}
} else {
for (k in seq_along(prev_vS_list)) {
this_vS <- prev_vS_list[[k]]$dt_vS
this_samp_for_vS <- prev_vS_list[[k]]$dt_samp_for_vS


this_vS_new <- merge(this_vS,
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)

this_vS_new[, id_coalition := id_coalition_new]
this_vS_new[, id_coalition_new := NULL]

this_samp_for_vS_new <- merge(this_samp_for_vS,
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)

this_samp_for_vS_new[, id_coalition := id_coalition_new]
this_samp_for_vS_new[, id_coalition_new := NULL]


prev_vS_list_new[[k]] <- list(dt_vS = this_vS_new, dt_samp_for_vS = this_samp_for_vS_new)
}
}
names(prev_vS_list_new) <- names(prev_vS_list)

# Merge the new vS_list with the old vS_list
vS_list <- c(prev_vS_list_new, vS_list)
}


return(vS_list)
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -982,3 +982,38 @@
2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714
3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978

# output_lm_numeric_independence_keep_samp_for_vS

Code
(out <- code)
Message
Success with message:
max_n_coalitions is NULL or larger than or 2^n_features = 32,
and is therefore set to 2^n_features = 32.

* Model class: <lm>
* Approach: independence
* Iterative estimation: TRUE
* Number of feature-wise Shapley values: 5
* Number of observations to explain: 3

-- iterative computation started --

-- Iteration 1 -----------------------------------------------------------------
i Using 5 of 32 coalitions, 5 new.

-- Iteration 2 -----------------------------------------------------------------
i Using 10 of 32 coalitions, 4 new.

-- Iteration 3 -----------------------------------------------------------------
i Using 12 of 32 coalitions, 2 new.

-- Iteration 4 -----------------------------------------------------------------
i Using 16 of 32 coalitions, 4 new.
Output
explain_id none Solar.R Wind Temp Month Day
<int> <num> <num> <num> <num> <num> <num>
1: 1 42.44 -4.541 8.330 17.491 -5.585 -3.093
2: 2 42.44 2.246 -3.285 -5.258 -5.585 -1.997
3: 3 42.44 3.704 -18.549 -1.467 -2.545 1.289

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,23 @@ test_that("output_verbose_1_3_4_5", {
"output_verbose_1_3_4_5"
)
})


# Just checking that internal$output$dt_samp_for_vS works for iterative
test_that("output_lm_numeric_independence_keep_samp_for_vS", {
expect_snapshot_rds(
(out <- explain(
testing = TRUE,
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = "independence",
phi0 = p0,
output_args = list(keep_samp_for_vS = TRUE),
iterative = TRUE
)),
"output_lm_numeric_independence_keep_samp_for_vS"
)

expect_false(is.null(out$internal$output$dt_samp_for_vS))
})
Loading