Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support per-metric model specification in MBM #3009

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.modelbridge.discrete import DiscreteModelBridge
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import (
Expand Down Expand Up @@ -99,7 +100,8 @@ def test_SAASBO(self) -> None:
SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP),
)
self.assertEqual(
saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP
saasbo.model.surrogate.model_configs[0].botorch_model_class,
SaasFullyBayesianSingleTaskGP,
)

@mock_botorch_optimize
Expand Down Expand Up @@ -459,9 +461,16 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
self.assertIsInstance(mtgp, TorchModelBridge)
self.assertIsInstance(mtgp.model, BoTorchModel)
self.assertEqual(mtgp.model.acquisition_class, Acquisition)
self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP)
is_moo = isinstance(
exp.optimization_config, MultiObjectiveOptimizationConfig
)
if is_moo:
self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP)
models = mtgp.model.surrogate.model.models
else:
models = [mtgp.model.surrogate.model]

for model in mtgp.model.surrogate.model.models:
for model in models:
self.assertIsInstance(
model,
SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP,
Expand Down
11 changes: 10 additions & 1 deletion ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
check_outcome_dataset_match,
choose_botorch_acqf_class,
construct_acquisition_and_optimizer_options,
ModelConfig,
)
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
Expand Down Expand Up @@ -79,6 +80,8 @@ class SurrogateSpec:

allow_batched_models: bool = True

model_configs: list[ModelConfig] = field(default_factory=list)
metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict)
outcomes: list[str] = field(default_factory=list)


Expand Down Expand Up @@ -241,13 +244,19 @@ def fit(
input_transform_options=spec.input_transform_options,
outcome_transform_classes=spec.outcome_transform_classes,
outcome_transform_options=spec.outcome_transform_options,
model_configs=spec.model_configs,
metric_to_model_configs=spec.metric_to_model_configs,
allow_batched_models=spec.allow_batched_models,
)
else:
self._surrogate = Surrogate()

# Fit the surrogate.
self.surrogate.model_options.update(additional_model_inputs)
for config in self.surrogate.model_configs:
config.model_options.update(additional_model_inputs)
for config_list in self.surrogate.metric_to_model_configs.values():
for config in config_list:
config.model_options.update(additional_model_inputs)
self.surrogate.fit(
datasets=datasets,
search_space_digest=search_space_digest,
Expand Down
222 changes: 160 additions & 62 deletions ax/models/torch/botorch_modular/surrogate.py

Large diffs are not rendered by default.

118 changes: 114 additions & 4 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from dataclasses import dataclass, field
from logging import Logger
from typing import Any

Expand All @@ -34,29 +35,138 @@
from botorch.models.model import Model, ModelList
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_fully_bayesian
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor

MIN_OBSERVED_NOISE_LEVEL = 1e-7
logger: Logger = get_logger(__name__)


@dataclass
class ModelConfig:
"""Configuration for the BoTorch Model used in Surrogate.

Args:
botorch_model_class: ``Model`` class to be used as the underlying
BoTorch model. If None is provided a model class will be selected (either
one for all outcomes or a ModelList with separate models for each outcome)
will be selected automatically based off the datasets at `construct` time.
This argument is deprecated in favor of model_configs.
model_options: Dictionary of options / kwargs for the BoTorch
``Model`` constructed during ``Surrogate.fit``.
Note that the corresponding attribute will later be updated to include any
additional kwargs passed into ``BoTorchModel.fit``.
This argument is deprecated in favor of model_configs.
mll_class: ``MarginalLogLikelihood`` class to use for model-fitting.
This argument is deprecated in favor of model_configs.
mll_options: Dictionary of options / kwargs for the MLL. This argument is
deprecated in favor of model_configs.
outcome_transform_classes: List of BoTorch outcome transforms classes. Passed
down to the BoTorch ``Model``. Multiple outcome transforms can be chained
together using ``ChainedOutcomeTransform``. This argument is deprecated in
favor of model_configs.
outcome_transform_options: Outcome transform classes kwargs. The keys are
class string names and the values are dictionaries of outcome transform
kwargs. For example,
`
outcome_transform_classes = [Standardize]
outcome_transform_options = {
"Standardize": {"m": 1},
`
For more options see `botorch/models/transforms/outcome.py`. This argument
is deprecated in favor of model_configs.
input_transform_classes: List of BoTorch input transforms classes.
Passed down to the BoTorch ``Model``. Multiple input transforms
will be chained together using ``ChainedInputTransform``.
This argument is deprecated in favor of model_configs.
input_transform_options: Input transform classes kwargs. The keys are
class string names and the values are dictionaries of input transform
kwargs. For example,
`
input_transform_classes = [Normalize, Round]
input_transform_options = {
"Normalize": {"d": 3},
"Round": {"integer_indices": [0], "categorical_features": {1: 2}},
}
`
For more input options see `botorch/models/transforms/input.py`.
This argument is deprecated in favor of model_configs.
covar_module_class: Covariance module class. This gets initialized after
parsing the ``covar_module_options`` in ``covar_module_argparse``,
and gets passed to the model constructor as ``covar_module``.
This argument is deprecated in favor of model_configs.
covar_module_options: Covariance module kwargs. This argument is deprecated
in favor of model_configs.
likelihood: ``Likelihood`` class. This gets initialized with
``likelihood_options`` and gets passed to the model constructor.
This argument is deprecated in favor of model_configs.
likelihood_options: Likelihood options. This argument is deprecated in favor
of model_configs.
"""

botorch_model_class: type[Model] | None = None
model_options: dict[str, Any] = field(default_factory=dict)
mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
mll_options: dict[str, Any] = field(default_factory=dict)
input_transform_classes: list[type[InputTransform]] | None = None
input_transform_options: dict[str, dict[str, Any]] | None = field(
default_factory=dict
)
outcome_transform_classes: list[type[OutcomeTransform]] | None = None
outcome_transform_options: dict[str, dict[str, Any]] = field(default_factory=dict)
covar_module_class: type[Kernel] | None = None
covar_module_options: dict[str, Any] = field(default_factory=dict)
likelihood_class: type[Likelihood] | None = None
likelihood_options: dict[str, Any] = field(default_factory=dict)


def use_model_list(
datasets: Sequence[SupervisedDataset],
botorch_model_class: type[Model],
model_configs: list[ModelConfig] | None = None,
metric_to_model_configs: dict[str, list[ModelConfig]] | None = None,
allow_batched_models: bool = True,
) -> bool:
if issubclass(botorch_model_class, MultiTaskGP):
# We currently always wrap multi-task models into `ModelListGP`.
model_configs = model_configs or []
metric_to_model_configs = metric_to_model_configs or {}
if len(datasets) == 1 and datasets[0].Y.shape[-1] == 1:
# There is only one outcome, so we can use a single model.
return False
elif (
len(model_configs) > 1
or len(metric_to_model_configs) > 0
or any(len(model_config) for model_config in metric_to_model_configs.values())
):
# There are multiple outcomes and outcomes might be modeled with different
# models
return True
elif issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP):
# Otherwise, the same model class is used for all outcomes.
# Determine what the model class is.
if len(model_configs) > 0:
botorch_model_class = (
model_configs[0].botorch_model_class or botorch_model_class
)
if issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP):
# SAAS models do not support multiple outcomes.
# Use model list if there are multiple outcomes.
return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1
elif issubclass(botorch_model_class, MultiTaskGP):
# We wrap multi-task models into `ModelListGP` when there are
# multiple outcomes.
return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1
elif len(datasets) == 1:
# Just one outcome, can use single model.
# This method is called before multiple datasets are merged into
# one if using a batched model. If there is one dataset here,
# there should be a reason that a single model should be used:
# e.g. a contextual model, where we want to jointly model the metric
# each context (and context-level metrics are different outcomes).
return False
elif issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all(
torch.equal(datasets[0].X, ds.X) for ds in datasets[1:]
Expand Down
21 changes: 20 additions & 1 deletion ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ax.models.torch.botorch_modular.utils import (
choose_model_class,
construct_acquisition_and_optimizer_options,
ModelConfig,
)
from ax.models.torch.utils import _filter_X_observed
from ax.models.torch_base import TorchOptConfig
Expand Down Expand Up @@ -327,7 +328,21 @@ def test_fit(self, mock_fit: Mock) -> None:
mock_fit.assert_called_with(
dataset=self.block_design_training_data[0],
search_space_digest=self.mf_search_space_digest,
botorch_model_class=SingleTaskMultiFidelityGP,
model_config=ModelConfig(
botorch_model_class=None,
model_options={},
mll_class=ExactMarginalLogLikelihood,
mll_options={},
input_transform_classes=None,
input_transform_options={},
outcome_transform_classes=None,
outcome_transform_options={},
covar_module_class=None,
covar_module_options={},
likelihood_class=None,
likelihood_options={},
),
default_botorch_model_class=SingleTaskMultiFidelityGP,
state_dict=None,
refit=True,
)
Expand Down Expand Up @@ -727,6 +742,8 @@ def test_surrogate_model_options_propagation(
input_transform_options=None,
outcome_transform_classes=None,
outcome_transform_options=None,
model_configs=[],
metric_to_model_configs={},
allow_batched_models=True,
)

Expand Down Expand Up @@ -755,6 +772,8 @@ def test_surrogate_options_propagation(
input_transform_options=None,
outcome_transform_classes=None,
outcome_transform_options=None,
model_configs=[],
metric_to_model_configs={},
allow_batched_models=False,
)

Expand Down
Loading
Loading