File tree Expand file tree Collapse file tree 2 files changed +25
-1
lines changed
deepsensor/active_learning Expand file tree Collapse file tree 2 files changed +25
-1
lines changed Original file line number Diff line number Diff line change @@ -414,6 +414,10 @@ def __call__(
414414 )
415415
416416 self .min_or_max = acquisition_fn .min_or_max
417+ if self .min_or_max not in ["min" , "max" ]:
418+ raise ValueError (
419+ f"min_or_max must be either 'min' or 'max', got { self .min_or_max } ."
420+ )
417421
418422 if isinstance (tasks , Task ):
419423 tasks = [tasks ]
Original file line number Diff line number Diff line change 1515 OracleMAE ,
1616 OracleRMSE ,
1717 OracleMarginalNLL ,
18- OracleJointNLL ,
18+ OracleJointNLL , AcquisitionFunction ,
1919)
2020from deepsensor .active_learning .algorithms import GreedyAlgorithm
2121
@@ -258,3 +258,23 @@ def test_greedy_alg_with_oracle_acquisition_fn_without_task_loader_raises_value_
258258
259259 with self .assertRaises (ValueError ):
260260 _ = alg (acquisition_fn , task )
261+
262+ def assert_acquisition_fn_without_min_or_max_raises_error (
263+ self ,
264+ ):
265+ class DummyAcquisitionFn (AcquisitionFunction ):
266+ """Dummy acquisition function that doesn't set min or max"""
267+ def __call__ (self , ** kwargs ):
268+ return np .zeros (1 )
269+
270+ acquisition_fn = DummyAcquisitionFn (self .model )
271+
272+ X_s = self .ds_raw .air
273+
274+ with self .assertRaises (ValueError ):
275+ alg = GreedyAlgorithm (
276+ model = self .model ,
277+ X_t = X_s ,
278+ X_s = X_s ,
279+ N_new_context = 2 ,
280+ )
You can’t perform that action at this time.
0 commit comments