Skip to content

Commit

Permalink
feat(sweeps): optuna supports multi-objective optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Jun 26, 2023
1 parent 7104414 commit e7449ca
Showing 1 changed file with 107 additions and 52 deletions.
159 changes: 107 additions & 52 deletions jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import base64
import logging
import os
Expand All @@ -10,7 +11,6 @@
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple

import argparse
import click
import optuna
import wandb
Expand All @@ -19,8 +19,7 @@
from wandb.apis.public import QueuedRun, Run
from wandb.sdk.artifacts.public_artifact import Artifact
from wandb.sdk.launch.sweeps import SchedulerError
from wandb.sdk.launch.sweeps.scheduler import Scheduler, SweepRun, RunState

from wandb.sdk.launch.sweeps.scheduler import RunState, Scheduler, SweepRun

logger = logging.getLogger(__name__)
optuna.logging.set_verbosity(optuna.logging.WARNING)
Expand Down Expand Up @@ -130,6 +129,16 @@ def __init__(
if not self._optuna_config:
self._optuna_config = self._wandb_run.config.get("settings", {})

if self._sweep_config.get("metric"):
self.metric_names = [self._sweep_config["metric"]["name"]]
else: # multi-objective optimization
for metric in self._optuna_config.get("metrics", []):
if not metric.get("name"):
raise SchedulerError("Optuna metric missing name")
if not metric.get("goal"):
raise SchedulerError("Optuna metric missing goal")
self.metric_names += [metric["name"]]

# if metric is misconfigured, increment, stop sweep if 3 consecutive fails
self._num_misconfigured_runs = 0

Expand All @@ -146,12 +155,22 @@ def study_name(self) -> str:
optuna_study_name: str = self.study.study_name
return optuna_study_name

@property
def is_multi_objective(self) -> bool:
return len(self.metric_names) > 1

@property
def study_string(self) -> str:
msg = f"{LOG_PREFIX}{'Loading' if self._wandb_run.resumed else 'Creating'}"
msg += f" optuna study: {self.study_name} "
msg += f"[storage:{self.study._storage.__class__.__name__}"
msg += f", direction:{self.study.direction.name.capitalize()}"
if not self.is_multi_objective:
msg += f", direction: {self.study.direction.name.capitalize()}"
else:
msg += ", directions: "
for i, metric in enumerate(self.metric_names):
msg += f"{metric}:{self.study.directions[i].name.capitalize()}, "
msg = msg[:-2]
msg += f", pruner:{self.study.pruner.__class__.__name__}"
msg += f", sampler:{self.study.sampler.__class__.__name__}]"
return msg
Expand All @@ -169,17 +188,30 @@ def formatted_trials(self) -> str:
trial_strs = []
for trial in self.study.trials:
run_id = trial.user_attrs["run_id"]
vals = list(trial.intermediate_values.values())
best = None
if len(vals) > 0:
if self.study.direction == optuna.study.StudyDirection.MINIMIZE:
best = round(min(vals), 5)
elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE:
best = round(max(vals), 5)
trial_strs += [
f"\t[trial-{trial.number + 1}] run: {run_id}, state: "
f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}"
]
best: str = ""
if not self.is_multi_objective:
vals = list(trial.intermediate_values.values())
if len(vals) > 0:
if self.study.direction == optuna.study.StudyDirection.MINIMIZE:
best = f"{round(min(vals), 5)}"
elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE:
best = f"{round(max(vals), 5)}"
trial_strs += [
f"\t[trial-{trial.number + 1}] run: {run_id}, state: "
f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}"
]
else: # multi-objective optimization, only 1 metric logged in study
vals = trial.values
if vals:
for idx in range(len(self.metric_names)):
direction = self.study.directions[idx].name.capitalize()
best += f"{self.metric_names[idx]} ({direction}):"
best += f"{round(vals[idx], 5)}, "
best = best[:-2]
trial_strs += [
f"\t[trial-{trial.number + 1}] run: {run_id}, state: "
f"{trial.state.name}, best: {best or 'None'}"
]

return "\n".join(trial_strs[-10:]) # only print out last 10

Expand Down Expand Up @@ -222,8 +254,7 @@ def _load_optuna_classes(
mod, err = _get_module("optuna", filepath)
if not mod:
raise SchedulerError(
f"Failed to load optuna from path {filepath} "
f" with error: {err}"
f"Failed to load optuna from path {filepath} " f" with error: {err}"
)

# Set custom optuna trial creation method
Expand Down Expand Up @@ -286,9 +317,7 @@ def _load_file_from_artifact(self, artifact_name: str) -> str:
# load user-set optuna class definition file
artifact = self._wandb_run.use_artifact(artifact_name, type="optuna")
if not artifact:
raise SchedulerError(
f"Failed to load artifact: {artifact_name}"
)
raise SchedulerError(f"Failed to load artifact: {artifact_name}")

path = artifact.download()
optuna_filepath = self._optuna_config.get(
Expand Down Expand Up @@ -369,16 +398,30 @@ def _load_optuna(self) -> None:
else:
wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler")

direction = self._sweep_config.get("metric", {}).get("goal")
if len(self.metric_names) == 1:
directions = [self._sweep_config.get("metric", {}).get("goal")]
else:
directions = [x["goal"] for x in self._optuna_config["metrics"]]

self._storage_path = existing_storage or OptunaComponents.storage.value
self._study = optuna.create_study(
study_name=self.study_name,
storage=f"sqlite:///{self._storage_path}",
pruner=pruner,
sampler=sampler,
load_if_exists=True,
direction=direction,
)
if len(self.metric_names) == 1:
self._study = optuna.create_study(
study_name=self.study_name,
storage=f"sqlite:///{self._storage_path}",
pruner=pruner,
sampler=sampler,
load_if_exists=True,
direction=directions[0],
)
else: # multi-objective optimization
self._study = optuna.create_study(
study_name=self.study_name,
storage=f"sqlite:///{self._storage_path}",
pruner=pruner,
sampler=sampler,
load_if_exists=True,
directions=directions,
)
wandb.termlog(self.study_string)

if existing_storage:
Expand Down Expand Up @@ -465,33 +508,39 @@ def _get_run_history(self, run_id: str) -> List[int]:
logger.debug(f"Failed to poll run from public api: {str(e)}")
return []

metric_name = self._sweep_config["metric"]["name"]
history = api_run.scan_history(keys=["_step", metric_name])
metrics = [x[metric_name] for x in history]
history = api_run.scan_history(keys=self.metric_names + ["_step"])
metrics = []
for log in history:
metrics += [tuple(log.get(key) for key in self.metric_names)]

if len(metrics) == 0 and api_run.lastHistoryStep > -1:
logger.debug("No metrics, but lastHistoryStep exists")
wandb.termwarn(
f"{LOG_PREFIX}Detected logged metrics, but none matching " +
f"provided metric name: '{metric_name}'"
f"{LOG_PREFIX}Detected logged metrics, but none matching "
+ f"provided metric name(s): '{self.metric_names}'"
)

return metrics

def _poll_run(self, orun: OptunaRun) -> bool:
"""Polls metrics for a run, returns true if finished."""
metrics = self._get_run_history(orun.sweep_run.id)
for i, metric in enumerate(metrics[orun.num_metrics :]):
logger.debug(f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}")
prev = orun.trial._cached_frozen_trial.intermediate_values
if orun.num_metrics + i not in prev:
orun.trial.report(metric, orun.num_metrics + i)

if orun.trial.should_prune():
wandb.termlog(f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}")
self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED)
self._stop_run(orun.sweep_run.id)
return True
if not self.is_multi_objective: # can't report to trial when multi
for i, metric in enumerate(metrics[orun.num_metrics :]):
logger.debug(
f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}"
)
prev = orun.trial._cached_frozen_trial.intermediate_values
if orun.num_metrics + i not in prev:
orun.trial.report(metric, orun.num_metrics + i)

if orun.trial.should_prune():
wandb.termlog(
f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}"
)
self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED)
self._stop_run(orun.sweep_run.id)
return True

orun.num_metrics = len(metrics)

Expand All @@ -504,37 +553,43 @@ def _poll_run(self, orun: OptunaRun) -> bool:
if (
self._runs[orun.sweep_run.id].state == RunState.FINISHED
and len(prev_metrics) == 0
and not self.is_multi_objective
):
# run finished correctly, but never logged a metric
wandb.termwarn(
f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " +
f"'{self._sweep_config['metric']['name']}'. Check your sweep " +
"config and training script."
f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: "
+ f"'{self._sweep_config['metric']['name']}'. Check your sweep "
+ "config and training script."
)
self._num_misconfigured_runs += 1
self.study.tell(orun.trial, state=optuna.trial.TrialState.FAIL)

if self._num_misconfigured_runs >= self.MAX_MISCONFIGURED_RUNS:
raise SchedulerError(
f"Too many misconfigured runs ({self._num_misconfigured_runs}), stopping sweep early"
f"Too many misconfigured runs ({self._num_misconfigured_runs}),"
" stopping sweep early"
)

# Delete run in Scheduler memory, freeing up worker
del self._runs[orun.sweep_run.id]

return True

last_value = prev_metrics[orun.num_metrics - 1]
self._num_misconfigured_runs = 0 # only count consecutive
if self.is_multi_objective:
last_value = tuple(metrics[-1])
else:
last_value = prev_metrics[orun.num_metrics - 1]

self._num_misconfigured_runs = 0 # only count consecutive
self.study.tell(
trial=orun.trial,
state=optuna.trial.TrialState.COMPLETE,
values=last_value,
)
wandb.termlog(
f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) "
f"[last metric: {last_value}, total: {orun.num_metrics}]"
f"[last metric{'s' if len(self.metric_names) > 1 else ''}: {last_value}"
f", total: {orun.num_metrics}]"
)

# Delete run in Scheduler memory, freeing up worker
Expand Down

0 comments on commit e7449ca

Please sign in to comment.