Skip to content

Commit

Permalink
Updated average precision calculation (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored Oct 25, 2017
1 parent dc6ec6c commit 96e366f
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ install:
- pip install tensorflow
- pip install .
script:
- flake8 rampwf --ignore=F401,E211,E265
- flake8 rampwf --ignore=F401,E211,E265,W503
- pytest -s -v --cov=rampwf rampwf
after_success:
- codecov
Expand Down
4 changes: 2 additions & 2 deletions rampwf/score_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from .roc_auc import ROCAUC
from .detection import (
OSPA, SCP, DetectionPrecision, DetectionRecall, MADCenter, MADRadius,
AverageDetectionPrecision)
AverageDetectionPrecision, DetectionAveragePrecision)

__all__ = [
'Accuracy',
'AverageDetectionPrecision',
'BalancedAccuracy',
'BrierScore',
'BrierScoreReliability',
Expand All @@ -30,6 +29,7 @@
'Combined',
'DetectionPrecision',
'DetectionRecall',
'DetectionAveragePrecision',
'F1Above',
'MacroAveragedRecall',
'MakeCombined',
Expand Down
3 changes: 2 additions & 1 deletion rampwf/score_types/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .ospa import OSPA
from .precision_recall import (
DetectionPrecision, DetectionRecall, MADCenter, MADRadius)
from .average_precision import AverageDetectionPrecision
from .average_precision import (
AverageDetectionPrecision, DetectionAveragePrecision)
131 changes: 121 additions & 10 deletions rampwf/score_types/detection/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from .base import BaseScoreType
from .util import _filter_y_pred
from .iou import cc_iou
from .precision_recall import precision, recall


class AverageDetectionPrecision(BaseScoreType):
class DetectionAveragePrecision(BaseScoreType):
is_lower_the_better = False
minimum = 0.0
maximum = 1.0
Expand All @@ -23,14 +24,12 @@ def __init__(self, name=_name, precision=3, iou_threshold=0.5):
self.iou_threshold = iou_threshold

def __call__(self, y_true, y_pred):
y_pred_conf = [detected_object[0]
for single_detection in y_pred
for detected_object in single_detection]
min_conf, max_conf = np.min(y_pred_conf), np.max(y_pred_conf)
conf_thresholds = np.linspace(min_conf, max_conf, 20)
ps, rs = precision_recall_curve(y_true, y_pred, conf_thresholds,
iou_threshold=self.iou_threshold)
return average_precision_interpolated(ps, rs)
_, ps, rs = precision_recall_curve_greedy(
y_true, y_pred, iou_threshold=self.iou_threshold)
return average_precision_exact(ps, rs)


AverageDetectionPrecision = DetectionAveragePrecision


def precision_recall_curve(y_true, y_pred, conf_thresholds, iou_threshold=0.5):
Expand Down Expand Up @@ -67,6 +66,96 @@ def precision_recall_curve(y_true, y_pred, conf_thresholds, iou_threshold=0.5):
return np.array(ps), np.array(rs)


def _add_id(y):
"""
Helper function to flatten and add id column to list of lists.
Since the list of lists is flattened into a single array, empty lists
do not result in an entry in the final array.
"""
y_new = []

for i, y_patch in enumerate(y):
if len(y_patch):
tmp = np.asarray(y_patch)
tmp = np.insert(tmp, 0, i, axis=1)
y_new.append(tmp)

return np.vstack(y_new)


def precision_recall_curve_greedy(y_true, y_pred, iou_threshold=0.5):
"""
Calculate precision and recall incrementally based on the predictions
sorted by confidence (not calculating an exact optimal match for each
confidence level).
Parameters
----------
y_true : list of lists of (x, y, r) tuples
y_pred : list of lists of (conf, x, y, r) tuples
iou_threshold : float [0 - 1], default 0.5
Returns
-------
Three arrays:
confidence_values
The flattened and sorted confidence values of y_pred
precision
The corresponding precision values
recall
The corresponding recall values
"""
# flatten y_pred into single array and add column with img id
y_pred2 = _add_id(y_pred)

# Sorting predicted objects by decreasing confidence
y_pred2_sorted = y_pred2[np.argsort(y_pred2[:, 1])[::-1], :]

# array to store whether a match is observed or not for each prediction
res = np.empty(len(y_pred2_sorted), dtype='bool')

# object to keep track of matches: (img id, object index) pairs
matched = set([])
confs = []

for i, pred in enumerate(y_pred2_sorted):
patch_id, conf, row, col, rad = pred

y_patch = y_true[int(patch_id)]
n_true = len(y_patch)

if n_true > 0:
ious = np.empty(n_true)

for j in range(n_true):
ious[j] = cc_iou(y_patch[j], (row, col, rad))

i_max = np.argmax(ious)
if ((ious[i_max] > iou_threshold)
and (patch_id, i_max) not in matched):
res[i] = True
# add match identifier to set to later check
# we don't have a duplicate match
matched.add((patch_id, i_max))
else:
res[i] = False
else:
res[i] = False

confs.append(conf)

n_true_total = np.sum([len(x) for x in y_true])

recall = np.cumsum(res) / n_true_total
precision = np.cumsum(res) / np.arange(1, len(res) + 1)

return np.array(confs), precision, recall


def average_precision_interpolated(ps, rs):
"""
The Average Precision (AP) score.
Expand All @@ -78,7 +167,7 @@ def average_precision_interpolated(ps, rs):
TODO: they changed this in later:
http://homepages.inf.ed.ac.uk/ckiw/postscript/ijcv_voc09.pdf
https://stackoverflow.com/questions/36274638/map-metric-in-object-detection-and-computer-vision
https://stackoverflow.com/questions/36274638/map-metric-in-object-detection-and-computer-vision # noqa
Parameters
----------
Expand All @@ -104,3 +193,25 @@ def average_precision_interpolated(ps, rs):

ap = np.mean(p_at_r)
return ap


def average_precision_exact(ps, rs):
# from https://github.com/amdegroot/ssd.pytorch/blob/ce4c994db0ee11f82aabb4fdb3499dc970156db5/eval.py#L182-L213 # noqa

# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rs, [1.]))
mpre = np.concatenate(([0.], ps, [0.]))

# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]

# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])

return ap
60 changes: 36 additions & 24 deletions rampwf/score_types/tests/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

import pytest

# from rampwf.score_types import AverageDetectionPrecision
from rampwf.score_types.detection.ospa import ospa, ospa_single
from rampwf.score_types.detection.scp import scp_single
from rampwf.score_types.detection.ospa import ospa, ospa_single
from rampwf.score_types import DetectionAveragePrecision
from rampwf.score_types.detection.precision_recall import precision, recall
from rampwf.score_types.detection.precision_recall import mad_center
from rampwf.score_types.detection.precision_recall import mad_radius
from rampwf.score_types.detection.iou import cc_iou, cc_intersection
from rampwf.score_types.detection.scp import project_circle, circle_maps
from rampwf.score_types.detection.average_precision import (
precision_recall_curve_greedy)


x = [(1, 1, 1)]
Expand Down Expand Up @@ -138,28 +140,38 @@ def test_precision_recall():
assert math.isnan(mad_center(y_true, y_pred))


# def test_average_precision():
# ap = AverageDetectionPrecision()
# # perfect match
# y_true = [[(1, 1, 1), (3, 3, 1)]]
# y_pred = [[(1, 1, 1, 1), (1, 3, 3, 1)]]
# assert ap(y_true, y_pred) == 1

# # imperfect match
# y_true = [[(1, 1, 1), (3, 3, 1), (7, 7, 1), (9, 9, 1)]]
# y_pred = [[(1, 1, 1, 1), (1, 5, 5, 1)]]
# assert ap(y_true, y_pred) == pytest.approx(3. / 2 / 11, rel=1e-5)
# # would be 0.125 (1 / 8) exact method

# y_true = [[(1, 1, 1), (3, 3, 1), (7, 7, 1), (9, 9, 1)]]
# y_pred = [[(1, 1, 1.2, 1.2), (1, 3, 3, 1)]]
# assert ap(y_true, y_pred) == pytest.approx(6. / 11, rel=1e-5)
# # would be 0.5 with exact method

# # no match
# y_true = [[(1, 1, 1)]]
# y_pred = [[(1, 3, 3, 1)]]
# assert ap(y_true, y_pred) == 0
def test_average_precision():
ap = DetectionAveragePrecision()

# perfect match
y_true = [[(1, 1, 1), (3, 3, 1)]]
y_pred = [[(1, 1, 1, 1), (1, 3, 3, 1)]]
assert ap(y_true, y_pred) == 1

# imperfect match
y_true = [[(1, 1, 1), (3, 3, 1), (7, 7, 1), (9, 9, 1)]]
y_pred = [[(1, 1, 1, 1), (1, 5, 5, 1)]]
assert ap(y_true, y_pred) == 0.125

y_true = [[(1, 1, 1), (3, 3, 1), (7, 7, 1), (9, 9, 1)]]
y_pred = [[(1, 1, 1.2, 1.2), (1, 3, 3, 1)]]
assert ap(y_true, y_pred) == 0.5

# no match
y_true = [[(1, 1, 1)]]
y_pred = [[(1, 3, 3, 1)]]
assert ap(y_true, y_pred) == 0

# bigger example
y_true = [[(1, 1, 1), (3, 3, 1)], [(1, 1, 1), (3, 3, 1)]]
y_pred = [[(0.9, 1, 1, 1), (0.7, 5, 5, 1), (0.5, 8, 8, 1)],
[(0.8, 1, 1, 1), (0.6, 3, 3, 1), (0.4, 5, 5, 1)]]

conf, ps, rs = precision_recall_curve_greedy(y_true, y_pred)
assert conf.tolist() == [0.9, 0.8, 0.7, 0.6, 0.5, 0.4]
assert ps.tolist() == [1, 1, 2/3, 3/4, 3/5, 3/6] # noqa
assert rs.tolist() == [1/4, 2/4, 2/4, 3/4, 3/4, 3/4] # noqa
assert ap(y_true, y_pred) == 11 / 16 # 0.5 * 1 + 0.25 * 3/4 + 0.25 * 0


# # test circles
Expand Down

0 comments on commit 96e366f

Please sign in to comment.