diff --git a/emukit/core/loop/outer_loop.py b/emukit/core/loop/outer_loop.py index 5402f6f2..40cf2e9c 100644 --- a/emukit/core/loop/outer_loop.py +++ b/emukit/core/loop/outer_loop.py @@ -101,6 +101,7 @@ def run_loop( _log.info("Iteration {}".format(self.loop_state.iteration)) self._update_models() + self._update_loop_state() new_x = self.candidate_point_calculator.compute_next_points(self.loop_state, context) _log.debug("Next suggested point(s): {}".format(new_x)) results = user_function.evaluate(new_x) @@ -109,8 +110,14 @@ def run_loop( self.iteration_end_event(self, self.loop_state) self._update_models() + self._update_loop_state() _log.info("Finished outer loop") + def _update_loop_state(self) -> None: + """This method is called after the models are updated. Override this function to store additional statistics + other than the collected points and function values in the loop state.""" + pass + def _update_models(self): for model_updater in self.model_updaters: model_updater.update(self.loop_state) diff --git a/emukit/quadrature/loop/__init__.py b/emukit/quadrature/loop/__init__.py index ef650520..f49dabee 100644 --- a/emukit/quadrature/loop/__init__.py +++ b/emukit/quadrature/loop/__init__.py @@ -12,12 +12,18 @@ from .bayesian_monte_carlo_loop import BayesianMonteCarlo # noqa: F401 +from .bq_loop_state import QuadratureLoopState +from .bq_outer_loop import QuadratureOuterLoop +from .bq_stopping_conditions import CoefficientOfVariationStoppingCondition from .vanilla_bq_loop import VanillaBayesianQuadratureLoop # noqa: F401 from .wsabil_loop import WSABILLoop # noqa: F401 __all__ = [ + "QuadratureOuterLoop", "BayesianMonteCarlo", "VanillaBayesianQuadratureLoop", "WSABILLoop", + "QuadratureLoopState", "point_calculators", + "CoefficientOfVariationStoppingCondition", ] diff --git a/emukit/quadrature/loop/bayesian_monte_carlo_loop.py b/emukit/quadrature/loop/bayesian_monte_carlo_loop.py index 289461a6..ca6d5067 100644 --- a/emukit/quadrature/loop/bayesian_monte_carlo_loop.py +++ b/emukit/quadrature/loop/bayesian_monte_carlo_loop.py @@ -8,14 +8,15 @@ # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from ...core.loop import FixedIntervalUpdater, ModelUpdater, OuterLoop -from ...core.loop.loop_state import create_loop_state +from ...core.loop import FixedIntervalUpdater, ModelUpdater from ...core.parameter_space import ParameterSpace from ..loop.point_calculators import BayesianMonteCarloPointCalculator from ..methods import WarpedBayesianQuadratureModel +from .bq_loop_state import create_bq_loop_state +from .bq_outer_loop import QuadratureOuterLoop -class BayesianMonteCarlo(OuterLoop): +class BayesianMonteCarlo(QuadratureOuterLoop): """The loop for Bayesian Monte Carlo (BMC). @@ -61,7 +62,7 @@ def __init__(self, model: WarpedBayesianQuadratureModel, model_updater: ModelUpd space = ParameterSpace(model.reasonable_box_bounds.convert_to_list_of_continuous_parameters()) candidate_point_calculator = BayesianMonteCarloPointCalculator(model, space) - loop_state = create_loop_state(model.X, model.Y) + loop_state = create_bq_loop_state(model.X, model.Y) super().__init__(candidate_point_calculator, model_updater, loop_state) diff --git a/emukit/quadrature/loop/bq_loop_state.py b/emukit/quadrature/loop/bq_loop_state.py new file mode 100644 index 00000000..c555365d --- /dev/null +++ b/emukit/quadrature/loop/bq_loop_state.py @@ -0,0 +1,48 @@ +# Copyright 2020-2024 The Emukit Authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import List, Optional + +import numpy as np + +from ...core.loop.loop_state import LoopState, create_loop_state +from ...core.loop.user_function_result import UserFunctionResult + + +class QuadratureLoopState(LoopState): + """Contains the state of the BQ loop, which includes a history of all integrand evaluations and integral mean and + variance estimates. + + :param initial_results: The results from previous integrand evaluations. + + """ + + def __init__(self, initial_results: List[UserFunctionResult]) -> None: + + super().__init__(initial_results) + + self.integral_means = [] + self.integral_vars = [] + + def update_integral_stats(self, integral_mean: float, integral_var: float) -> None: + """Adds the latest integral mean and variance to the loop state. + + :param integral_mean: The latest integral mean estimate. + :param integral_var: The latest integral variance. + """ + self.integral_means.append(integral_mean) + self.integral_vars.append(integral_var) + + +def create_bq_loop_state(x_init: np.ndarray, y_init: np.ndarray, **kwargs) -> QuadratureLoopState: + """Creates a BQ loop state object using the provided data. + + :param x_init: x values for initial function evaluations. Shape: (n_initial_points x n_input_dims) + :param y_init: y values for initial function evaluations. Shape: (n_initial_points x n_output_dims) + :param kwargs: extra outputs observed from a function evaluation. Shape: (n_initial_points x n_dims) + :return: The BQ loop state. + """ + + loop_state = create_loop_state(x_init, y_init, **kwargs) + return QuadratureLoopState(loop_state.results) diff --git a/emukit/quadrature/loop/bq_outer_loop.py b/emukit/quadrature/loop/bq_outer_loop.py new file mode 100644 index 00000000..49b30707 --- /dev/null +++ b/emukit/quadrature/loop/bq_outer_loop.py @@ -0,0 +1,39 @@ +# Copyright 2020-2024 The Emukit Authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import List, Union + +from ...core.loop import OuterLoop +from ...core.loop.candidate_point_calculators import CandidatePointCalculator +from ...core.loop.model_updaters import ModelUpdater +from .bq_loop_state import QuadratureLoopState + + +class QuadratureOuterLoop(OuterLoop): + """Base class for a Bayesian quadrature loop. + + :param candidate_point_calculator: Finds next point(s) to evaluate. + :param model_updaters: Updates the model with the new data and fits the model hyper-parameters. + :param loop_state: Object that keeps track of the history of the BQ loop. Default is None, resulting in empty + initial state. + + :raises ValueError: If more than one model updater is provided. + + """ + + def __init__( + self, + candidate_point_calculator: CandidatePointCalculator, + model_updaters: Union[ModelUpdater, List[ModelUpdater]], + loop_state: QuadratureLoopState = None, + ): + if isinstance(model_updaters, list): + raise ValueError("The BQ loop only supports a single model.") + + super().__init__(candidate_point_calculator, model_updaters, loop_state) + + def _update_loop_state(self) -> None: + model = self.model_updaters[0].model # only works if there is one model, but for BQ nothing else makes sense + integral_mean, integral_var = model.integrate() + self.loop_state.update_integral_stats(integral_mean, integral_var) diff --git a/emukit/quadrature/loop/bq_stopping_conditions.py b/emukit/quadrature/loop/bq_stopping_conditions.py new file mode 100644 index 00000000..e4798a59 --- /dev/null +++ b/emukit/quadrature/loop/bq_stopping_conditions.py @@ -0,0 +1,63 @@ +# Copyright 2020-2024 The Emukit Authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging + +import numpy as np + +from ...core.loop.stopping_conditions import StoppingCondition +from .bq_loop_state import QuadratureLoopState + +_log = logging.getLogger(__name__) + + +class CoefficientOfVariationStoppingCondition(StoppingCondition): + r"""Stops once the coefficient of variation (COV) falls below a threshold. + + The COV is given by + + .. math:: + COV = \frac{\sigma}{\mu} + + where :math:`\mu` and :math:`\sigma^2` are the current mean and variance respectively of the integral according to + the BQ posterior model. + + :param eps: Threshold under which the COV must fall. + :param delay: Number of times the stopping condition needs to be true in a row in order to stop. Defaults to 1. + + :raises ValueError: If `delay` is smaller than 1. + :raises ValueError: If `eps` is non-negative. + + """ + + def __init__(self, eps: float, delay: int = 1) -> None: + + if delay < 1: + raise ValueError(f"delay ({delay}) must be and integer greater than zero.") + + if eps <= 0.0: + raise ValueError(f"eps ({eps}) must be positive.") + + self.eps = eps + self.delay = delay + self.times_true = 0 # counts how many times stopping has been triggered in a row + + def should_stop(self, loop_state: QuadratureLoopState) -> bool: + if len(loop_state.integral_means) < 1: + return False + + m = loop_state.integral_means[-1] + v = loop_state.integral_vars[-1] + should_stop = (np.sqrt(v) / m) < self.eps + + if should_stop: + self.times_true += 1 + else: + self.times_true = 0 + + should_stop = should_stop and (self.times_true >= self.delay) + + if should_stop: + _log.info(f"Stopped as coefficient of variation is below threshold of {self.eps}.") + return should_stop diff --git a/emukit/quadrature/loop/vanilla_bq_loop.py b/emukit/quadrature/loop/vanilla_bq_loop.py index 743a3a18..e0f82171 100644 --- a/emukit/quadrature/loop/vanilla_bq_loop.py +++ b/emukit/quadrature/loop/vanilla_bq_loop.py @@ -6,15 +6,16 @@ from ...core.acquisition import Acquisition -from ...core.loop import FixedIntervalUpdater, ModelUpdater, OuterLoop, SequentialPointCalculator -from ...core.loop.loop_state import create_loop_state +from ...core.loop import FixedIntervalUpdater, ModelUpdater, SequentialPointCalculator from ...core.optimization import AcquisitionOptimizerBase, GradientAcquisitionOptimizer from ...core.parameter_space import ParameterSpace from ..acquisitions import IntegralVarianceReduction from ..methods import VanillaBayesianQuadrature +from .bq_loop_state import create_bq_loop_state +from .bq_outer_loop import QuadratureOuterLoop -class VanillaBayesianQuadratureLoop(OuterLoop): +class VanillaBayesianQuadratureLoop(QuadratureOuterLoop): """The loop for standard ('vanilla') Bayesian Quadrature. .. seealso:: @@ -46,7 +47,7 @@ def __init__( if acquisition_optimizer is None: acquisition_optimizer = GradientAcquisitionOptimizer(space) candidate_point_calculator = SequentialPointCalculator(acquisition, acquisition_optimizer) - loop_state = create_loop_state(model.X, model.Y) + loop_state = create_bq_loop_state(model.X, model.Y) super().__init__(candidate_point_calculator, model_updater, loop_state) diff --git a/emukit/quadrature/loop/wsabil_loop.py b/emukit/quadrature/loop/wsabil_loop.py index 16dbed36..5264c0f6 100644 --- a/emukit/quadrature/loop/wsabil_loop.py +++ b/emukit/quadrature/loop/wsabil_loop.py @@ -5,15 +5,16 @@ """The WSABI-L loop""" -from ...core.loop import FixedIntervalUpdater, ModelUpdater, OuterLoop, SequentialPointCalculator -from ...core.loop.loop_state import create_loop_state +from ...core.loop import FixedIntervalUpdater, ModelUpdater, SequentialPointCalculator from ...core.optimization import AcquisitionOptimizerBase, GradientAcquisitionOptimizer from ...core.parameter_space import ParameterSpace from ..acquisitions import UncertaintySampling from ..methods import WSABIL +from .bq_loop_state import create_bq_loop_state +from .bq_outer_loop import QuadratureOuterLoop -class WSABILLoop(OuterLoop): +class WSABILLoop(QuadratureOuterLoop): """The loop for WSABI-L. .. rubric:: References @@ -44,7 +45,7 @@ def __init__( if acquisition_optimizer is None: acquisition_optimizer = GradientAcquisitionOptimizer(space) candidate_point_calculator = SequentialPointCalculator(acquisition, acquisition_optimizer) - loop_state = create_loop_state(model.X, model.Y) + loop_state = create_bq_loop_state(model.X, model.Y) super().__init__(candidate_point_calculator, model_updater, loop_state)