From 5192a2d5500c879bcf444e6e7da906fc7968e03e Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Wed, 30 Oct 2024 12:34:01 -0700 Subject: [PATCH] Remove custom `optimizer_argparse` case for qKG Summary: Context: It appears that using `qKnowledgeGradient` with MBM doesn't work, since [this line](https://github.com/facebook/Ax/blob/535af4edff70cbf20a49c676377f5c8945560d03/ax/models/torch/botorch_modular/acquisition.py#L339) passes the argument `optimizer` to `_argparse_kg`, which errors [here](https://github.com/facebook/Ax/blob/535af4edff70cbf20a49c676377f5c8945560d03/ax/models/torch/botorch_modular/optimizer_argparse.py#L169) because it has now received the argument "optimizer" twice. We don't really need the `optimizer_argparse` special case for qKG anymore. This existed for two reasons: - in order to construct initial conditions, which can be handled by `optimize_acqf`, and - to ensure that the optimizer is `optimize_acqf`, because others are not supported This diff: * Modifies the `optimize_argparse` case for qKG to do nothing except update the optimizer to `optimize_acqf` and then call the base case Implementation notes: * Isn't it nonintuitive to set the optimizer then override it? Yes, a little, but the user can't choose the optimizer, so we're not overriding a user-specified choice. Also, lots of arguments to the `optimizer_argparse` functions get ignored. The "right" thing might be to put the choice of optimizer inside a dispatcher so that it can depend on the acquisition class, but that would be a bigger change. * Do we really need this dispatcher anymore, if it is doing so little? Yes, third parties may wish to use it to accommodate acquisition functions that are not in Ax. * Do the `optimize_argparse` functions still need to support so many arguments, given that some of them seemed to just be there for constructing initial conditions? Probably not; I mean to look into that in a follow-up. Differential Revision: D65227420 --- ax/benchmark/tests/test_benchmark.py | 12 +++++ .../botorch_modular/optimizer_argparse.py | 45 ++++++------------- .../torch/tests/test_optimizer_argparse.py | 43 +++++++++--------- 3 files changed, 49 insertions(+), 51 deletions(-) diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 80347724bcb..224160782cc 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -42,6 +42,7 @@ ) from ax.utils.testing.core_stubs import get_experiment from ax.utils.testing.mock import mock_botorch_optimize +from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.logei import qLogNoisyExpectedImprovement from botorch.acquisition.multi_objective.logei import ( qLogNoisyExpectedHypervolumeImprovement, @@ -323,6 +324,17 @@ def test_replication_mbm(self) -> None: mnist_problem, "MBM::SingleTaskGP_qLogNEI", ), + ( + get_sobol_botorch_modular_acquisition( + model_cls=SingleTaskGP, + acquisition_cls=qKnowledgeGradient, + distribute_replications=False, + ), + get_single_objective_benchmark_problem( + observe_noise_sd=False, num_trials=6 + ), + "MBM::SingleTaskGP_qKnowledgeGradient", + ), ]: with self.subTest(method=method, problem=problem): res = benchmark_replication(problem=problem, method=method, seed=0) diff --git a/ax/models/torch/botorch_modular/optimizer_argparse.py b/ax/models/torch/botorch_modular/optimizer_argparse.py index 03813a32576..239cdd3b1fb 100644 --- a/ax/models/torch/botorch_modular/optimizer_argparse.py +++ b/ax/models/torch/botorch_modular/optimizer_argparse.py @@ -12,11 +12,9 @@ import torch from ax.exceptions.core import UnsupportedError -from ax.utils.common.constants import Keys from ax.utils.common.typeutils import _argparse_type_encoder from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.knowledge_gradient import qKnowledgeGradient -from botorch.optim.initializers import gen_one_shot_kg_initial_conditions from botorch.utils.dispatcher import Dispatcher T = TypeVar("T") @@ -146,44 +144,29 @@ def _argparse_base( @optimizer_argparse.register(qKnowledgeGradient) def _argparse_kg( acqf: qKnowledgeGradient, - q: int, - bounds: torch.Tensor, + sequential: bool = True, num_restarts: int = NUM_RESTARTS, raw_samples: int = RAW_SAMPLES, - frac_random: float = 0.1, + init_batch_limit: int = INIT_BATCH_LIMIT, + batch_limit: int = BATCH_LIMIT, optimizer_options: dict[str, Any] | None = None, - **kwargs: Any, + **ignore: Any, ) -> dict[str, Any]: """ Argument constructor for optimization with qKG, differing from the - base case in that it computes and returns initial conditions. + base case in that it enforces that the optimizer is always `optimize_acqf` + is always `optimize_acqf`. - To do so, it requires specifying additional arguments `q` and `bounds` and - allows for specifying `frac_random`. + Arguments include those passed to `_argparse_base`, except for `optimizer`, + which will be ignored if passed. """ - base_options = _argparse_base( - acqf, - num_restarts=num_restarts, - raw_samples=raw_samples, - optimizer_options=optimizer_options, + return _argparse_base( + acqf=acqf, optimizer="optimize_acqf", - **kwargs, - ) - - initial_conditions = gen_one_shot_kg_initial_conditions( - acq_function=acqf, - bounds=bounds, - q=q, + sequential=sequential, num_restarts=num_restarts, raw_samples=raw_samples, - options={ - Keys.FRAC_RANDOM: frac_random, - Keys.NUM_INNER_RESTARTS: num_restarts, - Keys.RAW_INNER_SAMPLES: raw_samples, - }, + init_batch_limit=init_batch_limit, + batch_limit=batch_limit, + optimizer_options=optimizer_options, ) - - return { - **base_options, - Keys.BATCH_INIT_CONDITIONS: initial_conditions, - } diff --git a/ax/models/torch/tests/test_optimizer_argparse.py b/ax/models/torch/tests/test_optimizer_argparse.py index 7c01cdad4cb..22c2296a726 100644 --- a/ax/models/torch/tests/test_optimizer_argparse.py +++ b/ax/models/torch/tests/test_optimizer_argparse.py @@ -182,23 +182,26 @@ def test_optimizer_options(self) -> None: func(None, sequential=False, optimizer=optimizer) def test_kg(self) -> None: - with patch( - "botorch.optim.initializers.gen_one_shot_kg_initial_conditions" - ) as mock_gen_initial_conditions: - mock_gen_initial_conditions.return_value = "TEST" - reload(Argparse) - - user_options = {"foo": "bar", "num_restarts": 114} - generic_options = _argparse_base( - None, optimizer_options=user_options, optimizer="optimize_acqf" - ) - for acqf in (qKnowledgeGradient, qMultiFidelityKnowledgeGradient): - with self.subTest(acqf=acqf): - options = optimizer_argparse( - acqf, - q=None, - bounds=None, - optimizer_options=user_options, - ) - self.assertEqual(options.pop(Keys.BATCH_INIT_CONDITIONS), "TEST") - self.assertEqual(options, generic_options) + user_options = {"foo": "bar", "num_restarts": 114} + generic_options = _argparse_base( + None, optimizer_options=user_options, optimizer="optimize_acqf" + ) + for acqf in (qKnowledgeGradient, qMultiFidelityKnowledgeGradient): + with self.subTest(acqf=acqf): + options = optimizer_argparse( + acqf, + q=None, + bounds=None, + optimizer_options=user_options, + ) + self.assertEqual(options, generic_options) + + # check that optimizer other than `optimize_acqf` is overridden + options = optimizer_argparse( + acqf, + q=None, + bounds=None, + optimizer_options=user_options, + optimizer="happy birthday", + ) + self.assertEqual(options, generic_options)