From 0f13f78813ebfce0ee6f0c8380fa1a3d9b36262e Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Tue, 2 Jan 2024 16:17:31 +0100 Subject: [PATCH] Return 0 instead of NA when heuristic fails --- NEWS.md | 9 ++-- R/potential_interactions.R | 26 ++++++----- README.md | 2 +- man/potential_interactions.Rd | 17 +++++--- tests/testthat/test-potential_interactions.R | 46 ++++++++++---------- 5 files changed, 56 insertions(+), 44 deletions(-) diff --git a/NEWS.md b/NEWS.md index 4704449..233e3cc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,10 +7,10 @@ If no SHAP interaction values are available, by default, the color feature `v'` is selected by the heuristic `potential_interaction()`, which works as follows: 1. If the feature `v` (the on the x-axis) is numeric, it is binned into `nbins` bins. -2. Per bin, the SHAP values of `v` are regressed onto `v` and the R-squared is calculated. -3. The R-squared are averaged over bins, weighted by the bin size. +2. Per bin, the SHAP values of `v` are regressed onto `v'` and the R-squared is calculated. Rows with missing `v'` are discarded. +3. The R-squared are averaged over bins, weighted by the number of non-missing `v'` values. -This measures how much variability in the SHAP values is explained by `v'`, after accounting for `v`. +This measures how much variability in the SHAP values of `v` is explained by `v'`, after accounting for `v`. We have introduced four parameters to control the heuristic. Their defaults are in line with the old behaviour. @@ -35,7 +35,8 @@ We will continue to experiment with the defaults, which might change in the futu ## Other user-visible changes -- `sv_dependence()`: If `color_var = "auto"` (default) and no color feature seems to be relevant (SHAP interaction is `NULL`, or heuristic returns no positive value), there won't be any color scale. +- `sv_dependence()`: If `color_var = "auto"` (default) and no color feature seems to be relevant (SHAP interaction is `NULL`, or heuristic returns no positive value), there won't be any color scale. Furthermore, in some edge cases, a different +color feature might be selected. - `mshapviz()` objects can now be rowbinded via `rbind()` or `+`. Implemented by [@jmaspons](https://github.com/jmaspons) in [#110](https://github.com/ModelOriented/shapviz/pull/110). - `mshapviz()` is more strict when combining multiple "shapviz" objects. These now need to have identical column names, see [#114](https://github.com/ModelOriented/shapviz/pull/114). diff --git a/R/potential_interactions.R b/R/potential_interactions.R index ca0072f..9b38806 100644 --- a/R/potential_interactions.R +++ b/R/potential_interactions.R @@ -1,18 +1,21 @@ #' Interaction Strength #' -#' Returns vector of interaction strengths between variable `v` and all other variables, -#' see Details. +#' Returns a vector of interaction strengths between variable `v` and all other +#' variables, see Details. #' #' If SHAP interaction values are available, the interaction strength #' between feature `v` and another feature `v'` is measured by twice their #' mean absolute SHAP interaction values. #' -#' Otherwise, we use a heuristic calculated as follows to calculate interaction strength -#' between `v` and each other "color" feature `v': +#' Otherwise, we use a heuristic calculated as follows: #' 1. If `v` is numeric, it is binned into `nbins` bins. #' 2. Per bin, the SHAP values of `v` are regressed onto `v`, and the R-squared -#' is calculated. -#' 3. The R-squared are averaged over bins, weighted by the bin size. +#' is calculated. Rows with missing `v'` are discarded. +#' 3. The R-squared are averaged over bins, weighted by the number of +#' non-missing `v'` values. +#' +#' This measures how much variability in the SHAP values of `v` is explained by `v'`, +#' after accounting for `v`. #' #' Set `scale = TRUE` to multiply the R-squared by the within-bin variance #' of the SHAP values. This will put higher weight to bins with larger scatter. @@ -22,6 +25,8 @@ #' #' Finally, set `adjusted = TRUE` to use *adjusted* R-squared. #' +#' The algorithm does not consider observations with missing `v'` values. +#' #' @param obj An object of class "shapviz". #' @param v Variable name to calculate potential SHAP interactions for. #' @param nbins Into how many quantile bins should a numeric `v` be binned? @@ -60,7 +65,7 @@ potential_interactions <- function(obj, v, nbins = NULL, color_num = TRUE, nbins <- ceiling(min(sqrt(nrow(X)), nrow(X) / 20)) } out <- vapply( - X[v_other], # data.frame is a list + X[v_other], FUN = heuristic, FUN.VALUE = 1.0, s = S[, v], @@ -108,7 +113,8 @@ heuristic <- function(color, s, bins, color_num, scale, adjusted) { #' #' @inheritParams heuristic #' @returns -#' A (1x2) matrix with heuristic and number of observations. +#' A (1x2) matrix with the heuristic and the number of observations with non-missing +#' `v'`. heuristic_in_bin <- function(color, s, scale = FALSE, adjusted = FALSE) { ok <- !is.na(color) color <- color[ok] @@ -116,7 +122,7 @@ heuristic_in_bin <- function(color, s, scale = FALSE, adjusted = FALSE) { n <- length(s) var_s <- stats::var(s) if (n < 2L || var_s < .Machine$double.eps || length(unique(color)) < 2L) { - return(cbind(stat = NA, n = n)) + return(cbind(stat = 0, n = n)) } z <- stats::lm(s ~ color) var_r <- sum(z$residuals^2) / (if (adjusted) z$df.residual else n - 1) @@ -125,7 +131,7 @@ heuristic_in_bin <- function(color, s, scale = FALSE, adjusted = FALSE) { stat <- stat * var_s } if (!is.finite(stat)) { - stat <- NA + stat <- 0 } cbind(stat = stat, n = n) } diff --git a/README.md b/README.md index 23dc44d..b77b3ea 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ dia_2000 <- diamonds[sample(nrow(diamonds), 2000), x] shp <- shapviz(fit, X_pred = data.matrix(dia_2000), X = dia_2000) sv_importance(shp, show_numbers = TRUE) -sv_dependence(shp, v = x)} +sv_dependence(shp, v = x) ``` ![](man/figures/README-imp.svg) diff --git a/man/potential_interactions.Rd b/man/potential_interactions.Rd index 5cd141c..8f71981 100644 --- a/man/potential_interactions.Rd +++ b/man/potential_interactions.Rd @@ -36,23 +36,26 @@ higher weight. The default is \code{FALSE}. Ignored if \code{obj} contains SHAP A named vector of decreasing interaction strengths. } \description{ -Returns vector of interaction strengths between variable \code{v} and all other variables, -see Details. +Returns a vector of interaction strengths between variable \code{v} and all other +variables, see Details. } \details{ If SHAP interaction values are available, the interaction strength between feature \code{v} and another feature \verb{v'} is measured by twice their mean absolute SHAP interaction values. -Otherwise, we use a heuristic calculated as follows to calculate interaction strength -between \code{v} and each other "color" feature `v': +Otherwise, we use a heuristic calculated as follows: \enumerate{ \item If \code{v} is numeric, it is binned into \code{nbins} bins. \item Per bin, the SHAP values of \code{v} are regressed onto \code{v}, and the R-squared -is calculated. -\item The R-squared are averaged over bins, weighted by the bin size. +is calculated. Rows with missing \verb{v'} are discarded. +\item The R-squared are averaged over bins, weighted by the number of +non-missing \verb{v'} values. } +This measures how much variability in the SHAP values of \code{v} is explained by \verb{v'}, +after accounting for \code{v}. + Set \code{scale = TRUE} to multiply the R-squared by the within-bin variance of the SHAP values. This will put higher weight to bins with larger scatter. @@ -60,6 +63,8 @@ Set \code{color_num = FALSE} to \emph{not} turn the values of the "color" featur to numeric. Finally, set \code{adjusted = TRUE} to use \emph{adjusted} R-squared. + +The algorithm does not consider observations with missing \verb{v'} values. } \seealso{ \code{\link[=sv_dependence]{sv_dependence()}} diff --git a/tests/testthat/test-potential_interactions.R b/tests/testthat/test-potential_interactions.R index 102039e..6052859 100644 --- a/tests/testthat/test-potential_interactions.R +++ b/tests/testthat/test-potential_interactions.R @@ -66,45 +66,45 @@ test_that("heuristic_in_bin() returns R-squared", { ) }) -test_that("Failing heuristic_in_bin() returns NA", { - expect_equal(heuristic_in_bin(c(NA, NA), 1:2), cbind(stat = NA, n = 0)) +test_that("Failing heuristic_in_bin() returns 0", { + expect_equal(heuristic_in_bin(c(NA, NA), 1:2), cbind(stat = 0, n = 0)) }) -test_that("heuristic_in_bin() returns NA for constant response", { +test_that("heuristic_in_bin() returns 0 for constant response", { expect_equal( heuristic_in_bin(color = 1:3, s = c(1, 1, 1)), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) expect_equal( heuristic_in_bin(color = 1:3, s = c(1, 1, 1), scale = TRUE), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) expect_equal( heuristic_in_bin(color = 1:3, s = c(1, 1, 1), adjust = TRUE), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) expect_equal( heuristic_in_bin(color = 1:3, s = c(1, 1, 1), adjust = TRUE, scale = TRUE), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) }) -test_that("heuristic_in_bin() returns NA for constant color", { +test_that("heuristic_in_bin() returns 0 for constant color", { expect_equal( heuristic_in_bin(s = 1:3, color = c(1, 1, 1)), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) expect_equal( heuristic_in_bin(s = 1:3, color = c(1, 1, 1), scale = TRUE), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) expect_equal( heuristic_in_bin(s = 1:3, color = c(1, 1, 1), adjust = TRUE), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) expect_equal( heuristic_in_bin(s = 1:3, color = c(1, 1, 1), adjust = TRUE, scale = TRUE), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) }) @@ -112,38 +112,38 @@ test_that("heuristic_in_bin() returns 0 if response and color are constant", { z <- c(1, 1) expect_equal( heuristic_in_bin(color = z, s = z), - cbind(stat = NA, n = 2L) + cbind(stat = 0, n = 2L) ) expect_equal( heuristic_in_bin(color = z, s = z, scale = TRUE), - cbind(stat = NA, n = 2L) + cbind(stat = 0, n = 2L) ) expect_equal( heuristic_in_bin(color = z, s = z, adjust = TRUE), - cbind(stat = NA, n = 2L) + cbind(stat = 0, n = 2L) ) expect_equal( heuristic_in_bin(color = z, s = z, adjust = TRUE, scale = TRUE), - cbind(stat = NA, n = 2L) + cbind(stat = 0, n = 2L) ) }) -test_that("heuristic_in_bin() returns NA for single obs", { +test_that("heuristic_in_bin() returns 0 for single obs", { expect_equal( heuristic_in_bin(color = 2, s = 2), - cbind(stat = NA, n = 1L) + cbind(stat = 0, n = 1L) ) expect_equal( heuristic_in_bin(color = 2, s = 2, scale = TRUE), - cbind(stat = NA, n = 1L) + cbind(stat = 0, n = 1L) ) expect_equal( heuristic_in_bin(color = 2, s = 2, adjust = TRUE), - cbind(stat = NA, n = 1L) + cbind(stat = 0, n = 1L) ) expect_equal( heuristic_in_bin(color = 2, s = 2, adjust = TRUE, scale = TRUE), - cbind(stat = NA, n = 1L) + cbind(stat = 0, n = 1L) ) }) @@ -159,11 +159,11 @@ test_that("heuristic_in_bin() returns NA for single obs", { ) expect_equal( heuristic_in_bin(color = cc, s = 1:3, adjust = TRUE), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) expect_equal( heuristic_in_bin(color = cc, s = 2*(1:3), adjust = TRUE, scale = TRUE), - cbind(stat = NA, n = 3L) + cbind(stat = 0, n = 3L) ) })