diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 514a41d..0821672 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -650,26 +650,40 @@ def _make_trial(self) -> Tuple[Dict[str, Any], optuna.Trial]: config[param]["value"] = trial.suggest_categorical( param, [extras["value"]] ) + elif extras.get("distribution"): + raise SchedulerError( + "Distributions are deprecated. Please provide 'step' or 'log' with 'min' and 'max'." + ) elif type(extras.get("min")) == float: if not extras.get("max"): raise SchedulerError( "Error converting config. 'min' requires 'max'" ) - log = "log" in param + log = extras.get("log", False) + step = extras.get("step") config[param]["value"] = trial.suggest_float( - param, extras["min"], extras["max"], log=log + param, extras["min"], extras["max"], log=log, step=step ) elif type(extras.get("min")) == int: if not extras.get("max"): raise SchedulerError( "Error converting config. 'min' requires 'max'" ) - log = "log" in param - config[param]["value"] = trial.suggest_int( - param, extras["min"], extras["max"], log=log - ) + log = extras.get("log", False) + step = extras.get("step") + if step: + config[param]["value"] = trial.suggest_int( + param, extras["min"], extras["max"], log=log, step=step + ) + else: + config[param]["value"] = trial.suggest_int( + param, extras["min"], extras["max"], log=log + ) else: logger.debug(f"Unknown parameter type: param={param}, val={extras}") + raise SchedulerError( + f"Error converting config. Unknown parameter type: param={param}, val={extras}" + ) return config, trial def _make_trial_from_objective(self) -> Tuple[Dict[str, Any], optuna.Trial]: