Skip to content

Commit

Permalink
Add checks that only Range parameters can have ParameterConstraints i…
Browse files Browse the repository at this point in the history
…nstantiated (#2936)

Summary:
Pull Request resolved: #2936

As titled. This is a little janky but will do for now, a clearner validation scheme will be set up with via the Ax API workstream.

Reviewed By: saitcakmak

Differential Revision: D64784254

fbshipit-source-id: 3c421059f86298c9bed43d305b7532b6c4c0649a
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Oct 24, 2024
1 parent f37111c commit 322ea96
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_OnlyRangeParameterConstraints(self) -> None:

# Try (and fail) to create an experiment with constraints on choice
# paramaters
with self.assertRaises(UnsupportedError):
with self.assertRaises(ValueError):
ax_client.create_experiment(
name="experiment",
parameters=[
Expand All @@ -231,7 +231,7 @@ def test_OnlyRangeParameterConstraints(self) -> None:

# Try (and fail) to create an experiment with constraints on fixed
# parameters
with self.assertRaises(UnsupportedError):
with self.assertRaises(ValueError):
ax_client.create_experiment(
name="experiment",
parameters=[
Expand Down
135 changes: 128 additions & 7 deletions ax/service/tests/test_instantiation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,20 @@ def test_constraint_from_str(self) -> None:
"x1 + x2 <= not_numerical_bound",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=2.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=2.0,
),
},
)
with self.assertRaisesRegex(ValueError, "Outcome constraint bound"):
InstantiationBase.outcome_constraint_from_str("m1 <= not_numerical_bound")
Expand Down Expand Up @@ -92,15 +105,29 @@ def test_constraint_from_str(self) -> None:
"x1 <= 0",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None},
{
"x1": RangeParameter(
name="x1", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0
),
"x2": RangeParameter(
name="x2", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0
),
},
)
self.assertEqual(one_val_constraint.bound, 0.0)
self.assertEqual(one_val_constraint.constraint_dict, {"x1": 1.0})
one_val_constraint = InstantiationBase.constraint_from_str(
"-0.5*x1 >= -0.1",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None},
{
"x1": RangeParameter(
name="x1", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0
),
"x2": RangeParameter(
name="x2", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0
),
},
)
self.assertEqual(one_val_constraint.bound, 0.1)
self.assertEqual(one_val_constraint.constraint_dict, {"x1": 0.5})
Expand Down Expand Up @@ -128,28 +155,122 @@ def test_constraint_from_str(self) -> None:
"x1 - e*x2 + x3 <= 3",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None, "x3": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x3": RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
},
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
InstantiationBase.constraint_from_str(
"x1 - 2 *x2 + 3 *x3 <= 3",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None, "x3": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x3": RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
},
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
InstantiationBase.constraint_from_str(
"x1 - 2* x2 + 3* x3 <= 3",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None, "x3": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x3": RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
},
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
InstantiationBase.constraint_from_str(
"x1 - 2 * x2 + 3*x3 <= 3",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None, "x3": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x3": RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
},
)

with self.assertRaisesRegex(
ValueError, "Parameter constraints not supported for ChoiceParameter"
):
InstantiationBase.constraint_from_str(
"x1 + x2 <= 3",
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=2.0,
),
"x2": ChoiceParameter(
name="x2", parameter_type=ParameterType.FLOAT, values=[0, 1, 2]
),
},
)

def test_add_tracking_metrics(self) -> None:
Expand Down
12 changes: 11 additions & 1 deletion ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
RangeParameter,
TParameterType,
)
from ax.core.parameter_constraint import OrderConstraint, ParameterConstraint
from ax.core.parameter_constraint import (
OrderConstraint,
ParameterConstraint,
validate_constraint_parameters,
)
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.core.types import ComparisonOp, TParameterization, TParamValue
from ax.exceptions.core import UnsupportedError
Expand Down Expand Up @@ -403,6 +407,9 @@ def constraint_from_str(
assert (
right in parameter_names
), f"Parameter {right} not in {parameter_names}."
validate_constraint_parameters(
parameters=[parameters[left], parameters[right]]
)
return (
OrderConstraint(
lower_parameter=parameters[left], upper_parameter=parameters[right]
Expand Down Expand Up @@ -451,9 +458,12 @@ def constraint_from_str(
multiplier = -1.0
else:
multiplier = 1.0

assert (
parameter in parameter_names
), f"Parameter {parameter} not in {parameter_names}."
validate_constraint_parameters(parameters=[parameters[parameter]])

parameter_weight[parameter] = operator_sign * multiplier
# for operators
else:
Expand Down

0 comments on commit 322ea96

Please sign in to comment.