Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow BatchBALD to not consider some samples #62

Merged
merged 4 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 77 additions & 48 deletions al_bench/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from __future__ import annotations

from typing import Any, Dict, List, Mapping, Sequence
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast

import numpy as np
import scipy.stats
Expand All @@ -33,7 +33,12 @@ class ComputeCertainty:
all_certainty_types: List[str]
all_certainty_types = [confidence, margin, negative_entropy, batchbald]

def __init__(self, certainty_type, percentiles, cutoffs) -> None:
def __init__(
self,
certainty_type: Optional[Union[str, Sequence[str]]],
percentiles: Optional[Sequence[float]],
cutoffs: Optional[Union[Sequence[float], Mapping[str, Sequence[float]]]],
) -> None:
"""
certainty_type can be "confidence", "margin", "negative_entropy", "batchbald" or
a list (or tuple) of one or more of these. A value of None means all certainty
Expand Down Expand Up @@ -72,22 +77,29 @@ def __init__(self, certainty_type, percentiles, cutoffs) -> None:
if len(certainty_type) == 1 and isinstance(cutoffs, (list, tuple)):
cutoffs = {certainty_type[0]: cutoffs}
# If we have no information for a certainty type, default to no cutoffs.
cutoffs = {**{k: [] for k in certainty_type}, **cutoffs}
cutoffs = {
**{k: [] for k in certainty_type},
**cast(Mapping[str, Sequence[float]], cutoffs),
}
if not all(cut in certainty_type for cut in cutoffs.keys()):
raise ValueError(f"Something wrong with {cutoffs = }")
cutoffs = {key: [float(c) for c in value] for key, value in cutoffs.items()}
self.cutoffs: Mapping[str, Sequence[float]] = cutoffs

# Defaults. See
# Defaults. These can be overridden with setter methods.
self.batchbald_batch_size: int = 100
self.batchbald_num_samples: int = 10000
self.batchbald_excluded_samples: NDArray[np.int_] = np.array([], dtype=np.int_)

def set_batchbald_batch_size(self, batchbald_batch_size: int) -> None:
self.batchbald_batch_size = batchbald_batch_size

def set_batchbald_num_samples(self, batchbald_num_samples: int) -> None:
self.batchbald_num_samples = batchbald_num_samples

def set_batchbald_excluded_samples(self, exc_samples: NDArray[np.int_]) -> None:
self.batchbald_excluded_samples = exc_samples

def from_numpy_array(
self, predictions: NDArray[np.float_]
) -> Mapping[str, Mapping[str, Any]]:
Expand Down Expand Up @@ -116,6 +128,7 @@ def from_numpy_array(
# Normalize rows to sum to 1.0
predictions = predictions / np.sum(predictions, axis=-1, keepdims=True)
# Find the two largest values within each row.
partitioned: NDArray[np.float_]
partitioned = np.partition(predictions, -2, axis=-1)[..., -2:]

scores: Dict[str, NDArray[np.float_]] = dict()
Expand All @@ -138,50 +151,7 @@ def from_numpy_array(
# When certainty is determined by batchbald, we let batchbald_redux do our
# calculations.
if "batchbald" in self.certainty_type:
if len(predictions.shape) != 3:
raise ValueError(
"To compute statistics for batchbald,"
" `predictions` must be 3-dimensional,"
f" but its {len(predictions.shape)} dimensions"
f" are {predictions.shape}."
)

import torch
import batchbald_redux as bbald
import batchbald_redux.batchbald

# Indicate how many predictions we want batchbald to rate as uncertain. All
# the remaining will be rated as more certain via a constant.
batch_size: int = min(self.batchbald_batch_size, predictions.shape[0])
num_samples: int = self.batchbald_num_samples
epsilon = 7.8886090522101180541e-31 # 2**-100
predictions_copy = predictions.copy()
predictions_copy[predictions_copy < epsilon] = epsilon
log_predictions = np.log(predictions_copy)
with torch.no_grad():
bald: bbald.batchbald.CandidateBatch
bald = bbald.batchbald.get_batchbald_batch(
torch.from_numpy(log_predictions),
batch_size,
num_samples,
dtype=torch.double,
)
bald_indices = np.array(bald.indices)
bald_scores = np.array(bald.scores)
# For samples that are not ranked by batchbald, we will fallback to using
# (possibly shifted) confidence scores. Predictions will be averaged over
# Bayesian samples and then the most likely mean prediction score will be
# the confidence. These confidence scores are shifted, if needed, to ensure
# that the confidences scores are interpretted as more certain than the
# batchbald scores. Because the user labels the most uncertain (lowest)
# scores first, this puts the batchbald selections first, followed by the
# remaining selections once the batchbald selections are exhausted.
max_bald_scores = max(0.0, max(bald_scores))
scores["batchbald"] = np.full(predictions.shape[:-1], max_bald_scores)
scores["batchbald"] += np.max(
np.mean(predictions, axis=-2, keepdims=True), axis=-1
)
scores["batchbald"][bald_indices, :] = bald_scores[:, np.newaxis]
scores["batchbald"] = self.batchbald_scores(predictions)

# Report scores, percentile scores, and cutoff percentiles
response: Mapping[str, Mapping[str, Any]]
Expand All @@ -203,3 +173,62 @@ def from_numpy_array(
}

return response

def batchbald_scores(self, predictions: NDArray[np.float_]) -> NDArray[np.float_]:
if len(predictions.shape) != 3:
raise ValueError(
"To compute statistics for batchbald,"
" `predictions` must be 3-dimensional,"
f" but its {len(predictions.shape)} dimensions"
f" are {predictions.shape}."
)

# Exclude samples we've been asked to exclude
included_samples: NDArray[np.int_] = np.setdiff1d(
np.arange(len(predictions)),
self.batchbald_excluded_samples,
assume_unique=False,
)
included_predictions: NDArray[np.float_] = predictions[included_samples, ...]

# Convert prediction values to logarithms
epsilon: float = 7.8886090522101180541e-31 # 2**-100
included_predictions[included_predictions < epsilon] = epsilon
log_predictions: NDArray[np.float_] = np.log(included_predictions)

# Indicate how many predictions we want batchbald to rate as uncertain. All
# the remaining will be rated as more certain via a fallback approach.
batch_size: int = min(self.batchbald_batch_size, len(log_predictions))
num_samples: int = self.batchbald_num_samples

# Do the BatchBALD calculation
import torch
import batchbald_redux as bbald
import batchbald_redux.batchbald

with torch.no_grad():
bald: bbald.batchbald.CandidateBatch
bald = bbald.batchbald.get_batchbald_batch(
torch.from_numpy(log_predictions),
batch_size,
num_samples,
dtype=torch.double,
)
bald_indices: NDArray[np.int_] = np.array(bald.indices)
bald_scores: NDArray[np.float_] = np.array(bald.scores)

# For samples that are not ranked by batchbald, we will fallback to using
# (possibly shifted) confidence scores. Predictions will be averaged over
# Bayesian samples and then the most likely mean prediction score will be
# the confidence. These confidence scores are shifted, if needed, to ensure
# that the confidences scores are interpretted as more certain than the
# batchbald scores. Because the user labels the most uncertain (lowest)
# scores first, this puts the batchbald selections first, followed by the
# remaining selections once the batchbald selections are exhausted.
max_bald_scores: float = max(0.0, max(bald_scores))
response: NDArray[np.float_] = np.full(predictions.shape[:-1], max_bald_scores)
response += np.max(np.mean(predictions, axis=-2, keepdims=True), axis=-1)

# Overwrite the fallback scores for the samples that are selected by batchbald.
response[included_samples[bald_indices], :] = bald_scores[:, np.newaxis]
return response
Loading
Loading