diff --git a/vizier/_src/jax/models/multitask_tuned_gp_models.py b/vizier/_src/jax/models/multitask_tuned_gp_models.py index beb8c99cd..70110622a 100644 --- a/vizier/_src/jax/models/multitask_tuned_gp_models.py +++ b/vizier/_src/jax/models/multitask_tuned_gp_models.py @@ -59,6 +59,120 @@ class MultiTaskType(enum.Enum): SEPARABLE_DIAG_TASK_KERNEL_PRIOR = 'separable_diag_task_kernel_prior' +def build_task_kernel_scale_linop( + num_tasks: int, + multitask_type: MultiTaskType, +) -> Generator[sp.ModelParameter, jax.Array, tfp.tf2jax.linalg.LinearOperator]: + """Builds a Separable MultiTask GP's task kernel scale LinearOperator. + + Args: + num_tasks: The number of tasks. + multitask_type: The type of MultiTask GP. + + Yields: + Model parameters for the task kernel scale and a LinearOperator representing + the task kernel scale. + """ + if multitask_type == MultiTaskType.SEPARABLE_DIAG_TASK_KERNEL_PRIOR: + correlation_diag = yield sp.ModelParameter.from_prior( + tfd.Sample( + tfd.Uniform(low=jnp.float64(1e-6), high=1.0), + sample_shape=num_tasks, + name='correlation_diag', + ), + constraint=sp.Constraint( + bounds=(1e-6, 1.0), + bijector=tfb.Sigmoid(low=jnp.float64(1e-6), high=1.0), + ), + ) + task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorDiag( + correlation_diag + ) + elif multitask_type == MultiTaskType.SEPARABLE_LKJ_TASK_KERNEL_PRIOR: + # Generate parameters for the Cholesky of the task kernel matrix, + # which accounts for correlations between tasks. + num_task_kernel_entries = tfb.CorrelationCholesky().inverse_event_shape( + [num_tasks, num_tasks] + ) + correlation_cholesky_vec = yield sp.ModelParameter( + init_fn=lambda key: tfd.Sample( # pylint: disable=g-long-lambda + tfd.Normal(jnp.float64(0.0), 1.0), num_task_kernel_entries + ).sample(seed=key), + # Use `jnp.copy` to prevent tracers leaking from bijector cache. + regularizer=lambda x: -tfd.CholeskyLKJ( # pylint: disable=g-long-lambda + dimension=num_tasks, concentration=1.0 + ).log_prob(tfb.CorrelationCholesky()(jnp.copy(x))), + name='task_kernel_correlation_cholesky_vec', + ) + + task_kernel_correlation_cholesky = tfb.CorrelationCholesky()( + jnp.copy(correlation_cholesky_vec) + ) + + task_kernel_scale_vec = yield sp.ModelParameter( + init_fn=functools.partial( + jax.random.uniform, + shape=(num_tasks,), + dtype=jnp.float64, + minval=1e-6, + maxval=1.0, + ), + constraint=sp.Constraint( + bounds=(1e-6, 1.0), + bijector=tfb.Sigmoid(low=jnp.float64(1e-6), high=1.0), + ), + name='task_kernel_sqrt_diagonal', + ) + task_kernel_cholesky = ( + task_kernel_correlation_cholesky * task_kernel_scale_vec[:, jnp.newaxis] + ) + + # Build the `LinearOperator` object representing the task kernel matrix, + # to parameterize the Separable kernel. + task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorLowerTriangular( + task_kernel_cholesky + ) + elif multitask_type == MultiTaskType.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR: + # Generate parameters for the Cholesky of the task kernel matrix; + # accounts for correlations between tasks. The task kernel matrix must + # be positive definite, so we construct it via a Cholesky factor. + # Define the prior of the kernel task matrix to be centered at the + # identity. + prior_mean = jnp.eye(num_tasks, dtype=jnp.float64) + prior_mean_vec = tfb.FillTriangular().inverse(prior_mean) + prior_mean_batched = jnp.broadcast_to(prior_mean_vec, prior_mean_vec.shape) + + task_kernel_cholesky_entries = yield sp.ModelParameter.from_prior( + tfd.Independent( + tfd.Normal(prior_mean_batched, 1.0), + reinterpreted_batch_ndims=1, + name='task_kernel_cholesky_entries', + ) + ) + + # Apply a bijector to pack the task kernel entries into a lower + # triangular matrix and ensure the diagonal is positive. + task_kernel_bijector = tfb.Chain([ + tfb.TransformDiagonal( + tfb.Chain([tfb.Shift(jnp.float64(1e-3)), tfb.Softplus()]) + ), + tfb.FillTriangular(), + ]) + task_kernel_cholesky = task_kernel_bijector( + jnp.copy(task_kernel_cholesky_entries) + ) + + # Build the `LinearOperator` object representing the task kernel + # matrix, to parameterize the Separable kernel. + task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorLowerTriangular( + task_kernel_cholesky + ) + else: + raise ValueError(f'Unsupported multitask type: {multitask_type}') + + return task_kernel_scale_linop + + @struct.dataclass class VizierMultitaskGaussianProcess( sp.ModelCoroutine[Union[tfd.GaussianProcess, tfde.MultiTaskGaussianProcess]] @@ -101,115 +215,6 @@ def sample(key: Any) -> jnp.ndarray: return sample - def _build_task_kernel_scale_linop( - self, - ) -> Generator[ - sp.ModelParameter, jax.Array, tfp.tf2jax.linalg.LinearOperator - ]: - if self._multitask_type == MultiTaskType.SEPARABLE_DIAG_TASK_KERNEL_PRIOR: - correlation_diag = yield sp.ModelParameter.from_prior( - tfd.Sample( - tfd.Uniform(low=jnp.float64(1e-6), high=1.0), - sample_shape=self._num_tasks, - name='correlation_diag', - ), - constraint=sp.Constraint( - bounds=(1e-6, 1.0), - bijector=tfb.Sigmoid(low=jnp.float64(1e-6), high=1.0), - ), - ) - task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorDiag( - correlation_diag - ) - elif self._multitask_type == MultiTaskType.SEPARABLE_LKJ_TASK_KERNEL_PRIOR: - # Generate parameters for the Cholesky of the task kernel matrix, - # which accounts for correlations between tasks. - num_task_kernel_entries = tfb.CorrelationCholesky().inverse_event_shape( - [self._num_tasks, self._num_tasks] - ) - correlation_cholesky_vec = yield sp.ModelParameter( - init_fn=lambda key: tfd.Sample( # pylint: disable=g-long-lambda - tfd.Normal(jnp.float64(0.0), 1.0), num_task_kernel_entries - ).sample(seed=key), - # Use `jnp.copy` to prevent tracers leaking from bijector cache. - regularizer=lambda x: -tfd.CholeskyLKJ( # pylint: disable=g-long-lambda - dimension=self._num_tasks, concentration=1.0 - ).log_prob(tfb.CorrelationCholesky()(jnp.copy(x))), - name='task_kernel_correlation_cholesky_vec', - ) - - task_kernel_correlation_cholesky = tfb.CorrelationCholesky()( - jnp.copy(correlation_cholesky_vec) - ) - - task_kernel_scale_vec = yield sp.ModelParameter( - init_fn=functools.partial( - jax.random.uniform, - shape=(self._num_tasks,), - dtype=jnp.float64, - minval=1e-6, - maxval=1.0, - ), - constraint=sp.Constraint( - bounds=(1e-6, 1.0), - bijector=tfb.Sigmoid(low=jnp.float64(1e-6), high=1.0), - ), - name='task_kernel_sqrt_diagonal', - ) - task_kernel_cholesky = ( - task_kernel_correlation_cholesky - * task_kernel_scale_vec[:, jnp.newaxis] - ) - - # Build the `LinearOperator` object representing the task kernel matrix, - # to parameterize the Separable kernel. - task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorLowerTriangular( - task_kernel_cholesky - ) - elif ( - self._multitask_type == MultiTaskType.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR - ): - # Generate parameters for the Cholesky of the task kernel matrix; - # accounts for correlations between tasks. The task kernel matrix must - # be positive definite, so we construct it via a Cholesky factor. - # Define the prior of the kernel task matrix to be centered at the - # identity. - prior_mean = jnp.eye(self._num_tasks, dtype=jnp.float64) - prior_mean_vec = tfb.FillTriangular().inverse(prior_mean) - prior_mean_batched = jnp.broadcast_to( - prior_mean_vec, prior_mean_vec.shape - ) - - task_kernel_cholesky_entries = yield sp.ModelParameter.from_prior( - tfd.Independent( - tfd.Normal(prior_mean_batched, 1.0), - reinterpreted_batch_ndims=1, - name='task_kernel_cholesky_entries', - ) - ) - - # Apply a bijector to pack the task kernel entries into a lower - # triangular matrix and ensure the diagonal is positive. - task_kernel_bijector = tfb.Chain([ - tfb.TransformDiagonal( - tfb.Chain([tfb.Shift(jnp.float64(1e-6)), tfb.Softplus()]) - ), - tfb.FillTriangular(), - ]) - task_kernel_cholesky = task_kernel_bijector( - jnp.copy(task_kernel_cholesky_entries) - ) - - # Build the `LinearOperator` object representing the task kernel - # matrix, to parameterize the Separable kernel. - task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorLowerTriangular( - task_kernel_cholesky - ) - else: - raise ValueError(f'Unsupported multitask type: {self._multitask_type}') - - return task_kernel_scale_linop - def __call__( self, inputs: Optional[types.ModelInput] = None ) -> Generator[ @@ -356,7 +361,9 @@ def __call__( if self._multitask_type == MultiTaskType.INDEPENDENT: multitask_kernel = tfpke.Independent(self._num_tasks, kernel) else: - task_kernel_scale_linop = yield from self._build_task_kernel_scale_linop() + task_kernel_scale_linop = yield from build_task_kernel_scale_linop( + self._num_tasks, self._multitask_type + ) multitask_kernel = tfpke.Separable( self._num_tasks, base_kernel=kernel, diff --git a/vizier/_src/jax/models/tuned_gp_models.py b/vizier/_src/jax/models/tuned_gp_models.py index 278319f8e..c831073b9 100644 --- a/vizier/_src/jax/models/tuned_gp_models.py +++ b/vizier/_src/jax/models/tuned_gp_models.py @@ -30,6 +30,7 @@ from vizier._src.jax import types from vizier._src.jax.models import continuous_only_kernel from vizier._src.jax.models import mask_features +from vizier._src.jax.models import multitask_tuned_gp_models tfb = tfp.bijectors tfd = tfp.distributions @@ -93,6 +94,10 @@ class VizierGaussianProcess(sp.ModelCoroutine[tfd.GaussianProcess]): ) _boundary_epsilon: float = struct.field(default=1e-12, kw_only=True) _linear_coef: Optional[float] = struct.field(default=None, kw_only=True) + _multitask_type: multitask_tuned_gp_models.MultiTaskType = struct.field( + default=multitask_tuned_gp_models.MultiTaskType.INDEPENDENT, + kw_only=True, + ) def __attrs_post_init__(self): if self._num_metrics < 1: @@ -107,6 +112,9 @@ def build_model( *, use_retrying_cholesky: bool = True, linear_coef: Optional[float] = None, + multitask_type: multitask_tuned_gp_models.MultiTaskType = ( + multitask_tuned_gp_models.MultiTaskType.INDEPENDENT + ), ) -> sp.StochasticProcessModel: """Returns a StochasticProcessModel for the GP.""" gp_coroutine = VizierGaussianProcess( @@ -117,6 +125,7 @@ def build_model( _num_metrics=data.labels.shape[-1], _use_retrying_cholesky=use_retrying_cholesky, _linear_coef=linear_coef, + _multitask_type=multitask_type, ) return sp.StochasticProcessModel(gp_coroutine) @@ -271,8 +280,24 @@ def __call__( cholesky_fn = lambda matrix: retrying_cholesky(matrix)[0] if self._num_metrics > 1: + if ( + self._multitask_type + == multitask_tuned_gp_models.MultiTaskType.INDEPENDENT + ): + multitask_kernel = tfpke.Independent(self._num_metrics, kernel) + else: + task_kernel_scale_linop = ( + yield from multitask_tuned_gp_models.build_task_kernel_scale_linop( + self._num_metrics, self._multitask_type + ) + ) + multitask_kernel = tfpke.Separable( + self._num_metrics, + base_kernel=kernel, + task_kernel_scale_linop=task_kernel_scale_linop, + ) return tfde.MultiTaskGaussianProcess( - tfpke.Independent(self._num_metrics, kernel), + multitask_kernel, index_points=inputs, observation_noise_variance=observation_noise_variance, cholesky_fn=cholesky_fn, diff --git a/vizier/_src/jax/models/tuned_gp_models_test.py b/vizier/_src/jax/models/tuned_gp_models_test.py index 3b8a3e3bc..bdcf57131 100644 --- a/vizier/_src/jax/models/tuned_gp_models_test.py +++ b/vizier/_src/jax/models/tuned_gp_models_test.py @@ -23,6 +23,7 @@ from tensorflow_probability.substrates import jax as tfp from vizier._src.jax import stochastic_process_model as sp from vizier._src.jax import types +from vizier._src.jax.models import multitask_tuned_gp_models from vizier._src.jax.models import tuned_gp_models from vizier.jax import optimizers @@ -30,6 +31,7 @@ from absl.testing import parameterized tfb = tfp.bijectors +mt_type = multitask_tuned_gp_models.MultiTaskType class VizierGpTest(parameterized.TestCase): @@ -150,8 +152,28 @@ def _generate_xys(self, num_metrics: int): # No observations are padded because multimetric GP does not support # observation padding. dict(num_metrics=2, num_obs=10), + dict( + num_metrics=2, + num_obs=10, + multitask_type=mt_type.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR, + ), + dict( + num_metrics=2, + num_obs=10, + multitask_type=mt_type.SEPARABLE_LKJ_TASK_KERNEL_PRIOR, + ), + dict( + num_metrics=2, + num_obs=10, + multitask_type=mt_type.SEPARABLE_DIAG_TASK_KERNEL_PRIOR, + ), ) - def test_masking_works(self, num_metrics: int, num_obs: int): + def test_masking_works( + self, + num_metrics: int, + num_obs: int, + multitask_type: mt_type = mt_type.INDEPENDENT, + ): x_obs, y_obs = self._generate_xys(num_metrics) data = types.ModelData( features=types.ModelInput( @@ -171,7 +193,9 @@ def test_masking_works(self, num_metrics: int, num_obs: int): ) model1 = sp.CoroutineWithData( tuned_gp_models.VizierGaussianProcess( - types.ContinuousAndCategorical[int](9, 2), num_metrics + types.ContinuousAndCategorical[int](9, 2), + num_metrics, + _multitask_type=multitask_type, ), data=data, ) @@ -185,7 +209,9 @@ def test_masking_works(self, num_metrics: int, num_obs: int): ) model2 = sp.CoroutineWithData( tuned_gp_models.VizierGaussianProcess( - types.ContinuousAndCategorical[int](9, 2), num_metrics + types.ContinuousAndCategorical[int](9, 2), + num_metrics, + _multitask_type=multitask_type, ), data=modified_data, ) @@ -223,8 +249,28 @@ def test_masking_works(self, num_metrics: int, num_obs: int): # No observations are padded because multimetric GP does not support # observation padding. dict(num_metrics=2, num_obs=10), + dict( + num_metrics=2, + num_obs=10, + multitask_type=mt_type.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR, + ), + dict( + num_metrics=2, + num_obs=10, + multitask_type=mt_type.SEPARABLE_LKJ_TASK_KERNEL_PRIOR, + ), + dict( + num_metrics=3, + num_obs=10, + multitask_type=mt_type.SEPARABLE_DIAG_TASK_KERNEL_PRIOR, + ), ) - def test_good_log_likelihood(self, num_metrics: int, num_obs: int): + def test_good_log_likelihood( + self, + num_metrics: int, + num_obs: int, + multitask_type: mt_type = mt_type.INDEPENDENT, + ): # We use a fixed random seed for sampling categorical data (and continuous # data from `_generate_xys`, above) so that the same data is used for every # test run. @@ -254,7 +300,9 @@ def test_good_log_likelihood(self, num_metrics: int, num_obs: int): target_loss = -0.2 model = sp.CoroutineWithData( tuned_gp_models.VizierGaussianProcess( - types.ContinuousAndCategorical[int](9, 5), num_metrics + types.ContinuousAndCategorical[int](9, 5), + num_metrics, + _multitask_type=multitask_type, ), data=data, ) @@ -276,8 +324,28 @@ def test_good_log_likelihood(self, num_metrics: int, num_obs: int): # No observations are padded because multimetric GP does not support # observation padding. dict(num_metrics=2, num_obs=10), + dict( + num_metrics=2, + num_obs=10, + multitask_type=mt_type.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR, + ), + dict( + num_metrics=2, + num_obs=10, + multitask_type=mt_type.SEPARABLE_LKJ_TASK_KERNEL_PRIOR, + ), + dict( + num_metrics=2, + num_obs=10, + multitask_type=mt_type.SEPARABLE_DIAG_TASK_KERNEL_PRIOR, + ), ) - def test_good_log_likelihood_linear(self, num_metrics: int, num_obs: int): + def test_good_log_likelihood_linear( + self, + num_metrics: int, + num_obs: int, + multitask_type: mt_type = mt_type.INDEPENDENT, + ): """Tests that the GP with linear coef after ARD has good log likelihood. The tests use a fixed random seed for sampling categorical data (and @@ -287,6 +355,7 @@ def test_good_log_likelihood_linear(self, num_metrics: int, num_obs: int): Args: num_metrics: Number of metrics. num_obs: Number of observations. + multitask_type: The type of multitask GP to test. """ rng, init_rng, cat_rng = jax.random.split(jax.random.PRNGKey(2), 3) x_cont_obs, y_obs = self._generate_xys(num_metrics) @@ -317,6 +386,7 @@ def test_good_log_likelihood_linear(self, num_metrics: int, num_obs: int): types.ContinuousAndCategorical[int](9, 5), num_metrics, _linear_coef=1.0, + _multitask_type=multitask_type, ), data=data, )