diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 1932ed9..514a41d 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -1,3 +1,4 @@ +import argparse import base64 import logging import os @@ -10,17 +11,15 @@ from types import ModuleType from typing import Any, Dict, List, Optional, Tuple -import argparse import click import optuna import wandb from wandb.apis.internal import Api from wandb.apis.public import Api as PublicApi from wandb.apis.public import QueuedRun, Run -from wandb.sdk.artifacts.public_artifact import Artifact +from wandb.sdk.artifacts.artifact import Artifact from wandb.sdk.launch.sweeps import SchedulerError -from wandb.sdk.launch.sweeps.scheduler import Scheduler, SweepRun, RunState - +from wandb.sdk.launch.sweeps.scheduler import RunState, Scheduler, SweepRun logger = logging.getLogger(__name__) optuna.logging.set_verbosity(optuna.logging.WARNING) @@ -44,6 +43,12 @@ class OptunaRun: sweep_run: SweepRun +@dataclass +class Metric: + name: str + direction: optuna.study.StudyDirection + + def setup_scheduler(scheduler: Scheduler, **kwargs): """Setup a run to log a scheduler job. @@ -56,23 +61,27 @@ def setup_scheduler(scheduler: Scheduler, **kwargs): parser.add_argument("--project", type=str, default=kwargs.get("project")) parser.add_argument("--entity", type=str, default=kwargs.get("entity")) parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--name", type=str, default=None) - parser.add_argument("--enable_git", action="store_true", default=False) + parser.add_argument("--name", type=str, default=f"job-{scheduler.__name__}") cli_args = parser.parse_args() - name = cli_args.name or scheduler.__name__ + settings = {"job_name": cli_args.name} run = wandb.init( - settings={"disable_git": True} if not cli_args.enable_git else {}, + settings=settings, project=cli_args.project, entity=cli_args.entity, ) config = run.config + args = config.get("sweep_args", {}) - if not config.get("sweep_args", {}).get("sweep_id"): - _handle_job_logic(run, name, cli_args.enable_git) + if not args or not args.get("sweep_id"): + # when the config has no sweep args, this is being run directly from CLI + # and not in a sweep. Just log the code and return + if not os.getenv("WANDB_DOCKER"): + # if not docker, log the code to a git or code artifact + run.log_code(root=os.path.dirname(__file__)) + run.finish() return - args = config.get("sweep_args", {}) if cli_args.num_workers: # override kwargs.update({"num_workers": cli_args.num_workers}) @@ -80,34 +89,6 @@ def setup_scheduler(scheduler: Scheduler, **kwargs): _scheduler.start() -def _handle_job_logic(run, name, enable_git=False) -> None: - wandb.termlog( - "\nJob not configured to run a sweep, logging code and returning early." - ) - jobstr = f"{run.entity}/{run.project}/job" - - if os.environ.get("WANDB_DOCKER"): - wandb.termlog( - "Identified 'WANDB_DOCKER' environment var, creating image job..." - ) - tag = os.environ.get("WANDB_DOCKER", "").split(":") - if len(tag) == 2: - jobstr += f"-{tag[0].replace('/', '_')}_{tag[-1]}:latest" - else: - jobstr = f"found here: https://wandb.ai/{jobstr}s/" - wandb.termlog(f"Creating image job {click.style(jobstr, fg='yellow')}\n") - elif not enable_git: - jobstr += f"-{name}:latest" - wandb.termlog( - f"Creating code-artifact job: {click.style(jobstr, fg='yellow')}\n" - ) - else: - _s = click.style(f"https://wandb.ai/{jobstr}s/", fg="yellow") - wandb.termlog(f"Creating repo-artifact job found here: {_s}\n") - run.log_code(name=name, exclude_fn=lambda x: x.startswith("_")) - return - - class OptunaScheduler(Scheduler): OPT_TIMEOUT = 2 MAX_MISCONFIGURED_RUNS = 3 @@ -130,6 +111,8 @@ def __init__( if not self._optuna_config: self._optuna_config = self._wandb_run.config.get("settings", {}) + self._metric_defs = self._get_metric_names_and_directions() + # if metric is misconfigured, increment, stop sweep if 3 consecutive fails self._num_misconfigured_runs = 0 @@ -146,12 +129,22 @@ def study_name(self) -> str: optuna_study_name: str = self.study.study_name return optuna_study_name + @property + def is_multi_objective(self) -> bool: + return len(self._metric_defs) > 1 + @property def study_string(self) -> str: msg = f"{LOG_PREFIX}{'Loading' if self._wandb_run.resumed else 'Creating'}" msg += f" optuna study: {self.study_name} " msg += f"[storage:{self.study._storage.__class__.__name__}" - msg += f", direction:{self.study.direction.name.capitalize()}" + if not self.is_multi_objective: + msg += f", direction: {self._metric_defs[0].direction.name.capitalize()}" + else: + msg += ", directions: " + for metric in self._metric_defs: + msg += f"{metric.name}:{metric.direction.name.capitalize()}, " + msg = msg[:-2] msg += f", pruner:{self.study.pruner.__class__.__name__}" msg += f", sampler:{self.study.sampler.__class__.__name__}]" return msg @@ -168,21 +161,87 @@ def formatted_trials(self) -> str: trial_strs = [] for trial in self.study.trials: + if not trial.values: + continue + run_id = trial.user_attrs["run_id"] - vals = list(trial.intermediate_values.values()) - best = None - if len(vals) > 0: - if self.study.direction == optuna.study.StudyDirection.MINIMIZE: - best = round(min(vals), 5) - elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE: - best = round(max(vals), 5) - trial_strs += [ - f"\t[trial-{trial.number + 1}] run: {run_id}, state: " - f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" - ] + best: str = "" + if not self.is_multi_objective: + vals = list(trial.intermediate_values.values()) + if len(vals) > 0: + if self.study.direction == optuna.study.StudyDirection.MINIMIZE: + best = f"{round(min(vals), 5)}" + elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE: + best = f"{round(max(vals), 5)}" + trial_strs += [ + f"\t[trial-{trial.number + 1}] run: {run_id}, state: " + f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" + ] + else: # multi-objective optimization, only 1 metric logged in study + if len(trial.values) != len(self._metric_defs): + wandb.termwarn( + f"{LOG_PREFIX}Number of logged metrics ({trial.values})" + " does not match number of metrics defined " + f"({self._metric_defs}). Specify metrics for optimization" + " in the scheduler.settings.metrics portion of the sweep config" + ) + continue + + for val, metric in zip(trial.values, self._metric_defs): + direction = metric.direction.name.capitalize() + best += f"{metric.name} ({direction}):" + best += f"{round(val, 5)}, " + + # trim trailing comma and space + best = best[:-2] + trial_strs += [ + f"\t[trial-{trial.number + 1}] run: {run_id}, state: " + f"{trial.state.name}, best: {best}" + ] return "\n".join(trial_strs[-10:]) # only print out last 10 + def _get_metric_names_and_directions(self) -> List[Metric]: + """Helper to configure dict of at least one metric. + + Dict contains the metric names as keys, with the optimization + direction (or goal) as the value (type: optuna.study.StudyDirection) + """ + # if single-objective, just top level metric is set + if self._sweep_config.get("metric"): + direction = ( + optuna.study.StudyDirection.MINIMIZE + if self._sweep_config["metric"]["goal"] == "minimize" + else optuna.study.StudyDirection.MAXIMIZE + ) + metric = Metric( + name=self._sweep_config["metric"]["name"], direction=direction + ) + return [metric] + + # multi-objective optimization + metric_defs = [] + for metric in self._optuna_config.get("metrics", []): + if not metric.get("name"): + raise SchedulerError("Optuna metric missing name") + if not metric.get("goal"): + raise SchedulerError("Optuna metric missing goal") + + direction = ( + optuna.study.StudyDirection.MINIMIZE + if metric["goal"] == "minimize" + else optuna.study.StudyDirection.MAXIMIZE + ) + metric_defs += [Metric(name=metric["name"], direction=direction)] + + if len(metric_defs) == 0: + raise SchedulerError( + "Zero metrics found in the top level 'metric' section " + "and multi-objective metric section scheduler.settings.metrics" + ) + + return metric_defs + def _validate_optuna_study(self, study: optuna.Study) -> Optional[str]: """Accepts an optuna study, runs validation. @@ -222,8 +281,7 @@ def _load_optuna_classes( mod, err = _get_module("optuna", filepath) if not mod: raise SchedulerError( - f"Failed to load optuna from path {filepath} " - f" with error: {err}" + f"Failed to load optuna from path {filepath} with error: {err}" ) # Set custom optuna trial creation method @@ -286,9 +344,7 @@ def _load_file_from_artifact(self, artifact_name: str) -> str: # load user-set optuna class definition file artifact = self._wandb_run.use_artifact(artifact_name, type="optuna") if not artifact: - raise SchedulerError( - f"Failed to load artifact: {artifact_name}" - ) + raise SchedulerError(f"Failed to load artifact: {artifact_name}") path = artifact.download() optuna_filepath = self._optuna_config.get( @@ -369,16 +425,26 @@ def _load_optuna(self) -> None: else: wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler") - direction = self._sweep_config.get("metric", {}).get("goal") self._storage_path = existing_storage or OptunaComponents.storage.value - self._study = optuna.create_study( - study_name=self.study_name, - storage=f"sqlite:///{self._storage_path}", - pruner=pruner, - sampler=sampler, - load_if_exists=True, - direction=direction, - ) + directions = [metric.direction for metric in self._metric_defs] + if len(directions) == 1: + self._study = optuna.create_study( + study_name=self.study_name, + storage=f"sqlite:///{self._storage_path}", + pruner=pruner, + sampler=sampler, + load_if_exists=True, + direction=directions[0], + ) + else: # multi-objective optimization + self._study = optuna.create_study( + study_name=self.study_name, + storage=f"sqlite:///{self._storage_path}", + pruner=pruner, + sampler=sampler, + load_if_exists=True, + directions=directions, + ) wandb.termlog(self.study_string) if existing_storage: @@ -399,17 +465,13 @@ def _load_state(self) -> None: def _save_state(self) -> None: """Called when Scheduler class invokes exit(). - Save optuna study sqlite data to an artifact in the controller run + Save optuna study, or sqlite data to an artifact in the scheduler run """ - artifact = wandb.Artifact( - f"{OptunaComponents.storage.name}-{self._sweep_id}", type="optuna" - ) - if not self._storage_path: - wandb.termwarn( - f"{LOG_PREFIX}No db storage path found, saving to default path" - ) - self._storage_path = OptunaComponents.storage.value + if not self._study or self._storage_path: # nothing to save + return None + artifact_name = f"{OptunaComponents.storage.name}-{self._sweep_id}" + artifact = wandb.Artifact(artifact_name, type="optuna") artifact.add_file(self._storage_path) self._wandb_run.log_artifact(artifact) @@ -420,6 +482,7 @@ def _save_state(self) -> None: return def _get_next_sweep_run(self, worker_id: int) -> Optional[SweepRun]: + """Called repeatedly in the polling loop, whenever a worker is available.""" config, trial = self._trial_func() run: dict = self._api.upsert_run( project=self._project, @@ -465,15 +528,20 @@ def _get_run_history(self, run_id: str) -> List[int]: logger.debug(f"Failed to poll run from public api: {str(e)}") return [] - metric_name = self._sweep_config["metric"]["name"] - history = api_run.scan_history(keys=["_step", metric_name]) - metrics = [x[metric_name] for x in history] + names = [metric.name for metric in self._metric_defs] + history = api_run.scan_history(keys=names + ["_step"]) + metrics = [] + for log in history: + if self.is_multi_objective: + metrics += [tuple(log.get(key) for key in names)] + else: + metrics += [log.get(names[0])] if len(metrics) == 0 and api_run.lastHistoryStep > -1: logger.debug("No metrics, but lastHistoryStep exists") wandb.termwarn( - f"{LOG_PREFIX}Detected logged metrics, but none matching " + - f"provided metric name: '{metric_name}'" + f"{LOG_PREFIX}Detected logged metrics, but none matching " + + f"provided metric name(s): '{names}'" ) return metrics @@ -481,17 +549,22 @@ def _get_run_history(self, run_id: str) -> List[int]: def _poll_run(self, orun: OptunaRun) -> bool: """Polls metrics for a run, returns true if finished.""" metrics = self._get_run_history(orun.sweep_run.id) - for i, metric in enumerate(metrics[orun.num_metrics :]): - logger.debug(f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}") - prev = orun.trial._cached_frozen_trial.intermediate_values - if orun.num_metrics + i not in prev: - orun.trial.report(metric, orun.num_metrics + i) - - if orun.trial.should_prune(): - wandb.termlog(f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}") - self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED) - self._stop_run(orun.sweep_run.id) - return True + if not self.is_multi_objective: # can't report to trial when multi + for i, metric_val in enumerate(metrics[orun.num_metrics :]): + logger.debug( + f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}" + ) + prev = orun.trial._cached_frozen_trial.intermediate_values + if orun.num_metrics + i not in prev: + orun.trial.report(metric_val, orun.num_metrics + i) + + if orun.trial.should_prune(): + wandb.termlog( + f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}" + ) + self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED) + self._stop_run(orun.sweep_run.id) + return True orun.num_metrics = len(metrics) @@ -504,19 +577,21 @@ def _poll_run(self, orun: OptunaRun) -> bool: if ( self._runs[orun.sweep_run.id].state == RunState.FINISHED and len(prev_metrics) == 0 + and not self.is_multi_objective ): # run finished correctly, but never logged a metric wandb.termwarn( - f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " + - f"'{self._sweep_config['metric']['name']}'. Check your sweep " + - "config and training script." + f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " + + f"'{self._metric_defs[0].name}'. Check your sweep " + + "config and training script." ) self._num_misconfigured_runs += 1 self.study.tell(orun.trial, state=optuna.trial.TrialState.FAIL) if self._num_misconfigured_runs >= self.MAX_MISCONFIGURED_RUNS: raise SchedulerError( - f"Too many misconfigured runs ({self._num_misconfigured_runs}), stopping sweep early" + f"Too many misconfigured runs ({self._num_misconfigured_runs})," + " stopping sweep early" ) # Delete run in Scheduler memory, freeing up worker @@ -524,9 +599,12 @@ def _poll_run(self, orun: OptunaRun) -> bool: return True - last_value = prev_metrics[orun.num_metrics - 1] - self._num_misconfigured_runs = 0 # only count consecutive + if self.is_multi_objective: + last_value = tuple(metrics[-1]) + else: + last_value = prev_metrics[orun.num_metrics - 1] + self._num_misconfigured_runs = 0 # only count consecutive self.study.tell( trial=orun.trial, state=optuna.trial.TrialState.COMPLETE, @@ -534,7 +612,8 @@ def _poll_run(self, orun: OptunaRun) -> bool: ) wandb.termlog( f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) " - f"[last metric: {last_value}, total: {orun.num_metrics}]" + f"[last metric{'s' if self.is_multi_objective else ''}: {last_value}" + f", total: {orun.num_metrics}]" ) # Delete run in Scheduler memory, freeing up worker