Skip to content

Commit

Permalink
Multi-Armed Bandit (#343)
Browse files Browse the repository at this point in the history
This PR enables bandit models by bringing the following new features:
* a class for modeling binary targets
* a Beta-Bernoulli multi-armed bandit surrogate class
* a Thompson sampling acquisition function (should be turned into a
recommender in the future)
* a beta prior
* a bandit optimization example
* the necessary torch infrastructure, such sampler for torch's beta
distribution
  • Loading branch information
AdrianSosic authored Sep 6, 2024
2 parents 7e874db + 8eae563 commit 99f5e9e
Show file tree
Hide file tree
Showing 30 changed files with 701 additions and 45 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
and `SubspaceContinuous` classes
- New mechanisms for surrogate input/output scaling configurable per class
- `SurrogateProtocol` as an interface for user-defined surrogate architectures
- Support for binary targets via `BinaryTarget` class
- Support for bandit optimization via `BetaBernoulliMultiArmedBanditSurrogate` class
- Bandit optimization example
- `qThompsonSampling` acquisition function
- `BetaPrior` class

### Changed
- The transition from experimental to computational representation no longer happens
Expand Down
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@
Human readable output for search spaces
- Di Jin (Merck Life Science KGaA, Darmstadt, Germany):\
Cardinality constraints
- Julian Streibel (Merck Life Science KGaA, Darmstadt, Germany):\
Bernoulli multi-armed bandit and Thompson sampling
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Besides functionality to perform a typical recommend-measure loop, BayBE's highl
- ⚙️ Custom surrogate models: For specialized problems or active learning
- 🎭 Hybrid (mixed continuous and discrete) spaces
- 🚀 Transfer learning: Mix data from multiple campaigns and accelerate optimization
- 🎰 Bandit models: Efficiently find the best among many options in noisy environments (e.g. A/B Testing)
- 📈 Comprehensive backtest, simulation and imputation utilities: Benchmark and find your best settings
- 📝 Fully typed and hypothesis-tested: Robust code base
- 🔄 All objects are fully de-/serializable: Useful for storing results in databases or use in wrappers like APIs
Expand Down
6 changes: 6 additions & 0 deletions baybe/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
qNoisyExpectedImprovement,
qProbabilityOfImprovement,
qSimpleRegret,
qThompsonSampling,
qUpperConfidenceBound,
)

Expand All @@ -33,6 +34,7 @@
qPI = qProbabilityOfImprovement
UCB = UpperConfidenceBound
qUCB = qUpperConfidenceBound
qTS = qThompsonSampling

__all__ = [
######################### Acquisition functions
Expand All @@ -57,6 +59,8 @@
# Upper Confidence Bound
"UpperConfidenceBound",
"qUpperConfidenceBound",
# Thompson Sampling
"qThompsonSampling",
######################### Abbreviations
# Knowledge Gradient
"qKG",
Expand All @@ -79,4 +83,6 @@
# Upper Confidence Bound
"UCB",
"qUCB",
# Thompson Sampling
"qTS",
]
26 changes: 26 additions & 0 deletions baybe/acquisition/acqfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,29 @@ class qUpperConfidenceBound(AcquisitionFunction):
mean only, resulting in pure exploitation. Higher values shift the focus more and
more toward exploration.
"""


@define(frozen=True)
class qThompsonSampling(qSimpleRegret):
"""Thomson sampling, implemented via simple regret. Inherently Monte Carlo based.
This implementation exploits the fact that one-sample-based Thompson sampling
(i.e. where the action probability is approximated using a single posterior sample)
is equivalent to optimizing the Monte Carlo approximated posterior mean with
sample size one. The latter can be achieved via `qSimpleRegret` and controlling
its sample shape attribute.
"""

abbreviation: ClassVar[str] = "qTS"

n_mc_samples: int = field(default=1, init=False)
"""Number of Monte Carlo samples drawn from the posterior at each design point.
Restring the the sample size to one allows us to emulate (one-sample based)
Thompson sampling using the regular acquisition function machinery.
"""

@classproperty
def _non_botorch_attrs(cls) -> tuple[str, ...]:
flds = fields(qThompsonSampling)
return (flds.n_mc_samples.name,)
37 changes: 34 additions & 3 deletions baybe/acquisition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import warnings
from abc import ABC
from inspect import signature
from typing import ClassVar
from typing import TYPE_CHECKING, ClassVar

import pandas as pd
from attrs import define

from baybe.exceptions import UnidentifiedSubclassError
from baybe.objectives.base import Objective
from baybe.objectives.desirability import DesirabilityObjective
from baybe.objectives.single import SingleTargetObjective
Expand All @@ -27,6 +28,9 @@
from baybe.utils.boolean import is_abstract
from baybe.utils.dataframe import to_tensor

if TYPE_CHECKING:
from botorch.acquisition import AcquisitionFunction as BotorchAcquisitionFunction


@define(frozen=True)
class AcquisitionFunction(ABC, SerialMixin):
Expand Down Expand Up @@ -61,8 +65,10 @@ def to_botorch(
import torch
from botorch.acquisition.objective import LinearMCObjective

from baybe.acquisition.acqfs import qThompsonSampling

# Retrieve botorch acquisition function class and match attributes
acqf_cls = getattr(bo_acqf, self.__class__.__name__)
acqf_cls = _get_botorch_acqf_class(type(self))
params_dict = match_attributes(
self, acqf_cls.__init__, ignore=self._non_botorch_attrs
)[0]
Expand Down Expand Up @@ -111,7 +117,32 @@ def to_botorch(
raise ValueError(f"Unsupported objective type: {objective}")

params_dict.update(additional_params)
return acqf_cls(**params_dict)

acqf = acqf_cls(**params_dict)

if isinstance(self, qThompsonSampling):
assert hasattr(acqf, "_default_sample_shape")
acqf._default_sample_shape = torch.Size([self.n_mc_samples])

return acqf


def _get_botorch_acqf_class(
baybe_acqf_cls: type[AcquisitionFunction], /
) -> type[BotorchAcquisitionFunction]:
"""Extract the BoTorch acquisition class for the given BayBE acquisition class."""
import botorch

for cls in baybe_acqf_cls.mro():
if acqf_cls := getattr(botorch.acquisition, cls.__name__, False):
if is_abstract(acqf_cls):
continue
return acqf_cls # type: ignore

raise UnidentifiedSubclassError(
f"No BoTorch acquisition function class match found for "
f"'{baybe_acqf_cls.__name__}'."
)


# Register de-/serialization hooks
Expand Down
24 changes: 17 additions & 7 deletions baybe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ class UnusedObjectWarning(UserWarning):


##### Exceptions #####


class IncompatibilityError(Exception):
"""Incompatible components are used together."""


class IncompatibleSearchSpaceError(IncompatibilityError):
"""
A recommender is used with a search space that contains incompatible parts,
e.g. a discrete recommender is used with a hybrid or continuous search space.
"""


class NotEnoughPointsLeftError(Exception):
"""
More recommendations are requested than there are viable parameter configurations
Expand All @@ -24,13 +37,6 @@ class NoMCAcquisitionFunctionError(Exception):
"""


class IncompatibleSearchSpaceError(Exception):
"""
A recommender is used with a search space that contains incompatible parts,
e.g. a discrete recommender is used with a hybrid or continuous search space.
"""


class EmptySearchSpaceError(Exception):
"""The created search space contains no parameters."""

Expand Down Expand Up @@ -71,3 +77,7 @@ class UnmatchedAttributeError(Exception):

class InvalidSurrogateModelError(Exception):
"""An invalid surrogate model was chosen."""


class InvalidTargetValueError(Exception):
"""A target value was entered that is not in the target space."""
5 changes: 1 addition & 4 deletions baybe/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@ def to_subspace(self) -> SubspaceContinuous:


# Register (un-)structure hooks
_overrides = {
"_values": override(rename="values"),
"decorrelate": override(struct_hook=lambda x, _: x),
}
_overrides = {"_values": override(rename="values")}
# FIXME[typing]: https://github.com/python/mypy/issues/4717
converter.register_structure_hook(
Parameter,
Expand Down
2 changes: 2 additions & 0 deletions baybe/priors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from baybe.priors.basic import (
BetaPrior,
GammaPrior,
HalfCauchyPrior,
HalfNormalPrior,
Expand All @@ -14,6 +15,7 @@
)

__all__ = [
"BetaPrior",
"GammaPrior",
"HalfCauchyPrior",
"HalfNormalPrior",
Expand Down
18 changes: 18 additions & 0 deletions baybe/priors/basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A collection of common prior distributions."""

from __future__ import annotations

from typing import Any

from attrs import define, field
Expand Down Expand Up @@ -85,3 +87,19 @@ def _validate_order(self, _: Any, b: float) -> None: # noqa: DOC101, DOC103
f"For {self.__class__.__name__}, the upper bound `b` (provided: {b}) "
f"must be larger than the lower bound `a` (provided: {self.a})."
)


@define(frozen=True)
class BetaPrior(Prior):
"""A beta prior parameterized by alpha and beta."""

alpha: float = field(converter=float, validator=gt(0.0))
"""Alpha concentration parameter. Controls mass accumulated toward zero."""

beta: float = field(converter=float, validator=gt(0.0))
"""Beta concentration parameter. Controls mass accumulated toward one."""

def to_gpytorch(self, *args, **kwargs): # noqa: D102
raise NotImplementedError(
f"'{self.__class__.__name__}' does not have a gpytorch analog."
)
6 changes: 1 addition & 5 deletions baybe/recommenders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ def recommend(
lambda x: unstructure_base(
x,
# TODO: Remove once deprecation got expired:
overrides=dict(
allow_repeated_recommendations=cattrs.override(omit=True),
allow_recommending_already_measured=cattrs.override(omit=True),
acquisition_function_cls=cattrs.override(omit=True),
),
overrides=dict(acquisition_function_cls=cattrs.override(omit=True)),
),
)
converter.register_structure_hook(
Expand Down
2 changes: 2 additions & 0 deletions baybe/recommenders/pure/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def _recommend_with_discrete_parts(
or self.allow_recommending_already_measured,
)

# TODO: Introduce new flag to recommend batches larger than the search space

# Check if enough candidates are left
# TODO [15917]: This check is not perfectly correct.
if (not is_hybrid_space) and (len(candidates_exp) < batch_size):
Expand Down
8 changes: 6 additions & 2 deletions baybe/recommenders/pure/bayesian/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from attr.converters import optional
from attrs import define, field

from baybe.exceptions import NoMCAcquisitionFunctionError
from baybe.acquisition.acqfs import qThompsonSampling
from baybe.exceptions import IncompatibilityError, NoMCAcquisitionFunctionError
from baybe.recommenders.pure.bayesian.base import BayesianRecommender
from baybe.searchspace import (
SearchSpace,
Expand Down Expand Up @@ -95,12 +96,15 @@ def _recommend_discrete(
The dataframe indices of the recommended points in the provided
experimental representation.
"""
# For batch size > 1, this optimizer needs a MC acquisition function
if batch_size > 1 and not self.acquisition_function.is_mc:
raise NoMCAcquisitionFunctionError(
f"The '{self.__class__.__name__}' only works with Monte Carlo "
f"acquisition functions for batch sizes > 1."
)
if batch_size > 1 and isinstance(self.acquisition_function, qThompsonSampling):
raise IncompatibilityError(
"Thompson sampling currently only supports a batch size of 1."
)

from botorch.optim import optimize_acqf_discrete

Expand Down
3 changes: 3 additions & 0 deletions baybe/serialization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import cattrs
import pandas as pd
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn
from cattrs.strategies import configure_union_passthrough

from baybe.utils.basic import find_subclass, refers_to
from baybe.utils.boolean import is_abstract
Expand All @@ -20,6 +21,8 @@
converter = cattrs.Converter()
"""The default converter for (de-)serializing BayBE-related objects."""

configure_union_passthrough(bool | int | float | str, converter)


def unstructure_base(base: Any, overrides: dict | None = None) -> dict:
"""Unstructure an object into a dictionary and adds an entry for the class name.
Expand Down
2 changes: 2 additions & 0 deletions baybe/surrogates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""BayBE surrogates."""

from baybe.surrogates.bandit import BetaBernoulliMultiArmedBanditSurrogate
from baybe.surrogates.custom import CustomONNXSurrogate, register_custom_architecture
from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate
from baybe.surrogates.linear import BayesianLinearSurrogate
Expand All @@ -10,6 +11,7 @@
__all__ = [
"register_custom_architecture",
"BayesianLinearSurrogate",
"BetaBernoulliMultiArmedBanditSurrogate",
"CustomONNXSurrogate",
"GaussianProcessSurrogate",
"MeanPredictionSurrogate",
Expand Down
Loading

0 comments on commit 99f5e9e

Please sign in to comment.