From 5d616ce6e203e7b2c0595a57e05dde701e65d4fe Mon Sep 17 00:00:00 2001 From: Ben Sherman Date: Fri, 16 Jun 2023 10:18:21 -0700 Subject: [PATCH 01/17] resuming is working --- jobs/fashion_mnist_train/job.py | 43 ++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/jobs/fashion_mnist_train/job.py b/jobs/fashion_mnist_train/job.py index e035a52..015e232 100644 --- a/jobs/fashion_mnist_train/job.py +++ b/jobs/fashion_mnist_train/job.py @@ -6,18 +6,22 @@ import matplotlib.pyplot as plt import numpy as np import wandb + # To load the mnist data from keras.datasets import fashion_mnist +from keras.models import load_model + # importing various types of hidden layers from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D from tensorflow.keras.models import Sequential + # Adam legacy for m1/m2 macs from tensorflow.keras.optimizers.legacy import Adam -from wandb.keras import WandbMetricsLogger +from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint def train(project: Optional[str], entity: Optional[str], **kwargs: Any): - run = wandb.init(project=project, entity=entity) + run = wandb.init(project=project, entity=entity, resume=True) # get config, could be set from sweep scheduler train_config = run.config @@ -51,13 +55,17 @@ def train(project: Optional[str], entity: Optional[str], **kwargs: Any): train_X = np.expand_dims(train_X, -1).astype(np.float32) test_X = np.expand_dims(test_X, -1) - # load model - model = model_arch() - model.compile( - optimizer=Adam(learning_rate=learning_rate), - loss="sparse_categorical_crossentropy", - metrics=["sparse_categorical_accuracy"], - ) + # load model from checkpoint or create new model + ckpt = get_checkpoint() + if ckpt: + model = ckpt + else: + model = model_arch() + model.compile( + optimizer=Adam(learning_rate=learning_rate), + loss="sparse_categorical_crossentropy", + metrics=["sparse_categorical_accuracy"], + ) model.summary() model.fit( @@ -66,9 +74,7 @@ def train(project: Optional[str], entity: Optional[str], **kwargs: Any): epochs=epochs, steps_per_epoch=steps_per_epoch, validation_split=0.33, - callbacks=[ - WandbMetricsLogger(), - ], + callbacks=[WandbMetricsLogger(), WandbModelCheckpoint(filepath="model.h5")], ) # do some manual testing @@ -135,6 +141,19 @@ def model_arch(): return models +def get_checkpoint(): + assert wandb.run + api = wandb.Api() + run = api.run(wandb.run.path) + for artifact in run.logged_artifacts(): + if artifact.type == "model": + name = artifact.source_qualified_name + name = name.split(":")[0] + ":latest" + artifact = api.artifact(name) + path = artifact.download(root=".") + return load_model(path + "/model.h5") + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--entity", "-e", type=str, default=None) From dca31e8e5fd55e8d493590a317e6a8df73b0ad73 Mon Sep 17 00:00:00 2001 From: Tim Hays Date: Wed, 21 Jun 2023 16:16:31 -0700 Subject: [PATCH 02/17] Add more args --- jobs/fashion_mnist_train/job.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/jobs/fashion_mnist_train/job.py b/jobs/fashion_mnist_train/job.py index 015e232..c892e2e 100644 --- a/jobs/fashion_mnist_train/job.py +++ b/jobs/fashion_mnist_train/job.py @@ -20,16 +20,20 @@ from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint -def train(project: Optional[str], entity: Optional[str], **kwargs: Any): - run = wandb.init(project=project, entity=entity, resume=True) - - # get config, could be set from sweep scheduler - train_config = run.config - - # get training parameters from config - epochs = train_config.get("epochs", 10) - learning_rate = train_config.get("learning_rate", 0.001) - steps_per_epoch = train_config.get("steps_per_epoch", 100) +def train( + project: Optional[str], + entity: Optional[str], + epochs: Optional[int], + learning_rate: Optional[float], + steps_per_epoch: Optional[int], + **kwargs: Any +): + config = { + "epochs": epochs, + "learning_rate": learning_rate, + "steps_per_epoch": steps_per_epoch, + } + wandb.init(project=project, entity=entity, config=config, resume=True) # load data (train_X, train_y), (test_X, test_y) = fashion_mnist.load_data() @@ -158,6 +162,9 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--entity", "-e", type=str, default=None) parser.add_argument("--project", "-p", type=str, default=None) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--learning_rate", type=float, default=0.001) + parser.add_argument("--steps_per_epoch", type=int, default=100) args = parser.parse_args() train(**vars(args)) From a4f16ebbdb067974adbe6c3b63300d0679c235d9 Mon Sep 17 00:00:00 2001 From: Tim Hays Date: Thu, 22 Jun 2023 09:36:26 -0700 Subject: [PATCH 03/17] Use config values --- jobs/fashion_mnist_train/job.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/jobs/fashion_mnist_train/job.py b/jobs/fashion_mnist_train/job.py index c892e2e..2a112c3 100644 --- a/jobs/fashion_mnist_train/job.py +++ b/jobs/fashion_mnist_train/job.py @@ -33,7 +33,15 @@ def train( "learning_rate": learning_rate, "steps_per_epoch": steps_per_epoch, } - wandb.init(project=project, entity=entity, config=config, resume=True) + run = wandb.init(project=project, entity=entity, config=config, resume=True) + + # get config, could be set from sweep scheduler + train_config = run.config + + # get training parameters from config + epochs = train_config.get("epochs", 10) + learning_rate = train_config.get("learning_rate", 0.001) + steps_per_epoch = train_config.get("steps_per_epoch", 100) # load data (train_X, train_y), (test_X, test_y) = fashion_mnist.load_data() From e7449ca5b0d864ba89f3b481b1ae827403d0a985 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 26 Jun 2023 11:34:38 -0700 Subject: [PATCH 04/17] feat(sweeps): optuna supports multi-objective optimization --- .../optuna_scheduler/optuna_scheduler.py | 159 ++++++++++++------ 1 file changed, 107 insertions(+), 52 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 1932ed9..51fd845 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -1,3 +1,4 @@ +import argparse import base64 import logging import os @@ -10,7 +11,6 @@ from types import ModuleType from typing import Any, Dict, List, Optional, Tuple -import argparse import click import optuna import wandb @@ -19,8 +19,7 @@ from wandb.apis.public import QueuedRun, Run from wandb.sdk.artifacts.public_artifact import Artifact from wandb.sdk.launch.sweeps import SchedulerError -from wandb.sdk.launch.sweeps.scheduler import Scheduler, SweepRun, RunState - +from wandb.sdk.launch.sweeps.scheduler import RunState, Scheduler, SweepRun logger = logging.getLogger(__name__) optuna.logging.set_verbosity(optuna.logging.WARNING) @@ -130,6 +129,16 @@ def __init__( if not self._optuna_config: self._optuna_config = self._wandb_run.config.get("settings", {}) + if self._sweep_config.get("metric"): + self.metric_names = [self._sweep_config["metric"]["name"]] + else: # multi-objective optimization + for metric in self._optuna_config.get("metrics", []): + if not metric.get("name"): + raise SchedulerError("Optuna metric missing name") + if not metric.get("goal"): + raise SchedulerError("Optuna metric missing goal") + self.metric_names += [metric["name"]] + # if metric is misconfigured, increment, stop sweep if 3 consecutive fails self._num_misconfigured_runs = 0 @@ -146,12 +155,22 @@ def study_name(self) -> str: optuna_study_name: str = self.study.study_name return optuna_study_name + @property + def is_multi_objective(self) -> bool: + return len(self.metric_names) > 1 + @property def study_string(self) -> str: msg = f"{LOG_PREFIX}{'Loading' if self._wandb_run.resumed else 'Creating'}" msg += f" optuna study: {self.study_name} " msg += f"[storage:{self.study._storage.__class__.__name__}" - msg += f", direction:{self.study.direction.name.capitalize()}" + if not self.is_multi_objective: + msg += f", direction: {self.study.direction.name.capitalize()}" + else: + msg += ", directions: " + for i, metric in enumerate(self.metric_names): + msg += f"{metric}:{self.study.directions[i].name.capitalize()}, " + msg = msg[:-2] msg += f", pruner:{self.study.pruner.__class__.__name__}" msg += f", sampler:{self.study.sampler.__class__.__name__}]" return msg @@ -169,17 +188,30 @@ def formatted_trials(self) -> str: trial_strs = [] for trial in self.study.trials: run_id = trial.user_attrs["run_id"] - vals = list(trial.intermediate_values.values()) - best = None - if len(vals) > 0: - if self.study.direction == optuna.study.StudyDirection.MINIMIZE: - best = round(min(vals), 5) - elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE: - best = round(max(vals), 5) - trial_strs += [ - f"\t[trial-{trial.number + 1}] run: {run_id}, state: " - f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" - ] + best: str = "" + if not self.is_multi_objective: + vals = list(trial.intermediate_values.values()) + if len(vals) > 0: + if self.study.direction == optuna.study.StudyDirection.MINIMIZE: + best = f"{round(min(vals), 5)}" + elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE: + best = f"{round(max(vals), 5)}" + trial_strs += [ + f"\t[trial-{trial.number + 1}] run: {run_id}, state: " + f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" + ] + else: # multi-objective optimization, only 1 metric logged in study + vals = trial.values + if vals: + for idx in range(len(self.metric_names)): + direction = self.study.directions[idx].name.capitalize() + best += f"{self.metric_names[idx]} ({direction}):" + best += f"{round(vals[idx], 5)}, " + best = best[:-2] + trial_strs += [ + f"\t[trial-{trial.number + 1}] run: {run_id}, state: " + f"{trial.state.name}, best: {best or 'None'}" + ] return "\n".join(trial_strs[-10:]) # only print out last 10 @@ -222,8 +254,7 @@ def _load_optuna_classes( mod, err = _get_module("optuna", filepath) if not mod: raise SchedulerError( - f"Failed to load optuna from path {filepath} " - f" with error: {err}" + f"Failed to load optuna from path {filepath} " f" with error: {err}" ) # Set custom optuna trial creation method @@ -286,9 +317,7 @@ def _load_file_from_artifact(self, artifact_name: str) -> str: # load user-set optuna class definition file artifact = self._wandb_run.use_artifact(artifact_name, type="optuna") if not artifact: - raise SchedulerError( - f"Failed to load artifact: {artifact_name}" - ) + raise SchedulerError(f"Failed to load artifact: {artifact_name}") path = artifact.download() optuna_filepath = self._optuna_config.get( @@ -369,16 +398,30 @@ def _load_optuna(self) -> None: else: wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler") - direction = self._sweep_config.get("metric", {}).get("goal") + if len(self.metric_names) == 1: + directions = [self._sweep_config.get("metric", {}).get("goal")] + else: + directions = [x["goal"] for x in self._optuna_config["metrics"]] + self._storage_path = existing_storage or OptunaComponents.storage.value - self._study = optuna.create_study( - study_name=self.study_name, - storage=f"sqlite:///{self._storage_path}", - pruner=pruner, - sampler=sampler, - load_if_exists=True, - direction=direction, - ) + if len(self.metric_names) == 1: + self._study = optuna.create_study( + study_name=self.study_name, + storage=f"sqlite:///{self._storage_path}", + pruner=pruner, + sampler=sampler, + load_if_exists=True, + direction=directions[0], + ) + else: # multi-objective optimization + self._study = optuna.create_study( + study_name=self.study_name, + storage=f"sqlite:///{self._storage_path}", + pruner=pruner, + sampler=sampler, + load_if_exists=True, + directions=directions, + ) wandb.termlog(self.study_string) if existing_storage: @@ -465,15 +508,16 @@ def _get_run_history(self, run_id: str) -> List[int]: logger.debug(f"Failed to poll run from public api: {str(e)}") return [] - metric_name = self._sweep_config["metric"]["name"] - history = api_run.scan_history(keys=["_step", metric_name]) - metrics = [x[metric_name] for x in history] + history = api_run.scan_history(keys=self.metric_names + ["_step"]) + metrics = [] + for log in history: + metrics += [tuple(log.get(key) for key in self.metric_names)] if len(metrics) == 0 and api_run.lastHistoryStep > -1: logger.debug("No metrics, but lastHistoryStep exists") wandb.termwarn( - f"{LOG_PREFIX}Detected logged metrics, but none matching " + - f"provided metric name: '{metric_name}'" + f"{LOG_PREFIX}Detected logged metrics, but none matching " + + f"provided metric name(s): '{self.metric_names}'" ) return metrics @@ -481,17 +525,22 @@ def _get_run_history(self, run_id: str) -> List[int]: def _poll_run(self, orun: OptunaRun) -> bool: """Polls metrics for a run, returns true if finished.""" metrics = self._get_run_history(orun.sweep_run.id) - for i, metric in enumerate(metrics[orun.num_metrics :]): - logger.debug(f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}") - prev = orun.trial._cached_frozen_trial.intermediate_values - if orun.num_metrics + i not in prev: - orun.trial.report(metric, orun.num_metrics + i) - - if orun.trial.should_prune(): - wandb.termlog(f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}") - self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED) - self._stop_run(orun.sweep_run.id) - return True + if not self.is_multi_objective: # can't report to trial when multi + for i, metric in enumerate(metrics[orun.num_metrics :]): + logger.debug( + f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}" + ) + prev = orun.trial._cached_frozen_trial.intermediate_values + if orun.num_metrics + i not in prev: + orun.trial.report(metric, orun.num_metrics + i) + + if orun.trial.should_prune(): + wandb.termlog( + f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}" + ) + self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED) + self._stop_run(orun.sweep_run.id) + return True orun.num_metrics = len(metrics) @@ -504,19 +553,21 @@ def _poll_run(self, orun: OptunaRun) -> bool: if ( self._runs[orun.sweep_run.id].state == RunState.FINISHED and len(prev_metrics) == 0 + and not self.is_multi_objective ): # run finished correctly, but never logged a metric wandb.termwarn( - f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " + - f"'{self._sweep_config['metric']['name']}'. Check your sweep " + - "config and training script." + f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " + + f"'{self._sweep_config['metric']['name']}'. Check your sweep " + + "config and training script." ) self._num_misconfigured_runs += 1 self.study.tell(orun.trial, state=optuna.trial.TrialState.FAIL) if self._num_misconfigured_runs >= self.MAX_MISCONFIGURED_RUNS: raise SchedulerError( - f"Too many misconfigured runs ({self._num_misconfigured_runs}), stopping sweep early" + f"Too many misconfigured runs ({self._num_misconfigured_runs})," + " stopping sweep early" ) # Delete run in Scheduler memory, freeing up worker @@ -524,9 +575,12 @@ def _poll_run(self, orun: OptunaRun) -> bool: return True - last_value = prev_metrics[orun.num_metrics - 1] - self._num_misconfigured_runs = 0 # only count consecutive + if self.is_multi_objective: + last_value = tuple(metrics[-1]) + else: + last_value = prev_metrics[orun.num_metrics - 1] + self._num_misconfigured_runs = 0 # only count consecutive self.study.tell( trial=orun.trial, state=optuna.trial.TrialState.COMPLETE, @@ -534,7 +588,8 @@ def _poll_run(self, orun: OptunaRun) -> bool: ) wandb.termlog( f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) " - f"[last metric: {last_value}, total: {orun.num_metrics}]" + f"[last metric{'s' if len(self.metric_names) > 1 else ''}: {last_value}" + f", total: {orun.num_metrics}]" ) # Delete run in Scheduler memory, freeing up worker From a8ff3e2776e57499d963cf209e6131960c84b491 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 26 Jun 2023 15:47:44 -0700 Subject: [PATCH 05/17] add metric class --- .../optuna_scheduler/optuna_scheduler.py | 109 ++++++++++++------ 1 file changed, 76 insertions(+), 33 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 51fd845..1747fe8 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -43,6 +43,12 @@ class OptunaRun: sweep_run: SweepRun +@dataclass +class Metric: + name: str + direction: optuna.study.StudyDirection + + def setup_scheduler(scheduler: Scheduler, **kwargs): """Setup a run to log a scheduler job. @@ -129,15 +135,7 @@ def __init__( if not self._optuna_config: self._optuna_config = self._wandb_run.config.get("settings", {}) - if self._sweep_config.get("metric"): - self.metric_names = [self._sweep_config["metric"]["name"]] - else: # multi-objective optimization - for metric in self._optuna_config.get("metrics", []): - if not metric.get("name"): - raise SchedulerError("Optuna metric missing name") - if not metric.get("goal"): - raise SchedulerError("Optuna metric missing goal") - self.metric_names += [metric["name"]] + self._metric_defs = self._get_metric_names_and_directions() # if metric is misconfigured, increment, stop sweep if 3 consecutive fails self._num_misconfigured_runs = 0 @@ -157,7 +155,7 @@ def study_name(self) -> str: @property def is_multi_objective(self) -> bool: - return len(self.metric_names) > 1 + return len(self._metric_defs) > 1 @property def study_string(self) -> str: @@ -165,11 +163,11 @@ def study_string(self) -> str: msg += f" optuna study: {self.study_name} " msg += f"[storage:{self.study._storage.__class__.__name__}" if not self.is_multi_objective: - msg += f", direction: {self.study.direction.name.capitalize()}" + msg += f", direction: {self._metric_defs[0].direction.name.capitalize()}" else: msg += ", directions: " - for i, metric in enumerate(self.metric_names): - msg += f"{metric}:{self.study.directions[i].name.capitalize()}, " + for metric in self._metric_defs: + msg += f"{metric.name}:{metric.direction.name.capitalize()}, " msg = msg[:-2] msg += f", pruner:{self.study.pruner.__class__.__name__}" msg += f", sampler:{self.study.sampler.__class__.__name__}]" @@ -201,13 +199,23 @@ def formatted_trials(self) -> str: f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" ] else: # multi-objective optimization, only 1 metric logged in study - vals = trial.values - if vals: - for idx in range(len(self.metric_names)): - direction = self.study.directions[idx].name.capitalize() - best += f"{self.metric_names[idx]} ({direction}):" - best += f"{round(vals[idx], 5)}, " - best = best[:-2] + if not trial.values: + continue + + if len(trial.values) != len(self._metric_defs): + wandb.termwarn( + f"{LOG_PREFIX}Number of trial metrics ({trial.values})" + " does not match number of metrics defined " + f"({self._metric_defs})" + ) + continue + + for val, metric in zip(trial.values, self._metric_defs): + direction = metric.direction.name.capitalize() + best += f"{metric.name} ({direction}):" + best += f"{round(val, 5)}, " + + best = best[:-2] trial_strs += [ f"\t[trial-{trial.number + 1}] run: {run_id}, state: " f"{trial.state.name}, best: {best or 'None'}" @@ -215,6 +223,44 @@ def formatted_trials(self) -> str: return "\n".join(trial_strs[-10:]) # only print out last 10 + def _get_metric_names_and_directions(self) -> List[Metric]: + """Helper to configure dict of at least one metric. + + Dict contains the metric names as keys, with the optimization + direction (or goal) as the value + """ + # if single-objective, just top level metric is set + if self._sweep_config.get("metric"): + direction = ( + optuna.study.StudyDirection.MINIMIZE + if self._sweep_config["metric"]["goal"] == "minimize" + else optuna.study.StudyDirection.MAXIMIZE + ) + metric = Metric( + name=self._sweep_config["metric"]["name"], direction=direction + ) + return [metric] + + # multi-objective optimization + metric_defs = [] + for metric in self._optuna_config.get("metrics", []): + if not metric.get("name"): + raise SchedulerError("Optuna metric missing name") + if not metric.get("goal"): + raise SchedulerError("Optuna metric missing goal") + + direction = ( + optuna.study.StudyDirection.MINIMIZE + if metric["goal"] == "minimize" + else optuna.study.StudyDirection.MAXIMIZE + ) + metric_defs += [Metric(name=metric["name"], direction=direction)] + + if len(metric_defs) == 0: + raise SchedulerError("Optuna sweep missing metric") + + return metric_defs + def _validate_optuna_study(self, study: optuna.Study) -> Optional[str]: """Accepts an optuna study, runs validation. @@ -398,13 +444,9 @@ def _load_optuna(self) -> None: else: wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler") - if len(self.metric_names) == 1: - directions = [self._sweep_config.get("metric", {}).get("goal")] - else: - directions = [x["goal"] for x in self._optuna_config["metrics"]] - self._storage_path = existing_storage or OptunaComponents.storage.value - if len(self.metric_names) == 1: + directions = [metric.direction for metric in self._metric_defs] + if len(directions) == 1: self._study = optuna.create_study( study_name=self.study_name, storage=f"sqlite:///{self._storage_path}", @@ -508,16 +550,17 @@ def _get_run_history(self, run_id: str) -> List[int]: logger.debug(f"Failed to poll run from public api: {str(e)}") return [] - history = api_run.scan_history(keys=self.metric_names + ["_step"]) + names = [metric.name for metric in self._metric_defs] + history = api_run.scan_history(keys=names + ["_step"]) metrics = [] for log in history: - metrics += [tuple(log.get(key) for key in self.metric_names)] + metrics += [tuple(log.get(key) for key in names)] if len(metrics) == 0 and api_run.lastHistoryStep > -1: logger.debug("No metrics, but lastHistoryStep exists") wandb.termwarn( f"{LOG_PREFIX}Detected logged metrics, but none matching " - + f"provided metric name(s): '{self.metric_names}'" + + f"provided metric name(s): '{names}'" ) return metrics @@ -526,13 +569,13 @@ def _poll_run(self, orun: OptunaRun) -> bool: """Polls metrics for a run, returns true if finished.""" metrics = self._get_run_history(orun.sweep_run.id) if not self.is_multi_objective: # can't report to trial when multi - for i, metric in enumerate(metrics[orun.num_metrics :]): + for i, metric_val in enumerate(metrics[orun.num_metrics :]): logger.debug( f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}" ) prev = orun.trial._cached_frozen_trial.intermediate_values if orun.num_metrics + i not in prev: - orun.trial.report(metric, orun.num_metrics + i) + orun.trial.report(metric_val, orun.num_metrics + i) if orun.trial.should_prune(): wandb.termlog( @@ -558,7 +601,7 @@ def _poll_run(self, orun: OptunaRun) -> bool: # run finished correctly, but never logged a metric wandb.termwarn( f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " - + f"'{self._sweep_config['metric']['name']}'. Check your sweep " + + f"'{self._metric_defs[0].name}'. Check your sweep " + "config and training script." ) self._num_misconfigured_runs += 1 @@ -588,7 +631,7 @@ def _poll_run(self, orun: OptunaRun) -> bool: ) wandb.termlog( f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) " - f"[last metric{'s' if len(self.metric_names) > 1 else ''}: {last_value}" + f"[last metric{'s' if self.is_multi_objective else ''}: {last_value}" f", total: {orun.num_metrics}]" ) From b90c26d4ae6d67b282629fb9e273e8c39c167716 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 27 Jun 2023 10:26:17 -0700 Subject: [PATCH 06/17] small refactor --- .../optuna_scheduler/optuna_scheduler.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 1747fe8..92eeb2d 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Tuple import click +import joblib import optuna import wandb from wandb.apis.internal import Api @@ -97,18 +98,20 @@ def _handle_job_logic(run, name, enable_git=False) -> None: ) tag = os.environ.get("WANDB_DOCKER", "").split(":") if len(tag) == 2: - jobstr += f"-{tag[0].replace('/', '_')}_{tag[-1]}:latest" - else: + jobstr += f"-{tag[0].replace('/', '_')}:{tag[-1]}" + elif len(tag) == 1: + jobstr += f"-{tag[0].replace('/', '_')}:latest" + else: # unknown format jobstr = f"found here: https://wandb.ai/{jobstr}s/" wandb.termlog(f"Creating image job {click.style(jobstr, fg='yellow')}\n") elif not enable_git: jobstr += f"-{name}:latest" wandb.termlog( - f"Creating code-artifact job: {click.style(jobstr, fg='yellow')}\n" + f"Creating artifact job: {click.style(jobstr, fg='yellow')}\n" ) else: _s = click.style(f"https://wandb.ai/{jobstr}s/", fg="yellow") - wandb.termlog(f"Creating repo-artifact job found here: {_s}\n") + wandb.termlog(f"Creating repo job found here: {_s}\n") run.log_code(name=name, exclude_fn=lambda x: x.startswith("_")) return @@ -227,7 +230,7 @@ def _get_metric_names_and_directions(self) -> List[Metric]: """Helper to configure dict of at least one metric. Dict contains the metric names as keys, with the optimization - direction (or goal) as the value + direction (or goal) as the value (type: optuna.study.StudyDirection) """ # if single-objective, just top level metric is set if self._sweep_config.get("metric"): @@ -484,17 +487,19 @@ def _load_state(self) -> None: def _save_state(self) -> None: """Called when Scheduler class invokes exit(). - Save optuna study sqlite data to an artifact in the controller run + Save optuna study, or sqlite data to an artifact in the scheduler run """ - artifact = wandb.Artifact( - f"{OptunaComponents.storage.name}-{self._sweep_id}", type="optuna" - ) + if not self._study: # nothing to save + return None + + artifact_name = f"{OptunaComponents.storage.name}-{self._sweep_id}" if not self._storage_path: - wandb.termwarn( - f"{LOG_PREFIX}No db storage path found, saving to default path" - ) - self._storage_path = OptunaComponents.storage.value + wandb.termwarn(f"{LOG_PREFIX}No db storage path found, saving full model") + artifact_name = f"optuna-study-{self._sweep_id}" + joblib.dump(self.study, f"study-{self._sweep_id}.pkl") + self._storage_path = f"study-{self._sweep_id}.pkl" + artifact = wandb.Artifact(artifact_name, type="optuna") artifact.add_file(self._storage_path) self._wandb_run.log_artifact(artifact) @@ -505,6 +510,7 @@ def _save_state(self) -> None: return def _get_next_sweep_run(self, worker_id: int) -> Optional[SweepRun]: + """Called repeatedly in the polling loop, whenever a worker is available.""" config, trial = self._trial_func() run: dict = self._api.upsert_run( project=self._project, From 66d9c7d0a7b5150d8d6687a6ec35067988dae3e8 Mon Sep 17 00:00:00 2001 From: Tim Hays Date: Tue, 27 Jun 2023 11:02:18 -0700 Subject: [PATCH 07/17] Set initial epoch --- jobs/fashion_mnist_train/job.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jobs/fashion_mnist_train/job.py b/jobs/fashion_mnist_train/job.py index 2a112c3..3f35815 100644 --- a/jobs/fashion_mnist_train/job.py +++ b/jobs/fashion_mnist_train/job.py @@ -87,6 +87,7 @@ def train( steps_per_epoch=steps_per_epoch, validation_split=0.33, callbacks=[WandbMetricsLogger(), WandbModelCheckpoint(filepath="model.h5")], + initial_epoch=wandb.run.step, ) # do some manual testing From 5b021e0c0791a7030710aa0d70245cec27d15b9a Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Wed, 5 Jul 2023 11:21:22 -0700 Subject: [PATCH 08/17] add joblib to reqs --- jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py | 2 +- jobs/sweep_schedulers/optuna_scheduler/requirements.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 92eeb2d..32894e2 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -111,7 +111,7 @@ def _handle_job_logic(run, name, enable_git=False) -> None: ) else: _s = click.style(f"https://wandb.ai/{jobstr}s/", fg="yellow") - wandb.termlog(f"Creating repo job found here: {_s}\n") + wandb.termlog(f"Creating git repo job found here: {_s}\n") run.log_code(name=name, exclude_fn=lambda x: x.startswith("_")) return diff --git a/jobs/sweep_schedulers/optuna_scheduler/requirements.txt b/jobs/sweep_schedulers/optuna_scheduler/requirements.txt index a312fc9..04f1524 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/requirements.txt +++ b/jobs/sweep_schedulers/optuna_scheduler/requirements.txt @@ -1,3 +1,4 @@ wandb optuna scipy +joblib From 6831a805ee477a4793fd6bd1be884c832d86dd68 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 26 Jun 2023 11:34:38 -0700 Subject: [PATCH 09/17] feat(sweeps): optuna supports multi-objective optimization --- .../optuna_scheduler/optuna_scheduler.py | 159 ++++++++++++------ 1 file changed, 107 insertions(+), 52 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 1932ed9..51fd845 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -1,3 +1,4 @@ +import argparse import base64 import logging import os @@ -10,7 +11,6 @@ from types import ModuleType from typing import Any, Dict, List, Optional, Tuple -import argparse import click import optuna import wandb @@ -19,8 +19,7 @@ from wandb.apis.public import QueuedRun, Run from wandb.sdk.artifacts.public_artifact import Artifact from wandb.sdk.launch.sweeps import SchedulerError -from wandb.sdk.launch.sweeps.scheduler import Scheduler, SweepRun, RunState - +from wandb.sdk.launch.sweeps.scheduler import RunState, Scheduler, SweepRun logger = logging.getLogger(__name__) optuna.logging.set_verbosity(optuna.logging.WARNING) @@ -130,6 +129,16 @@ def __init__( if not self._optuna_config: self._optuna_config = self._wandb_run.config.get("settings", {}) + if self._sweep_config.get("metric"): + self.metric_names = [self._sweep_config["metric"]["name"]] + else: # multi-objective optimization + for metric in self._optuna_config.get("metrics", []): + if not metric.get("name"): + raise SchedulerError("Optuna metric missing name") + if not metric.get("goal"): + raise SchedulerError("Optuna metric missing goal") + self.metric_names += [metric["name"]] + # if metric is misconfigured, increment, stop sweep if 3 consecutive fails self._num_misconfigured_runs = 0 @@ -146,12 +155,22 @@ def study_name(self) -> str: optuna_study_name: str = self.study.study_name return optuna_study_name + @property + def is_multi_objective(self) -> bool: + return len(self.metric_names) > 1 + @property def study_string(self) -> str: msg = f"{LOG_PREFIX}{'Loading' if self._wandb_run.resumed else 'Creating'}" msg += f" optuna study: {self.study_name} " msg += f"[storage:{self.study._storage.__class__.__name__}" - msg += f", direction:{self.study.direction.name.capitalize()}" + if not self.is_multi_objective: + msg += f", direction: {self.study.direction.name.capitalize()}" + else: + msg += ", directions: " + for i, metric in enumerate(self.metric_names): + msg += f"{metric}:{self.study.directions[i].name.capitalize()}, " + msg = msg[:-2] msg += f", pruner:{self.study.pruner.__class__.__name__}" msg += f", sampler:{self.study.sampler.__class__.__name__}]" return msg @@ -169,17 +188,30 @@ def formatted_trials(self) -> str: trial_strs = [] for trial in self.study.trials: run_id = trial.user_attrs["run_id"] - vals = list(trial.intermediate_values.values()) - best = None - if len(vals) > 0: - if self.study.direction == optuna.study.StudyDirection.MINIMIZE: - best = round(min(vals), 5) - elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE: - best = round(max(vals), 5) - trial_strs += [ - f"\t[trial-{trial.number + 1}] run: {run_id}, state: " - f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" - ] + best: str = "" + if not self.is_multi_objective: + vals = list(trial.intermediate_values.values()) + if len(vals) > 0: + if self.study.direction == optuna.study.StudyDirection.MINIMIZE: + best = f"{round(min(vals), 5)}" + elif self.study.direction == optuna.study.StudyDirection.MAXIMIZE: + best = f"{round(max(vals), 5)}" + trial_strs += [ + f"\t[trial-{trial.number + 1}] run: {run_id}, state: " + f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" + ] + else: # multi-objective optimization, only 1 metric logged in study + vals = trial.values + if vals: + for idx in range(len(self.metric_names)): + direction = self.study.directions[idx].name.capitalize() + best += f"{self.metric_names[idx]} ({direction}):" + best += f"{round(vals[idx], 5)}, " + best = best[:-2] + trial_strs += [ + f"\t[trial-{trial.number + 1}] run: {run_id}, state: " + f"{trial.state.name}, best: {best or 'None'}" + ] return "\n".join(trial_strs[-10:]) # only print out last 10 @@ -222,8 +254,7 @@ def _load_optuna_classes( mod, err = _get_module("optuna", filepath) if not mod: raise SchedulerError( - f"Failed to load optuna from path {filepath} " - f" with error: {err}" + f"Failed to load optuna from path {filepath} " f" with error: {err}" ) # Set custom optuna trial creation method @@ -286,9 +317,7 @@ def _load_file_from_artifact(self, artifact_name: str) -> str: # load user-set optuna class definition file artifact = self._wandb_run.use_artifact(artifact_name, type="optuna") if not artifact: - raise SchedulerError( - f"Failed to load artifact: {artifact_name}" - ) + raise SchedulerError(f"Failed to load artifact: {artifact_name}") path = artifact.download() optuna_filepath = self._optuna_config.get( @@ -369,16 +398,30 @@ def _load_optuna(self) -> None: else: wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler") - direction = self._sweep_config.get("metric", {}).get("goal") + if len(self.metric_names) == 1: + directions = [self._sweep_config.get("metric", {}).get("goal")] + else: + directions = [x["goal"] for x in self._optuna_config["metrics"]] + self._storage_path = existing_storage or OptunaComponents.storage.value - self._study = optuna.create_study( - study_name=self.study_name, - storage=f"sqlite:///{self._storage_path}", - pruner=pruner, - sampler=sampler, - load_if_exists=True, - direction=direction, - ) + if len(self.metric_names) == 1: + self._study = optuna.create_study( + study_name=self.study_name, + storage=f"sqlite:///{self._storage_path}", + pruner=pruner, + sampler=sampler, + load_if_exists=True, + direction=directions[0], + ) + else: # multi-objective optimization + self._study = optuna.create_study( + study_name=self.study_name, + storage=f"sqlite:///{self._storage_path}", + pruner=pruner, + sampler=sampler, + load_if_exists=True, + directions=directions, + ) wandb.termlog(self.study_string) if existing_storage: @@ -465,15 +508,16 @@ def _get_run_history(self, run_id: str) -> List[int]: logger.debug(f"Failed to poll run from public api: {str(e)}") return [] - metric_name = self._sweep_config["metric"]["name"] - history = api_run.scan_history(keys=["_step", metric_name]) - metrics = [x[metric_name] for x in history] + history = api_run.scan_history(keys=self.metric_names + ["_step"]) + metrics = [] + for log in history: + metrics += [tuple(log.get(key) for key in self.metric_names)] if len(metrics) == 0 and api_run.lastHistoryStep > -1: logger.debug("No metrics, but lastHistoryStep exists") wandb.termwarn( - f"{LOG_PREFIX}Detected logged metrics, but none matching " + - f"provided metric name: '{metric_name}'" + f"{LOG_PREFIX}Detected logged metrics, but none matching " + + f"provided metric name(s): '{self.metric_names}'" ) return metrics @@ -481,17 +525,22 @@ def _get_run_history(self, run_id: str) -> List[int]: def _poll_run(self, orun: OptunaRun) -> bool: """Polls metrics for a run, returns true if finished.""" metrics = self._get_run_history(orun.sweep_run.id) - for i, metric in enumerate(metrics[orun.num_metrics :]): - logger.debug(f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}") - prev = orun.trial._cached_frozen_trial.intermediate_values - if orun.num_metrics + i not in prev: - orun.trial.report(metric, orun.num_metrics + i) - - if orun.trial.should_prune(): - wandb.termlog(f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}") - self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED) - self._stop_run(orun.sweep_run.id) - return True + if not self.is_multi_objective: # can't report to trial when multi + for i, metric in enumerate(metrics[orun.num_metrics :]): + logger.debug( + f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}" + ) + prev = orun.trial._cached_frozen_trial.intermediate_values + if orun.num_metrics + i not in prev: + orun.trial.report(metric, orun.num_metrics + i) + + if orun.trial.should_prune(): + wandb.termlog( + f"{LOG_PREFIX}Optuna pruning run: {orun.sweep_run.id}" + ) + self.study.tell(orun.trial, state=optuna.trial.TrialState.PRUNED) + self._stop_run(orun.sweep_run.id) + return True orun.num_metrics = len(metrics) @@ -504,19 +553,21 @@ def _poll_run(self, orun: OptunaRun) -> bool: if ( self._runs[orun.sweep_run.id].state == RunState.FINISHED and len(prev_metrics) == 0 + and not self.is_multi_objective ): # run finished correctly, but never logged a metric wandb.termwarn( - f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " + - f"'{self._sweep_config['metric']['name']}'. Check your sweep " + - "config and training script." + f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " + + f"'{self._sweep_config['metric']['name']}'. Check your sweep " + + "config and training script." ) self._num_misconfigured_runs += 1 self.study.tell(orun.trial, state=optuna.trial.TrialState.FAIL) if self._num_misconfigured_runs >= self.MAX_MISCONFIGURED_RUNS: raise SchedulerError( - f"Too many misconfigured runs ({self._num_misconfigured_runs}), stopping sweep early" + f"Too many misconfigured runs ({self._num_misconfigured_runs})," + " stopping sweep early" ) # Delete run in Scheduler memory, freeing up worker @@ -524,9 +575,12 @@ def _poll_run(self, orun: OptunaRun) -> bool: return True - last_value = prev_metrics[orun.num_metrics - 1] - self._num_misconfigured_runs = 0 # only count consecutive + if self.is_multi_objective: + last_value = tuple(metrics[-1]) + else: + last_value = prev_metrics[orun.num_metrics - 1] + self._num_misconfigured_runs = 0 # only count consecutive self.study.tell( trial=orun.trial, state=optuna.trial.TrialState.COMPLETE, @@ -534,7 +588,8 @@ def _poll_run(self, orun: OptunaRun) -> bool: ) wandb.termlog( f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) " - f"[last metric: {last_value}, total: {orun.num_metrics}]" + f"[last metric{'s' if len(self.metric_names) > 1 else ''}: {last_value}" + f", total: {orun.num_metrics}]" ) # Delete run in Scheduler memory, freeing up worker From a601def5fe46afe72b80316a46d5d88ace5aab73 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 26 Jun 2023 15:47:44 -0700 Subject: [PATCH 10/17] add metric class --- .../optuna_scheduler/optuna_scheduler.py | 109 ++++++++++++------ 1 file changed, 76 insertions(+), 33 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 51fd845..1747fe8 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -43,6 +43,12 @@ class OptunaRun: sweep_run: SweepRun +@dataclass +class Metric: + name: str + direction: optuna.study.StudyDirection + + def setup_scheduler(scheduler: Scheduler, **kwargs): """Setup a run to log a scheduler job. @@ -129,15 +135,7 @@ def __init__( if not self._optuna_config: self._optuna_config = self._wandb_run.config.get("settings", {}) - if self._sweep_config.get("metric"): - self.metric_names = [self._sweep_config["metric"]["name"]] - else: # multi-objective optimization - for metric in self._optuna_config.get("metrics", []): - if not metric.get("name"): - raise SchedulerError("Optuna metric missing name") - if not metric.get("goal"): - raise SchedulerError("Optuna metric missing goal") - self.metric_names += [metric["name"]] + self._metric_defs = self._get_metric_names_and_directions() # if metric is misconfigured, increment, stop sweep if 3 consecutive fails self._num_misconfigured_runs = 0 @@ -157,7 +155,7 @@ def study_name(self) -> str: @property def is_multi_objective(self) -> bool: - return len(self.metric_names) > 1 + return len(self._metric_defs) > 1 @property def study_string(self) -> str: @@ -165,11 +163,11 @@ def study_string(self) -> str: msg += f" optuna study: {self.study_name} " msg += f"[storage:{self.study._storage.__class__.__name__}" if not self.is_multi_objective: - msg += f", direction: {self.study.direction.name.capitalize()}" + msg += f", direction: {self._metric_defs[0].direction.name.capitalize()}" else: msg += ", directions: " - for i, metric in enumerate(self.metric_names): - msg += f"{metric}:{self.study.directions[i].name.capitalize()}, " + for metric in self._metric_defs: + msg += f"{metric.name}:{metric.direction.name.capitalize()}, " msg = msg[:-2] msg += f", pruner:{self.study.pruner.__class__.__name__}" msg += f", sampler:{self.study.sampler.__class__.__name__}]" @@ -201,13 +199,23 @@ def formatted_trials(self) -> str: f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" ] else: # multi-objective optimization, only 1 metric logged in study - vals = trial.values - if vals: - for idx in range(len(self.metric_names)): - direction = self.study.directions[idx].name.capitalize() - best += f"{self.metric_names[idx]} ({direction}):" - best += f"{round(vals[idx], 5)}, " - best = best[:-2] + if not trial.values: + continue + + if len(trial.values) != len(self._metric_defs): + wandb.termwarn( + f"{LOG_PREFIX}Number of trial metrics ({trial.values})" + " does not match number of metrics defined " + f"({self._metric_defs})" + ) + continue + + for val, metric in zip(trial.values, self._metric_defs): + direction = metric.direction.name.capitalize() + best += f"{metric.name} ({direction}):" + best += f"{round(val, 5)}, " + + best = best[:-2] trial_strs += [ f"\t[trial-{trial.number + 1}] run: {run_id}, state: " f"{trial.state.name}, best: {best or 'None'}" @@ -215,6 +223,44 @@ def formatted_trials(self) -> str: return "\n".join(trial_strs[-10:]) # only print out last 10 + def _get_metric_names_and_directions(self) -> List[Metric]: + """Helper to configure dict of at least one metric. + + Dict contains the metric names as keys, with the optimization + direction (or goal) as the value + """ + # if single-objective, just top level metric is set + if self._sweep_config.get("metric"): + direction = ( + optuna.study.StudyDirection.MINIMIZE + if self._sweep_config["metric"]["goal"] == "minimize" + else optuna.study.StudyDirection.MAXIMIZE + ) + metric = Metric( + name=self._sweep_config["metric"]["name"], direction=direction + ) + return [metric] + + # multi-objective optimization + metric_defs = [] + for metric in self._optuna_config.get("metrics", []): + if not metric.get("name"): + raise SchedulerError("Optuna metric missing name") + if not metric.get("goal"): + raise SchedulerError("Optuna metric missing goal") + + direction = ( + optuna.study.StudyDirection.MINIMIZE + if metric["goal"] == "minimize" + else optuna.study.StudyDirection.MAXIMIZE + ) + metric_defs += [Metric(name=metric["name"], direction=direction)] + + if len(metric_defs) == 0: + raise SchedulerError("Optuna sweep missing metric") + + return metric_defs + def _validate_optuna_study(self, study: optuna.Study) -> Optional[str]: """Accepts an optuna study, runs validation. @@ -398,13 +444,9 @@ def _load_optuna(self) -> None: else: wandb.termlog(f"{LOG_PREFIX}No sampler args, defaulting to TPESampler") - if len(self.metric_names) == 1: - directions = [self._sweep_config.get("metric", {}).get("goal")] - else: - directions = [x["goal"] for x in self._optuna_config["metrics"]] - self._storage_path = existing_storage or OptunaComponents.storage.value - if len(self.metric_names) == 1: + directions = [metric.direction for metric in self._metric_defs] + if len(directions) == 1: self._study = optuna.create_study( study_name=self.study_name, storage=f"sqlite:///{self._storage_path}", @@ -508,16 +550,17 @@ def _get_run_history(self, run_id: str) -> List[int]: logger.debug(f"Failed to poll run from public api: {str(e)}") return [] - history = api_run.scan_history(keys=self.metric_names + ["_step"]) + names = [metric.name for metric in self._metric_defs] + history = api_run.scan_history(keys=names + ["_step"]) metrics = [] for log in history: - metrics += [tuple(log.get(key) for key in self.metric_names)] + metrics += [tuple(log.get(key) for key in names)] if len(metrics) == 0 and api_run.lastHistoryStep > -1: logger.debug("No metrics, but lastHistoryStep exists") wandb.termwarn( f"{LOG_PREFIX}Detected logged metrics, but none matching " - + f"provided metric name(s): '{self.metric_names}'" + + f"provided metric name(s): '{names}'" ) return metrics @@ -526,13 +569,13 @@ def _poll_run(self, orun: OptunaRun) -> bool: """Polls metrics for a run, returns true if finished.""" metrics = self._get_run_history(orun.sweep_run.id) if not self.is_multi_objective: # can't report to trial when multi - for i, metric in enumerate(metrics[orun.num_metrics :]): + for i, metric_val in enumerate(metrics[orun.num_metrics :]): logger.debug( f"{orun.sweep_run.id} (step:{i+orun.num_metrics}) {metrics}" ) prev = orun.trial._cached_frozen_trial.intermediate_values if orun.num_metrics + i not in prev: - orun.trial.report(metric, orun.num_metrics + i) + orun.trial.report(metric_val, orun.num_metrics + i) if orun.trial.should_prune(): wandb.termlog( @@ -558,7 +601,7 @@ def _poll_run(self, orun: OptunaRun) -> bool: # run finished correctly, but never logged a metric wandb.termwarn( f"{LOG_PREFIX}Run ({orun.sweep_run.id}) never logged metric: " - + f"'{self._sweep_config['metric']['name']}'. Check your sweep " + + f"'{self._metric_defs[0].name}'. Check your sweep " + "config and training script." ) self._num_misconfigured_runs += 1 @@ -588,7 +631,7 @@ def _poll_run(self, orun: OptunaRun) -> bool: ) wandb.termlog( f"{LOG_PREFIX}Completing trial for run ({orun.sweep_run.id}) " - f"[last metric{'s' if len(self.metric_names) > 1 else ''}: {last_value}" + f"[last metric{'s' if self.is_multi_objective else ''}: {last_value}" f", total: {orun.num_metrics}]" ) From 106e1088ddd66d40f71d2dd99504c420def79739 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 27 Jun 2023 10:26:17 -0700 Subject: [PATCH 11/17] small refactor --- .../optuna_scheduler/optuna_scheduler.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 1747fe8..92eeb2d 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Tuple import click +import joblib import optuna import wandb from wandb.apis.internal import Api @@ -97,18 +98,20 @@ def _handle_job_logic(run, name, enable_git=False) -> None: ) tag = os.environ.get("WANDB_DOCKER", "").split(":") if len(tag) == 2: - jobstr += f"-{tag[0].replace('/', '_')}_{tag[-1]}:latest" - else: + jobstr += f"-{tag[0].replace('/', '_')}:{tag[-1]}" + elif len(tag) == 1: + jobstr += f"-{tag[0].replace('/', '_')}:latest" + else: # unknown format jobstr = f"found here: https://wandb.ai/{jobstr}s/" wandb.termlog(f"Creating image job {click.style(jobstr, fg='yellow')}\n") elif not enable_git: jobstr += f"-{name}:latest" wandb.termlog( - f"Creating code-artifact job: {click.style(jobstr, fg='yellow')}\n" + f"Creating artifact job: {click.style(jobstr, fg='yellow')}\n" ) else: _s = click.style(f"https://wandb.ai/{jobstr}s/", fg="yellow") - wandb.termlog(f"Creating repo-artifact job found here: {_s}\n") + wandb.termlog(f"Creating repo job found here: {_s}\n") run.log_code(name=name, exclude_fn=lambda x: x.startswith("_")) return @@ -227,7 +230,7 @@ def _get_metric_names_and_directions(self) -> List[Metric]: """Helper to configure dict of at least one metric. Dict contains the metric names as keys, with the optimization - direction (or goal) as the value + direction (or goal) as the value (type: optuna.study.StudyDirection) """ # if single-objective, just top level metric is set if self._sweep_config.get("metric"): @@ -484,17 +487,19 @@ def _load_state(self) -> None: def _save_state(self) -> None: """Called when Scheduler class invokes exit(). - Save optuna study sqlite data to an artifact in the controller run + Save optuna study, or sqlite data to an artifact in the scheduler run """ - artifact = wandb.Artifact( - f"{OptunaComponents.storage.name}-{self._sweep_id}", type="optuna" - ) + if not self._study: # nothing to save + return None + + artifact_name = f"{OptunaComponents.storage.name}-{self._sweep_id}" if not self._storage_path: - wandb.termwarn( - f"{LOG_PREFIX}No db storage path found, saving to default path" - ) - self._storage_path = OptunaComponents.storage.value + wandb.termwarn(f"{LOG_PREFIX}No db storage path found, saving full model") + artifact_name = f"optuna-study-{self._sweep_id}" + joblib.dump(self.study, f"study-{self._sweep_id}.pkl") + self._storage_path = f"study-{self._sweep_id}.pkl" + artifact = wandb.Artifact(artifact_name, type="optuna") artifact.add_file(self._storage_path) self._wandb_run.log_artifact(artifact) @@ -505,6 +510,7 @@ def _save_state(self) -> None: return def _get_next_sweep_run(self, worker_id: int) -> Optional[SweepRun]: + """Called repeatedly in the polling loop, whenever a worker is available.""" config, trial = self._trial_func() run: dict = self._api.upsert_run( project=self._project, From 9999126d6b7975357a49d9366ba0a1914761fa37 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Wed, 5 Jul 2023 11:21:22 -0700 Subject: [PATCH 12/17] add joblib to reqs --- jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py | 2 +- jobs/sweep_schedulers/optuna_scheduler/requirements.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 92eeb2d..32894e2 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -111,7 +111,7 @@ def _handle_job_logic(run, name, enable_git=False) -> None: ) else: _s = click.style(f"https://wandb.ai/{jobstr}s/", fg="yellow") - wandb.termlog(f"Creating repo job found here: {_s}\n") + wandb.termlog(f"Creating git repo job found here: {_s}\n") run.log_code(name=name, exclude_fn=lambda x: x.startswith("_")) return diff --git a/jobs/sweep_schedulers/optuna_scheduler/requirements.txt b/jobs/sweep_schedulers/optuna_scheduler/requirements.txt index a312fc9..04f1524 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/requirements.txt +++ b/jobs/sweep_schedulers/optuna_scheduler/requirements.txt @@ -1,3 +1,4 @@ wandb optuna scipy +joblib From 193a1872778beeec1a10f0a3db8d662890e857e6 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 1 Aug 2023 08:32:32 -0700 Subject: [PATCH 13/17] updated for new sdk version --- .../optuna_scheduler/optuna_scheduler.py | 44 ++++--------------- 1 file changed, 9 insertions(+), 35 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 32894e2..708bdfe 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -18,7 +18,7 @@ from wandb.apis.internal import Api from wandb.apis.public import Api as PublicApi from wandb.apis.public import QueuedRun, Run -from wandb.sdk.artifacts.public_artifact import Artifact +from wandb.sdk.artifacts.artifact import Artifact from wandb.sdk.launch.sweeps import SchedulerError from wandb.sdk.launch.sweeps.scheduler import RunState, Scheduler, SweepRun @@ -62,20 +62,24 @@ def setup_scheduler(scheduler: Scheduler, **kwargs): parser.add_argument("--project", type=str, default=kwargs.get("project")) parser.add_argument("--entity", type=str, default=kwargs.get("entity")) parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--name", type=str, default=None) + parser.add_argument("--name", type=str, default=f"job-{scheduler.__name__}") parser.add_argument("--enable_git", action="store_true", default=False) cli_args = parser.parse_args() - name = cli_args.name or scheduler.__name__ + settings = {"job_name": cli_args.name} + if cli_args.enable_git: + settings.update({"disable_git": True}) + run = wandb.init( - settings={"disable_git": True} if not cli_args.enable_git else {}, + settings=settings, project=cli_args.project, entity=cli_args.entity, ) config = run.config if not config.get("sweep_args", {}).get("sweep_id"): - _handle_job_logic(run, name, cli_args.enable_git) + # not a sweep, just finish the run and return + run.finish() return args = config.get("sweep_args", {}) @@ -86,36 +90,6 @@ def setup_scheduler(scheduler: Scheduler, **kwargs): _scheduler.start() -def _handle_job_logic(run, name, enable_git=False) -> None: - wandb.termlog( - "\nJob not configured to run a sweep, logging code and returning early." - ) - jobstr = f"{run.entity}/{run.project}/job" - - if os.environ.get("WANDB_DOCKER"): - wandb.termlog( - "Identified 'WANDB_DOCKER' environment var, creating image job..." - ) - tag = os.environ.get("WANDB_DOCKER", "").split(":") - if len(tag) == 2: - jobstr += f"-{tag[0].replace('/', '_')}:{tag[-1]}" - elif len(tag) == 1: - jobstr += f"-{tag[0].replace('/', '_')}:latest" - else: # unknown format - jobstr = f"found here: https://wandb.ai/{jobstr}s/" - wandb.termlog(f"Creating image job {click.style(jobstr, fg='yellow')}\n") - elif not enable_git: - jobstr += f"-{name}:latest" - wandb.termlog( - f"Creating artifact job: {click.style(jobstr, fg='yellow')}\n" - ) - else: - _s = click.style(f"https://wandb.ai/{jobstr}s/", fg="yellow") - wandb.termlog(f"Creating git repo job found here: {_s}\n") - run.log_code(name=name, exclude_fn=lambda x: x.startswith("_")) - return - - class OptunaScheduler(Scheduler): OPT_TIMEOUT = 2 MAX_MISCONFIGURED_RUNS = 3 From 2487be2312ac24ca1baa416e70a2c502788db1b3 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 1 Aug 2023 08:50:19 -0700 Subject: [PATCH 14/17] no job changes --- jobs/fashion_mnist_train/job.py | 65 ++++++++------------------------- 1 file changed, 15 insertions(+), 50 deletions(-) diff --git a/jobs/fashion_mnist_train/job.py b/jobs/fashion_mnist_train/job.py index 3f35815..e035a52 100644 --- a/jobs/fashion_mnist_train/job.py +++ b/jobs/fashion_mnist_train/job.py @@ -6,34 +6,18 @@ import matplotlib.pyplot as plt import numpy as np import wandb - # To load the mnist data from keras.datasets import fashion_mnist -from keras.models import load_model - # importing various types of hidden layers from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D from tensorflow.keras.models import Sequential - # Adam legacy for m1/m2 macs from tensorflow.keras.optimizers.legacy import Adam -from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint - - -def train( - project: Optional[str], - entity: Optional[str], - epochs: Optional[int], - learning_rate: Optional[float], - steps_per_epoch: Optional[int], - **kwargs: Any -): - config = { - "epochs": epochs, - "learning_rate": learning_rate, - "steps_per_epoch": steps_per_epoch, - } - run = wandb.init(project=project, entity=entity, config=config, resume=True) +from wandb.keras import WandbMetricsLogger + + +def train(project: Optional[str], entity: Optional[str], **kwargs: Any): + run = wandb.init(project=project, entity=entity) # get config, could be set from sweep scheduler train_config = run.config @@ -67,17 +51,13 @@ def train( train_X = np.expand_dims(train_X, -1).astype(np.float32) test_X = np.expand_dims(test_X, -1) - # load model from checkpoint or create new model - ckpt = get_checkpoint() - if ckpt: - model = ckpt - else: - model = model_arch() - model.compile( - optimizer=Adam(learning_rate=learning_rate), - loss="sparse_categorical_crossentropy", - metrics=["sparse_categorical_accuracy"], - ) + # load model + model = model_arch() + model.compile( + optimizer=Adam(learning_rate=learning_rate), + loss="sparse_categorical_crossentropy", + metrics=["sparse_categorical_accuracy"], + ) model.summary() model.fit( @@ -86,8 +66,9 @@ def train( epochs=epochs, steps_per_epoch=steps_per_epoch, validation_split=0.33, - callbacks=[WandbMetricsLogger(), WandbModelCheckpoint(filepath="model.h5")], - initial_epoch=wandb.run.step, + callbacks=[ + WandbMetricsLogger(), + ], ) # do some manual testing @@ -154,26 +135,10 @@ def model_arch(): return models -def get_checkpoint(): - assert wandb.run - api = wandb.Api() - run = api.run(wandb.run.path) - for artifact in run.logged_artifacts(): - if artifact.type == "model": - name = artifact.source_qualified_name - name = name.split(":")[0] + ":latest" - artifact = api.artifact(name) - path = artifact.download(root=".") - return load_model(path + "/model.h5") - - def main(): parser = argparse.ArgumentParser() parser.add_argument("--entity", "-e", type=str, default=None) parser.add_argument("--project", "-p", type=str, default=None) - parser.add_argument("--epochs", type=int, default=10) - parser.add_argument("--learning_rate", type=float, default=0.001) - parser.add_argument("--steps_per_epoch", type=int, default=100) args = parser.parse_args() train(**vars(args)) From 312de6ef9bbb30fb1f43b0b4ee566f6b22bc1332 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 1 Aug 2023 09:12:59 -0700 Subject: [PATCH 15/17] metric conditioned on multi --- jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 708bdfe..144d847 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -534,7 +534,10 @@ def _get_run_history(self, run_id: str) -> List[int]: history = api_run.scan_history(keys=names + ["_step"]) metrics = [] for log in history: - metrics += [tuple(log.get(key) for key in names)] + if self.is_multi_objective: + metrics += [tuple(log.get(key) for key in names)] + else: + metrics += [log.get(names[0])] if len(metrics) == 0 and api_run.lastHistoryStep > -1: logger.debug("No metrics, but lastHistoryStep exists") From de032a5faf80ce7dde84b5aca591e08067587a3f Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 1 Aug 2023 13:29:41 -0700 Subject: [PATCH 16/17] remove joblib --- .../optuna_scheduler/optuna_scheduler.py | 9 +-------- jobs/sweep_schedulers/optuna_scheduler/requirements.txt | 1 - 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index 144d847..ef45319 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional, Tuple import click -import joblib import optuna import wandb from wandb.apis.internal import Api @@ -463,16 +462,10 @@ def _save_state(self) -> None: Save optuna study, or sqlite data to an artifact in the scheduler run """ - if not self._study: # nothing to save + if not self._study or self._storage_path: # nothing to save return None artifact_name = f"{OptunaComponents.storage.name}-{self._sweep_id}" - if not self._storage_path: - wandb.termwarn(f"{LOG_PREFIX}No db storage path found, saving full model") - artifact_name = f"optuna-study-{self._sweep_id}" - joblib.dump(self.study, f"study-{self._sweep_id}.pkl") - self._storage_path = f"study-{self._sweep_id}.pkl" - artifact = wandb.Artifact(artifact_name, type="optuna") artifact.add_file(self._storage_path) self._wandb_run.log_artifact(artifact) diff --git a/jobs/sweep_schedulers/optuna_scheduler/requirements.txt b/jobs/sweep_schedulers/optuna_scheduler/requirements.txt index 04f1524..a312fc9 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/requirements.txt +++ b/jobs/sweep_schedulers/optuna_scheduler/requirements.txt @@ -1,4 +1,3 @@ wandb optuna scipy -joblib From d547ef12fd3faa2cae6771d85209c8c6d9071b34 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Thu, 3 Aug 2023 14:47:50 -0700 Subject: [PATCH 17/17] simplications and review comments --- .../optuna_scheduler/optuna_scheduler.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py index ef45319..514a41d 100644 --- a/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py +++ b/jobs/sweep_schedulers/optuna_scheduler/optuna_scheduler.py @@ -62,26 +62,26 @@ def setup_scheduler(scheduler: Scheduler, **kwargs): parser.add_argument("--entity", type=str, default=kwargs.get("entity")) parser.add_argument("--num_workers", type=int, default=1) parser.add_argument("--name", type=str, default=f"job-{scheduler.__name__}") - parser.add_argument("--enable_git", action="store_true", default=False) cli_args = parser.parse_args() settings = {"job_name": cli_args.name} - if cli_args.enable_git: - settings.update({"disable_git": True}) - run = wandb.init( settings=settings, project=cli_args.project, entity=cli_args.entity, ) config = run.config + args = config.get("sweep_args", {}) - if not config.get("sweep_args", {}).get("sweep_id"): - # not a sweep, just finish the run and return + if not args or not args.get("sweep_id"): + # when the config has no sweep args, this is being run directly from CLI + # and not in a sweep. Just log the code and return + if not os.getenv("WANDB_DOCKER"): + # if not docker, log the code to a git or code artifact + run.log_code(root=os.path.dirname(__file__)) run.finish() return - args = config.get("sweep_args", {}) if cli_args.num_workers: # override kwargs.update({"num_workers": cli_args.num_workers}) @@ -161,6 +161,9 @@ def formatted_trials(self) -> str: trial_strs = [] for trial in self.study.trials: + if not trial.values: + continue + run_id = trial.user_attrs["run_id"] best: str = "" if not self.is_multi_objective: @@ -175,14 +178,12 @@ def formatted_trials(self) -> str: f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" ] else: # multi-objective optimization, only 1 metric logged in study - if not trial.values: - continue - if len(trial.values) != len(self._metric_defs): wandb.termwarn( - f"{LOG_PREFIX}Number of trial metrics ({trial.values})" + f"{LOG_PREFIX}Number of logged metrics ({trial.values})" " does not match number of metrics defined " - f"({self._metric_defs})" + f"({self._metric_defs}). Specify metrics for optimization" + " in the scheduler.settings.metrics portion of the sweep config" ) continue @@ -191,10 +192,11 @@ def formatted_trials(self) -> str: best += f"{metric.name} ({direction}):" best += f"{round(val, 5)}, " + # trim trailing comma and space best = best[:-2] trial_strs += [ f"\t[trial-{trial.number + 1}] run: {run_id}, state: " - f"{trial.state.name}, best: {best or 'None'}" + f"{trial.state.name}, best: {best}" ] return "\n".join(trial_strs[-10:]) # only print out last 10 @@ -233,7 +235,10 @@ def _get_metric_names_and_directions(self) -> List[Metric]: metric_defs += [Metric(name=metric["name"], direction=direction)] if len(metric_defs) == 0: - raise SchedulerError("Optuna sweep missing metric") + raise SchedulerError( + "Zero metrics found in the top level 'metric' section " + "and multi-objective metric section scheduler.settings.metrics" + ) return metric_defs @@ -276,7 +281,7 @@ def _load_optuna_classes( mod, err = _get_module("optuna", filepath) if not mod: raise SchedulerError( - f"Failed to load optuna from path {filepath} " f" with error: {err}" + f"Failed to load optuna from path {filepath} with error: {err}" ) # Set custom optuna trial creation method