Skip to content

Commit

Permalink
Remove custom optimizer_argparse case for qKG (#2997)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2997

Context: It appears that using `qKnowledgeGradient` with MBM doesn't work, since [this line](https://github.com/facebook/Ax/blob/535af4edff70cbf20a49c676377f5c8945560d03/ax/models/torch/botorch_modular/acquisition.py#L339) passes the argument `optimizer` to `_argparse_kg`, which errors [here](https://github.com/facebook/Ax/blob/535af4edff70cbf20a49c676377f5c8945560d03/ax/models/torch/botorch_modular/optimizer_argparse.py#L169) because it has now received the argument "optimizer" twice.

We don't really need the `optimizer_argparse` special case for qKG anymore. This existed for two reasons:
- in order to construct initial conditions, which can be handled by `optimize_acqf`, and
- to ensure that the optimizer is `optimize_acqf`, because others are not supported

This diff:
* Modifies the `optimize_argparse` case for qKG to do nothing except update the optimizer to `optimize_acqf` and then call the base case

Implementation notes:
* Isn't it nonintuitive to set the optimizer then override it? Yes, a little, but the user can't choose the optimizer, so we're not overriding a user-specified choice. Also, lots of arguments to the `optimizer_argparse` functions get ignored. The "right" thing might be to put the choice of optimizer inside a dispatcher so that it can depend on the acquisition class, but that would be a bigger change.
* Do we really need this dispatcher anymore, if it is doing so little? Maybe. Third parties may wish to use it to accommodate acquisition functions that are not in Ax. On the other hand, this dispatcher is currently not doing much of anything.
* Do the `optimize_argparse` functions still need to support so many arguments, given that some of them seemed to just be there for constructing initial conditions? Probably not; I mean to look into that in a follow-up.

Reviewed By: saitcakmak

Differential Revision: D65227420

fbshipit-source-id: a211996fa3e41575d11a995625ef9ebd0d1002ab
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 31, 2024
1 parent 46fa5a5 commit 56c91ea
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 56 deletions.
12 changes: 12 additions & 0 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from ax.utils.testing.core_stubs import get_experiment
from ax.utils.testing.mock import mock_botorch_optimize
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.acquisition.multi_objective.logei import (
qLogNoisyExpectedHypervolumeImprovement,
Expand Down Expand Up @@ -323,6 +324,17 @@ def test_replication_mbm(self) -> None:
mnist_problem,
"MBM::SingleTaskGP_qLogNEI",
),
(
get_sobol_botorch_modular_acquisition(
model_cls=SingleTaskGP,
acquisition_cls=qKnowledgeGradient,
distribute_replications=False,
),
get_single_objective_benchmark_problem(
observe_noise_sd=False, num_trials=6
),
"MBM::SingleTaskGP_qKnowledgeGradient",
),
]:
with self.subTest(method=method, problem=problem):
res = benchmark_replication(problem=problem, method=method, seed=0)
Expand Down
55 changes: 22 additions & 33 deletions ax/models/torch/botorch_modular/optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@

from typing import Any, TypeVar, Union

import torch
from ax.exceptions.core import UnsupportedError
from ax.utils.common.constants import Keys
from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.optim.initializers import gen_one_shot_kg_initial_conditions
from botorch.utils.dispatcher import Dispatcher

T = TypeVar("T")
Expand Down Expand Up @@ -146,44 +143,36 @@ def _argparse_base(
@optimizer_argparse.register(qKnowledgeGradient)
def _argparse_kg(
acqf: qKnowledgeGradient,
q: int,
bounds: torch.Tensor,
*,
optimizer: str = "optimize_acqf",
sequential: bool = True,
num_restarts: int = NUM_RESTARTS,
raw_samples: int = RAW_SAMPLES,
frac_random: float = 0.1,
init_batch_limit: int = INIT_BATCH_LIMIT,
batch_limit: int = BATCH_LIMIT,
optimizer_options: dict[str, Any] | None = None,
**kwargs: Any,
**ignore: Any,
) -> dict[str, Any]:
"""
Argument constructor for optimization with qKG, differing from the
base case in that it computes and returns initial conditions.
To do so, it requires specifying additional arguments `q` and `bounds` and
allows for specifying `frac_random`.
base case in that it errors if the optimizer is not `optimize_acqf`.
"""
base_options = _argparse_base(
acqf,
num_restarts=num_restarts,
raw_samples=raw_samples,
optimizer_options=optimizer_options,
if optimizer != "optimize_acqf":
raise RuntimeError(
"Ax is attempting to use a discrete or mixed optimizer, "
f"`{optimizer}`, but this is not compatible with "
"`qKnowledgeGradient` or its subclasses. To address this, please "
"either use a different acquisition class or make parameters "
"continuous using the transform "
"`ax.modelbridge.registry.Cont_X_trans`."
)
return _argparse_base(
acqf=acqf,
optimizer="optimize_acqf",
**kwargs,
)

initial_conditions = gen_one_shot_kg_initial_conditions(
acq_function=acqf,
bounds=bounds,
q=q,
sequential=sequential,
num_restarts=num_restarts,
raw_samples=raw_samples,
options={
Keys.FRAC_RANDOM: frac_random,
Keys.NUM_INNER_RESTARTS: num_restarts,
Keys.RAW_INNER_SAMPLES: raw_samples,
},
init_batch_limit=init_batch_limit,
batch_limit=batch_limit,
optimizer_options=optimizer_options,
)

return {
**base_options,
Keys.BATCH_INIT_CONDITIONS: initial_conditions,
}
43 changes: 20 additions & 23 deletions ax/models/torch/tests/test_optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

from __future__ import annotations

from importlib import reload
from itertools import product
from unittest.mock import patch

from ax.exceptions.core import UnsupportedError
from ax.models.torch.botorch_modular import optimizer_argparse as Argparse
from ax.models.torch.botorch_modular.optimizer_argparse import (
_argparse_base,
BATCH_LIMIT,
Expand All @@ -23,7 +21,6 @@
optimizer_argparse,
RAW_SAMPLES,
)
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import LogExpectedImprovement
Expand Down Expand Up @@ -182,23 +179,23 @@ def test_optimizer_options(self) -> None:
func(None, sequential=False, optimizer=optimizer)

def test_kg(self) -> None:
with patch(
"botorch.optim.initializers.gen_one_shot_kg_initial_conditions"
) as mock_gen_initial_conditions:
mock_gen_initial_conditions.return_value = "TEST"
reload(Argparse)

user_options = {"foo": "bar", "num_restarts": 114}
generic_options = _argparse_base(
None, optimizer_options=user_options, optimizer="optimize_acqf"
)
for acqf in (qKnowledgeGradient, qMultiFidelityKnowledgeGradient):
with self.subTest(acqf=acqf):
options = optimizer_argparse(
acqf,
q=None,
bounds=None,
optimizer_options=user_options,
)
self.assertEqual(options.pop(Keys.BATCH_INIT_CONDITIONS), "TEST")
self.assertEqual(options, generic_options)
user_options = {"foo": "bar", "num_restarts": 114}
generic_options = _argparse_base(
None, optimizer_options=user_options, optimizer="optimize_acqf"
)
for acqf in (qKnowledgeGradient, qMultiFidelityKnowledgeGradient):
with self.subTest(acqf=acqf):
options = optimizer_argparse(
acqf,
q=None,
bounds=None,
optimizer_options=user_options,
)
self.assertEqual(options, generic_options)

with self.assertRaisesRegex(
RuntimeError,
"Ax is attempting to use a discrete or mixed optimizer, "
"`optimize_acqf_mixed`, ",
):
optimizer_argparse(acqf, optimizer="optimize_acqf_mixed")

0 comments on commit 56c91ea

Please sign in to comment.