Skip to content

Commit

Permalink
Remove custom optimizer_argparse case for qKG
Browse files Browse the repository at this point in the history
Summary:
Context: It appears that using `qKnowledgeGradient` with MBM doesn't work, since [this line](https://github.com/facebook/Ax/blob/535af4edff70cbf20a49c676377f5c8945560d03/ax/models/torch/botorch_modular/acquisition.py#L339) passes the argument `optimizer` to `_argparse_kg`, which errors [here](https://github.com/facebook/Ax/blob/535af4edff70cbf20a49c676377f5c8945560d03/ax/models/torch/botorch_modular/optimizer_argparse.py#L169) because it has now received the argument "optimizer" twice.

We don't really need the `optimizer_argparse` special case for qKG anymore. This existed for two reasons:
- in order to construct initial conditions, which can be handled by `optimize_acqf`, and
- to ensure that the optimizer is `optimize_acqf`, because others are not supported

This diff:
* Modifies the `optimize_argparse` case for qKG to do nothing except update the optimizer to `optimize_acqf` and then call the base case

Implementation notes:
* Isn't it nonintuitive to set the optimizer then override it? Yes, a little, but the user can't choose the optimizer, so we're not overriding a user-specified choice. Also, lots of arguments to the `optimizer_argparse` functions get ignored. The "right" thing might be to put the choice of optimizer inside a dispatcher so that it can depend on the acquisition class, but that would be a bigger change.
* Do we really need this dispatcher anymore, if it is doing so little? Yes, third parties may wish to use it to accommodate acquisition functions that are not in Ax.
* Do the `optimize_argparse` functions still need to support so many arguments, given that some of them seemed to just be there for constructing initial conditions? Probably not; I mean to look into that in a follow-up.

Differential Revision: D65227420
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 30, 2024
1 parent 535af4e commit 5192a2d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 51 deletions.
12 changes: 12 additions & 0 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from ax.utils.testing.core_stubs import get_experiment
from ax.utils.testing.mock import mock_botorch_optimize
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.acquisition.multi_objective.logei import (
qLogNoisyExpectedHypervolumeImprovement,
Expand Down Expand Up @@ -323,6 +324,17 @@ def test_replication_mbm(self) -> None:
mnist_problem,
"MBM::SingleTaskGP_qLogNEI",
),
(
get_sobol_botorch_modular_acquisition(
model_cls=SingleTaskGP,
acquisition_cls=qKnowledgeGradient,
distribute_replications=False,
),
get_single_objective_benchmark_problem(
observe_noise_sd=False, num_trials=6
),
"MBM::SingleTaskGP_qKnowledgeGradient",
),
]:
with self.subTest(method=method, problem=problem):
res = benchmark_replication(problem=problem, method=method, seed=0)
Expand Down
45 changes: 14 additions & 31 deletions ax/models/torch/botorch_modular/optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@

import torch
from ax.exceptions.core import UnsupportedError
from ax.utils.common.constants import Keys
from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.optim.initializers import gen_one_shot_kg_initial_conditions
from botorch.utils.dispatcher import Dispatcher

T = TypeVar("T")
Expand Down Expand Up @@ -146,44 +144,29 @@ def _argparse_base(
@optimizer_argparse.register(qKnowledgeGradient)
def _argparse_kg(
acqf: qKnowledgeGradient,
q: int,
bounds: torch.Tensor,
sequential: bool = True,
num_restarts: int = NUM_RESTARTS,
raw_samples: int = RAW_SAMPLES,
frac_random: float = 0.1,
init_batch_limit: int = INIT_BATCH_LIMIT,
batch_limit: int = BATCH_LIMIT,
optimizer_options: dict[str, Any] | None = None,
**kwargs: Any,
**ignore: Any,
) -> dict[str, Any]:
"""
Argument constructor for optimization with qKG, differing from the
base case in that it computes and returns initial conditions.
base case in that it enforces that the optimizer is always `optimize_acqf`
is always `optimize_acqf`.
To do so, it requires specifying additional arguments `q` and `bounds` and
allows for specifying `frac_random`.
Arguments include those passed to `_argparse_base`, except for `optimizer`,
which will be ignored if passed.
"""
base_options = _argparse_base(
acqf,
num_restarts=num_restarts,
raw_samples=raw_samples,
optimizer_options=optimizer_options,
return _argparse_base(
acqf=acqf,
optimizer="optimize_acqf",
**kwargs,
)

initial_conditions = gen_one_shot_kg_initial_conditions(
acq_function=acqf,
bounds=bounds,
q=q,
sequential=sequential,
num_restarts=num_restarts,
raw_samples=raw_samples,
options={
Keys.FRAC_RANDOM: frac_random,
Keys.NUM_INNER_RESTARTS: num_restarts,
Keys.RAW_INNER_SAMPLES: raw_samples,
},
init_batch_limit=init_batch_limit,
batch_limit=batch_limit,
optimizer_options=optimizer_options,
)

return {
**base_options,
Keys.BATCH_INIT_CONDITIONS: initial_conditions,
}
43 changes: 23 additions & 20 deletions ax/models/torch/tests/test_optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,23 +182,26 @@ def test_optimizer_options(self) -> None:
func(None, sequential=False, optimizer=optimizer)

def test_kg(self) -> None:
with patch(
"botorch.optim.initializers.gen_one_shot_kg_initial_conditions"
) as mock_gen_initial_conditions:
mock_gen_initial_conditions.return_value = "TEST"
reload(Argparse)

user_options = {"foo": "bar", "num_restarts": 114}
generic_options = _argparse_base(
None, optimizer_options=user_options, optimizer="optimize_acqf"
)
for acqf in (qKnowledgeGradient, qMultiFidelityKnowledgeGradient):
with self.subTest(acqf=acqf):
options = optimizer_argparse(
acqf,
q=None,
bounds=None,
optimizer_options=user_options,
)
self.assertEqual(options.pop(Keys.BATCH_INIT_CONDITIONS), "TEST")
self.assertEqual(options, generic_options)
user_options = {"foo": "bar", "num_restarts": 114}
generic_options = _argparse_base(
None, optimizer_options=user_options, optimizer="optimize_acqf"
)
for acqf in (qKnowledgeGradient, qMultiFidelityKnowledgeGradient):
with self.subTest(acqf=acqf):
options = optimizer_argparse(
acqf,
q=None,
bounds=None,
optimizer_options=user_options,
)
self.assertEqual(options, generic_options)

# check that optimizer other than `optimize_acqf` is overridden
options = optimizer_argparse(
acqf,
q=None,
bounds=None,
optimizer_options=user_options,
optimizer="happy birthday",
)
self.assertEqual(options, generic_options)

0 comments on commit 5192a2d

Please sign in to comment.