diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 42537a2164..f94fb0b349 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -68,12 +68,15 @@ class (e.g. an `ExactGP`) and this `GPyTorchModel`. See e.g. `SingleTaskGP`. def _validate_tensor_args( X: Tensor, Y: Tensor, Yvar: Optional[Tensor] = None, strict: bool = True ) -> None: - r"""Checks that `Y` and `Yvar` have an explicit output dimension if strict. - Checks that the dtypes of the inputs match, and warns if using float. + r"""Check the input tensors to verify that they are compatible with + BoTorch conventions. This checks that - This also checks that `Yvar` has the same trailing dimensions as `Y`. Note - we only infer that an explicit output dimension exists when `X` and `Y` have - the same `batch_shape`. + - `Y` and `Yvar` have an explicit output dimension if strict. + - The dtypes of the inputs match and warns if using float. + - `Yvar` has the same trailing dimensions as `Y`. + Note: We only infer that an explicit output dimension exists when + `X` and `Y` have the same `batch_shape`. + - The input tensors do not require gradients. Args: X: A `batch_shape x n x d`-dim Tensor, where `d` is the dimension of @@ -131,6 +134,17 @@ def _validate_tensor_args( InputDataWarning, stacklevel=3, # Warn at model constructor call. ) + if ( + X.requires_grad + or Y.requires_grad + or (Yvar is not None and Yvar.requires_grad) + ): + raise InputDataError( + "The BoTorch model inputs should not require gradients. This leads to " + f"errors during model fitting. Got {X.requires_grad=}, " + f"{Y.requires_grad=}" + + ("." if Yvar is None else f", and {Yvar.requires_grad=}.") + ) @property def batch_shape(self) -> torch.Size: diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index f586ba7df0..a25a4e2a21 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -255,6 +255,20 @@ def test_validate_tensor_args(self) -> None: ): GPyTorchModel._validate_tensor_args(X, Y, Yvar, strict=strict) + def test_validate_tensor_args_with_grad(self) -> None: + with self.assertRaisesRegex( + InputDataError, "inputs should not require gradients" + ): + GPyTorchModel._validate_tensor_args( + X=torch.randn(1, 1, requires_grad=True), Y=torch.randn(1, 1) + ) + with self.assertRaisesRegex(InputDataError, "Yvar.requires_grad=False"): + GPyTorchModel._validate_tensor_args( + X=torch.randn(1, 1, requires_grad=True), + Y=torch.randn(1, 1), + Yvar=torch.randn(1, 1), + ) + def test_condition_on_observations_tensor_validation(self) -> None: model = SimpleGPyTorchModel(torch.rand(5, 1), torch.randn(5, 1)) model.posterior(torch.rand(2, 1)) # evaluate the model to form caches.