Skip to content

Commit

Permalink
Fix numpy v2 breaking changes (#618)
Browse files Browse the repository at this point in the history
* [fix] Fix breaking changes with numpuy>=v2.0.0rc1

* [fix] Update version check to v2.0.0.

* [fix] Disable numerical jvp checks for complex sign function.

* Lift numpy<2 dependency.

---------

Co-authored-by: Agriya Khetarpal <[email protected]>
  • Loading branch information
fjosw and agriyakhetarpal authored Aug 14, 2024
1 parent 7832b4b commit 84c42fb
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 9 deletions.
8 changes: 7 additions & 1 deletion autograd/numpy/numpy_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ def __hash__(self): return id(self)
# Flatten has no function, only a method.
setattr(ArrayBox, 'flatten', anp.__dict__['ravel'])

if np.__version__ >= '1.25':
if np.lib.NumpyVersion(np.__version__) >= '2.0.0':
SequenceBox.register(np.linalg._linalg.EigResult)
SequenceBox.register(np.linalg._linalg.EighResult)
SequenceBox.register(np.linalg._linalg.QRResult)
SequenceBox.register(np.linalg._linalg.SlogdetResult)
SequenceBox.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= '1.25':
SequenceBox.register(np.linalg.linalg.EigResult)
SequenceBox.register(np.linalg.linalg.EighResult)
SequenceBox.register(np.linalg.linalg.QRResult)
Expand Down
4 changes: 3 additions & 1 deletion autograd/numpy/numpy_jvps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as onp
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
Expand Down Expand Up @@ -210,7 +211,8 @@ def fwd_grad_sort(g, ans, x, axis=-1, kind='quicksort', order=None):
sort_perm = anp.argsort(x, axis, kind, order)
return g[sort_perm]
defjvp(anp.sort, fwd_grad_sort)
defjvp(anp.msort, lambda g, ans, x: fwd_grad_sort(g, ans, x, axis=0))
if onp.lib.NumpyVersion(onp.__version__) < '2.0.0':
defjvp(anp.msort, lambda g, ans, x: fwd_grad_sort(g, ans, x, axis=0))

def fwd_grad_partition(g, ans, x, kth, axis=-1, kind='introselect', order=None):
partition_perm = anp.argpartition(x, kth, axis, kind, order)
Expand Down
3 changes: 2 additions & 1 deletion autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,8 @@ def grad_sort(ans, x, axis=-1, kind='quicksort', order=None):
sort_perm = anp.argsort(x, axis, kind, order)
return lambda g: unpermuter(g, sort_perm)
defvjp(anp.sort, grad_sort)
defvjp(anp.msort, grad_sort) # Until multi-D is allowed, these are the same.
if onp.lib.NumpyVersion(onp.__version__) < '2.0.0':
defvjp(anp.msort, grad_sort) # Until multi-D is allowed, these are the same.

def grad_partition(ans, x, kth, axis=-1, kind='introselect', order=None):
#TODO: Cast input with np.asanyarray()
Expand Down
16 changes: 14 additions & 2 deletions autograd/numpy/numpy_vspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class ArrayVSpace(VSpace):
def __init__(self, value):
value = np.array(value, copy=False)
value = np.asarray(value)
self.shape = value.shape
self.dtype = value.dtype

Expand Down Expand Up @@ -66,7 +66,19 @@ def _covector(self, x):
ComplexArrayVSpace.register(type_)


if np.__version__ >= '1.25':
if np.lib.NumpyVersion(np.__version__) >= '2.0.0':
class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.EigResult
class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.EighResult
class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.QRResult
class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SlogdetResult
class SVDResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SVDResult

EigResultVSpace.register(np.linalg._linalg.EigResult)
EighResultVSpace.register(np.linalg._linalg.EighResult)
QRResultVSpace.register(np.linalg._linalg.QRResult)
SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult)
SVDResultVSpace.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= '1.25':
class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EigResult
class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EighResult
class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.QRResult
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ keywords = [
"SciPy",
]
dependencies = [
"numpy<2",
"numpy",
]
# dynamic = ["version"]

Expand Down
6 changes: 5 additions & 1 deletion tests/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from builtins import range
import itertools
import numpy as onp
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.test_util import check_grads
Expand Down Expand Up @@ -94,7 +95,10 @@ def test_solve_arg1_3d():
D = 4
A = npr.randn(D+1, D, D) + 5*np.eye(D)
B = npr.randn(D+1, D)
fun = lambda A: np.linalg.solve(A, B)
if onp.lib.NumpyVersion(onp.__version__) < '2.0.0':
fun = lambda A: np.linalg.solve(A, B)
else:
fun = lambda A: np.linalg.solve(A, B[..., None])[..., 0]
check_grads(fun)(A)

def test_solve_arg1_3d_3d():
Expand Down
6 changes: 4 additions & 2 deletions tests/test_systematic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import absolute_import
import numpy as onp
import autograd.numpy.random as npr
import autograd.numpy as np
import operator as op
Expand Down Expand Up @@ -43,7 +44,7 @@ def test_log1p(): unary_ufunc_check(np.log1p, lims=[0.2, 2.0])
def test_log2(): unary_ufunc_check(np.log2, lims=[0.2, 2.0])
def test_rad2deg(): unary_ufunc_check(lambda x : np.rad2deg(x)/50.0, test_complex=False)
def test_radians(): unary_ufunc_check(np.radians, test_complex=False)
def test_sign(): unary_ufunc_check(np.sign)
def test_sign(): unary_ufunc_check(np.sign, test_complex=False)
def test_sin(): unary_ufunc_check(np.sin)
def test_sinh(): unary_ufunc_check(np.sinh)
def test_sqrt(): unary_ufunc_check(np.sqrt, lims=[1.0, 3.0])
Expand Down Expand Up @@ -155,7 +156,8 @@ def test_fmin(): combo_check(np.fmin, [0, 1])(
[R(1), R(1,4), R(3, 4)])

def test_sort(): combo_check(np.sort, [0])([R(1), R(7)])
def test_msort(): combo_check(np.msort, [0])([R(1), R(7)])
if onp.lib.NumpyVersion(onp.__version__) < '2.0.0':
def test_msort(): combo_check(np.msort, [0])([R(1), R(7)])
def test_partition(): combo_check(np.partition, [0])(
[R(7), R(14)], kth=[0, 3, 6])

Expand Down

0 comments on commit 84c42fb

Please sign in to comment.