From e53d23d54781bb76fc727c0aca921a9ff5f3051e Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Wed, 16 Aug 2023 09:12:27 -0700 Subject: [PATCH] feat(sweeps): optuna scheduler supports more distribution commands (#29) --- .../optuna_scheduler/optuna_scheduler.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 514a41d..27286a4 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -650,26 +650,40 @@ def _make_trial(self) -> Tuple[Dict[str, Any], optuna.Trial]: config[param]["value"] = trial.suggest_categorical( param, [extras["value"]] ) - elif type(extras.get("min")) == float: + elif extras.get("distribution"): + raise SchedulerError( + "Distributions are deprecated. Please provide 'step' or 'log' with 'min' and 'max'." + ) + elif isinstance(extras.get("min"), float): if not extras.get("max"): raise SchedulerError( "Error converting config. 'min' requires 'max'" ) - log = "log" in param + log = extras.get("log", False) + step = extras.get("step") config[param]["value"] = trial.suggest_float( - param, extras["min"], extras["max"], log=log + param, extras["min"], extras["max"], log=log, step=step ) - elif type(extras.get("min")) == int: + elif isinstance(extras.get("min"), int): if not extras.get("max"): raise SchedulerError( "Error converting config. 'min' requires 'max'" ) - log = "log" in param - config[param]["value"] = trial.suggest_int( - param, extras["min"], extras["max"], log=log - ) + log = extras.get("log", False) + step = extras.get("step") + if step: + config[param]["value"] = trial.suggest_int( + param, extras["min"], extras["max"], log=log, step=step + ) + else: + config[param]["value"] = trial.suggest_int( + param, extras["min"], extras["max"], log=log + ) else: logger.debug(f"Unknown parameter type: param={param}, val={extras}") + raise SchedulerError( + f"Error converting config. Unknown parameter type: param={param}, val={extras}" + ) return config, trial def _make_trial_from_objective(self) -> Tuple[Dict[str, Any], optuna.Trial]: