Skip to content

Commit 4ebbf21

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent ca1d95b commit 4ebbf21

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

jax/_src/numpy/reductions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,8 +2497,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24972497
if weights is None:
24982498
a, = promote_dtypes_inexact(a)
24992499
else:
2500-
a, weights = promote_dtypes_inexact(a, weights)
2501-
weights = lax.convert_element_type(weights, a.dtype)
2500+
a, q, weights = promote_dtypes_inexact(a, q, weights)
2501+
#weights = lax.convert_element_type(weights, a.dtype)
25022502
a_shape = a.shape
25032503
w_shape = np.shape(weights)
25042504
if np.ndim(weights) == 0:
@@ -2583,8 +2583,6 @@ def _weighted_quantile(qi):
25832583
elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1:
25842584
result = result.squeeze(axis=0)
25852585
return result
2586-
return result
2587-
25882586

25892587
if squash_nans:
25902588
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.

tests/lax_numpy_reducers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdim
789789
if axis is None:
790790
weights_shape = a_shape
791791
elif isinstance(axis, tuple):
792-
weights_shape = tuple(a_shape[i] for i in axis)
792+
weights_shape = a_shape
793793
else:
794794
weights_shape = (a_shape[axis],)
795795
weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3

0 commit comments

Comments
 (0)