Skip to content

Commit

Permalink
only forward-mode support 3D output
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Jan 28, 2025
1 parent bb257e9 commit cd06627
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions deepxde/gradients/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import gradients_forward
from . import gradients_reverse
from .. import backend as bkd
from .. import config


Expand Down Expand Up @@ -33,10 +34,13 @@ def jacobian(ys, xs, i=None, j=None):
(`i`, `j`)th entry J[`i`, `j`], `i`th row J[`i`, :], or `j`th column J[:, `j`].
When `ys` has shape (batch_size, dim_y), the output shape is (batch_size, 1).
When `ys` has shape (batch_size_out, batch_size, dim_y), the output shape is
(batch_size_out, batch_size, 1) if forward-mode autodiff is used or
(batch_size, 1) if reverse-mode autodiff is used.
(batch_size_out, batch_size, 1).
"""
if config.autodiff == "reverse":
if bkd.ndim(ys) == 3:
raise NotImplementedError(
"Reverse-mode autodiff doesn't support 3D output"
)
return gradients_reverse.jacobian(ys, xs, i=i, j=j)
if config.autodiff == "forward":
return gradients_forward.jacobian(ys, xs, i=i, j=j)
Expand Down Expand Up @@ -65,10 +69,13 @@ def hessian(ys, xs, component=0, i=0, j=0):
Returns:
H[`i`, `j`]. When `ys` has shape (batch_size, dim_y), the output shape is
(batch_size, 1). When `ys` has shape (batch_size_out, batch_size, dim_y),
the output shape is (batch_size_out, batch_size, 1) if forward-mode
autodiff is used or (batch_size, 1) if reverse-mode autodiff is used.
the output shape is (batch_size_out, batch_size, 1).
"""
if config.autodiff == "reverse":
if bkd.ndim(ys) == 3:
raise NotImplementedError(
"Reverse-mode autodiff doesn't support 3D output"
)
return gradients_reverse.hessian(ys, xs, component=component, i=i, j=j)
if config.autodiff == "forward":
return gradients_forward.hessian(ys, xs, component=component, i=i, j=j)
Expand Down

0 comments on commit cd06627

Please sign in to comment.