Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 105 additions & 23 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax._src import core
from jax._src import dtypes
from jax._src.numpy.util import (
_broadcast_to, ensure_arraylike,
_broadcast_to, check_arraylike, ensure_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where)
from jax._src.lax import control_flow
from jax._src.lax import lax as lax
Expand Down Expand Up @@ -2376,7 +2376,8 @@ def cumulative_prod(
@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
keepdims: bool = False, *, weights: ArrayLike | None = None,
interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
"""Compute the quantile of the data along the specified axis.

JAX implementation of :func:`numpy.quantile`.
Expand Down Expand Up @@ -2414,22 +2415,26 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
>>> jnp.quantile(x, q, method='nearest')
Array([2., 4., 7.], dtype=float32)
"""
a, q = ensure_arraylike("quantile", a, q)
if weights is None:
a, q = ensure_arraylike("quantile", a, q)
else:
a, q, weights = ensure_arraylike("quantile", a, q, weights)
if overwrite_input or out is not None:
raise ValueError("jax.numpy.quantile does not support overwrite_input=True "
"or out != None")
# TODO(jakevdp): remove the interpolation argument in JAX v0.9.0
if not isinstance(interpolation, DeprecatedArg):
raise TypeError("quantile() argument interpolation was removed in JAX"
" v0.8.0. Use method instead.")
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False)
return _quantile(a, q, axis, method, keepdims, False, weights)

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
keepdims: bool = False, *, weights: ArrayLike | None = None,
interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
"""Compute the quantile of the data along the specified axis, ignoring NaNs.

JAX implementation of :func:`numpy.nanquantile`.
Expand Down Expand Up @@ -2468,7 +2473,10 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
>>> jnp.nanquantile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)
"""
a, q = ensure_arraylike("nanquantile", a, q)
if weights is None:
a, q = ensure_arraylike("nanquantile", a, q)
else:
a, q, weights = ensure_arraylike("nanquantile", a, q, weights)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
Expand All @@ -2477,13 +2485,12 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
if not isinstance(interpolation, DeprecatedArg):
raise TypeError("nanquantile() argument interpolation was removed in JAX"
" v0.8.0. Use method instead.")
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True)
return _quantile(a, q, axis, method, keepdims, True, weights)

def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
method: str, keepdims: bool, squash_nans: bool) -> Array:
if method not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'")
a, = promote_dtypes_inexact(a)
method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array:
if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]:
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'")
keepdim = []
if dtypes.issubdtype(a.dtype, np.complexfloating):
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
Expand Down Expand Up @@ -2513,12 +2520,77 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
else:
axis = canonicalize_axis(axis, a.ndim)

q, = promote_dtypes_inexact(q)
q_was_scalar = q.ndim == 0
if q_was_scalar:
q = lax.expand_dims(q, (0,))
q_shape = q.shape
q_ndim = q.ndim
if q_ndim > 1:
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}")

a_shape = a.shape
# Handle weights
if weights is None:
a, = promote_dtypes_inexact(a)
else:
if method != "inverted_cdf":
raise ValueError("Weighted quantiles are only supported for method='inverted_cdf'")
if axis is None:
raise TypeError("Axis must be specified when shapes of a and weights differ.")
axis_tuple = canonicalize_axis_tuple(axis, a.ndim)

a, q, weights = promote_dtypes_inexact(a, q, weights)
a_shape = a.shape
w_shape = np.shape(weights)
if np.ndim(weights) == 0:
weights = lax.broadcast_in_dim(weights, a_shape, ())
w_shape = a_shape
if w_shape != a_shape:
expected_shape = tuple(a_shape[i] for i in axis_tuple)
if w_shape != expected_shape:
raise ValueError(f"Shape of weights must match the shape of the axes being reduced. "
f"Expected {expected_shape}, got {w_shape}")
weights = lax.broadcast_in_dim(
weights,
shape=a_shape,
broadcast_dimensions=axis_tuple
)

if squash_nans:
nan_mask = ~lax_internal._isnan(a)
weights = _where(nan_mask, weights, 0)
else:
with config.debug_nans(False):
has_nan_data = any(lax_internal._isnan(a), axis=axis, keepdims=True)
has_nan_weights = any(lax_internal._isnan(weights), axis=axis, keepdims=True)
a = _where(has_nan_data | has_nan_weights, np.nan, a)

total_weight = sum(weights, axis=axis, keepdims=True)
a_sorted, weights_sorted = lax_internal.sort_key_val(a, weights, dimension=axis)
cum_weights = cumsum(weights_sorted, axis=axis)
cum_weights_norm = lax.div(cum_weights, total_weight)

def _weighted_quantile(qi):
qi = lax.convert_element_type(qi, cum_weights_norm.dtype)
index_dtype = dtypes.default_int_dtype()
idx = _reduce_sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims)
idx = lax.clamp(_lax_const(idx, 0), idx, _lax_const(idx, a_sorted.shape[axis] - 1))

idx_expanded = lax.expand_dims(idx, (axis,)) if not keepdims else idx
return jnp.take_along_axis(a_sorted, idx_expanded, axis=axis).squeeze(axis=axis)
result = api.vmap(_weighted_quantile)(q)
shape_after = list(a_shape)
if keepdims:
shape_after[axis] = 1
else:
del shape_after[axis]
if not q_was_scalar:
result = result.reshape((q_shape[0], *shape_after))
else:
if result.ndim > 0 and result.shape[0] == 1:
result = result.reshape(tuple(shape_after))
return result

if squash_nans:
a = _where(lax._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
Expand Down Expand Up @@ -2593,14 +2665,18 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
pred = lax.le(high_weight, lax._const(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif method == "midpoint":
result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5))
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
elif method == "inverted_cdf":
result = high_value

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: this is not consistent with numpy.

else:
raise ValueError(f"{method=!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [np.shape(q)[0], *keepdim]
result = result.reshape(keepdim)
return lax.convert_element_type(result, a.dtype)
keepdim_out = list(keepdim)
if not q_was_scalar:
keepdim_out = [q_shape[0], *keepdim_out]
result = result.reshape(tuple(keepdim_out))
elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1:
result = result.squeeze(axis=0)
return result


# TODO(jakevdp): interpolation argument deprecated 2024-05-16
Expand All @@ -2609,7 +2685,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
def percentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
"""Compute the percentile of the data along the specified axis.

JAX implementation of :func:`numpy.percentile`.
Expand Down Expand Up @@ -2647,14 +2723,17 @@ def percentile(a: ArrayLike, q: ArrayLike,
>>> jnp.percentile(x, q, method='nearest')
Array([1., 3., 4.], dtype=float32)
"""
a, q = ensure_arraylike("percentile", a, q)
if weights is None:
a, q = ensure_arraylike("percentile", a, q)
else:
a, q, weights = ensure_arraylike("percentile", a, q, weights)
q, = promote_dtypes_inexact(q)
# TODO(jakevdp): remove the interpolation argument in JAX v0.9.0
if not isinstance(interpolation, DeprecatedArg):
raise TypeError("percentile() argument interpolation was removed in JAX"
" v0.8.0. Use method instead.")
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
method=method, keepdims=keepdims, weights=weights)


# TODO(jakevdp): interpolation argument deprecated 2024-05-16
Expand All @@ -2663,7 +2742,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
"""Compute the percentile of the data along the specified axis, ignoring NaN values.

JAX implementation of :func:`numpy.nanpercentile`.
Expand Down Expand Up @@ -2703,15 +2782,18 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
>>> jnp.nanpercentile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)
"""
a, q = ensure_arraylike("nanpercentile", a, q)
if weights is None:
a, q = ensure_arraylike("nanpercentile", a, q)
else:
a, q, weights = ensure_arraylike("nanpercentile", a, q, weights)
q, = promote_dtypes_inexact(q)
q = q / 100
# TODO(jakevdp): remove the interpolation argument in JAX v0.9.0
if not isinstance(interpolation, DeprecatedArg):
raise TypeError("nanpercentile() argument interpolation was removed in JAX"
" v0.8.0. Use method instead.")
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
method=method, keepdims=keepdims, weights=weights)


@export
Expand Down
76 changes: 76 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,82 @@ def testPercentilePrecision(self):
x = jnp.float64([1, 2, 3, 4, 7, 10])
self.assertEqual(jnp.percentile(x, 50), 3.5)

@jtu.sample_product(
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
((7,), None),
((6, 7,), None),
((47, 7), 0),
((47, 7), ()),
((4, 101), 1),
((4, 47, 7), (1, 2)),
((4, 47, 7), (0, 2)),
((4, 47, 7), (1, 0, 2)),
)
],
a_dtype=default_dtypes,
q_dtype=[np.float32],
q_shape=scalar_shapes + [(1,), (4,)],
keepdims=[False, True],
method=['linear', 'lower', 'higher', 'nearest', 'midpoint', 'inverted_cdf'],
)
def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdims, method):
rng = jtu.rand_default(self.rng())
a = rng(a_shape, a_dtype)
q = rng(q_shape, q_dtype)
if axis is None:
weights_shape = a_shape
elif isinstance(axis, tuple):
weights_shape = a_shape
else:
weights_shape = (a_shape[axis],)
weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3

def np_fun(a, q, weights):
return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims)
def jnp_fun(a, q, weights):
return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims)
args_maker = lambda: [
rng(a_shape, a_dtype),
rng(q_shape, q_dtype),
np.abs(rng(weights_shape, a_dtype)) + 1e-3
]
if method == "inverted_cdf":
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6)
self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6)
else:
with self.assertRaisesRegex(ValueError, "Weighted quantiles are only supported for method='inverted_cdf'"):
jnp_fun(*args_maker())

def test_weighted_quantile_negative_weights(self):
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
weights = jnp.array([1, -1, 1, 1, 1], dtype=float)
q = jnp.array([0.5])
with self.assertRaisesRegex(ValueError, "Weights must be non-negative"):
jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights)

def test_weighted_quantile_all_weights_zero(self):
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
weights = jnp.zeros_like(a)
q = jnp.array([0.5])
with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"):
jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights)

def test_weighted_quantile_weights_with_nan(self):
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float)
q = jnp.array([0.5])
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights)
assert np.isnan(np.array(result)).all()

def test_weighted_quantile_scalar_q(self):
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
q = 0.5
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights)
assert jnp.issubdtype(result.dtype, jnp.floating)
assert result.shape == ()

@jtu.sample_product(
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
Expand Down
Loading