Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sweeps): optuna supports multi-objective optimization #28

Merged
merged 19 commits into from
Aug 16, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 168 additions & 64 deletions jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import base64
import logging
import os
Expand All @@ -10,17 +11,16 @@
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple

import argparse
import click
import joblib
import optuna
import wandb
from wandb.apis.internal import Api
from wandb.apis.public import Api as PublicApi
from wandb.apis.public import QueuedRun, Run
from wandb.sdk.artifacts.public_artifact import Artifact
from wandb.sdk.launch.sweeps import SchedulerError
from wandb.sdk.launch.sweeps.scheduler import Scheduler, SweepRun, RunState

from wandb.sdk.launch.sweeps.scheduler import RunState, Scheduler, SweepRun

logger = logging.getLogger(__name__)
optuna.logging.set_verbosity(optuna.logging.WARNING)
Expand All @@ -44,6 +44,12 @@ class OptunaRun:
sweep_run: SweepRun


@dataclass
class Metric:
name: str
direction: optuna.study.StudyDirection


def setup_scheduler(scheduler: Scheduler, **kwargs):
"""Setup a run to log a scheduler job.

Expand Down Expand Up @@ -92,18 +98,20 @@ def _handle_job_logic(run, name, enable_git=False) -> None:
)
tag = os.environ.get("WANDB_DOCKER", "").split(":")
if len(tag) == 2:
jobstr += f"-{tag[0].replace('/', '_')}_{tag[-1]}:latest"
else:
jobstr += f"-{tag[0].replace('/', '_')}:{tag[-1]}"
elif len(tag) == 1:
jobstr += f"-{tag[0].replace('/', '_')}:latest"
else: # unknown format
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
jobstr = f"found here: https://wandb.ai/{jobstr}s/"
wandb.termlog(f"Creating image job {click.style(jobstr, fg='yellow')}\n")
elif not enable_git:
jobstr += f"-{name}:latest"
wandb.termlog(
f"Creating code-artifact job: {click.style(jobstr, fg='yellow')}\n"
f"Creating artifact job: {click.style(jobstr, fg='yellow')}\n"
)
else:
_s = click.style(f"https://wandb.ai/{jobstr}s/", fg="yellow")
wandb.termlog(f"Creating repo-artifact job found here: {_s}\n")
wandb.termlog(f"Creating repo job found here: {_s}\n")
run.log_code(name=name, exclude_fn=lambda x: x.startswith("_"))
return

Expand All @@ -130,6 +138,8 @@ def __init__(
if not self._optuna_config:
self._optuna_config = self._wandb_run.config.get("settings", {})

self._metric_defs = self._get_metric_names_and_directions()

# if metric is misconfigured, increment, stop sweep if 3 consecutive fails
self._num_misconfigured_runs = 0

Expand All @@ -146,12 +156,22 @@ def study_name(self) -> str:
optuna_study_name: str = self.study.study_name
return optuna_study_name

@property
def is_multi_objective(self) -> bool:
return len(self._metric_defs) > 1

@property
def study_string(self) -> str:
msg = f"{LOG_PREFIX}{'Loading' if self._wandb_run.resumed else 'Creating'}"
msg += f" optuna study: {self.study_name} "
msg += f"[storage:{self.study._storage.__class__.__name__}"
msg += f", direction:{self.study.direction.name.capitalize()}"
if not self.is_multi_objective:
msg += f", direction: {self._metric_defs[0].direction.name.capitalize()}"
else:
msg += ", directions: "
for metric in self._metric_defs:
msg += f"{metric.name}:{metric.direction.name.capitalize()}, "
msg = msg[:-2]
msg += f", pruner:{self.study.pruner.__class__.__name__}"
msg += f", sampler:{self.study.sampler.__class__.__name__}]"
return msg
Expand All @@ -169,20 +189,81 @@ def formatted_trials(self) -> str:
trial_strs = []
for trial in self.study.trials:
run_id = trial.user_attrs["run_id"]
vals = list(trial.intermediate_values.values())
best = None
if len(vals) > 0:
if self.study.direction == optuna.study.StudyDirection.MINIMIZE:
best = round(min(vals), 5)
elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE:
best = round(max(vals), 5)
trial_strs += [
f"\t[trial-{trial.number + 1}] run: {run_id}, state: "
f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}"
]
best: str = ""
if not self.is_multi_objective:
vals = list(trial.intermediate_values.values())
if len(vals) > 0:
if self.study.direction == optuna.study.StudyDirection.MINIMIZE:
best = f"{round(min(vals), 5)}"
elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE:
best = f"{round(max(vals), 5)}"
trial_strs += [
f"\t[trial-{trial.number + 1}] run: {run_id}, state: "
f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}"
]
else: # multi-objective optimization, only 1 metric logged in study
if not trial.values:
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
continue

if len(trial.values) != len(self._metric_defs):
wandb.termwarn(
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
f"{LOG_PREFIX}Number of trial metrics ({trial.values})"
" does not match number of metrics defined "
f"({self._metric_defs})"
)
continue

for val, metric in zip(trial.values, self._metric_defs):
direction = metric.direction.name.capitalize()
best += f"{metric.name} ({direction}):"
best += f"{round(val, 5)}, "

best = best[:-2]
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
trial_strs += [
f"\t[trial-{trial.number + 1}] run: {run_id}, state: "
f"{trial.state.name}, best: {best or 'None'}"
]

return "\n".join(trial_strs[-10:]) # only print out last 10

def _get_metric_names_and_directions(self) -> List[Metric]:
"""Helper to configure dict of at least one metric.

Dict contains the metric names as keys, with the optimization
direction (or goal) as the value (type: optuna.study.StudyDirection)
"""
# if single-objective, just top level metric is set
if self._sweep_config.get("metric"):
direction = (
optuna.study.StudyDirection.MINIMIZE
if self._sweep_config["metric"]["goal"] == "minimize"
else optuna.study.StudyDirection.MAXIMIZE
)
metric = Metric(
name=self._sweep_config["metric"]["name"], direction=direction
)
return [metric]

# multi-objective optimization
metric_defs = []
for metric in self._optuna_config.get("metrics", []):
if not metric.get("name"):
raise SchedulerError("Optuna metric missing name")
if not metric.get("goal"):
raise SchedulerError("Optuna metric missing goal")

direction = (
optuna.study.StudyDirection.MINIMIZE
if metric["goal"] == "minimize"
else optuna.study.StudyDirection.MAXIMIZE
)
metric_defs += [Metric(name=metric["name"], direction=direction)]

if len(metric_defs) == 0:
raise SchedulerError("Optuna sweep missing metric")
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved

return metric_defs

def _validate_optuna_study(self, study: optuna.Study) -> Optional[str]:
"""Accepts an optuna study, runs validation.

Expand Down Expand Up @@ -222,8 +303,7 @@ def _load_optuna_classes(
mod, err = _get_module("optuna", filepath)
if not mod:
raise SchedulerError(
f"Failed to load optuna from path {filepath} "
f" with error: {err}"
f"Failed to load optuna from path {filepath} " f" with error: {err}"
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
)

# Set custom optuna trial creation method
Expand Down Expand Up @@ -286,9 +366,7 @@ def _load_file_from_artifact(self, artifact_name: str) -> str:
# load user-set optuna class definition file
artifact = self._wandb_run.use_artifact(artifact_name, type="optuna")
if not artifact:
raise SchedulerError(
f"Failed to load artifact: {artifact_name}"
)
raise SchedulerError(f"Failed to load artifact: {artifact_name}")

path = artifact.download()
optuna_filepath = self._optuna_config.get(
Expand Down Expand Up @@ -369,16 +447,26 @@ def _load_optuna(self) -> None:
else:
wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler")

direction = self._sweep_config.get("metric", {}).get("goal")
self._storage_path = existing_storage or OptunaComponents.storage.value
self._study = optuna.create_study(
study_name=self.study_name,
storage=f"sqlite:///{self._storage_path}",
pruner=pruner,
sampler=sampler,
load_if_exists=True,
direction=direction,
)
directions = [metric.direction for metric in self._metric_defs]
if len(directions) == 1:
self._study = optuna.create_study(
study_name=self.study_name,
storage=f"sqlite:///{self._storage_path}",
pruner=pruner,
sampler=sampler,
load_if_exists=True,
direction=directions[0],
)
else: # multi-objective optimization
self._study = optuna.create_study(
study_name=self.study_name,
storage=f"sqlite:///{self._storage_path}",
pruner=pruner,
sampler=sampler,
load_if_exists=True,
directions=directions,
)
wandb.termlog(self.study_string)

if existing_storage:
Expand All @@ -399,17 +487,19 @@ def _load_state(self) -> None:
def _save_state(self) -> None:
"""Called when Scheduler class invokes exit().

Save optuna study sqlite data to an artifact in the controller run
Save optuna study, or sqlite data to an artifact in the scheduler run
"""
artifact = wandb.Artifact(
f"{OptunaComponents.storage.name}-{self._sweep_id}", type="optuna"
)
if not self._study: # nothing to save
return None

artifact_name = f"{OptunaComponents.storage.name}-{self._sweep_id}"
if not self._storage_path:
wandb.termwarn(
f"{LOG_PREFIX}No db storage path found, saving to default path"
)
self._storage_path = OptunaComponents.storage.value
wandb.termwarn(f"{LOG_PREFIX}No db storage path found, saving full model")
artifact_name = f"optuna-study-{self._sweep_id}"
joblib.dump(self.study, f"study-{self._sweep_id}.pkl")
self._storage_path = f"study-{self._sweep_id}.pkl"

artifact = wandb.Artifact(artifact_name, type="optuna")
artifact.add_file(self._storage_path)
self._wandb_run.log_artifact(artifact)

Expand All @@ -420,6 +510,7 @@ def _save_state(self) -> None:
return

def _get_next_sweep_run(self, worker_id: int) -> Optional[SweepRun]:
"""Called repeatedly in the polling loop, whenever a worker is available."""
config, trial = self._trial_func()
run: dict = self._api.upsert_run(
project=self._project,
Expand Down Expand Up @@ -465,33 +556,40 @@ def _get_run_history(self, run_id: str) -> List[int]:
logger.debug(f"Failed to poll run from public api: {str(e)}")
return []

metric_name = self._sweep_config["metric"]["name"]
history = api_run.scan_history(keys=["_step", metric_name])
metrics = [x[metric_name] for x in history]
names = [metric.name for metric in self._metric_defs]
history = api_run.scan_history(keys=names + ["_step"])
metrics = []
for log in history:
metrics += [tuple(log.get(key) for key in names)]

if len(metrics) == 0 and api_run.lastHistoryStep > -1:
logger.debug("No metrics, but lastHistoryStep exists")
wandb.termwarn(
f"{LOG_PREFIX}Detected logged metrics, but none matching " +
f"provided metric name: '{metric_name}'"
f"{LOG_PREFIX}Detected logged metrics, but none matching "
+ f"provided metric name(s): '{names}'"
)

return metrics

def _poll_run(self, orun: OptunaRun) -> bool:
"""Polls metrics for a run, returns true if finished."""
metrics = self._get_run_history(orun.sweep_run.id)
for i, metric in enumerate(metrics[orun.num_metrics :]):
logger.debug(f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}")
prev = orun.trial._cached_frozen_trial.intermediate_values
if orun.num_metrics + i not in prev:
orun.trial.report(metric, orun.num_metrics + i)

if orun.trial.should_prune():
wandb.termlog(f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}")
self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED)
self._stop_run(orun.sweep_run.id)
return True
if not self.is_multi_objective: # can't report to trial when multi
for i, metric_val in enumerate(metrics[orun.num_metrics :]):
logger.debug(
f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}"
)
prev = orun.trial._cached_frozen_trial.intermediate_values
if orun.num_metrics + i not in prev:
orun.trial.report(metric_val, orun.num_metrics + i)

if orun.trial.should_prune():
wandb.termlog(
f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}"
)
self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED)
self._stop_run(orun.sweep_run.id)
return True

orun.num_metrics = len(metrics)

Expand All @@ -504,37 +602,43 @@ def _poll_run(self, orun: OptunaRun) -> bool:
if (
self._runs[orun.sweep_run.id].state == RunState.FINISHED
and len(prev_metrics) == 0
and not self.is_multi_objective
):
# run finished correctly, but never logged a metric
wandb.termwarn(
f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " +
f"'{self._sweep_config['metric']['name']}'. Check your sweep " +
"config and training script."
f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: "
+ f"'{self._metric_defs[0].name}'. Check your sweep "
+ "config and training script."
)
self._num_misconfigured_runs += 1
self.study.tell(orun.trial, state=optuna.trial.TrialState.FAIL)

if self._num_misconfigured_runs >= self.MAX_MISCONFIGURED_RUNS:
raise SchedulerError(
f"Too many misconfigured runs ({self._num_misconfigured_runs}), stopping sweep early"
f"Too many misconfigured runs ({self._num_misconfigured_runs}),"
" stopping sweep early"
)

# Delete run in Scheduler memory, freeing up worker
del self._runs[orun.sweep_run.id]

return True

last_value = prev_metrics[orun.num_metrics - 1]
self._num_misconfigured_runs = 0 # only count consecutive
if self.is_multi_objective:
last_value = tuple(metrics[-1])
else:
last_value = prev_metrics[orun.num_metrics - 1]

self._num_misconfigured_runs = 0 # only count consecutive
self.study.tell(
trial=orun.trial,
state=optuna.trial.TrialState.COMPLETE,
values=last_value,
)
wandb.termlog(
f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) "
f"[last metric: {last_value}, total: {orun.num_metrics}]"
f"[last metric{'s' if self.is_multi_objective else ''}: {last_value}"
f", total: {orun.num_metrics}]"
)

# Delete run in Scheduler memory, freeing up worker
Expand Down