Skip to content

Commit

Permalink
add/update bootstrapped metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Janek Ebbers committed Jul 9, 2024
1 parent afcf533 commit 3a81c0f
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 37 deletions.
16 changes: 8 additions & 8 deletions sed_scores_eval/base_modules/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from collections import defaultdict
from sed_scores_eval.base_modules.io import parse_ground_truth, parse_audio_durations
from sed_scores_eval.utils import parallel


Expand Down Expand Up @@ -38,10 +38,10 @@ def bootstrap_from_deltas(
data_orig = {'deltas': deltas}
data_samples = {'deltas': []}
if ground_truth is not None:
data_orig['ground_truth'] = ground_truth
data_orig['ground_truth'] = parse_ground_truth(ground_truth)
data_samples['ground_truth'] = []
if audio_durations is not None:
data_orig['audio_durations'] = audio_durations
data_orig['audio_durations'] = parse_audio_durations(audio_durations)
data_samples['audio_durations'] = []

audio_ids_repeated = n_bootstrap_samples * audio_ids
Expand Down Expand Up @@ -87,11 +87,11 @@ def confidence_interval(bootstrapped_outputs, confidence=.9, axis=None):
}
return mean_low_high

mean = np.mean(bootstrapped_outputs, axis=axis)
low = np.percentile(
mean = float(np.mean(bootstrapped_outputs, axis=axis))
low = float(np.percentile(
bootstrapped_outputs, ((1 - confidence) / 2) * 100, axis=axis,
)
high = np.percentile(
))
high = float(np.percentile(
bootstrapped_outputs, (confidence + ((1 - confidence) / 2)) * 100, axis=axis,
)
))
return mean, low, high
4 changes: 2 additions & 2 deletions sed_scores_eval/base_modules/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ def validate_score_dataframe(scores, timestamps=None, event_classes=None):
)
onset_times = scores['onset'].to_numpy()
offset_times = scores['offset'].to_numpy()
timestamps_from_df = np.concatenate((onset_times, offset_times[-1:]))
if (offset_times == onset_times).any():
raise ValueError('Some frames have zero length.')
raise ValueError(f'Some frames have zero length: {timestamps_from_df}')
if not (offset_times[:-1] == onset_times[1:]).all():
raise ValueError(
f'onset times must match offset times of the previous frame.'
)
timestamps_from_df = np.concatenate((onset_times, offset_times[-1:]))
if timestamps is not None and not np.allclose(timestamps_from_df, timestamps):
raise ValueError(
f'timestamps from file {timestamps_from_df} do not match provided timestamps {timestamps}.'
Expand Down
2 changes: 1 addition & 1 deletion sed_scores_eval/collar_based/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .intermediate_statistics import accumulated_intermediate_statistics, intermediate_statistics_deltas
from .precision_recall import precision_recall_curve, fscore_curve, fscore, bootstrapped_fscore, best_fscore
from .precision_recall import precision_recall_curve, fscore_curve, fscore, bootstrapped_fscore, best_fscore, bootstrapped_best_fscore
from .error_rate import error_rate_curve, error_rate, best_error_rate
from . import reference
24 changes: 8 additions & 16 deletions sed_scores_eval/collar_based/precision_recall.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
from sed_scores_eval.base_modules.io import parse_inputs
from sed_scores_eval.base_modules.bootstrap import bootstrap
from sed_scores_eval.base_modules.precision_recall import (
Expand Down Expand Up @@ -226,7 +225,8 @@ def bootstrapped_fscore(
Returns:
"""
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
if scores is not None:
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
fscore, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
Expand Down Expand Up @@ -314,11 +314,10 @@ def best_fscore(
)


def bootstrapped_fscore_curve(
def bootstrapped_best_fscore(
scores, ground_truth, *, deltas=None,
onset_collar, offset_collar, offset_collar_rate=0.,
beta=1., time_decimals=6,
n_bootstrap_samples=100, num_jobs=1,
onset_collar, offset_collar, offset_collar_rate=0., beta=1.,
time_decimals=6, n_bootstrap_samples=100, num_jobs=1,
):
"""
Expand All @@ -337,9 +336,10 @@ def bootstrapped_fscore_curve(
Returns:
"""
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
if scores is not None:
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
fscore_curve, scores=scores, deltas=deltas,
best_fscore, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
deltas_fn_kwargs=dict(
ground_truth=ground_truth,
Expand All @@ -352,11 +352,3 @@ def bootstrapped_fscore_curve(
),
n_bootstrap_samples=n_bootstrap_samples,
)


def _recursive_get_item(stats, idx):
if isinstance(stats, dict):
return {key: _recursive_get_item(stats[key], idx) for key in stats}
if np.isscalar(stats):
return stats
return stats[idx]
2 changes: 1 addition & 1 deletion sed_scores_eval/intersection_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
bootstrapped_postprocessing_independent_psds_from_postprocessed_scores, deltas_postprocessing,
median_filter_independent_psds, bootstrapped_median_filter_independent_psds,
)
from .precision_recall import precision_recall_curve, fscore_curve, fscore, best_fscore, bootstrapped_fscore
from .precision_recall import precision_recall_curve, fscore_curve, fscore, best_fscore, bootstrapped_fscore, bootstrapped_best_fscore
from .error_rate import error_rate_curve, error_rate, best_error_rate
from . import reference
41 changes: 40 additions & 1 deletion sed_scores_eval/intersection_based/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def bootstrapped_fscore(
Returns:
"""
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
if scores is not None:
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
fscore, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
Expand Down Expand Up @@ -274,3 +275,41 @@ def best_fscore(
intermediate_stats, beta=beta,
min_precision=min_precision, min_recall=min_recall,
)


def bootstrapped_best_fscore(
scores, ground_truth, *, deltas=None,
dtc_threshold, gtc_threshold, beta=1., time_decimals=6,
n_bootstrap_samples=100, num_jobs=1,
):
"""
Args:
scores:
ground_truth:
deltas:
dtc_threshold:
gtc_threshold:
beta:
time_decimals:
n_bootstrap_samples:
num_jobs:
Returns:
"""
if scores is not None:
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
best_fscore, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
deltas_fn_kwargs=dict(
ground_truth=ground_truth,
dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold,
time_decimals=time_decimals,
),
eval_fn_kwargs=dict(
beta=beta,
),
n_bootstrap_samples=n_bootstrap_samples,
)
8 changes: 4 additions & 4 deletions sed_scores_eval/intersection_based/psds.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
from scipy.interpolate import interp1d
from sed_scores_eval.base_modules.io import parse_inputs
from sed_scores_eval.base_modules.io import parse_inputs, parse_audio_durations, parse_ground_truth
from sed_scores_eval.utils.array_ops import cummax, get_first_index_where
from sed_scores_eval.base_modules.curves import xsort, staircase_auc
from sed_scores_eval.base_modules.bootstrap import bootstrap
from sed_scores_eval.base_modules.io import parse_audio_durations
from sed_scores_eval.intersection_based.intermediate_statistics import intermediate_statistics_deltas, accumulated_intermediate_statistics

seconds_per_unit_of_time = {
Expand Down Expand Up @@ -243,7 +242,8 @@ def bootstrapped_psds(
Returns:
"""
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
if scores is not None:
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
psds, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
Expand All @@ -253,7 +253,7 @@ def bootstrapped_psds(
cttc_threshold=cttc_threshold, time_decimals=time_decimals,
),
eval_fn_kwargs=dict(
audio_durations=audio_durations,alpha_ct=alpha_ct,
audio_durations=audio_durations, alpha_ct=alpha_ct,
alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr,
non_oracle=non_oracle,
),
Expand Down
6 changes: 3 additions & 3 deletions sed_scores_eval/segment_based/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .intermediate_statistics import accumulated_intermediate_statistics
from .intermediate_statistics import accumulated_intermediate_statistics, intermediate_statistics_deltas
from .precision_recall import (
precision_recall_curve,
fscore_curve, fscore, best_fscore,
fscore_curve, fscore, bootstrapped_fscore, best_fscore, bootstrapped_best_fscore,
average_precision,
)
from .error_rate import error_rate_curve, error_rate, best_error_rate
from .roc import roc_curve, auroc
from .roc import roc_curve, auroc, bootstrapped_auroc
from . import reference
48 changes: 47 additions & 1 deletion sed_scores_eval/segment_based/precision_recall.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from sed_scores_eval.base_modules.io import parse_inputs
from sed_scores_eval.base_modules.precision_recall import (
precision_recall_curve_from_intermediate_statistics,
fscore_curve_from_intermediate_statistics,
single_fscore_from_intermediate_statistics,
best_fscore_from_intermediate_statistics,
average_precision_from_intermediate_statistics,
)
from sed_scores_eval.segment_based.intermediate_statistics import accumulated_intermediate_statistics
from sed_scores_eval.segment_based.intermediate_statistics import accumulated_intermediate_statistics, intermediate_statistics_deltas
from sed_scores_eval.base_modules.bootstrap import bootstrap


def precision_recall_curve(
Expand Down Expand Up @@ -170,6 +172,28 @@ def fscore(
)


def bootstrapped_fscore(
scores, ground_truth, audio_durations, threshold, *, deltas=None,
segment_length=1., beta=1., time_decimals=6, n_bootstrap_samples=100,
num_jobs=1,
):
if scores is not None:
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
fscore, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
deltas_fn_kwargs=dict(
ground_truth=ground_truth, audio_durations=audio_durations,
segment_length=segment_length,
time_decimals=time_decimals,
),
eval_fn_kwargs=dict(
threshold=threshold, beta=beta,
),
n_bootstrap_samples=n_bootstrap_samples,
)


def best_fscore(
scores, ground_truth, audio_durations, *, deltas=None,
segment_length=1., min_precision=0., min_recall=0., beta=1.,
Expand Down Expand Up @@ -234,6 +258,28 @@ def best_fscore(
)


def bootstrapped_best_fscore(
scores, ground_truth, audio_durations, *, deltas=None,
segment_length=1., min_precision=0., min_recall=0., beta=1.,
time_decimals=6, n_bootstrap_samples=100, num_jobs=1,
):
if scores is not None:
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
best_fscore, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
deltas_fn_kwargs=dict(
ground_truth=ground_truth, audio_durations=audio_durations,
segment_length=segment_length,
time_decimals=time_decimals,
),
eval_fn_kwargs=dict(
beta=beta, min_precision=min_precision, min_recall=min_recall,
),
n_bootstrap_samples=n_bootstrap_samples,
)


def average_precision(
scores, ground_truth, audio_durations, *, deltas=None,
segment_length, time_decimals=6, num_jobs=1,
Expand Down
23 changes: 23 additions & 0 deletions sed_scores_eval/segment_based/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,26 @@ def auroc(
return auroc_from_intermediate_statistics(
intermediate_stats, max_fpr=max_fpr, mcclish_correction=mcclish_correction
)


def bootstrapped_auroc(
scores, ground_truth, audio_durations, *, deltas=None,
segment_length, max_fpr=None, time_decimals=6,
n_bootstrap_samples=100, num_jobs=1,
):
if scores is not None:
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
auroc, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
deltas_fn_kwargs=dict(
ground_truth=ground_truth,
audio_durations=audio_durations,
segment_length=segment_length,
time_decimals=time_decimals,
),
eval_fn_kwargs=dict(
max_fpr=max_fpr,
),
n_bootstrap_samples=n_bootstrap_samples,
)

0 comments on commit 3a81c0f

Please sign in to comment.