diff --git a/deepxde/gradients/gradients.py b/deepxde/gradients/gradients.py index 31f674d98..11cef257d 100644 --- a/deepxde/gradients/gradients.py +++ b/deepxde/gradients/gradients.py @@ -4,6 +4,7 @@ from . import gradients_forward from . import gradients_reverse +from .. import backend as bkd from .. import config @@ -33,10 +34,13 @@ def jacobian(ys, xs, i=None, j=None): (`i`, `j`)th entry J[`i`, `j`], `i`th row J[`i`, :], or `j`th column J[:, `j`]. When `ys` has shape (batch_size, dim_y), the output shape is (batch_size, 1). When `ys` has shape (batch_size_out, batch_size, dim_y), the output shape is - (batch_size_out, batch_size, 1) if forward-mode autodiff is used or - (batch_size, 1) if reverse-mode autodiff is used. + (batch_size_out, batch_size, 1). """ if config.autodiff == "reverse": + if bkd.ndim(ys) == 3: + raise NotImplementedError( + "Reverse-mode autodiff doesn't support 3D output" + ) return gradients_reverse.jacobian(ys, xs, i=i, j=j) if config.autodiff == "forward": return gradients_forward.jacobian(ys, xs, i=i, j=j) @@ -65,10 +69,13 @@ def hessian(ys, xs, component=0, i=0, j=0): Returns: H[`i`, `j`]. When `ys` has shape (batch_size, dim_y), the output shape is (batch_size, 1). When `ys` has shape (batch_size_out, batch_size, dim_y), - the output shape is (batch_size_out, batch_size, 1) if forward-mode - autodiff is used or (batch_size, 1) if reverse-mode autodiff is used. + the output shape is (batch_size_out, batch_size, 1). """ if config.autodiff == "reverse": + if bkd.ndim(ys) == 3: + raise NotImplementedError( + "Reverse-mode autodiff doesn't support 3D output" + ) return gradients_reverse.hessian(ys, xs, component=component, i=i, j=j) if config.autodiff == "forward": return gradients_forward.hessian(ys, xs, component=component, i=i, j=j)