Skip to content

Commit ca1d95b

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent 7d50a32 commit ca1d95b

File tree

1 file changed

+20
-24
lines changed

1 file changed

+20
-24
lines changed

jax/_src/numpy/reductions.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323

2424
import numpy as np
2525

26-
import jax
27-
from jax import lax
28-
import jax._src.numpy as jnp
26+
from jax._src.lax import lax
2927
from jax._src import api
3028
from jax._src import core
3129
from jax._src import deprecations
@@ -2486,7 +2484,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24862484

24872485
q, = promote_dtypes_inexact(q)
24882486
q = lax_internal.asarray(q)
2489-
if getattr(q, "ndim", 0) == 0:
2487+
q_was_scalar = getattr(q, "ndim", 0) == 0
2488+
if q_was_scalar:
24902489
q = lax.expand_dims(q, (0,))
24912490
q_shape = q.shape
24922491
q_ndim = q.ndim
@@ -2534,7 +2533,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25342533
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
25352534

25362535
total_weight = sum(weights, axis=axis, keepdims=True)
2537-
a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis)
2536+
a_sorted, weights_sorted = lax_internal.sort_key_val(a, weights, dimension=axis)
25382537
cum_weights = cumsum(weights_sorted, axis=axis)
25392538
cum_weights_norm = lax.div(cum_weights, total_weight)
25402539

@@ -2576,17 +2575,16 @@ def _weighted_quantile(qi):
25762575
else:
25772576
raise ValueError(f"{method=!r} not recognized")
25782577
return out
2578+
result = api.vmap(_weighted_quantile)(q)
2579+
keepdim_out = list(keepdim)
2580+
if not q_was_scalar:
2581+
keepdim_out = [q_shape[0], *keepdim_out]
2582+
result = result.reshape(tuple(keepdim_out))
2583+
elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1:
2584+
result = result.squeeze(axis=0)
2585+
return result
2586+
return result
25792587

2580-
result = jax.vmap(_weighted_quantile)(q)
2581-
if keepdims and keepdim:
2582-
if q_ndim > 0:
2583-
keepdim = [q_shape[0], *keepdim]
2584-
result = result.reshape(tuple(keepdim))
2585-
else:
2586-
if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1):
2587-
if result.ndim > 0 and result.shape[0] == 1:
2588-
result = lax.squeeze(result, (0,))
2589-
return lax.convert_element_type(result, a.dtype)
25902588

25912589
if squash_nans:
25922590
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
@@ -2666,15 +2664,13 @@ def _weighted_quantile(qi):
26662664
result = high_value
26672665
else:
26682666
raise ValueError(f"{method=!r} not recognized")
2669-
if keepdims and keepdim:
2670-
if q_ndim > 0:
2671-
keepdim = [np.shape(q)[0], *keepdim]
2672-
result = result.reshape(keepdim)
2673-
else:
2674-
if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1):
2675-
if result.ndim > 0 and result.shape[0] == 1:
2676-
result = lax.squeeze(result, (0,))
2677-
return lax.convert_element_type(result, a.dtype)
2667+
keepdim_out = list(keepdim)
2668+
if not q_was_scalar:
2669+
keepdim_out = [q_shape[0], *keepdim_out]
2670+
result = result.reshape(tuple(keepdim_out))
2671+
elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1:
2672+
result = result.squeeze(axis=0)
2673+
return result
26782674

26792675

26802676
# TODO(jakevdp): interpolation argument deprecated 2024-05-16

0 commit comments

Comments
 (0)