diff --git a/sed_scores_eval/base_modules/bootstrap.py b/sed_scores_eval/base_modules/bootstrap.py index 87de65e..1f29059 100644 --- a/sed_scores_eval/base_modules/bootstrap.py +++ b/sed_scores_eval/base_modules/bootstrap.py @@ -1,46 +1,83 @@ import numpy as np +from collections import defaultdict from sed_scores_eval.utils import parallel +def bootstrap( + eval_fn, scores=None, deltas=None, deltas_fn=None, + n_bootstrap_samples=100, num_jobs=1, + deltas_fn_kwargs=None, eval_fn_kwargs=None, +): + if deltas_fn_kwargs is None: + deltas_fn_kwargs = {} + if eval_fn_kwargs is None: + eval_fn_kwargs = {} + if deltas is None: + assert scores is not None + assert deltas_fn is not None + deltas = deltas_fn( + scores=scores, num_jobs=num_jobs, **deltas_fn_kwargs, + ) + return bootstrap_from_deltas( + eval_fn, deltas, + n_bootstrap_samples=n_bootstrap_samples, num_jobs=num_jobs, + scores=None, **deltas_fn_kwargs, **eval_fn_kwargs, + ) + + def bootstrap_from_deltas( - metric_fn, deltas, *, - n_folds=5, n_iterations=20, num_jobs=1, - **metric_fn_kwargs, + eval_fn, deltas, *, + n_bootstrap_samples=100, num_jobs=1, + ground_truth=None, audio_durations=None, + **eval_fn_kwargs, ): if isinstance(deltas, (list, tuple)): audio_ids = sorted(deltas[0].keys()) else: audio_ids = sorted(deltas.keys()) - split_indices = np.linspace(0, len(audio_ids), n_folds+1).astype(int) - audio_id_fractions = [] - for i in range(n_iterations): - np.random.RandomState(i).shuffle(audio_ids) - for j in range(n_folds): - audio_id_fractions.append(list( - audio_ids[:split_indices[j]] + audio_ids[split_indices[j+1]:] - )) + data_orig = {'deltas': deltas} + data_samples = {'deltas': []} + if ground_truth is not None: + data_orig['ground_truth'] = ground_truth + data_samples['ground_truth'] = [] + if audio_durations is not None: + data_orig['audio_durations'] = audio_durations + data_samples['audio_durations'] = [] - if isinstance(deltas, (list, tuple)): - deltas_fractions = [ - [ - {audio_id: delts[audio_id] for audio_id in audio_id_fraction} - for delts in deltas - ] - for audio_id_fraction in audio_id_fractions - ] - else: - deltas_fractions = [ - {audio_id: deltas[audio_id] for audio_id in audio_id_fraction} - for audio_id_fraction in audio_id_fractions - ] - return list(zip(*parallel.map( - deltas_fractions, arg_keys='deltas', - func=metric_fn, max_jobs=num_jobs, - **metric_fn_kwargs, - ))) - - -def confidence_interval(bootstrapped_outputs, confidence=.9): + audio_ids_repeated = n_bootstrap_samples * audio_ids + np.random.RandomState(0).shuffle(audio_ids_repeated) + for i in range(n_bootstrap_samples): + for key in data_samples: + data_samples[key].append({}) + for j, audio_id in enumerate(audio_ids_repeated[i*len(audio_ids):(i+1)*len(audio_ids)]): + for key in data_samples: + if isinstance(data_orig[key], (list, tuple)): + if isinstance(data_samples[key][-1], dict): + data_samples[key][-1] = [{} for _ in range(len(data_orig[key]))] + for k in range(len(data_orig[key])): + data_samples[key][-1][k][f'{audio_id}_bootstrap{i}_clip{j}'] = data_orig[key][k][audio_id] + else: + data_samples[key][-1][f'{audio_id}_bootstrap{i}_clip{j}'] = data_orig[key][audio_id] + + arg_keys = sorted(data_samples.keys()) + args = [data_samples[key] for key in arg_keys] + ret = parallel.map( + args, arg_keys=arg_keys, + func=eval_fn, max_jobs=num_jobs, + **eval_fn_kwargs, + ) + if isinstance(ret[0], tuple): + return list(zip(*ret)) + return ret + + +def _recursive_multiply(deltas, factor): + if isinstance(deltas, dict): + return {key: _recursive_multiply(deltas[key], factor) for key in deltas.keys()} + return deltas * factor + + +def confidence_interval(bootstrapped_outputs, confidence=.9, axis=None): if isinstance(bootstrapped_outputs[0], dict): mean_low_high = { class_name: confidence_interval([ @@ -50,12 +87,11 @@ def confidence_interval(bootstrapped_outputs, confidence=.9): } return mean_low_high - mean = np.mean(bootstrapped_outputs) + mean = np.mean(bootstrapped_outputs, axis=axis) low = np.percentile( - bootstrapped_outputs, ((1 - confidence) / 2) * 100 + bootstrapped_outputs, ((1 - confidence) / 2) * 100, axis=axis, ) high = np.percentile( - bootstrapped_outputs, (confidence + ((1 - confidence) / 2)) * 100 + bootstrapped_outputs, (confidence + ((1 - confidence) / 2)) * 100, axis=axis, ) - return float(mean), float(low), float(high) - + return mean, low, high diff --git a/sed_scores_eval/base_modules/io.py b/sed_scores_eval/base_modules/io.py index 6549350..9ea25a6 100644 --- a/sed_scores_eval/base_modules/io.py +++ b/sed_scores_eval/base_modules/io.py @@ -50,8 +50,7 @@ def parse_scores(scores): def parse_ground_truth( - ground_truth, *, - tagging=False, audio_ids=None, additional_ids_ok=False + ground_truth, *, tagging=False, audio_ids=None, ): if not isinstance(ground_truth, (dict, str, Path)): raise ValueError( @@ -66,15 +65,12 @@ def parse_ground_truth( ground_truth = read_ground_truth_events(ground_truth) if not tagging: assert_non_connected_events(ground_truth) - if audio_ids is not None: - if additional_ids_ok: - ground_truth = {key: ground_truth[key] for key in audio_ids} - elif not (ground_truth.keys() == set(audio_ids)): - raise ValueError( - f'ground_truth audio ids do not match audio_ids. ' - f'Missing ids: {set(audio_ids) - ground_truth.keys()}. ' - f'Additional ids: {ground_truth.keys() - set(audio_ids)}.' - ) + if audio_ids is not None and not (ground_truth.keys() == set(audio_ids)): + raise ValueError( + f'ground_truth audio ids do not match audio_ids. ' + f'Missing ids: {set(audio_ids) - ground_truth.keys()}. ' + f'Additional ids: {ground_truth.keys() - set(audio_ids)}.' + ) return ground_truth @@ -91,7 +87,7 @@ def assert_non_connected_events(ground_truth): current_offset = event[1] -def parse_audio_durations(audio_durations, *, audio_ids=None, additional_ids_ok=False): +def parse_audio_durations(audio_durations, *, audio_ids=None): if not isinstance(audio_durations, (dict, str, Path)): raise ValueError( f'audio_durations must be dict, str or Path but ' @@ -101,15 +97,12 @@ def parse_audio_durations(audio_durations, *, audio_ids=None, additional_ids_ok= audio_durations = Path(audio_durations) assert audio_durations.is_file(), audio_durations audio_durations = read_audio_durations(audio_durations) - if audio_ids is not None: - if additional_ids_ok: - audio_durations = {key: audio_durations[key] for key in audio_ids} - elif not (audio_durations.keys() == set(audio_ids)): - raise ValueError( - f'audio_durations audio ids do not match audio_ids. ' - f'Missing ids: {set(audio_ids) - audio_durations.keys()}. ' - f'Additional ids: {audio_durations.keys() - set(audio_ids)}.' - ) + if audio_ids is not None and not (audio_durations.keys() == set(audio_ids)): + raise ValueError( + f'audio_durations audio ids do not match audio_ids. ' + f'Missing ids: {set(audio_ids) - audio_durations.keys()}. ' + f'Additional ids: {audio_durations.keys() - set(audio_ids)}.' + ) return audio_durations diff --git a/sed_scores_eval/collar_based/intermediate_statistics.py b/sed_scores_eval/collar_based/intermediate_statistics.py index 8eecde1..1203c72 100644 --- a/sed_scores_eval/collar_based/intermediate_statistics.py +++ b/sed_scores_eval/collar_based/intermediate_statistics.py @@ -159,8 +159,7 @@ def accumulated_intermediate_statistics( ) else: audio_ids = list(deltas.keys()) - ground_truth = parse_ground_truth( - ground_truth, audio_ids=audio_ids, additional_ids_ok=True) + ground_truth = parse_ground_truth(ground_truth, audio_ids=audio_ids) return accumulated_intermediate_statistics_from_deltas(deltas, ground_truth), audio_ids diff --git a/sed_scores_eval/collar_based/precision_recall.py b/sed_scores_eval/collar_based/precision_recall.py index 4655358..0102ee9 100644 --- a/sed_scores_eval/collar_based/precision_recall.py +++ b/sed_scores_eval/collar_based/precision_recall.py @@ -1,10 +1,11 @@ -from sed_scores_eval.base_modules.bootstrap import bootstrap_from_deltas -from sed_scores_eval.utils import parallel +import numpy as np +from sed_scores_eval.base_modules.io import parse_inputs +from sed_scores_eval.base_modules.bootstrap import bootstrap from sed_scores_eval.base_modules.precision_recall import ( single_fscore_from_intermediate_statistics, best_fscore_from_intermediate_statistics, precision_recall_curve_from_intermediate_statistics, - fscore_curve_from_intermediate_statistics + fscore_curve_from_intermediate_statistics, ) from sed_scores_eval.collar_based.intermediate_statistics import ( accumulated_intermediate_statistics, intermediate_statistics_deltas, @@ -188,20 +189,6 @@ def fscore( 'n_ref' (int): number of ground truth events """ - if isinstance(scores, (list, tuple)) or isinstance(deltas, (list, tuple)): - # batch input - batch_size = [len(v) for v in [scores, deltas] if v is not None][0] - f, p, r, stats = list(zip(*parallel.map( - (scores, deltas), arg_keys=('scores', 'deltas'), - func=fscore, max_jobs=num_jobs, - ground_truth=ground_truth, threshold=threshold, - onset_collar=onset_collar, offset_collar=offset_collar, - offset_collar_rate=offset_collar_rate, beta=beta, - return_onset_offset_dist_sum=return_onset_offset_dist_sum, - time_decimals=time_decimals, - num_jobs=num_jobs//batch_size, - ))) - return f, p, r, stats intermediate_stats, _ = accumulated_intermediate_statistics( scores=scores, ground_truth=ground_truth, deltas=deltas, onset_collar=onset_collar, offset_collar=offset_collar, @@ -218,7 +205,7 @@ def bootstrapped_fscore( scores, ground_truth, threshold, *, deltas=None, onset_collar, offset_collar, offset_collar_rate=0., beta=1., return_onset_offset_dist_sum=False, time_decimals=6, - n_folds=5, n_iterations=20, num_jobs=1, + n_bootstrap_samples=100, num_jobs=1, ): """ @@ -233,44 +220,27 @@ def bootstrapped_fscore( beta: return_onset_offset_dist_sum: time_decimals: - n_folds: - n_iterations: + n_bootstrap_samples: num_jobs: Returns: """ - if isinstance(scores, (list, tuple)) or isinstance(deltas, (list, tuple)): - # batch input - batch_size = [len(v) for v in [scores, deltas] if v is not None][0] - f, p, r, stats = list(zip(*parallel.map( - (scores, deltas), arg_keys=('scores', 'deltas'), - func=bootstrapped_fscore, max_jobs=num_jobs, - ground_truth=ground_truth, threshold=threshold, - onset_collar=onset_collar, offset_collar=offset_collar, - offset_collar_rate=offset_collar_rate, beta=beta, - return_onset_offset_dist_sum=return_onset_offset_dist_sum, - time_decimals=time_decimals, - n_folds=n_folds, n_iterations=n_iterations, - num_jobs=num_jobs//batch_size, - ))) - return f, p, r, stats - if deltas is None: - deltas = intermediate_statistics_deltas( - scores=scores, ground_truth=ground_truth, + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + return bootstrap( + fscore, scores=scores, deltas=deltas, + deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, onset_collar=onset_collar, offset_collar=offset_collar, offset_collar_rate=offset_collar_rate, return_onset_offset_dist_sum=return_onset_offset_dist_sum, - time_decimals=time_decimals, num_jobs=num_jobs, - ) - return bootstrap_from_deltas( - fscore, deltas, - n_folds=n_folds, n_iterations=n_iterations, num_jobs=num_jobs, - threshold=threshold, scores=None, ground_truth=ground_truth, - onset_collar=onset_collar, offset_collar=offset_collar, - offset_collar_rate=offset_collar_rate, beta=beta, - return_onset_offset_dist_sum=return_onset_offset_dist_sum, - time_decimals=time_decimals, + time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + threshold=threshold, beta=beta, + ), + n_bootstrap_samples=n_bootstrap_samples, ) @@ -342,3 +312,51 @@ def best_fscore( intermediate_stats, beta=beta, min_precision=min_precision, min_recall=min_recall, ) + + +def bootstrapped_fscore_curve( + scores, ground_truth, *, deltas=None, + onset_collar, offset_collar, offset_collar_rate=0., + beta=1., time_decimals=6, + n_bootstrap_samples=100, num_jobs=1, +): + """ + + Args: + scores: + ground_truth: + deltas: + onset_collar: + offset_collar: + offset_collar_rate: + beta: + time_decimals: + n_bootstrap_samples: + num_jobs: + + Returns: + + """ + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + return bootstrap( + fscore_curve, scores=scores, deltas=deltas, + deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, + onset_collar=onset_collar, offset_collar=offset_collar, + offset_collar_rate=offset_collar_rate, + time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + beta=beta, + ), + n_bootstrap_samples=n_bootstrap_samples, + ) + + +def _recursive_get_item(stats, idx): + if isinstance(stats, dict): + return {key: _recursive_get_item(stats[key], idx) for key in stats} + if np.isscalar(stats): + return stats + return stats[idx] diff --git a/sed_scores_eval/intersection_based/__init__.py b/sed_scores_eval/intersection_based/__init__.py index c5f11f6..e94c694 100644 --- a/sed_scores_eval/intersection_based/__init__.py +++ b/sed_scores_eval/intersection_based/__init__.py @@ -4,10 +4,10 @@ psds_from_psd_roc, multi_class_psd_roc_from_single_class_psd_rocs, ) from .pipsds import ( - postprocessing_independent_psd_roc, postprocessing_independent_psds, - bootstrapped_postprocessing_independent_psds, deltas_postprocessing, + postprocessing_independent_psd_roc_from_postprocessed_scores, postprocessing_independent_psds_from_postprocessed_scores, + bootstrapped_postprocessing_independent_psds_from_postprocessed_scores, deltas_postprocessing, median_filter_independent_psds, bootstrapped_median_filter_independent_psds, ) -from .precision_recall import precision_recall_curve, fscore_curve, fscore, best_fscore +from .precision_recall import precision_recall_curve, fscore_curve, fscore, best_fscore, bootstrapped_fscore from .error_rate import error_rate_curve, error_rate, best_error_rate from . import reference diff --git a/sed_scores_eval/intersection_based/intermediate_statistics.py b/sed_scores_eval/intersection_based/intermediate_statistics.py index 4c1a971..3552c77 100644 --- a/sed_scores_eval/intersection_based/intermediate_statistics.py +++ b/sed_scores_eval/intersection_based/intermediate_statistics.py @@ -174,8 +174,7 @@ def accumulated_intermediate_statistics( ) else: audio_ids = list(deltas.keys()) - ground_truth = parse_ground_truth( - ground_truth, audio_ids=audio_ids, additional_ids_ok=True) + ground_truth = parse_ground_truth(ground_truth, audio_ids=audio_ids) return accumulated_intermediate_statistics_from_deltas(deltas, ground_truth), audio_ids diff --git a/sed_scores_eval/intersection_based/pipsds.py b/sed_scores_eval/intersection_based/pipsds.py index cde053c..cc7b102 100644 --- a/sed_scores_eval/intersection_based/pipsds.py +++ b/sed_scores_eval/intersection_based/pipsds.py @@ -1,7 +1,8 @@ + import numpy as np -from functools import partial from sed_scores_eval.utils import parallel -from sed_scores_eval.base_modules.bootstrap import bootstrap_from_deltas +from sed_scores_eval.base_modules.io import parse_inputs, parse_ground_truth, parse_audio_durations +from sed_scores_eval.base_modules.bootstrap import bootstrap from sed_scores_eval.base_modules.postprocessing import medfilt from sed_scores_eval.intersection_based.intermediate_statistics import intermediate_statistics_deltas from sed_scores_eval.intersection_based.psds import _unique_cummax_sort, psd_roc, psds_from_psd_roc, multi_class_psd_roc_from_single_class_psd_rocs @@ -89,9 +90,31 @@ def merge_individual_rocs_into_overall_roc(rocs): return _unique_cummax_sort(tprs, efprs, scores, filter_lengths) -def postprocessing_independent_psds( - scores, ground_truth, audio_durations, *, postprocessing_functions, - scores_processed=None, deltas=None, +def postprocessing_independent_psd_roc_from_postprocessing_dependent_psd_rocs(single_class_psd_rocs, alpha_st, max_efpr): + + single_class_pi_psd_rocs = { + class_name: merge_individual_rocs_into_overall_roc([ + ( + *single_class_psd_rocs[m][class_name], + np.full( + len(single_class_psd_rocs[m][class_name][-1]), + m, + dtype=int, + ) + ) + for m in range(len(single_class_psd_rocs)) + ]) + for class_name in single_class_psd_rocs[0] + } + + pi_psd_roc = multi_class_psd_roc_from_single_class_psd_rocs( + single_class_pi_psd_rocs, alpha_st=alpha_st, max_efpr=max_efpr + ) + return pi_psd_roc, single_class_pi_psd_rocs + + +def postprocessing_independent_psds_from_postprocessed_scores( + scores, ground_truth, audio_durations, *, deltas=None, dtc_threshold, gtc_threshold, cttc_threshold=None, alpha_ct=.0, alpha_st=.0, unit_of_time='hour', max_efpr=100., time_decimals=6, num_jobs=1, @@ -120,16 +143,14 @@ def postprocessing_independent_psds( 2023 Args: - scores (dict, str, pathlib.Path): dict of SED score DataFrames + scores (list of dict, str or pathlib.Path): SED score DataFrames (cf. sed_scores_eval.utils.scores.create_score_dataframe) - or a directory path (as str or pathlib.Path) from where the SED - scores can be loaded. + or a directory paths (as str or pathlib.Path) from where the SED + scores can be loaded for different post-processings. ground_truth (dict, str or pathlib.Path): dict of lists of ground truth event tuples (onset, offset, event label) for each audio clip or a file path from where the ground truth can be loaded. audio_durations: The duration of each audio file in the evaluation set. - postprocessing_functions: - scores_processed: deltas (dict of dicts of tuples): Must be deltas as returned by `accumulated_intermediate_statistics_from_deltas`. If not provided, deltas are computed within this function. Providing deltas is useful @@ -169,27 +190,13 @@ def postprocessing_independent_psds( for each event class. """ - if ( - isinstance(scores, (list, tuple)) - or (isinstance(scores_processed, (list, tuple)) and isinstance(scores_processed[0], (list, tuple))) - or (isinstance(deltas, (list, tuple)) and isinstance(deltas[0], (list, tuple))) - ): - batch_size = [len(v) for v in [scores, scores_processed, deltas] if v is not None][0] - return list(zip(*parallel.map( - (scores, scores_processed, deltas), ('scores', 'scores_processed', 'deltas'), - func=postprocessing_independent_psds, max_jobs=num_jobs, - postprocessing_functions=postprocessing_functions, - ground_truth=ground_truth, audio_durations=audio_durations, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, - alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, - time_decimals=time_decimals, - num_jobs=num_jobs//batch_size, - ))) - (pi_effective_tp_rate, pi_effective_fp_rate), single_class_pi_psd_rocs, psd_rocs, single_class_psd_rocs = postprocessing_independent_psd_roc( - scores=scores, ground_truth=ground_truth, audio_durations=audio_durations, - postprocessing_functions=postprocessing_functions, - scores_processed=scores_processed, deltas=deltas, + ( + (pi_effective_tp_rate, pi_effective_fp_rate), + single_class_pi_psd_rocs, + psd_rocs, single_class_psd_rocs + ) = postprocessing_independent_psd_roc_from_postprocessed_scores( + scores=scores, ground_truth=ground_truth, + audio_durations=audio_durations, deltas=deltas, dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, @@ -211,9 +218,8 @@ def postprocessing_independent_psds( ) -def postprocessing_independent_psd_roc( - scores, ground_truth, audio_durations, *, postprocessing_functions, - scores_processed=None, deltas=None, +def postprocessing_independent_psd_roc_from_postprocessed_scores( + scores, ground_truth, audio_durations, *, deltas=None, dtc_threshold, gtc_threshold, cttc_threshold=None, alpha_ct=.0, alpha_st=.0, unit_of_time='hour', max_efpr=100., time_decimals=6, num_jobs=1, @@ -240,16 +246,14 @@ def postprocessing_independent_psd_roc( 2023 Args: - scores (dict, str, pathlib.Path): dict of SED score DataFrames + scores (list of dict, str or pathlib.Path): SED score DataFrames (cf. sed_scores_eval.utils.scores.create_score_dataframe) - or a directory path (as str or pathlib.Path) from where the SED - scores can be loaded. + or a directory paths (as str or pathlib.Path) from where the SED + scores can be loaded for different post-processings. ground_truth (dict, str or pathlib.Path): dict of lists of ground truth event tuples (onset, offset, event label) for each audio clip or a file path from where the ground truth can be loaded. audio_durations: The duration of each audio file in the evaluation set. - postprocessing_functions: - scores_processed: deltas (dict of dicts of tuples): Must be deltas as returned by `accumulated_intermediate_statistics_from_deltas`. If not provided, deltas are computed within this function. Providing deltas is useful @@ -285,86 +289,33 @@ def postprocessing_independent_psd_roc( single_class_psd_rocs (dict of tuples of 1d np.ndarrays): tuple of MFI True Positive Rates and effective False Positive Rates for each event class. - """ - if ( - isinstance(scores, (list, tuple)) - or (isinstance(scores_processed, (list, tuple)) and isinstance(scores_processed[0], (list, tuple))) - or (isinstance(deltas, (list, tuple)) and isinstance(deltas[0], (list, tuple))) - ): - # batch input - batch_size = [len(v) for v in [scores, scores_processed, deltas] if v is not None][0] - return list(zip(*parallel.map( - (scores, scores_processed, deltas), ('scores', 'scores_processed', 'deltas'), - func=postprocessing_independent_psd_roc, max_jobs=num_jobs, - postprocessing_functions=postprocessing_functions, - ground_truth=ground_truth, audio_durations=audio_durations, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, - alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, - time_decimals=time_decimals, - num_jobs=num_jobs//batch_size, - ))) if deltas is not None: assert isinstance(deltas, (list, tuple)), type(deltas) - if scores_processed is None: - scores_processed = len(deltas) * [None] - if postprocessing_functions is not None: - assert len(deltas) == len(postprocessing_functions), (len(deltas), len(postprocessing_functions)) - if scores_processed is not None: - assert isinstance(scores_processed, (list, tuple)), type(scores_processed) - if postprocessing_functions is not None: - assert len(scores_processed) == len(postprocessing_functions), (len(scores_processed), len(postprocessing_functions)) - psd_rocs, single_class_psd_rocs = list(zip(*parallel.map( - (scores_processed, deltas), arg_keys=('scores', 'deltas'), - func=psd_roc, max_jobs=num_jobs, - ground_truth=ground_truth, audio_durations=audio_durations, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, - alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, - time_decimals=time_decimals, - num_jobs=num_jobs//len(scores_processed), - ))) - else: - assert isinstance(postprocessing_functions, (list, tuple)), type(postprocessing_functions) - psd_rocs, single_class_psd_rocs = list(zip(*parallel.map( - postprocessing_functions, arg_keys='postprocessing_fn', - func=psd_roc_postprocessing, max_jobs=num_jobs, - scores=scores, ground_truth=ground_truth, audio_durations=audio_durations, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, - alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, - time_decimals=time_decimals, - num_jobs=1, - ))) - - single_class_pi_psd_rocs = { - class_name: merge_individual_rocs_into_overall_roc([ - ( - *single_class_psd_rocs[m][class_name], - np.full( - len(single_class_psd_rocs[m][class_name][-1]), - m, - dtype=int, - ) - ) - for m in range(len(single_class_psd_rocs)) - ]) - for class_name in single_class_psd_rocs[0] - } - - pi_psd_roc = multi_class_psd_roc_from_single_class_psd_rocs( - single_class_pi_psd_rocs, alpha_st=alpha_st, max_efpr=max_efpr - ) + if scores is None: + scores = len(deltas) * [None] + else: + assert len(deltas) == len(scores), (len(deltas), len(scores)) + assert isinstance(scores, (list, tuple)), type(scores) + psd_rocs, single_class_psd_rocs = list(zip(*parallel.map( + (scores, deltas), arg_keys=('scores', 'deltas'), + func=psd_roc, max_jobs=num_jobs, + ground_truth=ground_truth, audio_durations=audio_durations, + dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, + cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, + alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, + time_decimals=time_decimals, + num_jobs=max(num_jobs//len(scores), 1), + ))) + pi_psd_roc, single_class_pi_psd_rocs = postprocessing_independent_psd_roc_from_postprocessing_dependent_psd_rocs(single_class_psd_rocs, alpha_st, max_efpr) return pi_psd_roc, single_class_pi_psd_rocs, psd_rocs, single_class_psd_rocs -def bootstrapped_postprocessing_independent_psds( - scores, ground_truth, audio_durations, *, postprocessing_functions, - scores_processed=None, deltas=None, +def bootstrapped_postprocessing_independent_psds_from_postprocessed_scores( + scores, ground_truth, audio_durations, *, deltas=None, dtc_threshold, gtc_threshold, cttc_threshold=None, alpha_ct=.0, alpha_st=.0, unit_of_time='hour', max_efpr=100., - time_decimals=6, n_folds=5, n_iterations=4, num_jobs=1, + time_decimals=6, n_bootstrap_samples=100, num_jobs=1, ): """ @@ -372,8 +323,6 @@ def bootstrapped_postprocessing_independent_psds( scores: ground_truth: audio_durations: - postprocessing_functions: - scores_processed: deltas: dtc_threshold: gtc_threshold: @@ -383,81 +332,55 @@ def bootstrapped_postprocessing_independent_psds( unit_of_time: max_efpr: time_decimals: - n_folds: - n_iterations: + n_bootstrap_samples: num_jobs: Returns: """ - if ( - isinstance(scores, (list, tuple)) - or (isinstance(scores_processed, (list, tuple)) and isinstance(scores_processed[0], (list, tuple))) - or (isinstance(deltas, (list, tuple)) and isinstance(deltas[0], (list, tuple))) - ): - # batch input - batch_size = [len(v) for v in [scores, scores_processed, deltas] if v is not None][0] - return list(zip(*parallel.map( - (scores, scores_processed, deltas), - arg_keys=('scores', 'scores_processed', 'deltas'), - func=bootstrapped_postprocessing_independent_psds, - postprocessing_functions=postprocessing_functions, - max_jobs=num_jobs, ground_truth=ground_truth, - audio_durations=audio_durations, + def deltas_fn(scores, num_jobs, **kwargs): + return list(parallel.map( + scores, arg_keys='scores', + func=intermediate_statistics_deltas, max_jobs=num_jobs, + **kwargs, num_jobs=1, + )) + ground_truth = parse_ground_truth(ground_truth) + audio_durations = parse_audio_durations(audio_durations) + return bootstrap( + postprocessing_independent_psds_from_postprocessed_scores, + scores=scores, deltas=deltas, deltas_fn=deltas_fn, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, + cttc_threshold=cttc_threshold, time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + audio_durations=audio_durations,alpha_ct=alpha_ct, alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, - time_decimals=time_decimals, - n_folds=n_folds, n_iterations=n_iterations, - num_jobs=num_jobs//batch_size, - ))) - if deltas is None: - assert isinstance(postprocessing_functions, (list, tuple)), type(postprocessing_functions) - if scores_processed is None: - deltas = list(parallel.map( - postprocessing_functions, arg_keys='postprocessing_fn', - func=deltas_postprocessing, max_jobs=num_jobs, - scores=scores, ground_truth=ground_truth, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, time_decimals=time_decimals, - num_jobs=1, - )) - else: - deltas = list(parallel.map( - scores_processed, arg_keys='scores', - func=intermediate_statistics_deltas, max_jobs=num_jobs, - ground_truth=ground_truth, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, time_decimals=time_decimals, - num_jobs=1, - )) - return bootstrap_from_deltas( - postprocessing_independent_psds, deltas, - n_folds=n_folds, n_iterations=n_iterations, num_jobs=num_jobs, - scores=None, ground_truth=ground_truth, audio_durations=audio_durations, - postprocessing_functions=postprocessing_functions, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, alpha_st=alpha_st, - unit_of_time=unit_of_time, max_efpr=max_efpr + ), + n_bootstrap_samples=n_bootstrap_samples, ) def median_filter_independent_psds( - scores, ground_truth, audio_durations, *, median_filter_lengths_in_sec, - scores_processed=None, deltas=None, + scores, ground_truth, audio_durations, *, + median_filter_lengths_in_sec, deltas=None, dtc_threshold, gtc_threshold, cttc_threshold=None, alpha_ct=.0, alpha_st=.0, unit_of_time='hour', max_efpr=100., time_decimals=6, num_jobs=1, ): - postprocessing_functions = [ - partial(medfilt, filter_length_in_sec=filter_length_in_sec) - for filter_length_in_sec in median_filter_lengths_in_sec - ] - return postprocessing_independent_psds( - scores, ground_truth, audio_durations, - postprocessing_functions=postprocessing_functions, - scores_processed=scores_processed, deltas=deltas, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, + if deltas is None: + scores_postprocessed = parallel.map( + median_filter_lengths_in_sec, arg_keys='filter_length_in_sec', + func=medfilt, max_jobs=num_jobs, + scores=scores, time_decimals=time_decimals, + ) + else: + assert len(deltas) == len(median_filter_lengths_in_sec) + scores_postprocessed = None + return postprocessing_independent_psds_from_postprocessed_scores( + scores_postprocessed, ground_truth, audio_durations, + deltas=deltas, dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, time_decimals=time_decimals, num_jobs=num_jobs, @@ -465,23 +388,26 @@ def median_filter_independent_psds( def bootstrapped_median_filter_independent_psds( - scores, ground_truth, audio_durations, *, median_filter_lengths_in_sec, - scores_processed=None, deltas=None, + scores, ground_truth, audio_durations, *, + median_filter_lengths_in_sec, deltas=None, dtc_threshold, gtc_threshold, cttc_threshold=None, alpha_ct=.0, alpha_st=.0, unit_of_time='hour', max_efpr=100., - time_decimals=6, n_folds=5, n_iterations=4, num_jobs=1, + time_decimals=6, n_bootstrap_samples=100, num_jobs=1, ): - postprocessing_functions = [ - partial(medfilt, filter_length_in_sec=filter_length_in_sec) - for filter_length_in_sec in median_filter_lengths_in_sec - ] - return bootstrapped_postprocessing_independent_psds( - scores, ground_truth, audio_durations, - postprocessing_functions=postprocessing_functions, - scores_processed=scores_processed, deltas=deltas, + if deltas is None: + scores_postprocessed = parallel.map( + median_filter_lengths_in_sec, arg_keys='filter_length_in_sec', + func=medfilt, max_jobs=num_jobs, + scores=scores, time_decimals=time_decimals, + ) + else: + assert len(deltas) == len(median_filter_lengths_in_sec) + scores_postprocessed = None + return bootstrapped_postprocessing_independent_psds_from_postprocessed_scores( + scores_postprocessed, ground_truth, audio_durations, deltas=deltas, dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, time_decimals=time_decimals, num_jobs=num_jobs, - n_folds=n_folds, n_iterations=n_iterations, + n_bootstrap_samples=n_bootstrap_samples, ) diff --git a/sed_scores_eval/intersection_based/precision_recall.py b/sed_scores_eval/intersection_based/precision_recall.py index f2764e2..32ce5cb 100644 --- a/sed_scores_eval/intersection_based/precision_recall.py +++ b/sed_scores_eval/intersection_based/precision_recall.py @@ -1,10 +1,14 @@ +from sed_scores_eval.base_modules.io import parse_inputs +from sed_scores_eval.base_modules.bootstrap import bootstrap from sed_scores_eval.base_modules.precision_recall import ( precision_recall_curve_from_intermediate_statistics, fscore_curve_from_intermediate_statistics, single_fscore_from_intermediate_statistics, - best_fscore_from_intermediate_statistics + best_fscore_from_intermediate_statistics, +) +from sed_scores_eval.intersection_based.intermediate_statistics import ( + accumulated_intermediate_statistics, intermediate_statistics_deltas, ) -from sed_scores_eval.intersection_based.intermediate_statistics import accumulated_intermediate_statistics def precision_recall_curve( @@ -171,6 +175,44 @@ def fscore( ) +def bootstrapped_fscore( + scores, ground_truth, threshold, *, deltas=None, + dtc_threshold, gtc_threshold, beta=1., time_decimals=6, + n_bootstrap_samples=100, num_jobs=1, +): + """ + + Args: + scores: + ground_truth: + threshold: + deltas: + dtc_threshold: + gtc_threshold: + beta: + time_decimals: + n_bootstrap_samples: + num_jobs: + + Returns: + + """ + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + return bootstrap( + fscore, scores=scores, deltas=deltas, + deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, + dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, + time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + threshold=threshold, beta=beta, + ), + n_bootstrap_samples=n_bootstrap_samples, + ) + + def best_fscore( scores, ground_truth, *, deltas=None, dtc_threshold, gtc_threshold, diff --git a/sed_scores_eval/intersection_based/psds.py b/sed_scores_eval/intersection_based/psds.py index 87f1d51..19717e4 100644 --- a/sed_scores_eval/intersection_based/psds.py +++ b/sed_scores_eval/intersection_based/psds.py @@ -1,9 +1,9 @@ import numpy as np from scipy.interpolate import interp1d +from sed_scores_eval.base_modules.io import parse_inputs from sed_scores_eval.utils.array_ops import cummax, get_first_index_where -from sed_scores_eval.utils import parallel from sed_scores_eval.base_modules.curves import xsort, staircase_auc -from sed_scores_eval.base_modules.bootstrap import bootstrap_from_deltas +from sed_scores_eval.base_modules.bootstrap import bootstrap from sed_scores_eval.base_modules.io import parse_audio_durations from sed_scores_eval.intersection_based.intermediate_statistics import intermediate_statistics_deltas, accumulated_intermediate_statistics @@ -82,19 +82,6 @@ def psds( for each event class. """ - if isinstance(scores, (list, tuple)) or isinstance(deltas, (list, tuple)): - # batch input - batch_size = [len(v) for v in [scores, deltas] if v is not None][0] - return list(zip(*parallel.map( - (scores, deltas), arg_keys=('scores', 'deltas'), - func=psds, max_jobs=num_jobs, - ground_truth=ground_truth, audio_durations=audio_durations, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, - alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, - time_decimals=time_decimals, - num_jobs=num_jobs//batch_size, - ))) (effective_tp_rate, effective_fp_rate), single_class_psd_rocs = psd_roc( scores=scores, ground_truth=ground_truth, audio_durations=audio_durations, deltas=deltas, @@ -199,8 +186,7 @@ def psd_roc( time_decimals=time_decimals, num_jobs=num_jobs, ) - audio_durations = parse_audio_durations( - audio_durations, audio_ids=audio_ids, additional_ids_ok=(deltas is not None)) + audio_durations = parse_audio_durations(audio_durations, audio_ids=audio_ids) dataset_duration = sum(audio_durations.values()) single_class_psd_rocs = _single_class_roc_from_intermediate_statistics( @@ -220,7 +206,7 @@ def bootstrapped_psds( scores, ground_truth, audio_durations, *, deltas=None, dtc_threshold, gtc_threshold, cttc_threshold=None, alpha_ct=.0, alpha_st=.0, unit_of_time='hour', max_efpr=100., - time_decimals=6, n_folds=5, n_iterations=20, num_jobs=1, + time_decimals=6, n_bootstrap_samples=100, num_jobs=1, ): """ @@ -237,41 +223,26 @@ def bootstrapped_psds( unit_of_time: max_efpr: time_decimals: - n_folds: - n_iterations: + n_bootstrap_samples: num_jobs: Returns: """ - if isinstance(scores, (list, tuple)) or isinstance(deltas, (list, tuple)): - # batch input - batch_size = [len(v) for v in [scores, deltas] if v is not None][0] - return list(zip(*parallel.map( - (scores, deltas), arg_keys=('scores', 'deltas'), - func=bootstrapped_psds, max_jobs=num_jobs, - ground_truth=ground_truth, audio_durations=audio_durations, + scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth) + return bootstrap( + psds, scores=scores, deltas=deltas, + deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs, + deltas_fn_kwargs=dict( + ground_truth=ground_truth, dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, + cttc_threshold=cttc_threshold, time_decimals=time_decimals, + ), + eval_fn_kwargs=dict( + audio_durations=audio_durations,alpha_ct=alpha_ct, alpha_st=alpha_st, unit_of_time=unit_of_time, max_efpr=max_efpr, - time_decimals=time_decimals, - n_folds=n_folds, n_iterations=n_iterations, - num_jobs=num_jobs//batch_size, - ))) - if deltas is None: - deltas = intermediate_statistics_deltas( - scores=scores, ground_truth=ground_truth, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, - time_decimals=time_decimals, num_jobs=num_jobs, - ) - return bootstrap_from_deltas( - psds, deltas, - n_folds=n_folds, n_iterations=n_iterations, num_jobs=num_jobs, - scores=None, ground_truth=ground_truth, audio_durations=audio_durations, - dtc_threshold=dtc_threshold, gtc_threshold=gtc_threshold, - cttc_threshold=cttc_threshold, alpha_ct=alpha_ct, alpha_st=alpha_st, - unit_of_time=unit_of_time, max_efpr=max_efpr + ), + n_bootstrap_samples=n_bootstrap_samples, ) @@ -384,6 +355,8 @@ def multi_class_psd_roc_from_single_class_psd_rocs(single_class_psd_rocs, alpha_ def _unique_cummax_sort(tp_ratio, effective_fp_rate, *other, max_efpr=None): + # make cummax choose higher threshold when two ops have same fpr&tpr + tp_ratio, effective_fp_rate, *other = [values[::-1] for values in [tp_ratio, effective_fp_rate, *other]] tp_ratio, effective_fp_rate, *other = xsort(tp_ratio, effective_fp_rate, *other) cummax_indices = cummax(tp_ratio)[1] tp_ratio, effective_fp_rate, *other = [values[cummax_indices] for values in [tp_ratio, effective_fp_rate, *other]] diff --git a/tests/test_collar_based/test_collar_based_fscore.py b/tests/test_collar_based/test_collar_based_fscore.py index 64d441d..14a7245 100644 --- a/tests/test_collar_based/test_collar_based_fscore.py +++ b/tests/test_collar_based/test_collar_based_fscore.py @@ -57,7 +57,6 @@ def test_collar_based_fscore_vs_sed_eval(dataset, threshold, collar, num_jobs): np.testing.assert_almost_equal(r[key], r_sed_eval[key]) - @pytest.mark.parametrize("dataset", ["validation", "eval"]) @pytest.mark.parametrize( "threshold", @@ -81,7 +80,6 @@ def test_collar_based_fscore_vs_sed_eval(dataset, threshold, collar, num_jobs): @pytest.mark.parametrize("num_jobs", [1, 2]) def test_bootstrapped_collar_based_fscore(dataset, threshold, collar, num_jobs): offset_collar_rate = collar - time_decimals = 30 test_data_dir = package_dir / 'tests' / 'data' if not test_data_dir.exists(): io.download_test_data() @@ -92,8 +90,7 @@ def test_bootstrapped_collar_based_fscore(dataset, threshold, collar, num_jobs): threshold=threshold, onset_collar=collar, offset_collar=collar, offset_collar_rate=offset_collar_rate, - time_decimals=time_decimals, - num_jobs=num_jobs, n_folds=5, n_iterations=4, + num_jobs=num_jobs, n_bootstrap_samples=20, ) f_intervals = confidence_interval(f) for class_name, (f_mean, f_low, f_high) in f_intervals.items(): @@ -105,8 +102,35 @@ def test_bootstrapped_collar_based_fscore(dataset, threshold, collar, num_jobs): threshold=threshold, onset_collar=collar, offset_collar=collar, offset_collar_rate=offset_collar_rate, - time_decimals=time_decimals, num_jobs=num_jobs, ) for class_name, (f_mean, f_low, f_high) in f_intervals.items(): assert f_low < f[class_name] < f_high, (f_low, f_mean, f_high) + + +@pytest.mark.parametrize("dataset", ["validation", "eval"]) +@pytest.mark.parametrize("collar", [.2, .5]) +@pytest.mark.parametrize("num_jobs", [1, 2]) +def test_collar_based_best_fscore(dataset, collar, num_jobs): + offset_collar_rate = collar + test_data_dir = package_dir / 'tests' / 'data' + if not test_data_dir.exists(): + io.download_test_data() + + best_f, _, _, best_thresholds, _ = collar_based.best_fscore( + scores=test_data_dir / dataset / "scores", + ground_truth=test_data_dir / dataset / "ground_truth.tsv", + onset_collar=collar, offset_collar=collar, + offset_collar_rate=offset_collar_rate, + num_jobs=num_jobs, + ) + f_ref, *_ = collar_based.fscore( + scores=test_data_dir / dataset / "scores", + ground_truth=test_data_dir / dataset / "ground_truth.tsv", + threshold=best_thresholds, + onset_collar=collar, offset_collar=collar, + offset_collar_rate=offset_collar_rate, + num_jobs=num_jobs, + ) + for key in f_ref.keys(): + assert abs(best_f[key] - f_ref[key]) < 1e-6, key diff --git a/tests/test_intersection_based/test_intersection_based_fscore.py b/tests/test_intersection_based/test_intersection_based_fscore.py index 882cdea..1ea4d28 100644 --- a/tests/test_intersection_based/test_intersection_based_fscore.py +++ b/tests/test_intersection_based/test_intersection_based_fscore.py @@ -45,3 +45,29 @@ def test_intersection_based_fscore_vs_psds_eval(dataset, threshold, dtc_gtc_thre ) for key in f_ref.keys(): np.testing.assert_almost_equal(f[key], f_ref[key]) + + +@pytest.mark.parametrize("dataset", ["validation", "eval"]) +@pytest.mark.parametrize("dtc_gtc_threshold", [(.1, .1), (.7, .7)]) +@pytest.mark.parametrize("num_jobs", [1, 2]) +def test_intersection_based_best_fscore(dataset, dtc_gtc_threshold, num_jobs): + dtc_threshold, gtc_threshold = dtc_gtc_threshold + test_data_dir = package_dir / 'tests' / 'data' + if not test_data_dir.exists(): + io.download_test_data() + + best_f, _, _, best_thresholds, _ = intersection_based.best_fscore( + scores=test_data_dir / dataset / "scores", + ground_truth=test_data_dir / dataset / "ground_truth.tsv", + gtc_threshold=gtc_threshold, dtc_threshold=dtc_threshold, + num_jobs=num_jobs, + ) + f_ref, *_ = intersection_based.fscore( + scores=test_data_dir / dataset / "scores", + ground_truth=test_data_dir / dataset / "ground_truth.tsv", + threshold=best_thresholds, + gtc_threshold=gtc_threshold, dtc_threshold=dtc_threshold, + num_jobs=num_jobs, + ) + for key in f_ref.keys(): + assert abs(best_f[key] - f_ref[key]) < 1e-6, key diff --git a/tests/test_intersection_based/test_pipsds.py b/tests/test_intersection_based/test_pipsds.py index 749ed6d..7960aa2 100644 --- a/tests/test_intersection_based/test_pipsds.py +++ b/tests/test_intersection_based/test_pipsds.py @@ -132,7 +132,7 @@ def test_bootstrapped_median_filter_independent_psds(dataset, params, num_jobs): cttc_threshold=params['cttc_threshold'], alpha_ct=params['alpha_ct'], alpha_st=params['alpha_st'], unit_of_time='hour', max_efpr=100., time_decimals=6, - num_jobs=num_jobs, n_folds=5, n_iterations=4, + num_jobs=num_jobs, n_bootstrap_samples=20, ) (psds_mean, psds_low, psds_high) = confidence_interval(pipsds) assert psds_low < psds_mean < psds_high, (psds_low, psds_mean, psds_high) @@ -211,7 +211,7 @@ def test_bootstrapped_median_filter_independent_psds_prefiltered(dataset, params cttc_threshold=params['cttc_threshold'], alpha_ct=params['alpha_ct'], alpha_st=params['alpha_st'], unit_of_time='hour', max_efpr=100., time_decimals=6, - num_jobs=num_jobs, n_folds=5, n_iterations=4, + num_jobs=num_jobs, n_bootstrap_samples=20, ) postprocessing_functions = [ @@ -224,26 +224,6 @@ def test_bootstrapped_median_filter_independent_psds_prefiltered(dataset, params psds_rocs_prefiltered, single_class_psd_rocs_prefiltered, ) = intersection_based.bootstrapped_median_filter_independent_psds( scores=None, - scores_processed=[postprocessing_fn(scores) for postprocessing_fn in postprocessing_functions], - ground_truth=ground_truth, - audio_durations=audio_durations, - median_filter_lengths_in_sec=median_filter_lengths_in_sec, - dtc_threshold=params['dtc_threshold'], - gtc_threshold=params['gtc_threshold'], - cttc_threshold=params['cttc_threshold'], - alpha_ct=params['alpha_ct'], alpha_st=params['alpha_st'], - unit_of_time='hour', max_efpr=100., time_decimals=6, - num_jobs=num_jobs, n_folds=5, n_iterations=4, - ) - assert (np.array(pipsds_prefiltered) == np.array(pipsds)).all() - - ( - pipsds_prefiltered, single_class_pipsds_prefiltered, - pi_psd_roc_prefiltered, single_class_pi_psd_rocs_prefiltered, - psds_rocs_prefiltered, single_class_psd_rocs_prefiltered, - ) = intersection_based.bootstrapped_median_filter_independent_psds( - scores=None, - scores_processed=None, deltas=[ intersection_based.deltas_postprocessing( scores, @@ -263,6 +243,6 @@ def test_bootstrapped_median_filter_independent_psds_prefiltered(dataset, params cttc_threshold=params['cttc_threshold'], alpha_ct=params['alpha_ct'], alpha_st=params['alpha_st'], unit_of_time='hour', max_efpr=100., time_decimals=6, - num_jobs=num_jobs, n_folds=5, n_iterations=4, + num_jobs=num_jobs, n_bootstrap_samples=20, ) assert (np.array(pipsds_prefiltered) == np.array(pipsds)).all() diff --git a/tests/test_intersection_based/test_psds.py b/tests/test_intersection_based/test_psds.py index 32bef42..db9ffed 100644 --- a/tests/test_intersection_based/test_psds.py +++ b/tests/test_intersection_based/test_psds.py @@ -191,7 +191,7 @@ def test_bootstrapped_psds(dataset, params, num_jobs): cttc_threshold=params['cttc_threshold'], alpha_ct=params['alpha_ct'], alpha_st=params['alpha_st'], unit_of_time='hour', max_efpr=100., time_decimals=6, - num_jobs=num_jobs, n_folds=5, n_iterations=4, + num_jobs=num_jobs, n_bootstrap_samples=20, ) (psds_mean, psds_low, psds_high) = confidence_interval(psds) assert psds_low < psds_mean < psds_high, (psds_low, psds_mean, psds_high)