From 322ea9627007e0702d6f87fe57b625ee43153e14 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Thu, 24 Oct 2024 08:55:28 -0700 Subject: [PATCH] Add checks that only Range parameters can have ParameterConstraints instantiated (#2936) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2936 As titled. This is a little janky but will do for now, a clearner validation scheme will be set up with via the Ax API workstream. Reviewed By: saitcakmak Differential Revision: D64784254 fbshipit-source-id: 3c421059f86298c9bed43d305b7532b6c4c0649a --- ax/core/tests/test_experiment.py | 4 +- ax/service/tests/test_instantiation_utils.py | 135 ++++++++++++++++++- ax/service/utils/instantiation.py | 12 +- 3 files changed, 141 insertions(+), 10 deletions(-) diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 25b97c0b1bf..414667bebfa 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -210,7 +210,7 @@ def test_OnlyRangeParameterConstraints(self) -> None: # Try (and fail) to create an experiment with constraints on choice # paramaters - with self.assertRaises(UnsupportedError): + with self.assertRaises(ValueError): ax_client.create_experiment( name="experiment", parameters=[ @@ -231,7 +231,7 @@ def test_OnlyRangeParameterConstraints(self) -> None: # Try (and fail) to create an experiment with constraints on fixed # parameters - with self.assertRaises(UnsupportedError): + with self.assertRaises(ValueError): ax_client.create_experiment( name="experiment", parameters=[ diff --git a/ax/service/tests/test_instantiation_utils.py b/ax/service/tests/test_instantiation_utils.py index bfa75db3f79..4cbf347897a 100644 --- a/ax/service/tests/test_instantiation_utils.py +++ b/ax/service/tests/test_instantiation_utils.py @@ -51,7 +51,20 @@ def test_constraint_from_str(self) -> None: "x1 + x2 <= not_numerical_bound", # pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but # got `Dict[str, None]`. - {"x1": None, "x2": None}, + { + "x1": RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=2.0, + ), + "x2": RangeParameter( + name="x2", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=2.0, + ), + }, ) with self.assertRaisesRegex(ValueError, "Outcome constraint bound"): InstantiationBase.outcome_constraint_from_str("m1 <= not_numerical_bound") @@ -92,7 +105,14 @@ def test_constraint_from_str(self) -> None: "x1 <= 0", # pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but # got `Dict[str, None]`. - {"x1": None, "x2": None}, + { + "x1": RangeParameter( + name="x1", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0 + ), + "x2": RangeParameter( + name="x2", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0 + ), + }, ) self.assertEqual(one_val_constraint.bound, 0.0) self.assertEqual(one_val_constraint.constraint_dict, {"x1": 1.0}) @@ -100,7 +120,14 @@ def test_constraint_from_str(self) -> None: "-0.5*x1 >= -0.1", # pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but # got `Dict[str, None]`. - {"x1": None, "x2": None}, + { + "x1": RangeParameter( + name="x1", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0 + ), + "x2": RangeParameter( + name="x2", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0 + ), + }, ) self.assertEqual(one_val_constraint.bound, 0.1) self.assertEqual(one_val_constraint.constraint_dict, {"x1": 0.5}) @@ -128,28 +155,122 @@ def test_constraint_from_str(self) -> None: "x1 - e*x2 + x3 <= 3", # pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but # got `Dict[str, None]`. - {"x1": None, "x2": None, "x3": None}, + { + "x1": RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + "x2": RangeParameter( + name="x2", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + "x3": RangeParameter( + name="x3", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + }, ) with self.assertRaisesRegex(ValueError, "A linear constraint should be"): InstantiationBase.constraint_from_str( "x1 - 2 *x2 + 3 *x3 <= 3", # pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but # got `Dict[str, None]`. - {"x1": None, "x2": None, "x3": None}, + { + "x1": RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + "x2": RangeParameter( + name="x2", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + "x3": RangeParameter( + name="x3", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + }, ) with self.assertRaisesRegex(ValueError, "A linear constraint should be"): InstantiationBase.constraint_from_str( "x1 - 2* x2 + 3* x3 <= 3", # pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but # got `Dict[str, None]`. - {"x1": None, "x2": None, "x3": None}, + { + "x1": RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + "x2": RangeParameter( + name="x2", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + "x3": RangeParameter( + name="x3", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + }, ) with self.assertRaisesRegex(ValueError, "A linear constraint should be"): InstantiationBase.constraint_from_str( "x1 - 2 * x2 + 3*x3 <= 3", # pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but # got `Dict[str, None]`. - {"x1": None, "x2": None, "x3": None}, + { + "x1": RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + "x2": RangeParameter( + name="x2", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + "x3": RangeParameter( + name="x3", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=4.0, + ), + }, + ) + + with self.assertRaisesRegex( + ValueError, "Parameter constraints not supported for ChoiceParameter" + ): + InstantiationBase.constraint_from_str( + "x1 + x2 <= 3", + { + "x1": RangeParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + lower=0.1, + upper=2.0, + ), + "x2": ChoiceParameter( + name="x2", parameter_type=ParameterType.FLOAT, values=[0, 1, 2] + ), + }, ) def test_add_tracking_metrics(self) -> None: diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index b51ff0eae7a..5f96d8cabf0 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -35,7 +35,11 @@ RangeParameter, TParameterType, ) -from ax.core.parameter_constraint import OrderConstraint, ParameterConstraint +from ax.core.parameter_constraint import ( + OrderConstraint, + ParameterConstraint, + validate_constraint_parameters, +) from ax.core.search_space import HierarchicalSearchSpace, SearchSpace from ax.core.types import ComparisonOp, TParameterization, TParamValue from ax.exceptions.core import UnsupportedError @@ -403,6 +407,9 @@ def constraint_from_str( assert ( right in parameter_names ), f"Parameter {right} not in {parameter_names}." + validate_constraint_parameters( + parameters=[parameters[left], parameters[right]] + ) return ( OrderConstraint( lower_parameter=parameters[left], upper_parameter=parameters[right] @@ -451,9 +458,12 @@ def constraint_from_str( multiplier = -1.0 else: multiplier = 1.0 + assert ( parameter in parameter_names ), f"Parameter {parameter} not in {parameter_names}." + validate_constraint_parameters(parameters=[parameters[parameter]]) + parameter_weight[parameter] = operator_sign * multiplier # for operators else: