Skip to content

Commit

Permalink
Supports more multitask types in GP-based designers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715844547
  • Loading branch information
vizier-team authored and copybara-github committed Jan 15, 2025
1 parent eaff660 commit 5ab515f
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 7 deletions.
6 changes: 6 additions & 0 deletions vizier/_src/algorithms/designers/gp/gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vizier._src.algorithms.designers.gp import transfer_learning as vtl
from vizier._src.jax import stochastic_process_model as sp
from vizier._src.jax import types
from vizier._src.jax.models import multitask_tuned_gp_models
from vizier._src.jax.models import tuned_gp_models
from vizier.jax import optimizers
from vizier.utils import profiler
Expand Down Expand Up @@ -143,20 +144,25 @@ def get_vizier_gp_coroutine(
data: types.ModelData,
*,
linear_coef: Optional[float] = None,
multitask_type: multitask_tuned_gp_models.MultiTaskType = (
multitask_tuned_gp_models.MultiTaskType.INDEPENDENT
),
) -> sp.ModelCoroutine:
"""Gets a GP model coroutine.
Args:
data: The data used to the train the GP model
linear_coef: If non-zero, uses a linear kernel with `linear_coef`
hyperparameter.
multitask_type: The type of multitask kernel to use for multimetric GP.
Returns:
The model coroutine.
"""
return tuned_gp_models.VizierGaussianProcess.build_model(
data,
linear_coef=linear_coef,
multitask_type=multitask_type,
).coroutine


Expand Down
15 changes: 13 additions & 2 deletions vizier/_src/algorithms/designers/gp/gp_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
from vizier._src.algorithms.designers.gp import acquisitions
from vizier._src.algorithms.designers.gp import gp_models
from vizier._src.jax import types
from vizier._src.jax.models import multitask_tuned_gp_models
from vizier.jax import optimizers
from vizier.pyvizier import converters

from absl.testing import absltest
from absl.testing import parameterized

mt_type = multitask_tuned_gp_models.MultiTaskType


def _setup_lambda_search(
f: Callable[[float], float],
Expand Down Expand Up @@ -363,7 +366,13 @@ def test_single_list_same_as_singleton(self):

self.assertAlmostEqual(list_gp_mse, singleton_gp_mse)

def test_multi_task(self):
@parameterized.parameters(
dict(multitask_type=mt_type.INDEPENDENT),
dict(multitask_type=mt_type.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR),
dict(multitask_type=mt_type.SEPARABLE_LKJ_TASK_KERNEL_PRIOR),
dict(multitask_type=mt_type.SEPARABLE_DIAG_TASK_KERNEL_PRIOR),
)
def test_multi_task(self, multitask_type: mt_type):
search_space = vz.SearchSpace()
search_space.root.add_float_param('x0', -5.0, 5.0)
problem = vz.ProblemStatement(
Expand Down Expand Up @@ -400,7 +409,9 @@ def test_multi_task(self):
train_spec = gp_models.GPTrainingSpec(
ard_optimizer=optimizers.default_optimizer(),
ard_rng=jax.random.PRNGKey(0),
coroutine=gp_models.get_vizier_gp_coroutine(data=model_data),
coroutine=gp_models.get_vizier_gp_coroutine(
data=model_data, multitask_type=multitask_type
),
)
gp = gp_models.train_gp(train_spec, model_data)

Expand Down
13 changes: 11 additions & 2 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from vizier._src.algorithms.optimizers import vectorized_base as vb
from vizier._src.jax import stochastic_process_model as sp
from vizier._src.jax import types
from vizier._src.jax.models import multitask_tuned_gp_models
from vizier.jax import optimizers
from vizier.pyvizier import converters
from vizier.pyvizier.converters import padding
Expand Down Expand Up @@ -146,6 +147,10 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
# Multi-objective parameters.
_num_scalarizations: int = attr.field(default=1000, kw_only=True)
_ref_scaling: float = attr.field(default=0.01, kw_only=True)
_multitask_type: multitask_tuned_gp_models.MultiTaskType = attr.field(
default=multitask_tuned_gp_models.MultiTaskType.INDEPENDENT,
kw_only=True,
)

# ------------------------------------------------------------------
# Internal attributes which should not be set by callers.
Expand Down Expand Up @@ -230,7 +235,9 @@ def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
self._use_trust_region = False

# Additional validations
coroutine = gp_models.get_vizier_gp_coroutine(empty_data)
coroutine = gp_models.get_vizier_gp_coroutine(
empty_data, multitask_type=self._multitask_type
)
params = sp.CoroutineWithData(coroutine, empty_data).setup(self._rng)
model = sp.StochasticProcessWithCoroutine(coroutine, params)
predictive = sp.UniformEnsemblePredictive(
Expand Down Expand Up @@ -389,7 +396,9 @@ def _create_gp_spec(
ard_optimizer=self._ard_optimizer,
ard_rng=ard_rng,
coroutine=gp_models.get_vizier_gp_coroutine(
data=data, linear_coef=self._linear_coef
data=data,
linear_coef=self._linear_coef,
multitask_type=self._multitask_type,
),
ensemble_size=self._ensemble_size,
ard_random_restarts=self._ard_random_restarts,
Expand Down
12 changes: 10 additions & 2 deletions vizier/_src/algorithms/designers/gp_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from vizier._src.algorithms.testing import test_runners
from vizier._src.benchmarks.experimenters.synthetic import simplekd
from vizier._src.jax import types
from vizier._src.jax.models import multitask_tuned_gp_models
from vizier.jax import optimizers
from vizier.pyvizier import converters
from vizier.pyvizier.converters import padding
Expand All @@ -43,6 +44,7 @@
from absl.testing import absltest
from absl.testing import parameterized

mt_type = multitask_tuned_gp_models.MultiTaskType

ard_optimizer = optimizers.default_optimizer()
vectorized_optimizer_factory = vb.VectorizedOptimizerFactory(
Expand Down Expand Up @@ -471,7 +473,13 @@ def _qei_factory(data: types.ModelData) -> acquisitions.AcquisitionFunction:
iters * n_parallel,
)

def test_multi_metrics(self):
@parameterized.parameters(
dict(multitask_type=mt_type.INDEPENDENT),
dict(multitask_type=mt_type.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR),
dict(multitask_type=mt_type.SEPARABLE_LKJ_TASK_KERNEL_PRIOR),
dict(multitask_type=mt_type.SEPARABLE_DIAG_TASK_KERNEL_PRIOR),
)
def test_multi_metrics(self, multitask_type: mt_type):
search_space = vz.SearchSpace()
search_space.root.add_float_param('x0', -5.0, 5.0)
problem = vz.ProblemStatement(
Expand All @@ -489,7 +497,7 @@ def test_multi_metrics(self):
)

iters = 2
designer = gp_bandit.VizierGPBandit(problem)
designer = gp_bandit.VizierGPBandit(problem, multitask_type=multitask_type)
self.assertLen(
test_runners.RandomMetricsRunner(
problem,
Expand Down
23 changes: 22 additions & 1 deletion vizier/_src/algorithms/designers/gp_ucb_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from vizier._src.algorithms.optimizers import vectorized_base as vb
from vizier._src.jax import stochastic_process_model as sp
from vizier._src.jax import types
from vizier._src.jax.models import multitask_tuned_gp_models
from vizier._src.jax.models import tuned_gp_models
from vizier.jax import optimizers
from vizier.pyvizier import converters
Expand Down Expand Up @@ -119,6 +120,11 @@ class UCBPEConfig(eqx.Module):
default=MultimetricPromisingRegionPenaltyType.AVERAGE, static=True
)

# The type of multitask kernel to use for multimetric problems.
multitask_type: multitask_tuned_gp_models.MultiTaskType = eqx.field(
default=multitask_tuned_gp_models.MultiTaskType.INDEPENDENT, static=True
)

def __repr__(self):
return eqx.tree_pformat(self, short_arrays=False)

Expand Down Expand Up @@ -740,7 +746,8 @@ def _build_gp_model_and_optimize_parameters(
# TODO: Creates a new abstract base class for GP models with a
# `build_model` API to avoid disabling the pytype attribute-error.
coroutine = self._gp_model_class.build_model( # pytype: disable=attribute-error
data
data,
multitask_type=self._config.multitask_type,
).coroutine
model = sp.CoroutineWithData(coroutine, data)

Expand Down Expand Up @@ -776,6 +783,20 @@ def _build_gp_model_and_optimize_parameters(
[[1.0] * data.features.categorical.padded_array.shape[-1]]
),
}
# Multitask GP models whose multitask type is not `INDEPENDENT` require
# extra parameters for the task kernel priors, which are randomly sampled
# and added to the fixed initialization parameters.
if (
data.labels.shape[-1] > 1
and self._config.multitask_type
!= multitask_tuned_gp_models.MultiTaskType.INDEPENDENT
):
rng, extra_params_rng = jax.random.split(rng, 2)
extra_random_init_params = eqx.filter_jit(model.setup)(extra_params_rng)
for p_name, p_value in extra_random_init_params.items():
if p_name not in fixed_init_params:
fixed_init_params[p_name] = jnp.array([p_value])

best_n = self._ensemble_size or 1
optimal_params, metrics = self._ard_optimizer(
init_params=jax.tree.map(
Expand Down
29 changes: 29 additions & 0 deletions vizier/_src/algorithms/designers/gp_ucb_pe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vizier._src.algorithms.designers import quasi_random
from vizier._src.algorithms.optimizers import eagle_strategy as es
from vizier._src.algorithms.optimizers import vectorized_base as vb
from vizier._src.jax.models import multitask_tuned_gp_models
from vizier.jax import optimizers
from vizier.pyvizier.converters import padding
from vizier.testing import test_studies
Expand All @@ -37,6 +38,8 @@

ensemble_ard_optimizer = optimizers.default_optimizer()

mt_type = multitask_tuned_gp_models.MultiTaskType


def _extract_predictions(
metadata: Any,
Expand Down Expand Up @@ -102,6 +105,30 @@ class GpUcbPeTest(parameterized.TestCase):
gp_ucb_pe.MultimetricPromisingRegionPenaltyType.INTERSECTION
),
),
dict(
iters=3,
batch_size=5,
num_seed_trials=5,
num_metrics=2,
multitask_type=mt_type.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR,
applies_padding=True,
),
dict(
iters=3,
batch_size=5,
num_seed_trials=5,
num_metrics=2,
multitask_type=mt_type.SEPARABLE_LKJ_TASK_KERNEL_PRIOR,
applies_padding=True,
),
dict(
iters=3,
batch_size=5,
num_seed_trials=5,
num_metrics=2,
multitask_type=mt_type.SEPARABLE_DIAG_TASK_KERNEL_PRIOR,
applies_padding=True,
),
)
def test_on_flat_space(
self,
Expand All @@ -121,6 +148,7 @@ def test_on_flat_space(
multimetric_promising_region_penalty_type: (
gp_ucb_pe.MultimetricPromisingRegionPenaltyType
) = gp_ucb_pe.MultimetricPromisingRegionPenaltyType.AVERAGE,
multitask_type: mt_type = mt_type.INDEPENDENT,
):
# We use string names so that test case names are readable. Convert them
# to objects.
Expand Down Expand Up @@ -166,6 +194,7 @@ def test_on_flat_space(
multimetric_promising_region_penalty_type=(
multimetric_promising_region_penalty_type
),
multitask_type=multitask_type,
),
ensemble_size=ensemble_size,
padding_schedule=padding.PaddingSchedule(
Expand Down

0 comments on commit 5ab515f

Please sign in to comment.