Skip to content

Commit

Permalink
Don't pass unused arguments to optimizer_argparse (#2999)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2999

Context:
- The only arguments ever passed to `optimizer_argparse` are `acqf`, `optimizer_options`, and `optimizer`, `q`, and `bounds`. The latter two are always ignored.
- `optimizer_argparse` accepts a bunch of arguments that are never passed to it... and never should be, because, as the docstring explains, `optimizer_options` is the right place to pass those.

Reviewed By: saitcakmak, Balandat

Differential Revision: D65233328

fbshipit-source-id: d0455d5b0429353ac1543384cc462fdcce2a49d5
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 31, 2024
1 parent 2b14d3f commit a9a9a7c
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 66 deletions.
2 changes: 0 additions & 2 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,6 @@ def optimize(
# Prepare arguments for optimizer
optimizer_options_with_defaults = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer=optimizer,
)
Expand Down
45 changes: 9 additions & 36 deletions ax/models/torch/botorch_modular/optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,10 @@ def optimizer_argparse(
acqf: AcquisitionFunction,
*,
optimizer: str,
sequential: bool = True,
num_restarts: int = NUM_RESTARTS,
raw_samples: int = RAW_SAMPLES,
init_batch_limit: int = INIT_BATCH_LIMIT,
batch_limit: int = BATCH_LIMIT,
optimizer_options: dict[str, Any] | None = None,
**ignore: Any,
) -> dict[str, Any]:
"""Extract the kwargs to be passed to a BoTorch optimizer.
NOTE: Since `optimizer_options` is how the user would typically pass in these
options, it takes precedence over other arguments. E.g., if both `num_restarts`
and `optimizer_options["num_restarts"]` are provided, this will use
`num_restarts` from `optimizer_options`.
Args:
acqf: The acquisition function being optimized.
optimizer: The optimizer to parse args for. Typically chosen by
Expand All @@ -50,24 +39,10 @@ def optimizer_argparse(
- "optimize_acqf_homotopy",
- "optimize_acqf_mixed",
- "optimize_acqf_mixed_alternating".
sequential: Whether we choose one candidate at a time in a sequential
manner. `sequential=False` is not supported by optimizers other than
`optimize_acqf` and will lead to an error.
num_restarts: The number of starting points for multistart acquisition
function optimization. Ignored if the optimizer is
`optimize_acqf_discrete`.
raw_samples: The number of samples for initialization. Ignored if the
optimizer is `optimize_acqf_discrete`.
init_batch_limit: The size of mini-batches used to evaluate the `raw_samples`.
This helps reduce peak memory usage. Ignored if the optimizer is
`optimize_acqf_discrete` or `optimize_acqf_discrete_local_search`.
batch_limit: The size of mini-batches used while optimizing the `acqf`.
This helps reduce peak memory usage. Ignored if the optimizer is
`optimize_acqf_discrete` or `optimize_acqf_discrete_local_search`.
optimizer_options: An optional dictionary of optimizer options. This may
include overrides for the above options (some of these under an `options`
dictionary) or any other option that is accepted by the optimizer. See
the docstrings in `botorch/optim/optimize.py` for supported options.
optimizer_options: An optional dictionary of optimizer options (some of
these under an `options` dictionary); default values will be used
where not specified. See the docstrings in
`botorch/optim/optimize.py` for supported options.
Example:
>>> optimizer_options = {
>>> "num_restarts": 20,
Expand Down Expand Up @@ -108,8 +83,8 @@ def optimizer_argparse(
options = {}
else:
options = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"num_restarts": NUM_RESTARTS,
"raw_samples": RAW_SAMPLES,
}

if optimizer in [
Expand All @@ -119,8 +94,8 @@ def optimizer_argparse(
"optimize_acqf_mixed_alternating",
]:
options["options"] = {
"init_batch_limit": init_batch_limit,
"batch_limit": batch_limit,
"init_batch_limit": INIT_BATCH_LIMIT,
"batch_limit": BATCH_LIMIT,
**provided_options.get("options", {}),
}
# Error out if options are specified for an optimizer that does not support the arg.
Expand All @@ -130,9 +105,7 @@ def optimizer_argparse(
)

if optimizer == "optimize_acqf":
options["sequential"] = sequential
elif sequential is False:
raise UnsupportedError(f"`{optimizer=}` does not support `sequential=False`.")
options["sequential"] = True

options.update(**{k: v for k, v in provided_options.items() if k != "options"})
return options
2 changes: 0 additions & 2 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ def _optimize_with_homotopy(
# Prepare arguments for optimizer
optimizer_options_with_defaults = optimizer_argparse(
self.acqf,
bounds=bounds,
q=n,
optimizer_options=optimizer_options,
optimizer="optimize_acqf_homotopy",
)
Expand Down
8 changes: 0 additions & 8 deletions ax/models/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,6 @@ def test_optimize(self, mock_optimize_acqf: Mock) -> None:
)
mock_optimizer_argparse.assert_called_once_with(
acquisition.acqf,
bounds=mock.ANY,
q=n,
optimizer_options=self.optimizer_options,
optimizer="optimize_acqf",
)
Expand Down Expand Up @@ -410,8 +408,6 @@ def test_optimize_discrete(self) -> None:
)
mock_optimizer_argparse.assert_called_once_with(
acquisition.acqf,
bounds=mock.ANY,
q=n,
optimizer_options=None,
optimizer="optimize_acqf_discrete",
)
Expand Down Expand Up @@ -459,8 +455,6 @@ def test_optimize_discrete(self) -> None:
)
mock_optimizer_argparse.assert_called_once_with(
acquisition.acqf,
bounds=mock.ANY,
q=3,
optimizer_options=None,
optimizer="optimize_acqf_discrete",
)
Expand Down Expand Up @@ -551,8 +545,6 @@ def test_optimize_acqf_discrete_local_search(
)
mock_optimizer_argparse.assert_called_once_with(
acquisition.acqf,
bounds=mock.ANY,
q=3,
optimizer_options=self.optimizer_options,
optimizer="optimize_acqf_discrete_local_search",
)
Expand Down
19 changes: 1 addition & 18 deletions ax/models/torch/tests/test_optimizer_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,6 @@ def test_optimizer_options(self) -> None:
optimizer=optimizer,
)

# `sequential=False` with optimizers other than `optimize_acqf`.
for optimizer in [
"optimize_acqf_homotopy",
"optimize_acqf_mixed",
"optimize_acqf_mixed_alternating",
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
]:
with self.assertRaisesRegex(
UnsupportedError, "does not support `sequential=False`"
):
optimizer_argparse(self.acqf, sequential=False, optimizer=optimizer)

def test_kg(self) -> None:
user_options = {"foo": "bar", "num_restarts": 114}
generic_options = optimizer_argparse(
Expand All @@ -158,11 +145,7 @@ def test_kg(self) -> None:
):
with self.subTest(acqf=acqf):
options = optimizer_argparse(
acqf,
q=None,
bounds=None,
optimizer_options=user_options,
optimizer="optimize_acqf",
acqf, optimizer_options=user_options, optimizer="optimize_acqf"
)
self.assertEqual(options, generic_options)

Expand Down

0 comments on commit a9a9a7c

Please sign in to comment.