Skip to content

Commit

Permalink
feat(sweeps): optuna scheduler supports more distribution commands
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Aug 16, 2023
1 parent b8cdabb commit feb117d
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,26 +650,40 @@ def _make_trial(self) -> Tuple[Dict[str, Any], optuna.Trial]:
config[param]["value"] = trial.suggest_categorical(
param, [extras["value"]]
)
elif extras.get("distribution"):
raise SchedulerError(
"Distributions are deprecated. Please provide 'step' or 'log' with 'min' and 'max'."
)
elif type(extras.get("min")) == float:

Check failure on line 657 in jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E721)

jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py:657:18: E721 Do not compare types, use `isinstance()`

Check failure on line 657 in jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E721)

jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py:657:18: E721 Do not compare types, use `isinstance()`
if not extras.get("max"):
raise SchedulerError(
"Error converting config. 'min' requires 'max'"
)
log = "log" in param
log = extras.get("log", False)
step = extras.get("step")
config[param]["value"] = trial.suggest_float(
param, extras["min"], extras["max"], log=log
param, extras["min"], extras["max"], log=log, step=step
)
elif type(extras.get("min")) == int:

Check failure on line 667 in jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E721)

jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py:667:18: E721 Do not compare types, use `isinstance()`

Check failure on line 667 in jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E721)

jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py:667:18: E721 Do not compare types, use `isinstance()`
if not extras.get("max"):
raise SchedulerError(
"Error converting config. 'min' requires 'max'"
)
log = "log" in param
config[param]["value"] = trial.suggest_int(
param, extras["min"], extras["max"], log=log
)
log = extras.get("log", False)
step = extras.get("step")
if step:
config[param]["value"] = trial.suggest_int(
param, extras["min"], extras["max"], log=log, step=step
)
else:
config[param]["value"] = trial.suggest_int(
param, extras["min"], extras["max"], log=log
)
else:
logger.debug(f"Unknown parameter type: param={param}, val={extras}")
raise SchedulerError(
f"Error converting config. Unknown parameter type: param={param}, val={extras}"
)
return config, trial

def _make_trial_from_objective(self) -> Tuple[Dict[str, Any], optuna.Trial]:
Expand Down

0 comments on commit feb117d

Please sign in to comment.