Skip to content

Commit

Permalink
sv bugfix in tree explainer (#167)
Browse files Browse the repository at this point in the history
* add SV to TreeSHAP-IQ available indices

* adds warning and proper test
  • Loading branch information
mmschlk authored Jun 5, 2024
1 parent 20d23d3 commit 5295a28
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
5 changes: 5 additions & 0 deletions shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
computing any-order Shapley Interactions for tree ensembles."""

import copy
import warnings
from typing import Any, Optional, Union

import numpy as np
Expand Down Expand Up @@ -48,6 +49,10 @@ def __init__(

super().__init__(model)

if index == "SV" and max_order > 1:
warnings.warn("For index='SV' the max_order is set to 1.")
max_order = 1

# validate and parse model
validated_model = validate_tree_model(model, class_label=class_label)
self._trees: list[TreeModel] = copy.deepcopy(validated_model)
Expand Down
8 changes: 4 additions & 4 deletions shapiq/explainer/tree/treeshapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.Ns_id_store: dict = {}
self.Ns_store: dict = {}
self.n_interpolation_size = self._n_features_in_tree
if self._index in ("SII", "k-SII"): # SP is of order at most d_max
if self._index in ("SV", "SII", "k-SII"): # SP is of order at most d_max
self.n_interpolation_size = min(self._edge_tree.max_depth, self._n_features_in_tree)
self._init_summary_polynomials()

Expand Down Expand Up @@ -308,7 +308,7 @@ def _compute_shapley_interaction_values(
self._int_height[node_id][interaction_sets] == order
]
if len(interactions_seen) > 0:
if self._index not in ("SII", "k-SII"): # for CII
if self._index not in ("SV", "SII", "k-SII"): # for CII
D_power = self.D_powers[self._n_features_in_tree - current_height]
index_quotient = self._n_features_in_tree - order
else: # for SII and k-SII
Expand Down Expand Up @@ -343,7 +343,7 @@ def _compute_shapley_interaction_values(
ancestor_heights = self._edge_tree.edge_heights[
interactions_ancestors[cond_interaction_seen]
]
if self._index not in ("SII", "k-SII"): # for CII
if self._index not in ("SV", "SII", "k-SII"): # for CII
D_power = self.D_powers[self._n_features_in_tree - current_height]
index_quotient = self._n_features_in_tree - order
else: # for SII and k-SII
Expand Down Expand Up @@ -405,7 +405,7 @@ def _init_summary_polynomials(self):
self.subset_ancestors_store[order] = subset_ancestors
self.D_store[order] = np.polynomial.chebyshev.chebpts2(self.n_interpolation_size)
self.D_powers_store[order] = self._cache(self.D_store[order])
if self._index in ("SII", "k-SII"):
if self._index in ("SV", "SII", "k-SII"):
self.Ns_store[order] = self._get_N(self.D_store[order])
else:
self.Ns_store[order] = self._get_N_cii(self.D_store[order], order)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_against_shap_implementation():
values=values,
)

explainer = TreeExplainer(model=tree_model, max_order=1, min_order=1, index="SII")
explainer = TreeExplainer(model=tree_model, max_order=1, min_order=1, index="SV")
explanation = explainer.explain(x_explain)

assert explanation[(0,)] == pytest.approx(-0.09263158, abs=1e-4)
Expand All @@ -114,5 +114,10 @@ def test_against_shap_implementation():
explainer = TreeExplainer(model=tree_model, max_order=1, min_order=1, index="SII")
explanation = explainer.explain(x_explain)

explainer = TreeExplainer(model=tree_model, max_order=1, min_order=1, index="SII")
explanation = explainer.explain(x_explain)
assert explanation[(0,)] == pytest.approx(-0.09263158, abs=1e-4)
assert explanation[(1,)] == pytest.approx(-0.12100478, abs=1e-4)
assert explanation[(2,)] == pytest.approx(0.02727273, abs=1e-4)
assert explanation[(3,)] == pytest.approx(0.0, abs=1e-4)

with pytest.warns(UserWarning):
_ = TreeExplainer(model=tree_model, max_order=2, min_order=1, index="SV")
10 changes: 6 additions & 4 deletions tests/tests_games/test_treeshapiq_xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_adult():
"""Test the AdultCensus TreeSHAP-IQ explanation game."""
max_order = 2
min_order = 1
index = "SII"
index = "k-SII"

# benchmark game
game = AdultCensusTreeSHAPIQXAI(
Expand All @@ -107,11 +107,13 @@ def test_adult():
assert np.isclose(exact_values[interaction], gt_interaction_values[interaction])


def test_california():
@pytest.mark.parametrize("index_order", [("k-SII", 2), ("SV", 1)])
def test_california(index_order):
"""Test the CaliforniaHousing TreeSHAP-IQ explanation game."""
max_order = 2
max_order = int(index_order[1])
index = str(index_order[0])

min_order = 1
index = "SII"

# benchmark game
game = CaliforniaHousingTreeSHAPIQXAI(
Expand Down

0 comments on commit 5295a28

Please sign in to comment.