From 3e652805bc2b57189c18df0228e3749f528fe21a Mon Sep 17 00:00:00 2001 From: vanderplas Date: Tue, 7 May 2024 01:01:58 -0700 Subject: [PATCH] Replace use of deprecated ``kind`` keyword in jax.numpy.sort ``kind`` is being replaced by ``stable`` in NumPy 2.0, and jax.numpy is in the process of deprecating the old argument. PiperOrigin-RevId: 631326978 --- .../python/internal/backend/numpy/misc.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/misc.py b/tensorflow_probability/python/internal/backend/numpy/misc.py index 9d47fb82a2..c6977864d0 100644 --- a/tensorflow_probability/python/internal/backend/numpy/misc.py +++ b/tensorflow_probability/python/internal/backend/numpy/misc.py @@ -60,8 +60,12 @@ def _argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): values = np.negative(values) else: raise ValueError('Unrecognized direction: {}.'.format(direction)) - return np.argsort( - values, axis, kind='stable' if stable else 'quicksort').astype(np.int32) + try: + # stable keyword introduced in NumPy 2.0. + return np.argsort(values, axis, stable=stable).astype(np.int32) + except TypeError: + return np.argsort( + values, axis, kind='stable' if stable else 'quicksort').astype(np.int32) def _histogram_fixed_width(values, value_range, nbins=100, dtype=np.int32, @@ -103,7 +107,11 @@ def _sort(values, axis=-1, direction='ASCENDING', name=None): # pylint: disable values = np.negative(values) else: raise ValueError('Unrecognized direction: {}.'.format(direction)) - result = np.sort(values, axis, kind='stable') + try: + # NumPy 2.0 + result = np.sort(values, axis, stable=True) + except TypeError: + result = np.sort(values, axis, kind='stable') if direction == 'DESCENDING': return np.negative(result) return result