Skip to content

Commit

Permalink
HSS: only check for dependents if the root parameter is present when …
Browse files Browse the repository at this point in the history
…casting parameterization (#2361)

Summary:
Pull Request resolved: #2361

When `check_all_parameters_present=True`, we error out if any parameter is missing. When it is False, we need to tolerate any missing parameters and ignore that branch of the HSS tree -- not doing so would be ignoring the argument. The previous implementation would error out while trying to look up the value of the missing parameter if it had dependents.

Reviewed By: Balandat

Differential Revision: D56084081

fbshipit-source-id: a2cb39219fd0c14c06caa1835b9e0422713947a1
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 13, 2024
1 parent c09963b commit 97d2a3a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
17 changes: 11 additions & 6 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,11 @@ def _cast_parameterization(
Args:
parameters: Parameterization to cast to hierarchical structure.
check_all_parameters_present: Whether to raise an error if a paramete
that is expected to be present (according to values of other
parameters and the hierarchical structure of the search space)
is not specified.
check_all_parameters_present: Whether to raise an error if a parameter
that is expected to be present (according to values of other
parameters and the hierarchical structure of the search space)
is not specified. When this is False, if a parameter is missing,
its dependents will not be included in the returned parameterization.
"""
error_msg_prefix: str = (
f"Parameterization {parameters} violates the hierarchical structure "
Expand All @@ -682,11 +683,15 @@ def _find_applicable_parameters(root: Parameter) -> Set[str]:
+ f"Parameter '{root.name}' not in parameterization to cast."
)

if not root.is_hierarchical:
# Return if the root parameter is not hierarchical or if it is not
# in the parameterization to cast.
if not root.is_hierarchical or root.name not in parameters:
return applicable

# Find the dependents of the current root parameter.
root_val = parameters[root.name]
for val, deps in root.dependents.items():
if parameters[root.name] == val:
if root_val == val:
for dep in deps:
applicable.update(_find_applicable_parameters(root=self[dep]))

Expand Down
45 changes: 38 additions & 7 deletions ax/core/tests/test_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
SearchSpace,
SearchSpaceDigest,
)
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -621,13 +622,12 @@ def setUp(self) -> None:
"num_boost_rounds": 12,
}
)
self.hss_1_arm_missing_param = Arm(
parameters={
"model": "Linear",
"l2_reg_weight": 0.0001,
"num_boost_rounds": 12,
}
)
self.hss_1_missing_params: TParameterization = {
"model": "Linear",
"l2_reg_weight": 0.0001,
"num_boost_rounds": 12,
}
self.hss_1_arm_missing_param = Arm(parameters=self.hss_1_missing_params)
self.hss_1_arm_1_cast = Arm(
parameters={
"model": "Linear",
Expand Down Expand Up @@ -759,6 +759,7 @@ def test_flatten(self) -> None:
self.assertTrue(str(flattened_hss_with_constraints).startswith("SearchSpace"))

def test_cast_arm(self) -> None:
# This uses _cast_parameterization with check_all_parameters_present=True.
self.assertEqual( # Check one subtree.
self.hss_1._cast_arm(arm=self.hss_1_arm_1_flat),
self.hss_1_arm_1_cast,
Expand All @@ -775,6 +776,7 @@ def test_cast_arm(self) -> None:
self.hss_1._cast_arm(arm=self.hss_1_arm_missing_param)

def test_cast_observation_features(self) -> None:
# This uses _cast_parameterization with check_all_parameters_present=False.
# Ensure that during casting, full parameterization is saved
# in metadata and actual parameterization is cast to HSS.
hss_1_obs_feats_1 = ObservationFeatures.from_arm(arm=self.hss_1_arm_1_flat)
Expand All @@ -798,6 +800,35 @@ def test_cast_observation_features(self) -> None:
ObservationFeatures.from_arm(arm=self.hss_1_arm_1_cast),
)

def test_cast_parameterization(self) -> None:
# NOTE: This is also tested in test_cast_arm & test_cast_observation_features.
with self.assertRaisesRegex(RuntimeError, "not in parameterization to cast"):
self.hss_1._cast_parameterization(
parameters=self.hss_1_missing_params,
check_all_parameters_present=True,
)
# An active leaf param is missing, it'll get ignored. There's an inactive
# leaf param, that'll just get filtered out.
self.assertEqual(
self.hss_1._cast_parameterization(
parameters=self.hss_1_missing_params,
check_all_parameters_present=False,
),
{"l2_reg_weight": 0.0001, "model": "Linear"},
)
# A hierarchical param is missing, all its dependents will be ignored.
# In this case, it is the root param, so we'll have empty parameterization.
self.assertEqual(
self.hss_1._cast_parameterization(
parameters={
"l2_reg_weight": 0.0001,
"num_boost_rounds": 12,
},
check_all_parameters_present=False,
),
{},
)

def test_flatten_observation_features(self) -> None:
# Ensure that during casting, full parameterization is saved
# in metadata and actual parameterization is cast to HSS; during
Expand Down

0 comments on commit 97d2a3a

Please sign in to comment.