Skip to content

Commit

Permalink
updated experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed May 24, 2024
1 parent 184de48 commit 40f4967
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 134 deletions.
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
black
build
pre-commit
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ruff==0.3.5
scikit-image==0.22.0
scikit-learn==1.4.1.post1
scipy==1.13.0
tqdm==4.66.2
torch==2.2.2
torchvision==0.17.2
transformers==4.39.3
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# https://packaging.python.org/guides/single-sourcing-package-version/
def read(rel_path):
with codecs.open(os.path.join(work_directory, rel_path), "r") as fp:
with codecs.open(str(os.path.join(work_directory, rel_path)), "r") as fp:
return fp.read()


Expand Down
8 changes: 0 additions & 8 deletions shapiq/games/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@

# metrics
from .metrics import (
compute_kendall_tau,
compute_mae,
compute_mse,
compute_precision_at_k,
get_all_metrics,
)

Expand Down Expand Up @@ -93,10 +89,6 @@
"get_game_files",
# all metrics
"get_all_metrics",
"compute_mae",
"compute_mse",
"compute_kendall_tau",
"compute_precision_at_k",
# local_xai games
"LocalExplanation",
"AdultCensusLocalXAI",
Expand Down
69 changes: 36 additions & 33 deletions shapiq/games/benchmark/metrics.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
"""Metrics for evaluating the performance of interaction values."""

import copy
from typing import Optional

import numpy as np
from scipy.stats import kendalltau

from ...interaction_values import InteractionValues
from ...utils.sets import powerset
from ...utils.sets import count_interactions, powerset

__all__ = [
"compute_mse",
"compute_mae",
"compute_kendall_tau",
"compute_precision_at_k",
"get_all_metrics",
]
__all__ = ["get_all_metrics"]


def _remove_empty_value(interaction: InteractionValues) -> InteractionValues:
"""Manually sets the empty value to zero.
"""Manually sets the empty value (e.g. baseline value) to zero in the values array.
Args:
interaction: The interaction values to remove the empty value from.
Expand All @@ -27,41 +22,44 @@ def _remove_empty_value(interaction: InteractionValues) -> InteractionValues:
The interaction values without the empty value.
"""
try:
empty_index = interaction.interaction_lookup[()]
interaction.values[empty_index] = 0
return interaction
_ = interaction.interaction_lookup[()]
new_interaction = copy.deepcopy(interaction)
empty_index = new_interaction.interaction_lookup[()]
new_interaction.values[empty_index] = 0
return new_interaction
except KeyError:
return interaction


def compute_mse(ground_truth: InteractionValues, estimated: InteractionValues) -> float:
"""Compute the mean squared error between two interaction values.
Args:
ground_truth: The ground truth interaction values.
estimated: The estimated interaction values.
Returns:
The mean squared error between the ground truth and estimated interaction values.
"""
difference = ground_truth - estimated
diff_values = _remove_empty_value(difference).values
return float(np.mean(diff_values**2))

def compute_diff_metrics(ground_truth: InteractionValues, estimated: InteractionValues) -> dict:
"""Computes metrics via the difference between the ground truth and estimated interaction
values.
def compute_mae(ground_truth: InteractionValues, estimated: InteractionValues) -> float:
"""Compute the mean absolute error between two interaction values.
Computes the following metrics:
- Mean Squared Error (MSE)
- Mean Absolute Error (MAE)
- Sum of Squared Errors (SSE)
- Sum of Absolute Errors (SAE)
Args:
ground_truth: The ground truth interaction values.
estimated: The estimated interaction values.
Returns:
The mean absolute error between the ground truth and estimated interaction values.
The metrics between the ground truth and estimated interaction values.
"""
difference = ground_truth - estimated
diff_values = _remove_empty_value(difference).values
return float(np.mean(np.abs(diff_values)))
n_values = count_interactions(
ground_truth.n_players, ground_truth.max_order, ground_truth.min_order
)
metrics = {
"MSE": np.sum(diff_values**2) / n_values,
"MAE": np.sum(np.abs(diff_values)) / n_values,
"SSE": np.sum(diff_values**2),
"SAE": np.sum(np.abs(diff_values)),
}
return metrics


def compute_kendall_tau(
Expand Down Expand Up @@ -129,7 +127,7 @@ def get_all_metrics(
Args:
ground_truth: The ground truth interaction values.
estimated: The estimated interaction values.
order_indicator: The order indicator for the metrics. Defaults to None.
order_indicator: An optional order indicator to prepend to the metrics. Defaults to `None`.
Returns:
The metrics as a dictionary.
Expand All @@ -140,12 +138,17 @@ def get_all_metrics(
order_indicator += "_"

metrics = {
order_indicator + "MSE": compute_mse(ground_truth, estimated),
order_indicator + "MAE": compute_mae(ground_truth, estimated),
order_indicator + "Precision@10": compute_precision_at_k(ground_truth, estimated, k=10),
order_indicator + "Precision@5": compute_precision_at_k(ground_truth, estimated, k=5),
order_indicator + "KendallTau": compute_kendall_tau(ground_truth, estimated),
order_indicator + "KendallTau@10": compute_kendall_tau(ground_truth, estimated, k=10),
order_indicator + "KendallTau@50": compute_kendall_tau(ground_truth, estimated, k=50),
}

# get diff metrics
metrics_diff = compute_diff_metrics(ground_truth, estimated)
if order_indicator != "": # add the order indicator to the diff metrics
metrics_diff = {order_indicator + key: value for key, value in metrics_diff.items()}

metrics.update(metrics_diff)
return metrics
8 changes: 6 additions & 2 deletions shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class InteractionValues:
Attributes:
values: The interaction values of the model in vectorized form.
index: The interaction index estimated. Available indices are 'SII', 'kSII', 'STII', and
'FSII'.
index: The interaction index estimated. All available indices are defined in
`ALL_AVAILABLE_INDICES`.
max_order: The order of the approximation.
n_players: The number of players.
min_order: The minimum order of the approximation. Defaults to 0.
Expand All @@ -34,6 +34,10 @@ class InteractionValues:
'empty value' since it denotes the value of the empty coalition (empty set). If not
provided it is searched for in the values vector (raising an Error if not found).
Defaults to `None`.
Raises:
UserWarning: If the index is not a valid index as defined in `ALL_AVAILABLE_INDICES`.
TypeError: If the baseline value is not a number.
"""

values: np.ndarray
Expand Down
Loading

0 comments on commit 40f4967

Please sign in to comment.