Skip to content

Commit 43cdb5c

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent 4ebbf21 commit 43cdb5c

File tree

2 files changed

+44
-70
lines changed

2 files changed

+44
-70
lines changed

jax/_src/numpy/reductions.py

Lines changed: 35 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import numpy as np
2525

26-
from jax._src.lax import lax
26+
from jax._src import config
2727
from jax._src import api
2828
from jax._src import core
2929
from jax._src import deprecations
@@ -2483,8 +2483,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24832483
axis = _canonicalize_axis(axis, a.ndim)
24842484

24852485
q, = promote_dtypes_inexact(q)
2486-
q = lax_internal.asarray(q)
2487-
q_was_scalar = getattr(q, "ndim", 0) == 0
2486+
q_was_scalar = q.ndim == 0
24882487
if q_was_scalar:
24892488
q = lax.expand_dims(q, (0,))
24902489
q_shape = q.shape
@@ -2497,40 +2496,37 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24972496
if weights is None:
24982497
a, = promote_dtypes_inexact(a)
24992498
else:
2499+
if method != "inverted_cdf":
2500+
raise ValueError("Weighted quantiles are only supported for method='inverted_cdf'")
2501+
if axis is None:
2502+
raise TypeError("Axis must be specified when shapes of a and weights differ.")
2503+
axis_tuple = canonicalize_axis_tuple(axis, a.ndim)
2504+
25002505
a, q, weights = promote_dtypes_inexact(a, q, weights)
2501-
#weights = lax.convert_element_type(weights, a.dtype)
25022506
a_shape = a.shape
25032507
w_shape = np.shape(weights)
25042508
if np.ndim(weights) == 0:
25052509
weights = lax.broadcast_in_dim(weights, a_shape, ())
25062510
w_shape = a_shape
2507-
else:
2508-
w_shape = np.shape(weights)
25092511
if w_shape != a_shape:
2510-
if axis is None:
2511-
raise TypeError("Axis must be specified when shapes of a and weights differ.")
2512-
if isinstance(axis, tuple):
2513-
if w_shape != tuple(a_shape[i] for i in axis):
2514-
raise ValueError("Shape of weights must match the shape of the axes being reduced.")
2515-
weights = lax.broadcast_in_dim(
2516-
weights,
2517-
shape=a_shape,
2518-
broadcast_dimensions=axis
2519-
)
2520-
w_shape = a_shape
2521-
else:
2522-
if len(w_shape) != 1 or w_shape[0] != a_shape[axis]:
2523-
raise ValueError("Length of weights not compatible with specified axis.")
2524-
weights = lax.expand_dims(weights, (axis,))
2525-
weights = _broadcast_to(weights, a.shape)
2526-
w_shape = a_shape
2512+
expected_shape = tuple(a_shape[i] for i in axis_tuple)
2513+
if w_shape != expected_shape:
2514+
raise ValueError(f"Shape of weights must match the shape of the axes being reduced. "
2515+
f"Expected {expected_shape}, got {w_shape}")
2516+
weights = lax.broadcast_in_dim(
2517+
weights,
2518+
shape=a_shape,
2519+
broadcast_dimensions=axis_tuple
2520+
)
25272521

25282522
if squash_nans:
25292523
nan_mask = ~lax_internal._isnan(a)
25302524
weights = _where(nan_mask, weights, 0)
25312525
else:
2532-
with jax.debug_nans(False):
2533-
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
2526+
with config.debug_nans(False):
2527+
has_nan_data = any(lax_internal._isnan(a), axis=axis, keepdims=True)
2528+
has_nan_weights = any(lax_internal._isnan(weights), axis=axis, keepdims=True)
2529+
a = _where(has_nan_data | has_nan_weights, np.nan, a)
25342530

25352531
total_weight = sum(weights, axis=axis, keepdims=True)
25362532
a_sorted, weights_sorted = lax_internal.sort_key_val(a, weights, dimension=axis)
@@ -2539,49 +2535,23 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25392535

25402536
def _weighted_quantile(qi):
25412537
qi = lax.convert_element_type(qi, cum_weights_norm.dtype)
2542-
index_dtype = dtypes.canonicalize_dtype(dtypes.int_)
2543-
idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims)
2538+
index_dtype = dtypes.default_int_dtype()
2539+
idx = _reduce_sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims)
25442540
idx = lax.clamp(_lax_const(idx, 0), idx, _lax_const(idx, a_sorted.shape[axis] - 1))
2545-
idx_prev = lax.clamp(idx - 1, _lax_const(idx, 0), _lax_const(idx, a_sorted.shape[axis] - 1))
2546-
2547-
slice_sizes = list(a_shape)
2548-
slice_sizes[axis] = 1
2549-
offset_start = q_ndim
2550-
total_offset_dims = len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1
2551-
dnums = lax.GatherDimensionNumbers(
2552-
offset_dims=tuple(range(offset_start, total_offset_dims)),
2553-
collapsed_slice_dims=(axis,),
2554-
start_index_map=(axis,)
2555-
)
2556-
val = lax.gather(a_sorted, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2557-
val_prev = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2558-
cw_prev = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2559-
cw_next = lax.gather(cum_weights_norm, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2560-
if method == "linear":
2561-
denom = cw_next - cw_prev
2562-
denom = _where(denom == 0, 1, denom)
2563-
weight = (qi - cw_prev) / denom
2564-
out = val_prev * (1 - weight) + val * weight
2565-
elif method == "lower":
2566-
out = val_prev
2567-
elif method == "higher":
2568-
out = val
2569-
elif method == "nearest":
2570-
out = _where(lax.abs(qi - cw_prev) < lax.abs(qi - cw_next), val_prev, val)
2571-
elif method == "midpoint":
2572-
out = (val_prev + val) / 2
2573-
elif method == "inverted_cdf":
2574-
out = val
2575-
else:
2576-
raise ValueError(f"{method=!r} not recognized")
2577-
return out
2541+
2542+
idx_expanded = lax.expand_dims(idx, (axis,)) if not keepdims else idx
2543+
return jnp.take_along_axis(a_sorted, idx_expanded, axis=axis).squeeze(axis=axis)
25782544
result = api.vmap(_weighted_quantile)(q)
2579-
keepdim_out = list(keepdim)
2545+
shape_after = list(a_shape)
2546+
if keepdims:
2547+
shape_after[axis] = 1
2548+
else:
2549+
del shape_after[axis]
25802550
if not q_was_scalar:
2581-
keepdim_out = [q_shape[0], *keepdim_out]
2582-
result = result.reshape(tuple(keepdim_out))
2583-
elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1:
2584-
result = result.squeeze(axis=0)
2551+
result = result.reshape((q_shape[0], *shape_after))
2552+
else:
2553+
if result.ndim > 0 and result.shape[0] == 1:
2554+
result = result.reshape(tuple(shape_after))
25852555
return result
25862556

25872557
if squash_nans:

tests/lax_numpy_reducers_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,28 +803,32 @@ def jnp_fun(a, q, weights):
803803
rng(q_shape, q_dtype),
804804
np.abs(rng(weights_shape, a_dtype)) + 1e-3
805805
]
806-
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6)
807-
self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6)
806+
if method == "inverted_cdf":
807+
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6)
808+
self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6)
809+
else:
810+
with self.assertRaisesRegex(ValueError, "Weighted quantiles are only supported for method='inverted_cdf'"):
811+
jnp_fun(*args_maker())
808812

809813
def test_weighted_quantile_negative_weights(self):
810814
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
811815
weights = jnp.array([1, -1, 1, 1, 1], dtype=float)
812816
q = jnp.array([0.5])
813817
with self.assertRaisesRegex(ValueError, "Weights must be non-negative"):
814-
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights)
818+
jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights)
815819

816820
def test_weighted_quantile_all_weights_zero(self):
817821
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
818822
weights = jnp.zeros_like(a)
819823
q = jnp.array([0.5])
820824
with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"):
821-
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights)
825+
jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights)
822826

823827
def test_weighted_quantile_weights_with_nan(self):
824828
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
825829
weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float)
826830
q = jnp.array([0.5])
827-
result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights)
831+
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights)
828832
assert np.isnan(np.array(result)).all()
829833

830834
def test_weighted_quantile_scalar_q(self):

0 commit comments

Comments
 (0)