diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py index be1bc465c6..fc7ec6a5de 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py @@ -613,48 +613,46 @@ def precompute_regression_model( if _precomputed_divisor_matrix_cholesky is not None: observation_scale = _scale_from_precomputed( _precomputed_divisor_matrix_cholesky, kernel) - elif observations_is_missing is not None: - # If observations are missing, there's nothing we can do to preserve the - # operator structure, so densify. - - observation_covariance = kernel.matrix_over_all_tasks( - observation_index_points, observation_index_points).to_dense() - - if observation_noise_variance is not None: - broadcast_shape = distribution_util.get_broadcast_shape( - observation_covariance, observation_noise_variance[ - ..., tf.newaxis, tf.newaxis]) - observation_covariance = tf.broadcast_to(observation_covariance, - broadcast_shape) - observation_covariance = _add_diagonal_shift( - observation_covariance, observation_noise_variance) - vec_observations_is_missing = _vec(observations_is_missing) - observation_covariance = tf.linalg.LinearOperatorFullMatrix( - psd_kernels_util.mask_matrix( - observation_covariance, - is_missing=vec_observations_is_missing), - is_non_singular=True, - is_positive_definite=True) - observation_scale = cholesky_util.cholesky_from_fn( - observation_covariance, cholesky_fn) + solve_on_observations = _precomputed_solve_on_observation else: - observation_scale = mtgp._compute_flattened_scale( # pylint:disable=protected-access - kernel=kernel, - index_points=observation_index_points, - cholesky_fn=cholesky_fn, - observation_noise_variance=observation_noise_variance) - - # Note that the conditional mean is - # k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter - # term since it won't change per iteration. - vec_diff = _vec(observations - mean_fn(observation_index_points)) - - if observations_is_missing is not None: - vec_diff = tf.where(vec_observations_is_missing, - tf.zeros([], dtype=vec_diff.dtype), - vec_diff) - solve_on_observations = _precomputed_solve_on_observation - if solve_on_observations is None: + # Note that the conditional mean is + # k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter + # term since it won't change per iteration. + vec_diff = _vec(observations - mean_fn(observation_index_points)) + + if observations_is_missing is not None: + # If observations are missing, there's nothing we can do to preserve + # the operator structure, so densify. + vec_observations_is_missing = _vec(observations_is_missing) + vec_diff = tf.where(vec_observations_is_missing, + tf.zeros([], dtype=vec_diff.dtype), + vec_diff) + + observation_covariance = kernel.matrix_over_all_tasks( + observation_index_points, observation_index_points).to_dense() + + if observation_noise_variance is not None: + broadcast_shape = distribution_util.get_broadcast_shape( + observation_covariance, observation_noise_variance[ + ..., tf.newaxis, tf.newaxis]) + observation_covariance = tf.broadcast_to(observation_covariance, + broadcast_shape) + observation_covariance = _add_diagonal_shift( + observation_covariance, observation_noise_variance) + observation_covariance = tf.linalg.LinearOperatorFullMatrix( + psd_kernels_util.mask_matrix( + observation_covariance, + is_missing=vec_observations_is_missing), + is_non_singular=True, + is_positive_definite=True) + observation_scale = cholesky_util.cholesky_from_fn( + observation_covariance, cholesky_fn) + else: + observation_scale = mtgp._compute_flattened_scale( # pylint:disable=protected-access + kernel=kernel, + index_points=observation_index_points, + cholesky_fn=cholesky_fn, + observation_noise_variance=observation_noise_variance) solve_on_observations = observation_scale.solvevec( observation_scale.solvevec(vec_diff), adjoint=True) @@ -678,6 +676,7 @@ def flattened_conditional_mean_fn(x): observation_noise_variance=observation_noise_variance, predictive_noise_variance=predictive_noise_variance, cholesky_fn=cholesky_fn, + observations_is_missing=observations_is_missing, _flattened_conditional_mean_fn=flattened_conditional_mean_fn, _observation_scale=observation_scale, validate_args=validate_args, diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py index 66258acc99..2680bf6038 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py @@ -474,16 +474,26 @@ def testMeanVarianceJit(self): tf.function(jit_compile=True)(mtgprm.mean)() tf.function(jit_compile=True)(mtgprm.variance)() - def testMeanVarianceAndCovariancePrecomputed(self): + @parameterized.parameters(True, False) + def testMeanVarianceAndCovariancePrecomputed(self, has_missing_observations): num_tasks = 3 + num_obs = 7 amplitude = np.array([1., 2.], np.float64).reshape([2, 1]) length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3]) observation_noise_variance = np.array([1e-9], np.float64) observation_index_points = ( - np.random.uniform(-1., 1., (1, 1, 7, 2)).astype(np.float64)) + np.random.uniform(-1., 1., (1, 1, num_obs, 2)).astype(np.float64)) observations = np.linspace( - -20., 20., 7 * num_tasks).reshape(7, num_tasks).astype(np.float64) + -20., 20., num_obs * num_tasks).reshape( + num_obs, num_tasks).astype(np.float64) + + if has_missing_observations: + observations_is_missing = np.stack( + [np.random.randint(2, size=(num_obs,))] * num_tasks, axis=-1 + ).astype(np.bool_) + else: + observations_is_missing = None index_points = np.random.uniform(-1., 1., (6, 2)).astype(np.float64) @@ -497,6 +507,7 @@ def testMeanVarianceAndCovariancePrecomputed(self): observation_index_points=observation_index_points, observations=observations, observation_noise_variance=observation_noise_variance, + observations_is_missing=observations_is_missing, validate_args=True) precomputed_mtgprm = mtgprm_lib.MultiTaskGaussianProcessRegressionModel.precompute_regression_model( @@ -505,6 +516,7 @@ def testMeanVarianceAndCovariancePrecomputed(self): observation_index_points=observation_index_points, observations=observations, observation_noise_variance=observation_noise_variance, + observations_is_missing=observations_is_missing, validate_args=True) mock_cholesky_fn = mock.Mock(return_value=None) @@ -514,6 +526,7 @@ def testMeanVarianceAndCovariancePrecomputed(self): observation_index_points=observation_index_points, observations=observations, observation_noise_variance=observation_noise_variance, + observations_is_missing=observations_is_missing, _precomputed_divisor_matrix_cholesky=precomputed_mtgprm._precomputed_divisor_matrix_cholesky, _precomputed_solve_on_observation=precomputed_mtgprm._precomputed_solve_on_observation, cholesky_fn=mock_cholesky_fn,