From 3a81c0f075db33f44360820953c63a508cd9f11b Mon Sep 17 00:00:00 2001 From: Janek Ebbers Date: Tue, 9 Jul 2024 10:28:13 -0400 Subject: [PATCH] add/update bootstrapped metrics --- sed_scores_eval/base_modules/bootstrap.py | 16 +++---- sed_scores_eval/base_modules/scores.py | 4 +- sed_scores_eval/collar_based/__init__.py | 2 +- .../collar_based/precision_recall.py | 24 ++++------ .../intersection_based/__init__.py | 2 +- .../intersection_based/precision_recall.py | 41 +++++++++++++++- sed_scores_eval/intersection_based/psds.py | 8 ++-- sed_scores_eval/segment_based/__init__.py | 6 +-- .../segment_based/precision_recall.py | 48 ++++++++++++++++++- sed_scores_eval/segment_based/roc.py | 23 +++++++++ 10 files changed, 137 insertions(+), 37 deletions(-) diff --git a/sed_scores_eval/base_modules/bootstrap.py b/sed_scores_eval/base_modules/bootstrap.py index 1f29059..b7a3f3a 100644 --- a/sed_scores_eval/base_modules/bootstrap.py +++ b/sed_scores_eval/base_modules/bootstrap.py @@ -1,5 +1,5 @@ import numpy as np -from collections import defaultdict +from sed_scores_eval.base_modules.io import parse_ground_truth, parse_audio_durations from sed_scores_eval.utils import parallel @@ -38,10 +38,10 @@ def bootstrap_from_deltas( data_orig = {'deltas': deltas} data_samples = {'deltas': []} if ground_truth is not None: - data_orig['ground_truth'] = ground_truth + data_orig['ground_truth'] = parse_ground_truth(ground_truth) data_samples['ground_truth'] = [] if audio_durations is not None: - data_orig['audio_durations'] = audio_durations + data_orig['audio_durations'] = parse_audio_durations(audio_durations) data_samples['audio_durations'] = [] audio_ids_repeated = n_bootstrap_samples * audio_ids @@ -87,11 +87,11 @@ def confidence_interval(bootstrapped_outputs, confidence=.9, axis=None): } return mean_low_high - mean = np.mean(bootstrapped_outputs, axis=axis) - low = np.percentile( + mean = float(np.mean(bootstrapped_outputs, axis=axis)) + low = float(np.percentile( bootstrapped_outputs, ((1 - confidence) / 2) * 100, axis=axis, - ) - high = np.percentile( + )) + high = float(np.percentile( bootstrapped_outputs, (confidence + ((1 - confidence) / 2)) * 100, axis=axis, - ) + )) return mean, low, high diff --git a/sed_scores_eval/base_modules/scores.py b/sed_scores_eval/base_modules/scores.py index 2ada3d6..c53f59c 100644 --- a/sed_scores_eval/base_modules/scores.py +++ b/sed_scores_eval/base_modules/scores.py @@ -102,13 +102,13 @@ def validate_score_dataframe(scores, timestamps=None, event_classes=None): ) onset_times = scores['onset'].to_numpy() offset_times = scores['offset'].to_numpy() + timestamps_from_df = np.concatenate((onset_times, offset_times[-1:])) if (offset_times == onset_times).any(): - raise ValueError('Some frames have zero length.') + raise ValueError(f'Some frames have zero length: {timestamps_from_df}') if not (offset_times[:-1] == onset_times[1:]).all(): raise ValueError( f'onset times must match offset times of the previous frame.' ) - timestamps_from_df = np.concatenate((onset_times, offset_times[-1:])) if timestamps is not None and not np.allclose(timestamps_from_df, timestamps): raise ValueError( f'timestamps from file {timestamps_from_df} do not match provided timestamps {timestamps}.' diff --git a/sed_scores_eval/collar_based/__init__.py b/sed_scores_eval/collar_based/__init__.py index 98e715c..52b4979 100644 --- a/sed_scores_eval/collar_based/__init__.py +++ b/sed_scores_eval/collar_based/__init__.py @@ -1,4 +1,4 @@ from .intermediate_statistics import accumulated_intermediate_statistics, intermediate_statistics_deltas -from .precision_recall import precision_recall_curve, fscore_curve, fscore, bootstrapped_fscore, best_fscore +from .precision_recall import precision_recall_curve, fscore_curve, fscore, bootstrapped_fscore, best_fscore, bootstrapped_best_fscore from .error_rate import error_rate_curve, error_rate, best_error_rate from . import reference diff --git a/sed_scores_eval/collar_based/precision_recall.py b/sed_scores_eval/collar_based/precision_recall.py index 0102ee9..32538d2 100644 --- a/sed_scores_eval/collar_based/precision_recall.py +++ b/sed_scores_eval/collar_based/precision_recall.py @@ -1,4 +1,3 @@ -import numpy as np from sed_scores_eval.base_modules.io import parse_inputs from sed_scores_eval.base_modules.bootstrap import bootstrap from sed_scores_eval.base_modules.precision_recall import ( @@ -226,7 +225,8 @@ def bootstrapped_fscore( Returns: """ - scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + if scores is not None: + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) return bootstrap( fscore, scores=scores, deltas=deltas, deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, @@ -314,11 +314,10 @@ def best_fscore( ) -def bootstrapped_fscore_curve( +def bootstrapped_best_fscore( scores, ground_truth, *, deltas=None, - onset_collar, offset_collar, offset_collar_rate=0., - beta=1., time_decimals=6, - n_bootstrap_samples=100, num_jobs=1, + onset_collar, offset_collar, offset_collar_rate=0., beta=1., + time_decimals=6, n_bootstrap_samples=100, num_jobs=1, ): """ @@ -337,9 +336,10 @@ def bootstrapped_fscore_curve( Returns: """ - scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + if scores is not None: + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) return bootstrap( - fscore_curve, scores=scores, deltas=deltas, + best_fscore, scores=scores, deltas=deltas, deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, deltas_fn_kwargs=dict( ground_truth=ground_truth, @@ -352,11 +352,3 @@ def bootstrapped_fscore_curve( ), n_bootstrap_samples=n_bootstrap_samples, ) - - -def _recursive_get_item(stats, idx): - if isinstance(stats, dict): - return {key: _recursive_get_item(stats[key], idx) for key in stats} - if np.isscalar(stats): - return stats - return stats[idx] diff --git a/sed_scores_eval/intersection_based/__init__.py b/sed_scores_eval/intersection_based/__init__.py index e94c694..c4e2cb6 100644 --- a/sed_scores_eval/intersection_based/__init__.py +++ b/sed_scores_eval/intersection_based/__init__.py @@ -8,6 +8,6 @@ bootstrapped_postprocessing_independent_psds_from_postprocessed_scores, deltas_postprocessing, median_filter_independent_psds, bootstrapped_median_filter_independent_psds, ) -from .precision_recall import precision_recall_curve, fscore_curve, fscore, best_fscore, bootstrapped_fscore +from .precision_recall import precision_recall_curve, fscore_curve, fscore, best_fscore, bootstrapped_fscore, bootstrapped_best_fscore from .error_rate import error_rate_curve, error_rate, best_error_rate from . import reference diff --git a/sed_scores_eval/intersection_based/precision_recall.py b/sed_scores_eval/intersection_based/precision_recall.py index 32ce5cb..133b6b5 100644 --- a/sed_scores_eval/intersection_based/precision_recall.py +++ b/sed_scores_eval/intersection_based/precision_recall.py @@ -197,7 +197,8 @@ def bootstrapped_fscore( Returns: """ - scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + if scores is not None: + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) return bootstrap( fscore, scores=scores, deltas=deltas, deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, @@ -274,3 +275,41 @@ def best_fscore( intermediate_stats, beta=beta, min_precision=min_precision, min_recall=min_recall, ) + + +def bootstrapped_best_fscore( + scores, ground_truth, *, deltas=None, + dtc_threshold, gtc_threshold, beta=1., time_decimals=6, + n_bootstrap_samples=100, num_jobs=1, +): + """ + + Args: + scores: + ground_truth: + deltas: + dtc_threshold: + gtc_threshold: + beta: + time_decimals: + n_bootstrap_samples: + num_jobs: + + Returns: + + """ + if scores is not None: + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + return bootstrap( + best_fscore, scores=scores, deltas=deltas, + deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, + dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, + time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + beta=beta, + ), + n_bootstrap_samples=n_bootstrap_samples, + ) diff --git a/sed_scores_eval/intersection_based/psds.py b/sed_scores_eval/intersection_based/psds.py index 397b278..e3ce693 100644 --- a/sed_scores_eval/intersection_based/psds.py +++ b/sed_scores_eval/intersection_based/psds.py @@ -1,10 +1,9 @@ import numpy as np from scipy.interpolate import interp1d -from sed_scores_eval.base_modules.io import parse_inputs +from sed_scores_eval.base_modules.io import parse_inputs, parse_audio_durations, parse_ground_truth from sed_scores_eval.utils.array_ops import cummax, get_first_index_where from sed_scores_eval.base_modules.curves import xsort, staircase_auc from sed_scores_eval.base_modules.bootstrap import bootstrap -from sed_scores_eval.base_modules.io import parse_audio_durations from sed_scores_eval.intersection_based.intermediate_statistics import intermediate_statistics_deltas, accumulated_intermediate_statistics seconds_per_unit_of_time = { @@ -243,7 +242,8 @@ def bootstrapped_psds( Returns: """ - scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + if scores is not None: + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) return bootstrap( psds, scores=scores, deltas=deltas, deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, @@ -253,7 +253,7 @@ def bootstrapped_psds( cttc_threshold=cttc_threshold, time_decimals=time_decimals, ), eval_fn_kwargs=dict( - audio_durations=audio_durations,alpha_ct=alpha_ct, + audio_durations=audio_durations, alpha_ct=alpha_ct, alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, non_oracle=non_oracle, ), diff --git a/sed_scores_eval/segment_based/__init__.py b/sed_scores_eval/segment_based/__init__.py index ccd0da2..36a9e9d 100644 --- a/sed_scores_eval/segment_based/__init__.py +++ b/sed_scores_eval/segment_based/__init__.py @@ -1,9 +1,9 @@ -from .intermediate_statistics import accumulated_intermediate_statistics +from .intermediate_statistics import accumulated_intermediate_statistics, intermediate_statistics_deltas from .precision_recall import ( precision_recall_curve, - fscore_curve, fscore, best_fscore, + fscore_curve, fscore, bootstrapped_fscore, best_fscore, bootstrapped_best_fscore, average_precision, ) from .error_rate import error_rate_curve, error_rate, best_error_rate -from .roc import roc_curve, auroc +from .roc import roc_curve, auroc, bootstrapped_auroc from . import reference diff --git a/sed_scores_eval/segment_based/precision_recall.py b/sed_scores_eval/segment_based/precision_recall.py index c1e0d3e..5a89ebf 100644 --- a/sed_scores_eval/segment_based/precision_recall.py +++ b/sed_scores_eval/segment_based/precision_recall.py @@ -1,3 +1,4 @@ +from sed_scores_eval.base_modules.io import parse_inputs from sed_scores_eval.base_modules.precision_recall import ( precision_recall_curve_from_intermediate_statistics, fscore_curve_from_intermediate_statistics, @@ -5,7 +6,8 @@ best_fscore_from_intermediate_statistics, average_precision_from_intermediate_statistics, ) -from sed_scores_eval.segment_based.intermediate_statistics import accumulated_intermediate_statistics +from sed_scores_eval.segment_based.intermediate_statistics import accumulated_intermediate_statistics, intermediate_statistics_deltas +from sed_scores_eval.base_modules.bootstrap import bootstrap def precision_recall_curve( @@ -170,6 +172,28 @@ def fscore( ) +def bootstrapped_fscore( + scores, ground_truth, audio_durations, threshold, *, deltas=None, + segment_length=1., beta=1., time_decimals=6, n_bootstrap_samples=100, + num_jobs=1, +): + if scores is not None: + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + return bootstrap( + fscore, scores=scores, deltas=deltas, + deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, audio_durations=audio_durations, + segment_length=segment_length, + time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + threshold=threshold, beta=beta, + ), + n_bootstrap_samples=n_bootstrap_samples, + ) + + def best_fscore( scores, ground_truth, audio_durations, *, deltas=None, segment_length=1., min_precision=0., min_recall=0., beta=1., @@ -234,6 +258,28 @@ def best_fscore( ) +def bootstrapped_best_fscore( + scores, ground_truth, audio_durations, *, deltas=None, + segment_length=1., min_precision=0., min_recall=0., beta=1., + time_decimals=6, n_bootstrap_samples=100, num_jobs=1, +): + if scores is not None: + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + return bootstrap( + best_fscore, scores=scores, deltas=deltas, + deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, audio_durations=audio_durations, + segment_length=segment_length, + time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + beta=beta, min_precision=min_precision, min_recall=min_recall, + ), + n_bootstrap_samples=n_bootstrap_samples, + ) + + def average_precision( scores, ground_truth, audio_durations, *, deltas=None, segment_length, time_decimals=6, num_jobs=1, diff --git a/sed_scores_eval/segment_based/roc.py b/sed_scores_eval/segment_based/roc.py index 453a6e4..01f8785 100644 --- a/sed_scores_eval/segment_based/roc.py +++ b/sed_scores_eval/segment_based/roc.py @@ -108,3 +108,26 @@ def auroc( return auroc_from_intermediate_statistics( intermediate_stats, max_fpr=max_fpr, mcclish_correction=mcclish_correction ) + + +def bootstrapped_auroc( + scores, ground_truth, audio_durations, *, deltas=None, + segment_length, max_fpr=None, time_decimals=6, + n_bootstrap_samples=100, num_jobs=1, +): + if scores is not None: + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + return bootstrap( + auroc, scores=scores, deltas=deltas, + deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, + audio_durations=audio_durations, + segment_length=segment_length, + time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + max_fpr=max_fpr, + ), + n_bootstrap_samples=n_bootstrap_samples, + )