Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss value metric based on optimal performance definition #66

Merged
merged 40 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6c77820
add loss value metric
jteijema Aug 27, 2024
760125e
Update reqs
jteijema Aug 27, 2024
567ca37
Merge branch 'main' into loss-metric
jteijema Sep 4, 2024
11589a6
sort imports
jteijema Sep 4, 2024
8db2ce5
Merge branch 'main' into loss-metric
J535D165 Sep 26, 2024
cd1d84f
remove useless function
jteijema Oct 24, 2024
62c295f
Normalize loss function between worst and best
jteijema Oct 24, 2024
9e51c7b
Add loss tests
jteijema Oct 24, 2024
5fa3410
remove metrics import
jteijema Oct 24, 2024
18bb2e2
remove sklearn from deps
jteijema Oct 24, 2024
c6b18f4
Add line between imports and first func
jteijema Oct 24, 2024
0454e5d
Ruff!
jteijema Oct 24, 2024
dda0bcf
add breaks in loss value function
jteijema Oct 24, 2024
019cb58
Remove prints from loss test
jteijema Oct 24, 2024
7b69304
Add comments to algorithm for loss
jteijema Oct 24, 2024
9ccf0cd
Add new tests
jteijema Oct 30, 2024
71c6bdd
Update algorithm
jteijema Oct 30, 2024
66c3445
Remove leftover debugging message
jteijema Oct 30, 2024
a6527bf
Remove decimal args from tests
jteijema Oct 30, 2024
d4c5cbd
Add new value error for invalid set
jteijema Oct 31, 2024
868a3bf
Refactor loss tests
jteijema Oct 31, 2024
b4f17c4
Change api usage with cumsum
jteijema Oct 31, 2024
07653f3
Merge branch 'loss-metric' of https://github.com/jteijema/asreview-in…
jteijema Oct 31, 2024
ceaf9b3
Simplify denominator
jteijema Oct 31, 2024
ebdfd83
Return formula instead of deriving in code.
jteijema Oct 31, 2024
0e3116f
Simplify formula and update the comments
jteijema Oct 31, 2024
b042b5b
update metrics loss docstring
jteijema Oct 31, 2024
f75c499
change best to optimal
jteijema Oct 31, 2024
74a1e84
Update tests file
jteijema Oct 31, 2024
1621f8c
Refactor the docstring for loss function
jteijema Oct 31, 2024
85dfd97
Value error update
jteijema Oct 31, 2024
867c6bf
Add instance type check
jteijema Oct 31, 2024
8a3d4ef
Merge branch 'loss-metric' of https://github.com/jteijema/asreview-in…
jteijema Oct 31, 2024
0884240
Linter
jteijema Oct 31, 2024
da331ad
Add loss to readme
jteijema Oct 31, 2024
5dc96f9
Add loss to output metrics
jteijema Oct 31, 2024
017afba
Add to inline explanation
jteijema Oct 31, 2024
511503a
Update formulas in readme
jteijema Oct 31, 2024
cc9c569
Update explanation in readme
jteijema Oct 31, 2024
08b70cc
Change readme explanation
jteijema Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,44 @@ pinpoint hard-to-find papers. The ATD, on the other hand, measures performance
throughout the entire screening process, eliminating reliance on arbitrary
cut-off values, and can be used to compare different models.

### Loss
The Loss metric evaluates the performance of an active learning model by
quantifying how closely it approximates the ideal screening process. This
quantification is then normalized between the ideal curve and the worst possible
curve.

While metrics like WSS, Recall, and ERF evaluate the performance at specific
points on the recall curve, the Loss metric provides an overall measure of
performance.

To compute the loss, we start with three key concepts:

1. **Optimal AUC**: This is the area under a "perfect recall curve," where
relevant records are identified as early as possible. Mathematically, it is
computed as $Nx \times Ny - \frac{Ny \times (Ny - 1)}{2}$, where $Nx$ is the
total number of records, and $Ny$ is the number of relevant records.

2. **Worst AUC**: This represents the area under a worst-case recall curve,
where all relevant records appear at the end of the screening process. This
is calculated as $\frac{Ny \times (Ny + 1)}{2}$.

3. **Actual AUC**: This is the area under the recall curve produced by the model
during the screening process. It can be obtained by summing up the cumulative
recall values for the labeled records.

The normalized loss is calculated by taking the difference between the optimal
AUC and the actual AUC, divided by the difference between the optimal AUC and
the worst AUC.

$$\text{Normalized Loss} = \frac{Ny \times \left(Nx - \frac{Ny - 1}{2}\right) -
\sum \text{Cumulative Recall}}{Ny \times (Nx - Ny)}$$

The lower the loss, the closer the model is to the perfect recall curve,
indicating higher performance.

![Recall plot illustrating loss metric](https://github.com/jteijema/asreview-insights/blob/loss-metric/figures/loss_metric_example.png?raw=true)

In this figure, the green area between the recall curve and the perfect recall line is the lossed performance, which is then normalized for the total area (green and red combined).

## Basic usage

Expand Down Expand Up @@ -467,6 +504,11 @@ which results in
]
]
},
{
"id": "loss",
"title": "Loss",
"value": 0.01707543880041846
},
{
"id": "erf",
"title": "Extra Relevant record Found",
Expand Down
30 changes: 30 additions & 0 deletions asreviewcontrib/insights/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,36 @@ def _recall_values(labels, x_absolute=False, y_absolute=False):
return x.tolist(), y.tolist()


def _loss_value(labels):
Ny = sum(labels)
Nx = len(labels)

if Ny == 0 or Nx == Ny:
raise ValueError("Need both 0 and 1 labels")

# The normalized loss is computed based on:
#
# 1. The "optimal" possible AUC, representing the area under an optimal recall
# curve, is the total area, Nx * Ny, minus the area above the stepwise
# curve, (Ny * (Ny - 1)) / 2. Combined to Ny * (Nx - (Ny - 1)) / 2.
#
# 2. The "actual" AUC is the cumulative recall sum, calculated with
# np.cumsum(labels).sum().
#
# 3. The "worst" AUC, where all positive labels are clustered at the end, is
# calculated as (Ny * (Ny + 1)) / 2. To normalize, we need the difference
# between the optimal and worst AUCs. We simplify this difference:
#
# (Nx * Ny - ((Ny * (Ny - 1)) / 2)) - ((Ny * (Ny + 1)) / 2)
#
# This simplifies to the hyperbolic paraboloid Ny * (Nx - Ny), which is
# the denominator in our normalized loss.
#
# Finally, we compute the normalized loss as:
# (optimal - actual) / (optimal - worst).
return float((Ny * (Nx - (Ny - 1) / 2) - np.cumsum(labels).sum()) / (Ny * (Nx - Ny))) # noqa: E501


def _wss_values(labels, x_absolute=False, y_absolute=False):
n_docs = len(labels)
n_pos_docs = sum(labels)
Expand Down
20 changes: 20 additions & 0 deletions asreviewcontrib/insights/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from asreviewcontrib.insights.algorithms import _erf_values
from asreviewcontrib.insights.algorithms import _fn_values
from asreviewcontrib.insights.algorithms import _fp_values
from asreviewcontrib.insights.algorithms import _loss_value
from asreviewcontrib.insights.algorithms import _recall_values
from asreviewcontrib.insights.algorithms import _tn_values
from asreviewcontrib.insights.algorithms import _tp_values
Expand Down Expand Up @@ -169,6 +170,20 @@ def _tnr(labels, intercept, x_absolute=False):

return _slice_metric(x, y, intercept)

def loss(state_obj, priors=False):
"""Compute the loss for active learning problem.

Computes the loss for active learning problem where all relevant records
have to be seen by a human.

See the inline documentation for detailed description of loss calculation.

Returns:
float: The loss value.
"""
labels = _pad_simulation_labels(state_obj, priors=priors)

return _loss_value(labels)

def get_metrics(
state_obj,
Expand Down Expand Up @@ -225,6 +240,11 @@ def get_metrics(
"title": "Work Saved over Sampling",
"value": [(i, v) for i, v in zip(wss, wss_values)],
},
{
"id": "loss",
"title": "Loss",
"value": _loss_value(labels),
},
{
"id": "erf",
"title": "Extra Relevant record Found",
Expand Down
Binary file added figures/loss_metric_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
51 changes: 51 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from pathlib import Path

import numpy as np
from asreview import open_state
from numpy import array_equal
from numpy.testing import assert_almost_equal
from numpy.testing import assert_raises

from asreviewcontrib.insights.algorithms import _loss_value
from asreviewcontrib.insights.metrics import _erf
from asreviewcontrib.insights.metrics import _recall
from asreviewcontrib.insights.metrics import _time_to_discovery
from asreviewcontrib.insights.metrics import _wss
from asreviewcontrib.insights.metrics import get_metrics
from asreviewcontrib.insights.metrics import loss
from asreviewcontrib.insights.metrics import recall

TEST_ASREVIEW_FILES = Path(Path(__file__).parent, "asreview_files")



def test_metric_recall_small_data():
labels = [1, 1, 1, 0]
r = _recall(labels, 0.5)
Expand Down Expand Up @@ -111,3 +118,47 @@ def test_label_padding():
stop_if_full = get_metrics(s)

assert stop_if_min == stop_if_full

def test_loss():
with open_state(
Path(TEST_ASREVIEW_FILES, "sim_van_de_schoot_2017_stop_if_min.asreview")
) as s:
loss_value = loss(s)
assert_almost_equal(loss_value, 0.011592855205548452)

def test_loss_value_function(seed=None):
test_cases = [
([1, 0], 0),
([0, 1], 1),
([1, 1, 0, 0, 0], 0),
([0, 0, 0, 1, 1], 1),
([1, 0, 1], 0.5)
]

for labels, expected_value in test_cases:
loss_value = _loss_value(labels)
assert_almost_equal(loss_value, expected_value)

error_cases = [[0, 0, 0], [0], [1]]
for labels in error_cases:
with assert_raises(ValueError):
_loss_value(labels)

if seed is not None:
np.random.seed(seed)

for _ in range(100):
length = np.random.randint(2, 100)
labels = np.random.randint(0, 2, length)

# Ensure labels are not all 0 or all 1
if np.all(labels == 0) or np.all(labels == 1):
labels[np.random.randint(0, length)] = 1 - labels[0]

loss_value = _loss_value(labels)
assert 0 <= loss_value <= 1

def test_single_value_formats():
assert isinstance(_wss([1,1,0,0], 0.5), float)
assert isinstance(_loss_value([1,1,0,0]), float)
assert isinstance(_erf([1,1,0,0], 0.5), float)