Skip to content

Commit

Permalink
Remove dispatching functionality from optimizer_argparse (#2998)
Browse files Browse the repository at this point in the history
Summary:

Context: This dispatcher's only usage is to raise an exception with qKG. There is no need for it to be a dispatcher.

This diff:
* makes `optimizer_argparse` no longer a dispatcher
* Moves the error for qKG into the body of the now-only `optimizer_argparse` function
* Removes special function for qKG
* Changes type annotations so that the first argument is always an `AcquisitionFunction`; it was always used that way, with different types used only in tests.

Reviewed By: saitcakmak, Balandat

Differential Revision: D65231763
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 30, 2024
1 parent 3ff0ad8 commit 37e434c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 150 deletions.
64 changes: 12 additions & 52 deletions ax/models/torch/botorch_modular/optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@

from __future__ import annotations

from typing import Any, TypeVar, Union
from typing import Any

from ax.exceptions.core import UnsupportedError
from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.utils.dispatcher import Dispatcher

T = TypeVar("T")
MaybeType = Union[T, type[T]] # Annotation for a type or instance thereof

# Acquisition defaults
NUM_RESTARTS = 20
Expand All @@ -26,14 +21,8 @@
BATCH_LIMIT = 5


optimizer_argparse = Dispatcher(
name="optimizer_argparse", encoder=_argparse_type_encoder
)


@optimizer_argparse.register(AcquisitionFunction)
def _argparse_base(
acqf: MaybeType[AcquisitionFunction],
def optimizer_argparse(
acqf: AcquisitionFunction,
*,
optimizer: str,
sequential: bool = True,
Expand Down Expand Up @@ -102,6 +91,15 @@ def _argparse_base(
f"optimizer=`{optimizer}` is not supported. Accepted options are "
f"{supported_optimizers}"
)
if (optimizer != "optimize_acqf") and isinstance(acqf, qKnowledgeGradient):
raise RuntimeError(
"Ax is attempting to use a discrete or mixed optimizer, "
f"`{optimizer}`, but this is not compatible with "
"`qKnowledgeGradient` or its subclasses. To address this, please "
"either use a different acquisition class or make parameters "
"continuous using the transform "
"`ax.modelbridge.registry.Cont_X_trans`."
)
provided_options = optimizer_options if optimizer_options is not None else {}

# Construct arguments from options that are not `provided_options`.
Expand Down Expand Up @@ -138,41 +136,3 @@ def _argparse_base(

options.update(**{k: v for k, v in provided_options.items() if k != "options"})
return options


@optimizer_argparse.register(qKnowledgeGradient)
def _argparse_kg(
acqf: qKnowledgeGradient,
*,
optimizer: str = "optimize_acqf",
sequential: bool = True,
num_restarts: int = NUM_RESTARTS,
raw_samples: int = RAW_SAMPLES,
init_batch_limit: int = INIT_BATCH_LIMIT,
batch_limit: int = BATCH_LIMIT,
optimizer_options: dict[str, Any] | None = None,
**ignore: Any,
) -> dict[str, Any]:
"""
Argument constructor for optimization with qKG, differing from the
base case in that it errors if the optimizer is not `optimize_acqf`.
"""
if optimizer != "optimize_acqf":
raise RuntimeError(
"Ax is attempting to use a discrete or mixed optimizer, "
f"`{optimizer}`, but this is not compatible with "
"`qKnowledgeGradient` or its subclasses. To address this, please "
"either use a different acquisition class or make parameters "
"continuous using the transform "
"`ax.modelbridge.registry.Cont_X_trans`."
)
return _argparse_base(
acqf=acqf,
optimizer="optimize_acqf",
sequential=sequential,
num_restarts=num_restarts,
raw_samples=raw_samples,
init_batch_limit=init_batch_limit,
batch_limit=batch_limit,
optimizer_options=optimizer_options,
)
169 changes: 71 additions & 98 deletions ax/models/torch/tests/test_optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,38 @@

from __future__ import annotations

from itertools import product
from unittest.mock import patch
from unittest.mock import MagicMock

from ax.exceptions.core import UnsupportedError
from ax.models.torch.botorch_modular.optimizer_argparse import (
_argparse_base,
BATCH_LIMIT,
INIT_BATCH_LIMIT,
MaybeType,
NUM_RESTARTS,
optimizer_argparse,
RAW_SAMPLES,
)
from ax.utils.common.testutils import TestCase
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import LogExpectedImprovement
from botorch.acquisition.knowledge_gradient import (
qKnowledgeGradient,
qMultiFidelityKnowledgeGradient,
)


class DummyAcquisitionFunction(AcquisitionFunction):
pass
def __init__(self) -> None:
return

# pyre-fixme[14]: Inconsistent override
# pyre-fixme[15]: Inconsistent override
def forward(self) -> int:
return 0


class OptimizerArgparseTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.acqf = DummyAcquisitionFunction()
self.default_expected_options = {
"optimize_acqf": {
"num_restarts": NUM_RESTARTS,
Expand Down Expand Up @@ -70,57 +73,24 @@ def setUp(self) -> None:
},
}

def test_notImplemented(self) -> None:
with self.assertRaisesRegex(
NotImplementedError, "Could not find signature for"
):
optimizer_argparse[type(None)] # passing `None` produces a different error

def test_unsupported_optimizer(self) -> None:
with self.assertRaisesRegex(
ValueError, "optimizer=`wishful thinking` is not supported"
):
optimizer_argparse(LogExpectedImprovement, optimizer="wishful thinking")

def test_register(self) -> None:
with patch.dict(optimizer_argparse.funcs, {}):

@optimizer_argparse.register(DummyAcquisitionFunction)
def _argparse(acqf: MaybeType[DummyAcquisitionFunction]) -> None:
pass

self.assertEqual(optimizer_argparse[DummyAcquisitionFunction], _argparse)

def test_fallback(self) -> None:
with patch.dict(optimizer_argparse.funcs, {}):

@optimizer_argparse.register(AcquisitionFunction)
def _argparse(acqf: MaybeType[DummyAcquisitionFunction]) -> None:
pass

self.assertEqual(optimizer_argparse[DummyAcquisitionFunction], _argparse)
optimizer_argparse(self.acqf, optimizer="wishful thinking")

def test_optimizer_options(self) -> None:
# qKG should have a bespoke test
# currently there is only one function in fns_to_test
fns_to_test = [
elt
for elt in optimizer_argparse.funcs.values()
if elt is not optimizer_argparse[qKnowledgeGradient]
]
user_options = {"foo": "bar", "num_restarts": 13}
for func, optimizer in product(
fns_to_test,
[
"optimize_acqf",
"optimize_acqf_discrete",
"optimize_acqf_mixed",
"optimize_acqf_discrete_local_search",
],
):
with self.subTest(func=func, optimizer=optimizer):
parsed_options = func(
None, optimizer_options=user_options, optimizer=optimizer
for optimizer in [
"optimize_acqf",
"optimize_acqf_discrete",
"optimize_acqf_mixed",
"optimize_acqf_discrete_local_search",
]:
with self.subTest(optimizer=optimizer):
parsed_options = optimizer_argparse(
self.acqf, optimizer_options=user_options, optimizer=optimizer
)
self.assertDictEqual(
{**self.default_expected_options[optimizer], **user_options},
Expand All @@ -130,66 +100,69 @@ def test_optimizer_options(self) -> None:
# Also test sub-options.
inner_options = {"batch_limit": 10, "maxiter": 20}
options = {"options": inner_options}
for func in fns_to_test:
for optimizer in [
"optimize_acqf",
"optimize_acqf_mixed",
"optimize_acqf_mixed_alternating",
]:
default = self.default_expected_options[optimizer]
parsed_options = func(
None, optimizer_options=options, optimizer=optimizer
for optimizer in [
"optimize_acqf",
"optimize_acqf_mixed",
"optimize_acqf_mixed_alternating",
]:
default = self.default_expected_options[optimizer]
parsed_options = optimizer_argparse(
self.acqf, optimizer_options=options, optimizer=optimizer
)
expected_options = {k: v for k, v in default.items() if k != "options"}
if "options" in default:
expected_options["options"] = {
**default["options"],
**inner_options,
}
else:
expected_options["options"] = inner_options
self.assertDictEqual(expected_options, parsed_options)

# Error out if options is specified for an optimizer that does
# not support the arg.
for optimizer in [
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
]:
with self.assertRaisesRegex(UnsupportedError, "`options` argument"):
optimizer_argparse(
self.acqf,
optimizer_options={"options": {"batch_limit": 10, "maxiter": 20}},
optimizer=optimizer,
)
expected_options = {k: v for k, v in default.items() if k != "options"}
if "options" in default:
expected_options["options"] = {
**default["options"],
**inner_options,
}
else:
expected_options["options"] = inner_options
self.assertDictEqual(expected_options, parsed_options)

# Error out if options is specified for an optimizer that does
# not support the arg.
for optimizer in [
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
]:
with self.assertRaisesRegex(UnsupportedError, "`options` argument"):
func(
None,
optimizer_options={
"options": {"batch_limit": 10, "maxiter": 20}
},
optimizer=optimizer,
)

# `sequential=False` with optimizers other than `optimize_acqf`.
for optimizer in [
"optimize_acqf_homotopy",
"optimize_acqf_mixed",
"optimize_acqf_mixed_alternating",
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
]:
with self.assertRaisesRegex(
UnsupportedError, "does not support `sequential=False`"
):
func(None, sequential=False, optimizer=optimizer)

# `sequential=False` with optimizers other than `optimize_acqf`.
for optimizer in [
"optimize_acqf_homotopy",
"optimize_acqf_mixed",
"optimize_acqf_mixed_alternating",
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
]:
with self.assertRaisesRegex(
UnsupportedError, "does not support `sequential=False`"
):
optimizer_argparse(self.acqf, sequential=False, optimizer=optimizer)

def test_kg(self) -> None:
user_options = {"foo": "bar", "num_restarts": 114}
generic_options = _argparse_base(
None, optimizer_options=user_options, optimizer="optimize_acqf"
generic_options = optimizer_argparse(
self.acqf, optimizer_options=user_options, optimizer="optimize_acqf"
)
for acqf in (qKnowledgeGradient, qMultiFidelityKnowledgeGradient):
for acqf in (
qKnowledgeGradient(model=MagicMock(), posterior_transform=MagicMock()),
qMultiFidelityKnowledgeGradient(
model=MagicMock(), posterior_transform=MagicMock()
),
):
with self.subTest(acqf=acqf):
options = optimizer_argparse(
acqf,
q=None,
bounds=None,
optimizer_options=user_options,
optimizer="optimize_acqf",
)
self.assertEqual(options, generic_options)

Expand Down

0 comments on commit 37e434c

Please sign in to comment.