Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sweeps): optuna supports multi-objective optimization #28

Merged
merged 19 commits into from
Aug 16, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,26 @@ def setup_scheduler(scheduler: Scheduler, **kwargs):
parser.add_argument("--entity", type=str, default=kwargs.get("entity"))
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--name", type=str, default=f"job-{scheduler.__name__}")
parser.add_argument("--enable_git", action="store_true", default=False)
cli_args = parser.parse_args()

settings = {"job_name": cli_args.name}
if cli_args.enable_git:
settings.update({"disable_git": True})

run = wandb.init(
settings=settings,
project=cli_args.project,
entity=cli_args.entity,
)
config = run.config
args = config.get("sweep_args", {})

if not config.get("sweep_args", {}).get("sweep_id"):
# not a sweep, just finish the run and return
if not args or not args.get("sweep_id"):
# when the config has no sweep args, this is being run directly from CLI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the user is running the wandb launch -j <job> here? Is this specifically done to create the job?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now this would be triggered by running python optuna_scheduler.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And a user would do that to create the job. Sg

# and not in a sweep. Just log the code and return
if not os.getenv("WANDB_DOCKER"):
# if not docker, log the code to a git or code artifact
run.log_code(root=os.path.dirname(__file__))
run.finish()
return

args = config.get("sweep_args", {})
if cli_args.num_workers: # override
kwargs.update({"num_workers": cli_args.num_workers})

Expand Down Expand Up @@ -161,6 +161,9 @@ def formatted_trials(self) -> str:

trial_strs = []
for trial in self.study.trials:
if not trial.values:
continue

run_id = trial.user_attrs["run_id"]
best: str = ""
if not self.is_multi_objective:
Expand All @@ -175,14 +178,12 @@ def formatted_trials(self) -> str:
f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}"
]
else: # multi-objective optimization, only 1 metric logged in study
if not trial.values:
continue

if len(trial.values) != len(self._metric_defs):
wandb.termwarn(
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
f"{LOG_PREFIX}Number of trial metrics ({trial.values})"
f"{LOG_PREFIX}Number of logged metrics ({trial.values})"
" does not match number of metrics defined "
f"({self._metric_defs})"
f"({self._metric_defs}). Specify metrics for optimization"
" in the scheduler.settings.metrics portion of the sweep config"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would a user have done to create this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure its possible, i included this error message in case. If we had a testing suite.... I'll play around with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg

)
continue

Expand All @@ -191,10 +192,11 @@ def formatted_trials(self) -> str:
best += f"{metric.name} ({direction}):"
best += f"{round(val, 5)}, "

# trim trailing comma and space
best = best[:-2]
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
trial_strs += [
f"\t[trial-{trial.number + 1}] run: {run_id}, state: "
f"{trial.state.name}, best: {best or 'None'}"
f"{trial.state.name}, best: {best}"
]

return "\n".join(trial_strs[-10:]) # only print out last 10
Expand Down Expand Up @@ -233,7 +235,10 @@ def _get_metric_names_and_directions(self) -> List[Metric]:
metric_defs += [Metric(name=metric["name"], direction=direction)]

if len(metric_defs) == 0:
raise SchedulerError("Optuna sweep missing metric")
raise SchedulerError(
"Zero metrics found in the top level 'metric' section "
"and multi-objective metric section scheduler.settings.metrics"
)

return metric_defs

Expand Down Expand Up @@ -276,7 +281,7 @@ def _load_optuna_classes(
mod, err = _get_module("optuna", filepath)
if not mod:
raise SchedulerError(
f"Failed to load optuna from path {filepath} " f" with error: {err}"
f"Failed to load optuna from path {filepath} with error: {err}"
)

# Set custom optuna trial creation method
Expand Down