diff --git a/boiling_learning/automl/tuners.py b/boiling_learning/automl/tuners.py index c6eeab8a..7240179e 100644 --- a/boiling_learning/automl/tuners.py +++ b/boiling_learning/automl/tuners.py @@ -41,7 +41,7 @@ def set_state(self, state: Dict[str, Any]) -> None: self.stop_search = state['stop_search'] -class EarlyStoppingHyperbandOracle(ak.tuners.hyperband.HyperbandOracle): +class EarlyStoppingHyperbandOracle(kt.oracles.HyperbandOracle): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.stop_search = False @@ -227,6 +227,7 @@ def __init__( goal: Any, objective: str = 'val_loss', max_epochs: int = 100, + max_trials: int = 1000, factor: int = 3, seed: Optional[int] = None, hyperparameters: Optional[kt.HyperParameters] = None, @@ -244,6 +245,7 @@ def __init__( tune_new_entries=tune_new_entries, allow_new_entries=allow_new_entries, ) + oracle.max_trials = max_trials super().__init__(oracle=oracle, **kwargs) def on_epoch_end(