Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss value metric based on optimal performance definition #66

Merged
merged 40 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6c77820
add loss value metric
jteijema Aug 27, 2024
760125e
Update reqs
jteijema Aug 27, 2024
567ca37
Merge branch 'main' into loss-metric
jteijema Sep 4, 2024
11589a6
sort imports
jteijema Sep 4, 2024
8db2ce5
Merge branch 'main' into loss-metric
J535D165 Sep 26, 2024
cd1d84f
remove useless function
jteijema Oct 24, 2024
62c295f
Normalize loss function between worst and best
jteijema Oct 24, 2024
9e51c7b
Add loss tests
jteijema Oct 24, 2024
5fa3410
remove metrics import
jteijema Oct 24, 2024
18bb2e2
remove sklearn from deps
jteijema Oct 24, 2024
c6b18f4
Add line between imports and first func
jteijema Oct 24, 2024
0454e5d
Ruff!
jteijema Oct 24, 2024
dda0bcf
add breaks in loss value function
jteijema Oct 24, 2024
019cb58
Remove prints from loss test
jteijema Oct 24, 2024
7b69304
Add comments to algorithm for loss
jteijema Oct 24, 2024
9ccf0cd
Add new tests
jteijema Oct 30, 2024
71c6bdd
Update algorithm
jteijema Oct 30, 2024
66c3445
Remove leftover debugging message
jteijema Oct 30, 2024
a6527bf
Remove decimal args from tests
jteijema Oct 30, 2024
d4c5cbd
Add new value error for invalid set
jteijema Oct 31, 2024
868a3bf
Refactor loss tests
jteijema Oct 31, 2024
b4f17c4
Change api usage with cumsum
jteijema Oct 31, 2024
07653f3
Merge branch 'loss-metric' of https://github.com/jteijema/asreview-in…
jteijema Oct 31, 2024
ceaf9b3
Simplify denominator
jteijema Oct 31, 2024
ebdfd83
Return formula instead of deriving in code.
jteijema Oct 31, 2024
0e3116f
Simplify formula and update the comments
jteijema Oct 31, 2024
b042b5b
update metrics loss docstring
jteijema Oct 31, 2024
f75c499
change best to optimal
jteijema Oct 31, 2024
74a1e84
Update tests file
jteijema Oct 31, 2024
1621f8c
Refactor the docstring for loss function
jteijema Oct 31, 2024
85dfd97
Value error update
jteijema Oct 31, 2024
867c6bf
Add instance type check
jteijema Oct 31, 2024
8a3d4ef
Merge branch 'loss-metric' of https://github.com/jteijema/asreview-in…
jteijema Oct 31, 2024
0884240
Linter
jteijema Oct 31, 2024
da331ad
Add loss to readme
jteijema Oct 31, 2024
5dc96f9
Add loss to output metrics
jteijema Oct 31, 2024
017afba
Add to inline explanation
jteijema Oct 31, 2024
511503a
Update formulas in readme
jteijema Oct 31, 2024
cc9c569
Update explanation in readme
jteijema Oct 31, 2024
08b70cc
Change readme explanation
jteijema Oct 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions asreviewcontrib/insights/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from sklearn import metrics


def _recall_values(labels, x_absolute=False, y_absolute=False):
Expand All @@ -21,11 +20,34 @@ def _recall_values(labels, x_absolute=False, y_absolute=False):


def _loss_value(labels):
positive_doc_ratio = sum(labels) / len(labels)
triangle_before_perfect_recall = positive_doc_ratio * 0.5
aera_under_recall_curve = metrics.auc(*_recall_values(labels))

return 1 - (triangle_before_perfect_recall + aera_under_recall_curve)
Ny = sum(labels)
Nx = len(labels)

# The best AUC represents the entire area under the perfect curve, which is
# the total area Nx * Ny, minus the area above the perfect curve (which is
# the sum of a series with a formula (Ny * Ny) / 2) plus 0.5 to account for
# the boundary.
best_auc = Nx * Ny - (((Ny * Ny) / 2) + 0.5)
jteijema marked this conversation as resolved.
Show resolved Hide resolved

# Compute recall values (y) based on the provided labels. We don't need x
# values because the points are uniformly spaced.
y = np.array(_recall_values(labels, x_absolute=True, y_absolute=True)[1])

# The actual AUC is calculated by approximating the area under the curve
# using the trapezoidal rule. (y[1:] + y[:-1]) / 2 takes the average height
# between consecutive y values, and we sum them up.
actual_auc = np.sum((y[1:] + y[:-1]) / 2)

# The worst AUC represents the area under the worst-case step curve, which
# is simply the area under the recall curve where all positive labels are
# clumped at the end, calculated as (Ny * Ny) / 2.
worst_auc = ((Ny * Ny) / 2)
jteijema marked this conversation as resolved.
Show resolved Hide resolved

# The normalized loss is the difference between the best AUC and the actual
# AUC, normalized by the range between the best and worst AUCs.
normalized_loss = (best_auc - actual_auc) / (best_auc - worst_auc) if best_auc != worst_auc else 0 # noqa: E501
jteijema marked this conversation as resolved.
Show resolved Hide resolved

return normalized_loss
jteijema marked this conversation as resolved.
Show resolved Hide resolved


def _wss_values(labels, x_absolute=False, y_absolute=False):
Expand Down
4 changes: 0 additions & 4 deletions asreviewcontrib/insights/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,8 @@ def loss(state_obj, priors=False):
"""
labels = _pad_simulation_labels(state_obj, priors=priors)

return _loss(labels)

def _loss(labels):
return _loss_value(labels)


def get_metrics(
state_obj,
recall=None,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
"Programming Language :: Python :: 3.11"
]
license = {text = "Apache-2.0"}
dependencies = ["numpy", "matplotlib", "asreview>=1,<2", "scikit-learn"]
dependencies = ["numpy", "matplotlib", "asreview>=1,<2"]
dynamic = ["version"]
requires-python = ">=3.7"

Expand Down
36 changes: 36 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from numpy import array_equal
from numpy.testing import assert_almost_equal

from asreviewcontrib.insights.algorithms import _loss_value
from asreviewcontrib.insights.metrics import _recall
from asreviewcontrib.insights.metrics import _time_to_discovery
from asreviewcontrib.insights.metrics import get_metrics
from asreviewcontrib.insights.metrics import loss
from asreviewcontrib.insights.metrics import recall

TEST_ASREVIEW_FILES = Path(Path(__file__).parent, "asreview_files")
Expand Down Expand Up @@ -111,3 +113,37 @@ def test_label_padding():
stop_if_full = get_metrics(s)

assert stop_if_min == stop_if_full

def test_loss():
with open_state(
Path(TEST_ASREVIEW_FILES, "sim_van_de_schoot_2017_stop_if_min.asreview")
) as s:
loss_value = loss(s)
assert_almost_equal(loss_value, 0.011590940352087164, decimal=6)

def test_loss_value_function():
labels = [1, 0]
loss_value = _loss_value(labels)
assert_almost_equal(loss_value, 0, decimal=6)

labels = [0, 1]
loss_value = _loss_value(labels)
assert_almost_equal(loss_value, 1, decimal=6)

labels = [1, 1, 0, 0, 0]
loss_value = _loss_value(labels)
assert_almost_equal(loss_value, 0, decimal=6)

labels = [0, 0, 0, 1, 1]
loss_value = _loss_value(labels)
assert_almost_equal(loss_value, 1, decimal=6)

import random
jteijema marked this conversation as resolved.
Show resolved Hide resolved
for i in range(100):
length = random.randint(2, 100)
labels = [random.randint(0, 1) for _ in range(length)]
jteijema marked this conversation as resolved.
Show resolved Hide resolved
loss_value = _loss_value(labels)
if not (0 <= loss_value <= 1):
print(f"Test {i+1}: Labels: {labels}, Loss: {loss_value}")
jteijema marked this conversation as resolved.
Show resolved Hide resolved
assert 0 <= loss_value <= 1, f"Loss value {loss_value} not between 0 and 1 for \
jteijema marked this conversation as resolved.
Show resolved Hide resolved
labels {labels}"