-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(sweeps): optuna supports multi-objective optimization #28
Changes from 1 commit
5d616ce
dca31e8
a4f16eb
e7449ca
a8ff3e2
b90c26d
66d9c7d
afa5aec
5b021e0
6831a80
a601def
106e108
9999126
193a187
e92d10e
2487be2
312de6e
de032a5
d547ef1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,26 +62,26 @@ def setup_scheduler(scheduler: Scheduler, **kwargs): | |
parser.add_argument("--entity", type=str, default=kwargs.get("entity")) | ||
parser.add_argument("--num_workers", type=int, default=1) | ||
parser.add_argument("--name", type=str, default=f"job-{scheduler.__name__}") | ||
parser.add_argument("--enable_git", action="store_true", default=False) | ||
cli_args = parser.parse_args() | ||
|
||
settings = {"job_name": cli_args.name} | ||
if cli_args.enable_git: | ||
settings.update({"disable_git": True}) | ||
|
||
run = wandb.init( | ||
settings=settings, | ||
project=cli_args.project, | ||
entity=cli_args.entity, | ||
) | ||
config = run.config | ||
args = config.get("sweep_args", {}) | ||
|
||
if not config.get("sweep_args", {}).get("sweep_id"): | ||
# not a sweep, just finish the run and return | ||
if not args or not args.get("sweep_id"): | ||
# when the config has no sweep args, this is being run directly from CLI | ||
# and not in a sweep. Just log the code and return | ||
if not os.getenv("WANDB_DOCKER"): | ||
# if not docker, log the code to a git or code artifact | ||
run.log_code(root=os.path.dirname(__file__)) | ||
run.finish() | ||
return | ||
|
||
args = config.get("sweep_args", {}) | ||
if cli_args.num_workers: # override | ||
kwargs.update({"num_workers": cli_args.num_workers}) | ||
|
||
|
@@ -161,6 +161,9 @@ def formatted_trials(self) -> str: | |
|
||
trial_strs = [] | ||
for trial in self.study.trials: | ||
if not trial.values: | ||
continue | ||
|
||
run_id = trial.user_attrs["run_id"] | ||
best: str = "" | ||
if not self.is_multi_objective: | ||
|
@@ -175,14 +178,12 @@ def formatted_trials(self) -> str: | |
f"{trial.state.name}, num-metrics: {len(vals)}, best: {best}" | ||
] | ||
else: # multi-objective optimization, only 1 metric logged in study | ||
if not trial.values: | ||
continue | ||
|
||
if len(trial.values) != len(self._metric_defs): | ||
wandb.termwarn( | ||
gtarpenning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"{LOG_PREFIX}Number of trial metrics ({trial.values})" | ||
f"{LOG_PREFIX}Number of logged metrics ({trial.values})" | ||
" does not match number of metrics defined " | ||
f"({self._metric_defs})" | ||
f"({self._metric_defs}). Specify metrics for optimization" | ||
" in the scheduler.settings.metrics portion of the sweep config" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would a user have done to create this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Im not sure its possible, i included this error message in case. If we had a testing suite.... I'll play around with it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sg |
||
) | ||
continue | ||
|
||
|
@@ -191,10 +192,11 @@ def formatted_trials(self) -> str: | |
best += f"{metric.name} ({direction}):" | ||
best += f"{round(val, 5)}, " | ||
|
||
# trim trailing comma and space | ||
best = best[:-2] | ||
gtarpenning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
trial_strs += [ | ||
f"\t[trial-{trial.number + 1}] run: {run_id}, state: " | ||
f"{trial.state.name}, best: {best or 'None'}" | ||
f"{trial.state.name}, best: {best}" | ||
] | ||
|
||
return "\n".join(trial_strs[-10:]) # only print out last 10 | ||
|
@@ -233,7 +235,10 @@ def _get_metric_names_and_directions(self) -> List[Metric]: | |
metric_defs += [Metric(name=metric["name"], direction=direction)] | ||
|
||
if len(metric_defs) == 0: | ||
raise SchedulerError("Optuna sweep missing metric") | ||
raise SchedulerError( | ||
"Zero metrics found in the top level 'metric' section " | ||
"and multi-objective metric section scheduler.settings.metrics" | ||
) | ||
|
||
return metric_defs | ||
|
||
|
@@ -276,7 +281,7 @@ def _load_optuna_classes( | |
mod, err = _get_module("optuna", filepath) | ||
if not mod: | ||
raise SchedulerError( | ||
f"Failed to load optuna from path {filepath} " f" with error: {err}" | ||
f"Failed to load optuna from path {filepath} with error: {err}" | ||
) | ||
|
||
# Set custom optuna trial creation method | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the user is running the
wandb launch -j <job>
here? Is this specifically done to create the job?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now this would be triggered by running
python optuna_scheduler.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And a user would do that to create the job. Sg