diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 51fd845..1747fe8 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -43,6 +43,12 @@ class OptunaRun: sweep_run: SweepRun +@dataclass +class Metric: + name: str + direction: optuna.study.StudyDirection + + def setup_scheduler(scheduler: Scheduler, **kwargs): """Setup a run to log a scheduler job. @@ -129,15 +135,7 @@ def __init__( if not self._optuna_config: self._optuna_config = self._wandb_run.config.get("settings", {}) - if self._sweep_config.get("metric"): - self.metric_names = [self._sweep_config["metric"]["name"]] - else: # multi-objective optimization - for metric in self._optuna_config.get("metrics", []): - if not metric.get("name"): - raise SchedulerError("Optuna metric missing name") - if not metric.get("goal"): - raise SchedulerError("Optuna metric missing goal") - self.metric_names += [metric["name"]] + self._metric_defs = self._get_metric_names_and_directions() # if metric is misconfigured, increment, stop sweep if 3 consecutive fails self._num_misconfigured_runs = 0 @@ -157,7 +155,7 @@ def study_name(self) -> str: @property def is_multi_objective(self) -> bool: - return len(self.metric_names) > 1 + return len(self._metric_defs) > 1 @property def study_string(self) -> str: @@ -165,11 +163,11 @@ def study_string(self) -> str: msg += f" optuna study: {self.study_name} " msg += f"[storage:{self.study._storage.__class__.__name__}" if not self.is_multi_objective: - msg += f", direction: {self.study.direction.name.capitalize()}" + msg += f", direction: {self._metric_defs[0].direction.name.capitalize()}" else: msg += ", directions: " - for i, metric in enumerate(self.metric_names): - msg += f"{metric}:{self.study.directions[i].name.capitalize()}, " + for metric in self._metric_defs: + msg += f"{metric.name}:{metric.direction.name.capitalize()}, " msg = msg[:-2] msg += f", pruner:{self.study.pruner.__class__.__name__}" msg += f", sampler:{self.study.sampler.__class__.__name__}]" @@ -201,13 +199,23 @@ def formatted_trials(self) -> str: f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" ] else: # multi-objective optimization, only 1 metric logged in study - vals = trial.values - if vals: - for idx in range(len(self.metric_names)): - direction = self.study.directions[idx].name.capitalize() - best += f"{self.metric_names[idx]} ({direction}):" - best += f"{round(vals[idx], 5)}, " - best = best[:-2] + if not trial.values: + continue + + if len(trial.values) != len(self._metric_defs): + wandb.termwarn( + f"{LOG_PREFIX}Number of trial metrics ({trial.values})" + " does not match number of metrics defined " + f"({self._metric_defs})" + ) + continue + + for val, metric in zip(trial.values, self._metric_defs): + direction = metric.direction.name.capitalize() + best += f"{metric.name} ({direction}):" + best += f"{round(val, 5)}, " + + best = best[:-2] trial_strs += [ f"\t[trial-{trial.number + 1}] run: {run_id}, state: " f"{trial.state.name}, best: {best or 'None'}" @@ -215,6 +223,44 @@ def formatted_trials(self) -> str: return "\n".join(trial_strs[-10:]) # only print out last 10 + def _get_metric_names_and_directions(self) -> List[Metric]: + """Helper to configure dict of at least one metric. + + Dict contains the metric names as keys, with the optimization + direction (or goal) as the value + """ + # if single-objective, just top level metric is set + if self._sweep_config.get("metric"): + direction = ( + optuna.study.StudyDirection.MINIMIZE + if self._sweep_config["metric"]["goal"] == "minimize" + else optuna.study.StudyDirection.MAXIMIZE + ) + metric = Metric( + name=self._sweep_config["metric"]["name"], direction=direction + ) + return [metric] + + # multi-objective optimization + metric_defs = [] + for metric in self._optuna_config.get("metrics", []): + if not metric.get("name"): + raise SchedulerError("Optuna metric missing name") + if not metric.get("goal"): + raise SchedulerError("Optuna metric missing goal") + + direction = ( + optuna.study.StudyDirection.MINIMIZE + if metric["goal"] == "minimize" + else optuna.study.StudyDirection.MAXIMIZE + ) + metric_defs += [Metric(name=metric["name"], direction=direction)] + + if len(metric_defs) == 0: + raise SchedulerError("Optuna sweep missing metric") + + return metric_defs + def _validate_optuna_study(self, study: optuna.Study) -> Optional[str]: """Accepts an optuna study, runs validation. @@ -398,13 +444,9 @@ def _load_optuna(self) -> None: else: wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler") - if len(self.metric_names) == 1: - directions = [self._sweep_config.get("metric", {}).get("goal")] - else: - directions = [x["goal"] for x in self._optuna_config["metrics"]] - self._storage_path = existing_storage or OptunaComponents.storage.value - if len(self.metric_names) == 1: + directions = [metric.direction for metric in self._metric_defs] + if len(directions) == 1: self._study = optuna.create_study( study_name=self.study_name, storage=f"sqlite:///{self._storage_path}", @@ -508,16 +550,17 @@ def _get_run_history(self, run_id: str) -> List[int]: logger.debug(f"Failed to poll run from public api: {str(e)}") return [] - history = api_run.scan_history(keys=self.metric_names + ["_step"]) + names = [metric.name for metric in self._metric_defs] + history = api_run.scan_history(keys=names + ["_step"]) metrics = [] for log in history: - metrics += [tuple(log.get(key) for key in self.metric_names)] + metrics += [tuple(log.get(key) for key in names)] if len(metrics) == 0 and api_run.lastHistoryStep > -1: logger.debug("No metrics, but lastHistoryStep exists") wandb.termwarn( f"{LOG_PREFIX}Detected logged metrics, but none matching " - + f"provided metric name(s): '{self.metric_names}'" + + f"provided metric name(s): '{names}'" ) return metrics @@ -526,13 +569,13 @@ def _poll_run(self, orun: OptunaRun) -> bool: """Polls metrics for a run, returns true if finished.""" metrics = self._get_run_history(orun.sweep_run.id) if not self.is_multi_objective: # can't report to trial when multi - for i, metric in enumerate(metrics[orun.num_metrics :]): + for i, metric_val in enumerate(metrics[orun.num_metrics :]): logger.debug( f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}" ) prev = orun.trial._cached_frozen_trial.intermediate_values if orun.num_metrics + i not in prev: - orun.trial.report(metric, orun.num_metrics + i) + orun.trial.report(metric_val, orun.num_metrics + i) if orun.trial.should_prune(): wandb.termlog( @@ -558,7 +601,7 @@ def _poll_run(self, orun: OptunaRun) -> bool: # run finished correctly, but never logged a metric wandb.termwarn( f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " - + f"'{self._sweep_config['metric']['name']}'. Check your sweep " + + f"'{self._metric_defs[0].name}'. Check your sweep " + "config and training script." ) self._num_misconfigured_runs += 1 @@ -588,7 +631,7 @@ def _poll_run(self, orun: OptunaRun) -> bool: ) wandb.termlog( f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) " - f"[last metric{'s' if len(self.metric_names) > 1 else ''}: {last_value}" + f"[last metric{'s' if self.is_multi_objective else ''}: {last_value}" f", total: {orun.num_metrics}]" )