Skip to content

Commit

Permalink
add max order 1 to tabular explainer (#168)
Browse files Browse the repository at this point in the history
* add max order 1 to tabular explainer

* changed test name

---------

Co-authored-by: Maximilian <[email protected]>
  • Loading branch information
hbaniecki and mmschlk authored Jun 5, 2024
1 parent 5295a28 commit 08bb0db
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 19 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## Changelog

### v1.0.1 (2024-06-05)

- add `max_order=1` to `TabularExplainer`
-

### v1.0.0 (2024-06-04)

Expand Down
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
include pyproject.toml
include README.md
include CHANGELOG.md
4 changes: 2 additions & 2 deletions docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ Algorithms
- Yu et al. `Linear tree shap <https://doi.org/10.48550/arXiv.2209.08192>`_. NeurIPS 2022
- Harris et al. `Joint Shapley values: a measure of joint feature importance <https://doi.org/10.48550/arXiv.2107.11357>`_. ICLR 2022
- Covert et al. `Improving KernelSHAP: Practical Shapley value estimation using linear regression <https://doi.org/10.48550/arXiv.2012.01536>`_. AISTATS 2021
- Sundararajan et al. `The Shapley Taylor Interaction Index <https://doi.org/10.48550/arXiv.1902.05622>`_. ICML 2020
- Sundararajan et al. `The Shapley Taylor interaction index <https://doi.org/10.48550/arXiv.1902.05622>`_. ICML 2020
- Lundberg et al. `From local explanations to global understanding with explainable AI for trees <https://doi.org/10.1038/s42256-019-0138-9>`_. NMI 2020
- Lundberg et al. `A unified approach to interpreting model predictions <https://doi.org/10.48550/arXiv.1705.07874>`_. NeurIPS 2017

Related software tools and benchmarks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

- Olsen et al. `A comparative study of methods for estimating model-agnostic Shapley value explanations <https://doi.org/10.1007/s10618-024-01016-z>`_. DAMI 2024
- Li et al. `M4: A unified xai benchmark for faithfulness evaluation of feature attribution methods across metrics, modalities and models <https://openreview.net/forum?id=6zcfrSz98y>`_. NeurIPS 2023
- Li et al. `M4: A unified XAI benchmark for faithfulness evaluation of feature attribution methods across metrics, modalities and models <https://openreview.net/forum?id=6zcfrSz98y>`_. NeurIPS 2023
- Jiang et al. `OpenDataVal: A unified benchmark for data valuation <https://doi.org/10.48550/arXiv.2306.10577>`_. NeurIPS 2023
- Hedström et al. `Quantus: An explainable AI toolkit for responsible evaluation of neural network explanations and beyond <https://www.jmlr.org/papers/v24/22-0142.html>`_. JMLR 2023
- Agarwal et al. `OpenXAI: Towards a transparent evaluation of model explanations <https://doi.org/10.48550/arXiv.2206.11104>`_. NeurIPS 2022
Expand Down
2 changes: 1 addition & 1 deletion shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
the well established Shapley value and its generalization to interaction.
"""

__version__ = "1.0.0"
__version__ = "1.0.0.9000"

# approximator classes
from .approximator import (
Expand Down
52 changes: 42 additions & 10 deletions shapiq/explainer/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,42 @@
SHAPIQ,
SVARMIQ,
InconsistentKernelSHAPIQ,
KernelSHAP,
KernelSHAPIQ,
PermutationSamplingSII,
PermutationSamplingSTII,
PermutationSamplingSV,
RegressionFSII,
UnbiasedKernelSHAP,
)
from shapiq.approximator._base import Approximator
from shapiq.explainer._base import Explainer
from shapiq.games.imputer import ConditionalImputer, MarginalImputer
from shapiq.interaction_values import InteractionValues

APPROXIMATOR_CONFIGURATIONS = {
"Regression": {
"regression": {
"SII": InconsistentKernelSHAPIQ,
"FSII": RegressionFSII,
"k-SII": InconsistentKernelSHAPIQ,
"SV": KernelSHAP,
},
"Permutation": {
"permutation": {
"SII": PermutationSamplingSII,
"STII": PermutationSamplingSTII,
"k-SII": PermutationSamplingSII,
"SV": PermutationSamplingSV,
},
"montecarlo": {
"SII": SHAPIQ,
"STII": SHAPIQ,
"FSII": SHAPIQ,
"k-SII": SHAPIQ,
"SV": UnbiasedKernelSHAP,
},
"ShapIQ": {"SII": SHAPIQ, "STII": SHAPIQ, "FSII": SHAPIQ, "k-SII": SHAPIQ},
}

AVAILABLE_INDICES = {"SII", "k-SII", "STII", "FSII"}
AVAILABLE_INDICES = {"SII", "k-SII", "STII", "FSII", "SV"}


class TabularExplainer(Explainer):
Expand All @@ -48,12 +59,13 @@ class TabularExplainer(Explainer):
data: A background dataset to be used for imputation.
imputer: Either an object of class Imputer or a string from ``["marginal", "conditional"]``.
Defaults to ``"marginal"``, which innitializes the default MarginalImputer.
approximator: An approximator to use for the explainer. Defaults to ``"auto"``, which will
approximator: An approximator object to use for the explainer. Defaults to ``"auto"``, which will
automatically choose the approximator based on the number of features and the number of
samples in the background data.
index: Type of Shapley interaction index to use. Must be one of ``"SII"`` (Shapley Interaction Index),
``"k-SII"`` (k-Shapley Interaction Index), ``"STII"`` (Shapley-Taylor Interaction Index), or
``"FSII"`` (Faithful Shapley Interaction Index). Defaults to ``"k-SII"``.
``"k-SII"`` (k-Shapley Interaction Index), ``"STII"`` (Shapley-Taylor Interaction Index),
``"FSII"`` (Faithful Shapley Interaction Index), or ``"SV"`` (Shapley Value) for ``max_order=1``.
Defaults to ``"k-SII"``.
max_order: The maximum interaction order to be computed. Defaults to ``2``.
random_state: The random state to initialize Imputer and Approximator with. Defaults to ``None``.
**kwargs: Additional keyword-only arguments passed to the imputer.
Expand All @@ -77,8 +89,6 @@ def __init__(
) -> None:
if index not in AVAILABLE_INDICES:
raise ValueError(f"Invalid index `{index}`. " f"Valid indices are {AVAILABLE_INDICES}.")
if max_order < 2:
raise ValueError("The maximum order must be at least 2.")

super().__init__(model, data)

Expand Down Expand Up @@ -145,10 +155,32 @@ def baseline_value(self) -> float:
def _init_approximator(
self, approximator: Union[Approximator, str], index: str, max_order: int
) -> Approximator:

if isinstance(approximator, Approximator): # if the approximator is already given
return approximator

if approximator == "auto":
if index == "FSII":
if max_order == 1:
if index != "SV":
warnings.warn(
"`max_order=1` but `index != 'SV'`, setting `index = 'SV'`. Using the KernelSHAP approximator."
)
self.index = "SV"
return KernelSHAP(
n=self._n_features,
random_state=self._random_state,
)
elif index == "SV":
if max_order != 1:
warnings.warn(
"`index='SV'` but `max_order != 1`, setting `max_order = 1`. Using the KernelSHAP approximator."
)
self._max_order = 1
return KernelSHAP(
n=self._n_features,
random_state=self._random_state,
)
elif index == "FSII":
return RegressionFSII(
n=self._n_features,
max_order=max_order,
Expand Down
26 changes: 20 additions & 6 deletions tests/tests_explainer/test_explainer_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def data():
INDICES = ["SII", "k-SII", "STII", "FSII"]
MAX_ORDERS = [2, 3]
IMPUTER = ["marginal", "conditional"]
APPROXIMATOR = ["regression", "montecarlo", "permutation"]


@pytest.mark.parametrize("index", INDICES)
Expand Down Expand Up @@ -72,25 +73,27 @@ def test_auto_params(dt_model, data):
assert explainer._approximator.__class__.__name__ == "KernelSHAPIQ"


def test_init_params_error(dt_model, data):
def test_init_params_error_and_warning(dt_model, data):
"""Test the initialization of the interaction explainer."""
model_function = dt_model.predict
with pytest.raises(ValueError):
TabularExplainer(model=model_function, data=data, index="invalid", max_order=0)
with pytest.warns():
TabularExplainer(
model=model_function,
data=data,
index="invalid",
max_order=1,
)
with pytest.raises(ValueError):
with pytest.warns():
TabularExplainer(
model=model_function,
data=data,
max_order=0,
index="SV",
)


def test_init_params_approx(dt_model, data):
"""Test the initialization of the interaction explainer."""
"""Test the initialization of the tabular explainer."""
model_function = dt_model.predict
with pytest.raises(ValueError):
TabularExplainer(
Expand All @@ -99,7 +102,7 @@ def test_init_params_approx(dt_model, data):
approximator="invalid",
)
explainer = TabularExplainer(
approximator="Regression",
approximator="regression",
index="FSII",
model=model_function,
data=data,
Expand All @@ -117,6 +120,17 @@ def test_init_params_approx(dt_model, data):
assert explainer._approximator == approximator


@pytest.mark.parametrize("approximator", APPROXIMATOR)
@pytest.mark.parametrize("max_order", MAX_ORDERS + [1])
def test_init_params_approx_params(dt_model, data, approximator, max_order):
"""Test the initialization of the tabular explainer."""
explainer = TabularExplainer(
approximator=approximator, model=dt_model, data=data, max_order=max_order
)
iv = explainer.explain(data[0])
assert iv.__class__.__name__ == "InteractionValues"


BUDGETS = [2**5, 2**8, None]


Expand Down

0 comments on commit 08bb0db

Please sign in to comment.