Skip to content

Commit

Permalink
Merge pull request #159 from jrmaddison/jrmaddison/numpy_2.0
Browse files Browse the repository at this point in the history
NumPy 2.0 fix
  • Loading branch information
dham committed Jul 8, 2024
2 parents 271343e + 4485fe6 commit 92121af
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion numpy_adjoint/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, *args, **kwargs):

@classmethod
def _ad_init_object(cls, obj):
return cls(obj.shape, numpy.float_, buffer=obj)
return cls(obj.shape, obj.dtype, buffer=obj)

def _ad_create_checkpoint(self):
return self.copy()
Expand Down
10 changes: 10 additions & 0 deletions tests/pyadjoint/test_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np
from pyadjoint import *
from numpy_adjoint import *


def test_ndarray_getitem_single():
a = create_overloaded_object(np.array([-2.0]))
J = ReducedFunctional(a[0], Control(a))
dJ = J.derivative()
assert dJ == 1.0

0 comments on commit 92121af

Please sign in to comment.