Skip to content

Commit

Permalink
Merge pull request #59 from Leengit/Bald_to_BALD
Browse files Browse the repository at this point in the history
STYLE: Use BALD instead of Bald, per arXiv:1906.08158
  • Loading branch information
Leengit authored May 9, 2024
2 parents 9dd6b77 + db6643a commit 8500803
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion al_bench/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ComputeCertainty:
negative_entropy: str = "negative_entropy"
batchbald: str = "batchbald"
all_certainty_types: List[str]
all_certainty_types = ["confidence", "margin", "negative_entropy", "batchbald"]
all_certainty_types = [confidence, margin, negative_entropy, batchbald]

def __init__(self, certainty_type, percentiles, cutoffs) -> None:
"""
Expand Down
19 changes: 12 additions & 7 deletions al_bench/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,9 @@ def select_next_indices(
return predict_order


class BaldStrategyHandler(GenericStrategyHandler):
class BALDStrategyHandler(GenericStrategyHandler):
def __init__(self) -> None:
super(BaldStrategyHandler, self).__init__()
super(BALDStrategyHandler, self).__init__()

def select_next_indices(
self,
Expand All @@ -654,17 +654,17 @@ def select_next_indices(
) -> NDArray[np.int_]:
"""
Select new examples to be labeled by the expert. This choses the unlabeled
examples based upon the BALD criterion. (See also BatchBaldStrategyHandler.)
examples based upon the BALD criterion. (See also BatchBALDStrategyHandler.)
"""
print(f"self.predictions.shape = {self.predictions.shape}")
raise NotImplementedError(
"BaldStrategyHandler::select_next_indices is not yet implemented."
"BALDStrategyHandler::select_next_indices is not yet implemented."
)


class BatchBaldStrategyHandler(GenericStrategyHandler):
class BatchBALDStrategyHandler(GenericStrategyHandler):
def __init__(self) -> None:
super(BatchBaldStrategyHandler, self).__init__()
super(BatchBALDStrategyHandler, self).__init__()

def select_next_indices(
self,
Expand All @@ -673,7 +673,7 @@ def select_next_indices(
) -> NDArray[np.int_]:
"""
Select new examples to be labeled by the expert. This chooses the unlabeled
examples based upon the Batch-BALD criterion. (See also BaldStrategyHandler.)
examples based upon the BatchBALD criterion. (See also BALDStrategyHandler.)
"""
if validation_indices is None:
validation_indices = np.array((), dtype=np.int64)
Expand Down Expand Up @@ -714,3 +714,8 @@ def select_next_indices(
dtype=torch.double,
)
return available_indices[candidates.indices]


# Support legacy names for now
BaldStrategyHandler = BALDStrategyHandler
BatchBaldStrategyHandler = BatchBALDStrategyHandler
4 changes: 2 additions & 2 deletions example/BayesianExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"\n",
"<p>Each ALBench simulation requires the user to specifiy a DatasetHandler, a ModelHandler, and a StrategyHandler. A <b>DatasetHandler</b> provides an interface to the other handlers that is agnostic to the format of the dataset. A <b>ModelHandler</b> provides an interface that is agnostic to such things as whether the model is implemented in TensorFlow or PyTorch. A <b>StrategyHandler</b> implements a specific active learning strategy: for non-Bayesian models this includes random selection, least confidence, least margin, and maximum entropy; for Bayesian models this includes BatchBALD.\n",
" \n",
"The present implementation of the BatchBALD active learning strategy leans heavily upon use of the <a href=\"https://github.com/BlackHC/batchbald_redux/blob/master/README.md\">batchbald_redux Python package</a>. Programmers who wish to add additional strategies to ALBench may benefit from looking at the implementation of the BatchBaldStrategyHandler within <a href=\"https://github.com/DigitalSlideArchive/ALBench/blob/main/al_bench/strategy.py\">al_bench/strategy.py</a>."
"The present implementation of the BatchBALD active learning strategy leans heavily upon use of the <a href=\"https://github.com/BlackHC/batchbald_redux/blob/master/README.md\">batchbald_redux Python package</a>. Programmers who wish to add additional strategies to ALBench may benefit from looking at the implementation of the BatchBALDStrategyHandler within <a href=\"https://github.com/DigitalSlideArchive/ALBench/blob/main/al_bench/strategy.py\">al_bench/strategy.py</a>."
]
},
{
Expand Down Expand Up @@ -1754,7 +1754,7 @@
"\n",
"for name, my_strategy_handler in (\n",
" # (\"BALD\", alb.strategy.BaldStrategyHandler()),\n",
" (\"BatchBALD\", alb.strategy.BatchBaldStrategyHandler()),\n",
" (\"BatchBALD\", alb.strategy.BatchBALDStrategyHandler()),\n",
"):\n",
" print(f\"=== Begin Strategy {repr(name)} at {datetime.now()} ===\")\n",
" my_strategy_handler.set_dataset_handler(my_dataset_handler)\n",
Expand Down
4 changes: 3 additions & 1 deletion test/test_0120_bayesian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def test_0120_bayesian_model() -> None:
name: str
my_strategy_handler: alb.strategy.AbstractStrategyHandler
for name, my_strategy_handler in (
("BatchBALD", alb.strategy.BatchBaldStrategyHandler()),
# ("BatchBALD", alb.strategy.BatchBALDStrategyHandler()),
# Deliberately use legacy name, to test that it is still supported.
("BatchBald", alb.strategy.BatchBaldStrategyHandler()),
):
my_strategy_handler.set_dataset_handler(my_dataset_handler)
my_strategy_handler.set_model_handler(my_model_handler)
Expand Down

0 comments on commit 8500803

Please sign in to comment.