diff --git a/numpy_adjoint/array.py b/numpy_adjoint/array.py index 7b824498..282101b3 100644 --- a/numpy_adjoint/array.py +++ b/numpy_adjoint/array.py @@ -11,7 +11,7 @@ def __init__(self, *args, **kwargs): @classmethod def _ad_init_object(cls, obj): - return cls(obj.shape, numpy.float_, buffer=obj) + return cls(obj.shape, obj.dtype, buffer=obj) def _ad_create_checkpoint(self): return self.copy() diff --git a/tests/pyadjoint/test_numpy.py b/tests/pyadjoint/test_numpy.py new file mode 100644 index 00000000..dc582cdc --- /dev/null +++ b/tests/pyadjoint/test_numpy.py @@ -0,0 +1,10 @@ +import numpy as np +from pyadjoint import * +from numpy_adjoint import * + + +def test_ndarray_getitem_single(): + a = create_overloaded_object(np.array([-2.0])) + J = ReducedFunctional(a[0], Control(a)) + dJ = J.derivative() + assert dJ == 1.0