Skip to content

Commit

Permalink
Support MultiTypeExperiment in Instantiation (#2939)
Browse files Browse the repository at this point in the history
Summary:

1. **InstatntiationBase:** Add support returning MultiTypeExperiment in InstatntiationBase._make_experiment.
2. **MultiTypeExperiment:** Add add_tracking_metrics function in MultiTypeExperiment to support batch adding metrics when creating a MultiTypeExperiment.
3. **AxClient**: Add support for creating MultiTypeExperiment, add_trial_type and add_tracking_metrics.

Reviewed By: sdaulton

Differential Revision: D64612495
  • Loading branch information
andycylmeta authored and facebook-github-bot committed Nov 3, 2024
1 parent d983045 commit 8f54e69
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 9 deletions.
42 changes: 42 additions & 0 deletions ax/core/multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
default_trial_type: str,
default_runner: Runner,
optimization_config: OptimizationConfig | None = None,
tracking_metrics: list[Metric] | None = None,
status_quo: Arm | None = None,
description: str | None = None,
is_test: bool = False,
Expand All @@ -65,6 +66,7 @@ def __init__(
default_runner: Default runner for trials of the default type.
optimization_config: Optimization config of the experiment.
tracking_metrics: Additional tracking metrics not used for optimization.
These are associated with the default trial type.
runner: Default runner used for trials on this experiment.
status_quo: Arm representing existing "control" arm.
description: Description of the experiment.
Expand Down Expand Up @@ -101,6 +103,7 @@ def __init__(
experiment_type=experiment_type,
properties=properties,
default_data_type=default_data_type,
tracking_metrics=tracking_metrics,
)

def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment":
Expand Down Expand Up @@ -163,6 +166,45 @@ def add_tracking_metric(
self._metric_to_canonical_name[metric.name] = canonical_name
return self

def add_tracking_metrics(
self,
metrics: list[Metric],
metrics_to_trial_types: dict[str, str] | None = None,
canonical_names: dict[str, str] | None = None,
) -> Experiment:
"""Add a list of new metrics to the experiment.
If any of the metrics are already defined on the experiment,
we raise an error and don't add any of them to the experiment
Args:
metrics: Metrics to be added.
metrics_to_trial_types: The mapping from metric names to corresponding
trial types for each metric. If provided, the metrics will be
added to their trial types. If not provided, then the default
trial type will be used.
canonical_names: A mapping of metric names to their
canonical names(The default metrics for which the metrics are
proxies.)
Returns:
The experiment with the added metrics.
"""
metrics_to_trial_types = metrics_to_trial_types or {}
canonical_name = None
for metric in metrics:
if canonical_names is not None:
canonical_name = none_throws(canonical_names).get(metric.name, None)

self.add_tracking_metric(
metric=metric,
trial_type=metrics_to_trial_types.get(
metric.name, self._default_trial_type
),
canonical_name=canonical_name,
)
return self

# pyre-fixme[14]: `update_tracking_metric` overrides method defined in
# `Experiment` inconsistently.
def update_tracking_metric(
Expand Down
34 changes: 34 additions & 0 deletions ax/core/tests/test_multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,40 @@ def test_runner_for_trial_type(self) -> None:
):
self.experiment.runner_for_trial_type(trial_type="invalid")

def test_add_tracking_metrics(self) -> None:
type1_metrics = [
BraninMetric("m3_type1", ["x1", "x2"]),
BraninMetric("m4_type1", ["x1", "x2"]),
]
type2_metrics = [
BraninMetric("m3_type2", ["x1", "x2"]),
BraninMetric("m4_type2", ["x1", "x2"]),
]
default_type_metrics = [
BraninMetric("m5_default_type", ["x1", "x2"]),
]
self.experiment.add_tracking_metrics(
metrics=type1_metrics + type2_metrics + default_type_metrics,
metrics_to_trial_types={
"m3_type1": "type1",
"m4_type1": "type1",
"m3_type2": "type2",
"m4_type2": "type2",
},
)
self.assertDictEqual(
self.experiment._metric_to_trial_type,
{
"m1": "type1",
"m2": "type2",
"m3_type1": "type1",
"m4_type1": "type1",
"m3_type2": "type2",
"m4_type2": "type2",
"m5_default_type": "type1",
},
)


class MultiTypeExperimentUtilsTest(TestCase):
def setUp(self) -> None:
Expand Down
49 changes: 41 additions & 8 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,22 @@
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.objective import MultiObjective, Objective
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import (
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.core.types import (
TEvaluationOutcome,
TModelPredictArm,
TParameterization,
TParamValue,
)

from ax.core.utils import get_pending_observation_features_based_on_trial_status
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.early_stopping.utils import estimate_early_stopping_savings
Expand Down Expand Up @@ -90,6 +93,7 @@
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import assert_is_instance, none_throws


logger: Logger = get_logger(__name__)


Expand Down Expand Up @@ -251,6 +255,8 @@ def create_experiment(
immutable_search_space_and_opt_config: bool = True,
is_test: bool = False,
metric_definitions: dict[str, dict[str, Any]] | None = None,
default_trial_type: str | None = None,
default_runner: Runner | None = None,
) -> None:
"""Create a new experiment and save it if DBSettings available.
Expand Down Expand Up @@ -316,6 +322,15 @@ def create_experiment(
to that metric. Note these are modified in-place. Each
Metric must have its own dictionary (metrics cannot share a
single dictionary object).
default_trial_type: The default trial type if multiple
trial types are intended to be used in the experiment. If specified,
a MultiTypeExperiment will be created. Otherwise, a single-type
Experiment will be created.
default_runner: The default runner in this experiment.
This applies to MultiTypeExperiment (when default_trial_type
is specified) and needs to be specified together with
default_trial_type. This will be ignored for single-type Experiment
(when default_trial_type is not specified).
"""
self._validate_early_stopping_strategy(support_intermediate_data)

Expand Down Expand Up @@ -344,6 +359,8 @@ def create_experiment(
support_intermediate_data=support_intermediate_data,
immutable_search_space_and_opt_config=immutable_search_space_and_opt_config,
is_test=is_test,
default_trial_type=default_trial_type,
default_runner=default_runner,
**objective_kwargs,
)
self._set_runner(experiment=experiment)
Expand Down Expand Up @@ -416,6 +433,8 @@ def add_tracking_metrics(
self,
metric_names: list[str],
metric_definitions: dict[str, dict[str, Any]] | None = None,
metrics_to_trial_types: dict[str, str] | None = None,
canonical_names: dict[str, str] | None = None,
) -> None:
"""Add a list of new metrics to the experiment.
Expand All @@ -428,20 +447,34 @@ def add_tracking_metrics(
to that metric. Note these are modified in-place. Each
Metric must have its is own dictionary (metrics cannot share a
single dictionary object).
metrics_to_trial_types: Only applicable to MultiTypeExperiment.
The mapping from metric names to corresponding
trial types for each metric. If provided, the metrics will be
added with their respective trial types. If not provided, then the
default trial type will be used.
canonical_names: A mapping from metric name (of a particular trial type)
to the metric name of the default trial type. Only applicable to
MultiTypeExperiment.
"""
metric_definitions = (
self.metric_definitions
if metric_definitions is None
else metric_definitions
)
self.experiment.add_tracking_metrics(
metrics=[
self._make_metric(
name=metric_name, metric_definitions=metric_definitions
)
for metric_name in metric_names
]
)
metric_objects = [
self._make_metric(name=metric_name, metric_definitions=metric_definitions)
for metric_name in metric_names
]

if isinstance(self.experiment, MultiTypeExperiment):
experiment = assert_is_instance(self.experiment, MultiTypeExperiment)
experiment.add_tracking_metrics(
metrics=metric_objects,
metrics_to_trial_types=metrics_to_trial_types,
canonical_names=canonical_names,
)
else:
self.experiment.add_tracking_metrics(metrics=metric_objects)

@copy_doc(Experiment.remove_tracking_metric)
def remove_tracking_metric(self, metric_name: str) -> None:
Expand Down
106 changes: 105 additions & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ax.core.arm import Arm
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint
from ax.core.parameter import (
Expand Down Expand Up @@ -57,6 +58,7 @@
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import Models
from ax.runners.synthetic import SyntheticRunner

from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.best_point import (
Expand All @@ -83,7 +85,7 @@
from ax.utils.testing.mock import mock_botorch_optimize
from ax.utils.testing.modeling_stubs import get_observation1, get_observation1trans
from botorch.test_functions.multi_objective import BraninCurrin
from pyre_extensions import none_throws
from pyre_extensions import assert_is_instance, none_throws

if TYPE_CHECKING:
from ax.core.types import TTrialEvaluation
Expand Down Expand Up @@ -821,6 +823,7 @@ def test_create_experiment(self) -> None:
is_test=True,
)
assert ax_client._experiment is not None
self.assertEqual(ax_client.experiment.__class__.__name__, "Experiment")
self.assertEqual(ax_client._experiment, ax_client.experiment)
self.assertEqual(
# pyre-fixme[16]: `Optional` has no attribute `search_space`.
Expand Down Expand Up @@ -903,6 +906,107 @@ def test_create_experiment(self) -> None:
{"test_objective", "some_metric", "test_tracking_metric"},
)

def test_create_multitype_experiment(self) -> None:
"""
Test create multitype experiment, add trial type, and add metrics to
different trial types
"""
ax_client = AxClient(
GenerationStrategy(
steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]
)
)
ax_client.create_experiment(
name="test_experiment",
parameters=[
{
"name": "x",
"type": "range",
"bounds": [0.001, 0.1],
"value_type": "float",
"log_scale": True,
"digits": 6,
},
{
"name": "y",
"type": "choice",
"values": [1, 2, 3],
"value_type": "int",
"is_ordered": True,
},
{"name": "x3", "type": "fixed", "value": 2, "value_type": "int"},
{
"name": "x4",
"type": "range",
"bounds": [1.0, 3.0],
"value_type": "int",
},
{
"name": "x5",
"type": "choice",
"values": ["one", "two", "three"],
"value_type": "str",
},
{
"name": "x6",
"type": "range",
"bounds": [1.0, 3.0],
"value_type": "int",
},
],
objectives={"test_objective": ObjectiveProperties(minimize=True)},
outcome_constraints=["some_metric >= 3", "some_metric <= 4.0"],
parameter_constraints=["x4 <= x6"],
tracking_metric_names=["test_tracking_metric"],
is_test=True,
default_trial_type="test_trial_type",
default_runner=SyntheticRunner(),
)

self.assertEqual(ax_client.experiment.__class__.__name__, "MultiTypeExperiment")
experiment = assert_is_instance(ax_client.experiment, MultiTypeExperiment)
self.assertEqual(
experiment._trial_type_to_runner["test_trial_type"].__class__.__name__,
"SyntheticRunner",
)
self.assertEqual(
experiment._metric_to_trial_type,
{
"test_tracking_metric": "test_trial_type",
"test_objective": "test_trial_type",
"some_metric": "test_trial_type",
},
)
experiment.add_trial_type(
trial_type="test_trial_type_2",
runner=SyntheticRunner(),
)
ax_client.add_tracking_metrics(
metric_names=[
"some_metric2_type1",
"some_metric3_type1",
"some_metric4_type2",
"some_metric5_type2",
],
metrics_to_trial_types={
"some_metric2_type1": "test_trial_type",
"some_metric4_type2": "test_trial_type_2",
"some_metric5_type2": "test_trial_type_2",
},
)
self.assertEqual(
experiment._metric_to_trial_type,
{
"test_tracking_metric": "test_trial_type",
"test_objective": "test_trial_type",
"some_metric": "test_trial_type",
"some_metric2_type1": "test_trial_type",
"some_metric3_type1": "test_trial_type",
"some_metric4_type2": "test_trial_type_2",
"some_metric5_type2": "test_trial_type_2",
},
)

def test_create_single_objective_experiment_with_objectives_dict(self) -> None:
ax_client = AxClient(
GenerationStrategy(
Expand Down
19 changes: 19 additions & 0 deletions ax/service/tests/test_instantiation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RangeParameter,
)
from ax.core.search_space import HierarchicalSearchSpace
from ax.runners.synthetic import SyntheticRunner
from ax.service.utils.instantiation import InstantiationBase
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
Expand Down Expand Up @@ -431,3 +432,21 @@ def test_hss(self) -> None:
self.assertIsInstance(search_space, HierarchicalSearchSpace)
# pyre-fixme[16]: `SearchSpace` has no attribute `_root`.
self.assertEqual(search_space._root.name, "root")

def test_make_multitype_experiment_with_default_trial_type(self) -> None:
experiment = InstantiationBase.make_experiment(
name="test_make_experiment",
parameters=[{"name": "x", "type": "range", "bounds": [0, 1]}],
tracking_metric_names=None,
default_trial_type="test_trial_type",
default_runner=SyntheticRunner(),
)
self.assertEqual(experiment.__class__.__name__, "MultiTypeExperiment")

def test_make_single_type_experiment_with_no_default_trial_type(self) -> None:
experiment = InstantiationBase.make_experiment(
name="test_make_experiment",
parameters=[{"name": "x", "type": "range", "bounds": [0, 1]}],
tracking_metric_names=None,
)
self.assertEqual(experiment.__class__.__name__, "Experiment")
Loading

0 comments on commit 8f54e69

Please sign in to comment.