From 535af4edff70cbf20a49c676377f5c8945560d03 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 30 Oct 2024 10:33:29 -0700 Subject: [PATCH] Add `optimize_acqf_mixed_alternating` to `mock_botorch_optimize_context_manager` & reduce duplication with `mock_optimize_context_manager` (#2973) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2973 A previous diff added mixed optimizer to MBM. This diff adds it to optimizer mocks. `mock_botorch_optimize_context_manager` had a good bit of overlap with BoTorch's `mock_optimize_context_manager`, which is also cleaned up in this diff. It now uses `mock_optimize_context_manager` and adds additional mocks on top of that. Reviewed By: paschai Differential Revision: D65067691 fbshipit-source-id: 47185e63e6e462c843d55f29d031be35583d8b05 --- ax/utils/testing/core_stubs.py | 2 +- ax/utils/testing/mock.py | 111 +++++++++++++--------------- ax/utils/testing/tests/test_mock.py | 79 ++++++++++++++++++++ 3 files changed, 132 insertions(+), 60 deletions(-) create mode 100644 ax/utils/testing/tests/test_mock.py diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 79cf5c13996..694410c6678 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -273,7 +273,7 @@ def get_branin_experiment( status_quo=Arm(parameters={"x1": 0.0, "x2": 0.0}) if with_status_quo else None, ) - if with_batch: + if with_batch or with_completed_batch: for _ in range(num_batch_trial): sobol_generator = get_sobol(search_space=exp.search_space) sobol_run = sobol_generator.gen(n=15) diff --git a/ax/utils/testing/mock.py b/ax/utils/testing/mock.py index 24288b7f197..e10cda54297 100644 --- a/ax/utils/testing/mock.py +++ b/ax/utils/testing/mock.py @@ -12,12 +12,8 @@ from unittest import mock from botorch.fit import fit_fully_bayesian_model_nuts -from botorch.generation.gen import minimize_with_timeout -from botorch.optim.initializers import ( - gen_batch_initial_conditions, - gen_one_shot_kg_initial_conditions, -) -from scipy.optimize import OptimizeResult +from botorch.optim.optimize_mixed import optimize_acqf_mixed_alternating +from botorch.test_utils.mock import mock_optimize_context_manager from torch import Tensor @@ -29,80 +25,77 @@ def mock_botorch_optimize_context_manager( Currently, the primary tactic is to force the underlying scipy methods to stop after just one iteration. + This context manager uses BoTorch's `mock_optimize_context_manager`, and + adds some additional mocks that are not possible to cover in BoTorch due to + the need to mock the functions where they are used. + Args: force: If True will not raise an AssertionError if no mocks are called. USE RESPONSIBLY. """ - def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult: - if kwargs["options"] is None: - kwargs["options"] = {} - - kwargs["options"]["maxiter"] = 1 - return minimize_with_timeout(*args, **kwargs) - - def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor: - kwargs["num_restarts"] = 2 - kwargs["raw_samples"] = 4 - - return gen_batch_initial_conditions(*args, **kwargs) - - def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None: - kwargs["num_restarts"] = 2 - kwargs["raw_samples"] = 4 - - return gen_one_shot_kg_initial_conditions(*args, **kwargs) - def minimal_fit_fully_bayesian(*args: Any, **kwargs: Any) -> None: fit_fully_bayesian_model_nuts(*args, **_get_minimal_mcmc_kwargs(**kwargs)) - with ExitStack() as es: - mock_generation = es.enter_context( - mock.patch( - "botorch.generation.gen.minimize_with_timeout", - wraps=one_iteration_minimize, - ) + def minimal_mixed_optimizer(*args: Any, **kwargs: Any) -> tuple[Tensor, Tensor]: + # BoTorch's `mock_optimize_context_manager` also has some mocks for this, + # but the full set of mocks applied here cannot be covered by that. + kwargs["raw_samples"] = 2 + kwargs["num_restarts"] = 1 + kwargs["options"].update( + { + "maxiter_alternating": 1, + "maxiter_continuous": 1, + "maxiter_init": 1, + "maxiter_discrete": 1, + } ) + return optimize_acqf_mixed_alternating(*args, **kwargs) - mock_fit = es.enter_context( - mock.patch( - "botorch.optim.core.minimize_with_timeout", - wraps=one_iteration_minimize, - ) - ) - - mock_gen_ics = es.enter_context( + with ExitStack() as es: + mock_mcmc_mbm = es.enter_context( mock.patch( - "botorch.optim.optimize.gen_batch_initial_conditions", - wraps=minimal_gen_ics, + "ax.models.torch.botorch_modular.utils.fit_fully_bayesian_model_nuts", + wraps=minimal_fit_fully_bayesian, ) ) - mock_gen_os_ics = es.enter_context( + mock_mixed_optimizer = es.enter_context( mock.patch( - "botorch.optim.optimize.gen_one_shot_kg_initial_conditions", - wraps=minimal_gen_os_ics, + "ax.models.torch.botorch_modular.acquisition." + "optimize_acqf_mixed_alternating", + wraps=minimal_mixed_optimizer, ) ) - mock_mcmc_mbm = es.enter_context( - mock.patch( - "ax.models.torch.botorch_modular.utils.fit_fully_bayesian_model_nuts", - wraps=minimal_fit_fully_bayesian, - ) - ) + es.enter_context(mock_optimize_context_manager()) yield - if (not force) and all( - mock_.call_count < 1 - for mock_ in [ - mock_generation, - mock_fit, - mock_gen_ics, - mock_gen_os_ics, - mock_mcmc_mbm, - ] + # Only raise if none of the BoTorch or Ax side mocks were called. + # We do this by catching the error that could be raised by the BoTorch + # context manager, and combining it with the signals from Ax side mocks. + try: + es.close() + except AssertionError as e: + # Check if the error is due to no BoTorch mocks being called. + if "No mocks were called" in str(e): + botorch_mocks_called = False + else: + raise + else: + botorch_mocks_called = True + + if ( + not force + and all( + mock_.call_count < 1 + for mock_ in [ + mock_mcmc_mbm, + mock_mixed_optimizer, + ] + ) + and botorch_mocks_called is False ): raise AssertionError( "No mocks were called in the context manager. Please remove unused " diff --git a/ax/utils/testing/tests/test_mock.py b/ax/utils/testing/tests/test_mock.py new file mode 100644 index 00000000000..2c84e03375e --- /dev/null +++ b/ax/utils/testing/tests/test_mock.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from unittest.mock import patch + +import torch +from ax.modelbridge.registry import Models +from ax.modelbridge.transforms.choice_encode import OrderedChoiceToIntegerRange +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.mock import mock_botorch_optimize_context_manager +from botorch.generation.gen import gen_candidates_scipy +from botorch.optim.optimize_mixed import generate_starting_points +from botorch.utils.testing import MockAcquisitionFunction +from pyro.infer import MCMC + + +class TestMock(TestCase): + def test_no_mocks_called(self) -> None: + # Should raise by default if no mocks are called. + with self.assertRaisesRegex(AssertionError, "No mocks were called"): + with mock_botorch_optimize_context_manager(): + pass + # Doesn't raise when force=True. + with mock_botorch_optimize_context_manager(force=True): + pass + + def test_botorch_mocks(self) -> None: + # Should not raise when BoTorch mocks are called. + with mock_botorch_optimize_context_manager(): + gen_candidates_scipy( + initial_conditions=torch.tensor([[0.0]]), + acquisition_function=MockAcquisitionFunction(), # pyre-ignore [6] + ) + + def test_fully_bayesian_mocks(self) -> None: + experiment = get_branin_experiment(with_completed_batch=True) + with patch("botorch.fit.MCMC", wraps=MCMC) as mock_mcmc: + with mock_botorch_optimize_context_manager(): + Models.SAASBO(experiment=experiment, data=experiment.lookup_data()) + mock_mcmc.assert_called_once() + kwargs = mock_mcmc.call_args.kwargs + self.assertEqual(kwargs["num_samples"], 16) + self.assertEqual(kwargs["warmup_steps"], 0) + + def test_mixed_optimizer_mocks(self) -> None: + experiment = get_branin_experiment( + with_completed_batch=True, with_choice_parameter=True + ) + with patch( + "botorch.optim.optimize_mixed.generate_starting_points", + wraps=generate_starting_points, + ) as mock_gen: + with mock_botorch_optimize_context_manager(): + Models.BOTORCH_MODULAR( + experiment=experiment, + data=experiment.lookup_data(), + transforms=[OrderedChoiceToIntegerRange], + ).gen(n=1) + mock_gen.assert_called_once() + opt_inputs = mock_gen.call_args.kwargs["opt_inputs"] + self.assertEqual(opt_inputs.raw_samples, 2) + self.assertEqual(opt_inputs.num_restarts, 1) + self.assertEqual( + opt_inputs.options, + { + "init_batch_limit": 32, + "batch_limit": 5, + "maxiter_alternating": 1, + "maxiter_continuous": 1, + "maxiter_init": 1, + "maxiter_discrete": 1, + }, + )