Skip to content

Commit

Permalink
Rename and reorganize classes (facebook#2977)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2977

** Context**
The current structure of the code is such:
* Each `BenchmarkProblem` has a `BenchmarkRunner`
* ` BenchmarkRunner` is the only runner
* A` BenchmarkRunner` has a `ParamBasedTestProblem`, which is either a `BoTorchTestProblem`, `SurrogateTestFunction`, or a special subclass such as `Jenatton`.
The directory structure and names have gotten quite out of touch with the code.

**New class names**
* `ParamBasedTestProblem` -> `TestFunction` (maybe we should call this `BenchmarkTestFunction`?)
* `BoTorchTestProblem` -> `BoTorchTestFunction`

**New directory structure**
| benchmark_problem.py
| problems/
|    | synthetic/hss/jenatton.py
|    | ...
| benchmark_runner.py
| test_function.py
| test_function.py
| test_functions/
|    | botorch_test.py
|    | surrogate.py

Future diffs:
* rename `BenchmarkRunner.test_problem` to `BenchmarkRunner.test_function` (D65088791)

Differential Revision: D64969707

Reviewed By: saitcakmak, Balandat
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 30, 2024
1 parent 852629a commit 69e877f
Show file tree
Hide file tree
Showing 15 changed files with 112 additions and 99 deletions.
8 changes: 4 additions & 4 deletions ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import pandas as pd

from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_functions.botorch_test import BoTorchTestFunction
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, Objective
Expand Down Expand Up @@ -309,7 +309,7 @@ def create_problem_from_botorch(
Create a `BenchmarkProblem` from a BoTorch `BaseTestProblem`.
Uses specialized Metrics and Runners for benchmarking. The test problem's
result will be computed by the Runner, `BoTorchTestProblemRunner`, and
result will be computed by the Runner, `BenchmarkRunner`, and
retrieved by the Metric(s), which are `BenchmarkMetric`s.
Args:
Expand Down Expand Up @@ -378,7 +378,7 @@ def create_problem_from_botorch(
search_space=search_space,
optimization_config=optimization_config,
runner=BenchmarkRunner(
test_problem=BoTorchTestProblem(botorch_problem=test_problem),
test_problem=BoTorchTestFunction(botorch_problem=test_problem),
outcome_names=outcome_names,
search_space_digest=extract_search_space_digest(
search_space=search_space,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy.typing as npt

import torch
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.runner import Runner
Expand Down Expand Up @@ -48,15 +48,15 @@ class BenchmarkRunner(Runner):
Args:
outcome_names: The names of the outcomes returned by the problem.
test_problem: A ``ParamBasedTestProblem`` from which to generate
test_problem: A ``BenchmarkTestFunction`` from which to generate
deterministic data before adding noise.
noise_std: The standard deviation of the noise added to the data. Can be
a list or dict to be per-metric.
search_space_digest: Used to extract target fidelity and task.
"""

outcome_names: list[str]
test_problem: ParamBasedTestProblem
test_problem: BenchmarkTestFunction
noise_std: float | list[float] | dict[str, float] = 0.0
# pyre-fixme[16]: Pyre doesn't understand InitVars
search_space_digest: InitVar[SearchSpaceDigest | None] = None
Expand Down
32 changes: 32 additions & 0 deletions ax/benchmark/benchmark_test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass

from ax.core.types import TParamValue
from torch import Tensor


@dataclass(kw_only=True)
class BenchmarkTestFunction(ABC):
"""
The basic Ax class for generating deterministic data to benchmark against.
(Noise - if desired - is added by the runner.)
"""

@abstractmethod
def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
Evaluate noiselessly.
Returns:
1d tensor of shape (num_outcomes,).
"""
...
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,18 @@

# pyre-strict

from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from itertools import islice

import torch
from ax.core.types import TParamValue
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from botorch.test_functions.synthetic import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.utils.transforms import normalize, unnormalize
from torch import Tensor


@dataclass(kw_only=True)
class ParamBasedTestProblem(ABC):
"""
The basic Ax class for generating deterministic data to benchmark against.
(Noise - if desired - is added by the runner.)
"""

@abstractmethod
def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
Evaluate noiselessly.
Returns:
1d tensor of shape (num_outcomes,).
"""
...


@dataclass(kw_only=True)
class BoTorchTestProblem(ParamBasedTestProblem):
class BoTorchTestFunction(BenchmarkTestFunction):
"""
Class for generating data from a BoTorch ``BaseTestProblem``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass

import torch
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.observation import ObservationFeatures
from ax.core.types import TParamValue
from ax.modelbridge.torch import TorchModelBridge
Expand All @@ -21,7 +21,7 @@


@dataclass(kw_only=True)
class SurrogateTestFunction(ParamBasedTestProblem):
class SurrogateTestFunction(BenchmarkTestFunction):
"""
Data-generating function for surrogate benchmark problems.
Expand Down
14 changes: 7 additions & 7 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
BenchmarkProblem,
get_soo_config_and_outcome_names,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
Expand Down Expand Up @@ -113,7 +113,7 @@ def train_and_evaluate(


@dataclass(kw_only=True)
class PyTorchCNNTorchvisionParamBasedProblem(ParamBasedTestProblem):
class PyTorchCNNTorchvisionBenchmarkTestFunction(BenchmarkTestFunction):
name: str # The name of the dataset to load -- MNIST or FashionMNIST
device: torch.device = field(
default_factory=lambda: torch.device(
Expand Down Expand Up @@ -151,7 +151,7 @@ def __post_init__(self, train_loader: None, test_loader: None) -> None:
transform=transforms.ToTensor(),
)
# pyre-fixme: Undefined attribute [16]:
# `PyTorchCNNTorchvisionParamBasedProblem` has no attribute
# `PyTorchCNNTorchvisionBenchmarkTestFunction` has no attribute
# `train_loader`.
self.train_loader = DataLoader(train_set, num_workers=1)
# pyre-fixme
Expand All @@ -163,10 +163,10 @@ def evaluate_true(self, params: Mapping[str, int | float]) -> Tensor:
frac_correct = train_and_evaluate(
**params,
device=self.device,
# pyre-fixme[16]: `PyTorchCNNTorchvisionParamBasedProblem` has no
# pyre-fixme[16]: `PyTorchCNNTorchvisionBenchmarkTestFunction` has no
# attribute `train_loader`.
train_loader=self.train_loader,
# pyre-fixme[16]: `PyTorchCNNTorchvisionParamBasedProblem` has no
# pyre-fixme[16]: `PyTorchCNNTorchvisionBenchmarkTestFunction` has no
# attribute `test_loader`.
test_loader=self.test_loader,
)
Expand Down Expand Up @@ -215,7 +215,7 @@ def get_pytorch_cnn_torchvision_benchmark_problem(
objective_name="accuracy",
)
runner = BenchmarkRunner(
test_problem=PyTorchCNNTorchvisionParamBasedProblem(name=name),
test_problem=PyTorchCNNTorchvisionBenchmarkTestFunction(name=name),
outcome_names=outcome_names,
)
return BenchmarkProblem(
Expand Down
8 changes: 4 additions & 4 deletions ax/benchmark/problems/synthetic/discretized/mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from ax.benchmark.benchmark_metric import BenchmarkMetric

from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_functions.botorch_test import BoTorchTestFunction
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ParameterType, RangeParameter
Expand All @@ -47,7 +47,7 @@ def _get_problem_from_common_inputs(
Args:
bounds: The parameter bounds. These will be passed to
`BotorchTestProblemRunner` as `modified_bounds`, and the parameters
`BotorchTestFunction` as `modified_bounds`, and the parameters
will be renormalized from these bounds to the bounds of the original
problem. For example, if `bounds` are [(0, 3)] and the test
problem's original bounds are [(0, 2)], then the original problem
Expand Down Expand Up @@ -103,7 +103,7 @@ def _get_problem_from_common_inputs(
else:
test_problem = test_problem_class(dim=dim, bounds=test_problem_bounds)
runner = BenchmarkRunner(
test_problem=BoTorchTestProblem(
test_problem=BoTorchTestFunction(
botorch_problem=test_problem, modified_bounds=bounds
),
outcome_names=[metric_name],
Expand Down
6 changes: 3 additions & 3 deletions ax/benchmark/problems/synthetic/hss/jenatton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch
from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
Expand Down Expand Up @@ -50,7 +50,7 @@ def jenatton_test_function(


@dataclass(kw_only=True)
class Jenatton(ParamBasedTestProblem):
class Jenatton(BenchmarkTestFunction):
"""Jenatton test function for hierarchical search spaces."""

# pyre-fixme[14]: Inconsistent override
Expand Down
6 changes: 3 additions & 3 deletions ax/benchmark/tests/problems/test_mixed_integer_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import torch
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.benchmark_test_functions.botorch_test import BoTorchTestFunction

from ax.benchmark.problems.synthetic.discretized.mixed_integer import (
get_discrete_ackley,
get_discrete_hartmann,
get_discrete_rosenbrock,
)
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.core.arm import Arm
from ax.core.parameter import ParameterType
from ax.core.trial import Trial
Expand All @@ -35,7 +35,7 @@ def test_problems(self) -> None:
problem = constructor()
self.assertEqual(f"Discrete {name}", problem.name)
runner = problem.runner
test_problem = assert_is_instance(runner.test_problem, BoTorchTestProblem)
test_problem = assert_is_instance(runner.test_problem, BoTorchTestFunction)
botorch_problem = test_problem.botorch_problem
self.assertIsInstance(botorch_problem, problem_cls)
self.assertEqual(len(problem.search_space.parameters), dim)
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_problems(self) -> None:

for problem, params, expected_arg in cases:
runner = problem.runner
test_problem = assert_is_instance(runner.test_problem, BoTorchTestProblem)
test_problem = assert_is_instance(runner.test_problem, BoTorchTestFunction)
trial = Trial(experiment=MagicMock())
# pyre-fixme: Incompatible parameter type [6]: In call
# `Arm.__init__`, for argument `parameters`, expected `Dict[str,
Expand Down
Loading

0 comments on commit 69e877f

Please sign in to comment.