Skip to content

Commit

Permalink
add metric class
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Jun 26, 2023
1 parent e7449ca commit a8ff3e2
Showing 1 changed file with 76 additions and 33 deletions.
109 changes: 76 additions & 33 deletions jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class OptunaRun:
sweep_run: SweepRun


@dataclass
class Metric:
name: str
direction: optuna.study.StudyDirection


def setup_scheduler(scheduler: Scheduler, **kwargs):
"""Setup a run to log a scheduler job.
Expand Down Expand Up @@ -129,15 +135,7 @@ def __init__(
if not self._optuna_config:
self._optuna_config = self._wandb_run.config.get("settings", {})

if self._sweep_config.get("metric"):
self.metric_names = [self._sweep_config["metric"]["name"]]
else: # multi-objective optimization
for metric in self._optuna_config.get("metrics", []):
if not metric.get("name"):
raise SchedulerError("Optuna metric missing name")
if not metric.get("goal"):
raise SchedulerError("Optuna metric missing goal")
self.metric_names += [metric["name"]]
self._metric_defs = self._get_metric_names_and_directions()

# if metric is misconfigured, increment, stop sweep if 3 consecutive fails
self._num_misconfigured_runs = 0
Expand All @@ -157,19 +155,19 @@ def study_name(self) -> str:

@property
def is_multi_objective(self) -> bool:
return len(self.metric_names) > 1
return len(self._metric_defs) > 1

@property
def study_string(self) -> str:
msg = f"{LOG_PREFIX}{'Loading' if self._wandb_run.resumed else 'Creating'}"
msg += f" optuna study: {self.study_name} "
msg += f"[storage:{self.study._storage.__class__.__name__}"
if not self.is_multi_objective:
msg += f", direction: {self.study.direction.name.capitalize()}"
msg += f", direction: {self._metric_defs[0].direction.name.capitalize()}"
else:
msg += ", directions: "
for i, metric in enumerate(self.metric_names):
msg += f"{metric}:{self.study.directions[i].name.capitalize()}, "
for metric in self._metric_defs:
msg += f"{metric.name}:{metric.direction.name.capitalize()}, "
msg = msg[:-2]
msg += f", pruner:{self.study.pruner.__class__.__name__}"
msg += f", sampler:{self.study.sampler.__class__.__name__}]"
Expand Down Expand Up @@ -201,20 +199,68 @@ def formatted_trials(self) -> str:
f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}"
]
else: # multi-objective optimization, only 1 metric logged in study
vals = trial.values
if vals:
for idx in range(len(self.metric_names)):
direction = self.study.directions[idx].name.capitalize()
best += f"{self.metric_names[idx]} ({direction}):"
best += f"{round(vals[idx], 5)}, "
best = best[:-2]
if not trial.values:
continue

if len(trial.values) != len(self._metric_defs):
wandb.termwarn(
f"{LOG_PREFIX}Number of trial metrics ({trial.values})"
" does not match number of metrics defined "
f"({self._metric_defs})"
)
continue

for val, metric in zip(trial.values, self._metric_defs):
direction = metric.direction.name.capitalize()
best += f"{metric.name} ({direction}):"
best += f"{round(val, 5)}, "

best = best[:-2]
trial_strs += [
f"\t[trial-{trial.number + 1}] run: {run_id}, state: "
f"{trial.state.name}, best: {best or 'None'}"
]

return "\n".join(trial_strs[-10:]) # only print out last 10

def _get_metric_names_and_directions(self) -> List[Metric]:
"""Helper to configure dict of at least one metric.
Dict contains the metric names as keys, with the optimization
direction (or goal) as the value
"""
# if single-objective, just top level metric is set
if self._sweep_config.get("metric"):
direction = (
optuna.study.StudyDirection.MINIMIZE
if self._sweep_config["metric"]["goal"] == "minimize"
else optuna.study.StudyDirection.MAXIMIZE
)
metric = Metric(
name=self._sweep_config["metric"]["name"], direction=direction
)
return [metric]

# multi-objective optimization
metric_defs = []
for metric in self._optuna_config.get("metrics", []):
if not metric.get("name"):
raise SchedulerError("Optuna metric missing name")
if not metric.get("goal"):
raise SchedulerError("Optuna metric missing goal")

direction = (
optuna.study.StudyDirection.MINIMIZE
if metric["goal"] == "minimize"
else optuna.study.StudyDirection.MAXIMIZE
)
metric_defs += [Metric(name=metric["name"], direction=direction)]

if len(metric_defs) == 0:
raise SchedulerError("Optuna sweep missing metric")

return metric_defs

def _validate_optuna_study(self, study: optuna.Study) -> Optional[str]:
"""Accepts an optuna study, runs validation.
Expand Down Expand Up @@ -398,13 +444,9 @@ def _load_optuna(self) -> None:
else:
wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler")

if len(self.metric_names) == 1:
directions = [self._sweep_config.get("metric", {}).get("goal")]
else:
directions = [x["goal"] for x in self._optuna_config["metrics"]]

self._storage_path = existing_storage or OptunaComponents.storage.value
if len(self.metric_names) == 1:
directions = [metric.direction for metric in self._metric_defs]
if len(directions) == 1:
self._study = optuna.create_study(
study_name=self.study_name,
storage=f"sqlite:///{self._storage_path}",
Expand Down Expand Up @@ -508,16 +550,17 @@ def _get_run_history(self, run_id: str) -> List[int]:
logger.debug(f"Failed to poll run from public api: {str(e)}")
return []

history = api_run.scan_history(keys=self.metric_names + ["_step"])
names = [metric.name for metric in self._metric_defs]
history = api_run.scan_history(keys=names + ["_step"])
metrics = []
for log in history:
metrics += [tuple(log.get(key) for key in self.metric_names)]
metrics += [tuple(log.get(key) for key in names)]

if len(metrics) == 0 and api_run.lastHistoryStep > -1:
logger.debug("No metrics, but lastHistoryStep exists")
wandb.termwarn(
f"{LOG_PREFIX}Detected logged metrics, but none matching "
+ f"provided metric name(s): '{self.metric_names}'"
+ f"provided metric name(s): '{names}'"
)

return metrics
Expand All @@ -526,13 +569,13 @@ def _poll_run(self, orun: OptunaRun) -> bool:
"""Polls metrics for a run, returns true if finished."""
metrics = self._get_run_history(orun.sweep_run.id)
if not self.is_multi_objective: # can't report to trial when multi
for i, metric in enumerate(metrics[orun.num_metrics :]):
for i, metric_val in enumerate(metrics[orun.num_metrics :]):
logger.debug(
f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}"
)
prev = orun.trial._cached_frozen_trial.intermediate_values
if orun.num_metrics + i not in prev:
orun.trial.report(metric, orun.num_metrics + i)
orun.trial.report(metric_val, orun.num_metrics + i)

if orun.trial.should_prune():
wandb.termlog(
Expand All @@ -558,7 +601,7 @@ def _poll_run(self, orun: OptunaRun) -> bool:
# run finished correctly, but never logged a metric
wandb.termwarn(
f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: "
+ f"'{self._sweep_config['metric']['name']}'. Check your sweep "
+ f"'{self._metric_defs[0].name}'. Check your sweep "
+ "config and training script."
)
self._num_misconfigured_runs += 1
Expand Down Expand Up @@ -588,7 +631,7 @@ def _poll_run(self, orun: OptunaRun) -> bool:
)
wandb.termlog(
f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) "
f"[last metric{'s' if len(self.metric_names) > 1 else ''}: {last_value}"
f"[last metric{'s' if self.is_multi_objective else ''}: {last_value}"
f", total: {orun.num_metrics}]"
)

Expand Down

0 comments on commit a8ff3e2

Please sign in to comment.