Skip to content

Commit

Permalink
bootstrap updates
Browse files Browse the repository at this point in the history
  • Loading branch information
JanekEbb committed Jun 20, 2024
1 parent 83b501e commit e659dea
Show file tree
Hide file tree
Showing 13 changed files with 391 additions and 375 deletions.
110 changes: 73 additions & 37 deletions sed_scores_eval/base_modules/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,83 @@
import numpy as np
from collections import defaultdict
from sed_scores_eval.utils import parallel


def bootstrap(
eval_fn, scores=None, deltas=None, deltas_fn=None,
n_bootstrap_samples=100, num_jobs=1,
deltas_fn_kwargs=None, eval_fn_kwargs=None,
):
if deltas_fn_kwargs is None:
deltas_fn_kwargs = {}
if eval_fn_kwargs is None:
eval_fn_kwargs = {}
if deltas is None:
assert scores is not None
assert deltas_fn is not None
deltas = deltas_fn(
scores=scores, num_jobs=num_jobs, **deltas_fn_kwargs,
)
return bootstrap_from_deltas(
eval_fn, deltas,
n_bootstrap_samples=n_bootstrap_samples, num_jobs=num_jobs,
scores=None, **deltas_fn_kwargs, **eval_fn_kwargs,
)


def bootstrap_from_deltas(
metric_fn, deltas, *,
n_folds=5, n_iterations=20, num_jobs=1,
**metric_fn_kwargs,
eval_fn, deltas, *,
n_bootstrap_samples=100, num_jobs=1,
ground_truth=None, audio_durations=None,
**eval_fn_kwargs,
):
if isinstance(deltas, (list, tuple)):
audio_ids = sorted(deltas[0].keys())
else:
audio_ids = sorted(deltas.keys())
split_indices = np.linspace(0, len(audio_ids), n_folds+1).astype(int)
audio_id_fractions = []
for i in range(n_iterations):
np.random.RandomState(i).shuffle(audio_ids)
for j in range(n_folds):
audio_id_fractions.append(list(
audio_ids[:split_indices[j]] + audio_ids[split_indices[j+1]:]
))
data_orig = {'deltas': deltas}
data_samples = {'deltas': []}
if ground_truth is not None:
data_orig['ground_truth'] = ground_truth
data_samples['ground_truth'] = []
if audio_durations is not None:
data_orig['audio_durations'] = audio_durations
data_samples['audio_durations'] = []

if isinstance(deltas, (list, tuple)):
deltas_fractions = [
[
{audio_id: delts[audio_id] for audio_id in audio_id_fraction}
for delts in deltas
]
for audio_id_fraction in audio_id_fractions
]
else:
deltas_fractions = [
{audio_id: deltas[audio_id] for audio_id in audio_id_fraction}
for audio_id_fraction in audio_id_fractions
]
return list(zip(*parallel.map(
deltas_fractions, arg_keys='deltas',
func=metric_fn, max_jobs=num_jobs,
**metric_fn_kwargs,
)))


def confidence_interval(bootstrapped_outputs, confidence=.9):
audio_ids_repeated = n_bootstrap_samples * audio_ids
np.random.RandomState(0).shuffle(audio_ids_repeated)
for i in range(n_bootstrap_samples):
for key in data_samples:
data_samples[key].append({})
for j, audio_id in enumerate(audio_ids_repeated[i*len(audio_ids):(i+1)*len(audio_ids)]):
for key in data_samples:
if isinstance(data_orig[key], (list, tuple)):
if isinstance(data_samples[key][-1], dict):
data_samples[key][-1] = [{} for _ in range(len(data_orig[key]))]
for k in range(len(data_orig[key])):
data_samples[key][-1][k][f'{audio_id}_bootstrap{i}_clip{j}'] = data_orig[key][k][audio_id]
else:
data_samples[key][-1][f'{audio_id}_bootstrap{i}_clip{j}'] = data_orig[key][audio_id]

arg_keys = sorted(data_samples.keys())
args = [data_samples[key] for key in arg_keys]
ret = parallel.map(
args, arg_keys=arg_keys,
func=eval_fn, max_jobs=num_jobs,
**eval_fn_kwargs,
)
if isinstance(ret[0], tuple):
return list(zip(*ret))
return ret


def _recursive_multiply(deltas, factor):
if isinstance(deltas, dict):
return {key: _recursive_multiply(deltas[key], factor) for key in deltas.keys()}
return deltas * factor


def confidence_interval(bootstrapped_outputs, confidence=.9, axis=None):
if isinstance(bootstrapped_outputs[0], dict):
mean_low_high = {
class_name: confidence_interval([
Expand All @@ -50,12 +87,11 @@ def confidence_interval(bootstrapped_outputs, confidence=.9):
}
return mean_low_high

mean = np.mean(bootstrapped_outputs)
mean = np.mean(bootstrapped_outputs, axis=axis)
low = np.percentile(
bootstrapped_outputs, ((1 - confidence) / 2) * 100
bootstrapped_outputs, ((1 - confidence) / 2) * 100, axis=axis,
)
high = np.percentile(
bootstrapped_outputs, (confidence + ((1 - confidence) / 2)) * 100
bootstrapped_outputs, (confidence + ((1 - confidence) / 2)) * 100, axis=axis,
)
return float(mean), float(low), float(high)

return mean, low, high
35 changes: 14 additions & 21 deletions sed_scores_eval/base_modules/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def parse_scores(scores):


def parse_ground_truth(
ground_truth, *,
tagging=False, audio_ids=None, additional_ids_ok=False
ground_truth, *, tagging=False, audio_ids=None,
):
if not isinstance(ground_truth, (dict, str, Path)):
raise ValueError(
Expand All @@ -66,15 +65,12 @@ def parse_ground_truth(
ground_truth = read_ground_truth_events(ground_truth)
if not tagging:
assert_non_connected_events(ground_truth)
if audio_ids is not None:
if additional_ids_ok:
ground_truth = {key: ground_truth[key] for key in audio_ids}
elif not (ground_truth.keys() == set(audio_ids)):
raise ValueError(
f'ground_truth audio ids do not match audio_ids. '
f'Missing ids: {set(audio_ids) - ground_truth.keys()}. '
f'Additional ids: {ground_truth.keys() - set(audio_ids)}.'
)
if audio_ids is not None and not (ground_truth.keys() == set(audio_ids)):
raise ValueError(
f'ground_truth audio ids do not match audio_ids. '
f'Missing ids: {set(audio_ids) - ground_truth.keys()}. '
f'Additional ids: {ground_truth.keys() - set(audio_ids)}.'
)
return ground_truth


Expand All @@ -91,7 +87,7 @@ def assert_non_connected_events(ground_truth):
current_offset = event[1]


def parse_audio_durations(audio_durations, *, audio_ids=None, additional_ids_ok=False):
def parse_audio_durations(audio_durations, *, audio_ids=None):
if not isinstance(audio_durations, (dict, str, Path)):
raise ValueError(
f'audio_durations must be dict, str or Path but '
Expand All @@ -101,15 +97,12 @@ def parse_audio_durations(audio_durations, *, audio_ids=None, additional_ids_ok=
audio_durations = Path(audio_durations)
assert audio_durations.is_file(), audio_durations
audio_durations = read_audio_durations(audio_durations)
if audio_ids is not None:
if additional_ids_ok:
audio_durations = {key: audio_durations[key] for key in audio_ids}
elif not (audio_durations.keys() == set(audio_ids)):
raise ValueError(
f'audio_durations audio ids do not match audio_ids. '
f'Missing ids: {set(audio_ids) - audio_durations.keys()}. '
f'Additional ids: {audio_durations.keys() - set(audio_ids)}.'
)
if audio_ids is not None and not (audio_durations.keys() == set(audio_ids)):
raise ValueError(
f'audio_durations audio ids do not match audio_ids. '
f'Missing ids: {set(audio_ids) - audio_durations.keys()}. '
f'Additional ids: {audio_durations.keys() - set(audio_ids)}.'
)
return audio_durations


Expand Down
3 changes: 1 addition & 2 deletions sed_scores_eval/collar_based/intermediate_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def accumulated_intermediate_statistics(
)
else:
audio_ids = list(deltas.keys())
ground_truth = parse_ground_truth(
ground_truth, audio_ids=audio_ids, additional_ids_ok=True)
ground_truth = parse_ground_truth(ground_truth, audio_ids=audio_ids)
return accumulated_intermediate_statistics_from_deltas(deltas, ground_truth), audio_ids


Expand Down
114 changes: 66 additions & 48 deletions sed_scores_eval/collar_based/precision_recall.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from sed_scores_eval.base_modules.bootstrap import bootstrap_from_deltas
from sed_scores_eval.utils import parallel
import numpy as np
from sed_scores_eval.base_modules.io import parse_inputs
from sed_scores_eval.base_modules.bootstrap import bootstrap
from sed_scores_eval.base_modules.precision_recall import (
single_fscore_from_intermediate_statistics,
best_fscore_from_intermediate_statistics,
precision_recall_curve_from_intermediate_statistics,
fscore_curve_from_intermediate_statistics
fscore_curve_from_intermediate_statistics,
)
from sed_scores_eval.collar_based.intermediate_statistics import (
accumulated_intermediate_statistics, intermediate_statistics_deltas,
Expand Down Expand Up @@ -188,20 +189,6 @@ def fscore(
'n_ref' (int): number of ground truth events
"""
if isinstance(scores, (list, tuple)) or isinstance(deltas, (list, tuple)):
# batch input
batch_size = [len(v) for v in [scores, deltas] if v is not None][0]
f, p, r, stats = list(zip(*parallel.map(
(scores, deltas), arg_keys=('scores', 'deltas'),
func=fscore, max_jobs=num_jobs,
ground_truth=ground_truth, threshold=threshold,
onset_collar=onset_collar, offset_collar=offset_collar,
offset_collar_rate=offset_collar_rate, beta=beta,
return_onset_offset_dist_sum=return_onset_offset_dist_sum,
time_decimals=time_decimals,
num_jobs=num_jobs//batch_size,
)))
return f, p, r, stats
intermediate_stats, _ = accumulated_intermediate_statistics(
scores=scores, ground_truth=ground_truth, deltas=deltas,
onset_collar=onset_collar, offset_collar=offset_collar,
Expand All @@ -218,7 +205,7 @@ def bootstrapped_fscore(
scores, ground_truth, threshold, *, deltas=None,
onset_collar, offset_collar, offset_collar_rate=0., beta=1.,
return_onset_offset_dist_sum=False, time_decimals=6,
n_folds=5, n_iterations=20, num_jobs=1,
n_bootstrap_samples=100, num_jobs=1,
):
"""
Expand All @@ -233,44 +220,27 @@ def bootstrapped_fscore(
beta:
return_onset_offset_dist_sum:
time_decimals:
n_folds:
n_iterations:
n_bootstrap_samples:
num_jobs:
Returns:
"""
if isinstance(scores, (list, tuple)) or isinstance(deltas, (list, tuple)):
# batch input
batch_size = [len(v) for v in [scores, deltas] if v is not None][0]
f, p, r, stats = list(zip(*parallel.map(
(scores, deltas), arg_keys=('scores', 'deltas'),
func=bootstrapped_fscore, max_jobs=num_jobs,
ground_truth=ground_truth, threshold=threshold,
onset_collar=onset_collar, offset_collar=offset_collar,
offset_collar_rate=offset_collar_rate, beta=beta,
return_onset_offset_dist_sum=return_onset_offset_dist_sum,
time_decimals=time_decimals,
n_folds=n_folds, n_iterations=n_iterations,
num_jobs=num_jobs//batch_size,
)))
return f, p, r, stats
if deltas is None:
deltas = intermediate_statistics_deltas(
scores=scores, ground_truth=ground_truth,
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
fscore, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
deltas_fn_kwargs=dict(
ground_truth=ground_truth,
onset_collar=onset_collar, offset_collar=offset_collar,
offset_collar_rate=offset_collar_rate,
return_onset_offset_dist_sum=return_onset_offset_dist_sum,
time_decimals=time_decimals, num_jobs=num_jobs,
)
return bootstrap_from_deltas(
fscore, deltas,
n_folds=n_folds, n_iterations=n_iterations, num_jobs=num_jobs,
threshold=threshold, scores=None, ground_truth=ground_truth,
onset_collar=onset_collar, offset_collar=offset_collar,
offset_collar_rate=offset_collar_rate, beta=beta,
return_onset_offset_dist_sum=return_onset_offset_dist_sum,
time_decimals=time_decimals,
time_decimals=time_decimals,
),
eval_fn_kwargs=dict(
threshold=threshold, beta=beta,
),
n_bootstrap_samples=n_bootstrap_samples,
)


Expand Down Expand Up @@ -342,3 +312,51 @@ def best_fscore(
intermediate_stats, beta=beta,
min_precision=min_precision, min_recall=min_recall,
)


def bootstrapped_fscore_curve(
scores, ground_truth, *, deltas=None,
onset_collar, offset_collar, offset_collar_rate=0.,
beta=1., time_decimals=6,
n_bootstrap_samples=100, num_jobs=1,
):
"""
Args:
scores:
ground_truth:
deltas:
onset_collar:
offset_collar:
offset_collar_rate:
beta:
time_decimals:
n_bootstrap_samples:
num_jobs:
Returns:
"""
scores, ground_truth, audio_ids = parse_inputs(scores, ground_truth)
return bootstrap(
fscore_curve, scores=scores, deltas=deltas,
deltas_fn=intermediate_statistics_deltas, num_jobs=num_jobs,
deltas_fn_kwargs=dict(
ground_truth=ground_truth,
onset_collar=onset_collar, offset_collar=offset_collar,
offset_collar_rate=offset_collar_rate,
time_decimals=time_decimals,
),
eval_fn_kwargs=dict(
beta=beta,
),
n_bootstrap_samples=n_bootstrap_samples,
)


def _recursive_get_item(stats, idx):
if isinstance(stats, dict):
return {key: _recursive_get_item(stats[key], idx) for key in stats}
if np.isscalar(stats):
return stats
return stats[idx]
6 changes: 3 additions & 3 deletions sed_scores_eval/intersection_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
psds_from_psd_roc, multi_class_psd_roc_from_single_class_psd_rocs,
)
from .pipsds import (
postprocessing_independent_psd_roc, postprocessing_independent_psds,
bootstrapped_postprocessing_independent_psds, deltas_postprocessing,
postprocessing_independent_psd_roc_from_postprocessed_scores, postprocessing_independent_psds_from_postprocessed_scores,
bootstrapped_postprocessing_independent_psds_from_postprocessed_scores, deltas_postprocessing,
median_filter_independent_psds, bootstrapped_median_filter_independent_psds,
)
from .precision_recall import precision_recall_curve, fscore_curve, fscore, best_fscore
from .precision_recall import precision_recall_curve, fscore_curve, fscore, best_fscore, bootstrapped_fscore
from .error_rate import error_rate_curve, error_rate, best_error_rate
from . import reference
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def accumulated_intermediate_statistics(
)
else:
audio_ids = list(deltas.keys())
ground_truth = parse_ground_truth(
ground_truth, audio_ids=audio_ids, additional_ids_ok=True)
ground_truth = parse_ground_truth(ground_truth, audio_ids=audio_ids)
return accumulated_intermediate_statistics_from_deltas(deltas, ground_truth), audio_ids


Expand Down
Loading

0 comments on commit e659dea

Please sign in to comment.