Skip to content

Commit 54e9fe1

Browse files
committed
Raise ValueError if acquisition fn doesn't set min_or_max attr
1 parent 1a813a5 commit 54e9fe1

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

deepsensor/active_learning/algorithms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ def __call__(
414414
)
415415

416416
self.min_or_max = acquisition_fn.min_or_max
417+
if self.min_or_max not in ["min", "max"]:
418+
raise ValueError(
419+
f"min_or_max must be either 'min' or 'max', got {self.min_or_max}."
420+
)
417421

418422
if isinstance(tasks, Task):
419423
tasks = [tasks]

tests/test_active_learning.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
OracleMAE,
1616
OracleRMSE,
1717
OracleMarginalNLL,
18-
OracleJointNLL,
18+
OracleJointNLL, AcquisitionFunction,
1919
)
2020
from deepsensor.active_learning.algorithms import GreedyAlgorithm
2121

@@ -258,3 +258,23 @@ def test_greedy_alg_with_oracle_acquisition_fn_without_task_loader_raises_value_
258258

259259
with self.assertRaises(ValueError):
260260
_ = alg(acquisition_fn, task)
261+
262+
def assert_acquisition_fn_without_min_or_max_raises_error(
263+
self,
264+
):
265+
class DummyAcquisitionFn(AcquisitionFunction):
266+
"""Dummy acquisition function that doesn't set min or max"""
267+
def __call__(self, **kwargs):
268+
return np.zeros(1)
269+
270+
acquisition_fn = DummyAcquisitionFn(self.model)
271+
272+
X_s = self.ds_raw.air
273+
274+
with self.assertRaises(ValueError):
275+
alg = GreedyAlgorithm(
276+
model=self.model,
277+
X_t=X_s,
278+
X_s=X_s,
279+
N_new_context=2,
280+
)

0 commit comments

Comments
 (0)